Reputation: 682
I am trying to calculate the flop count of a single pass through the whisper forward() function. I am using the thop
library and getting 0 flop counts. I believe its because in its register_hooks
, Sequential
is set to zero_ops
. Here I am passing the input_features
along with the initial decoder_token_ids
. What other way is there to get the flop count? I also found ptflops
library but can't seem to call it properly. Any solution is appreciated.
import torch
import torch.nn as nn
import torch.nn.functional as F
from ptflops import get_model_complexity_info
import whisper_timestamped as whisper
import os
from thop import profile
import torch
from transformers import AutoFeatureExtractor, WhisperModel, WhisperForConditionalGeneration
from datasets import load_dataset
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
input_features = inputs.input_features
decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
# model_path = <local_path_to_saved_model>
model = whisper.load_model(os.path.join(model_path, 'tiny.pt'), device='cpu').float() # vanilla model
macs, params = profile(model, inputs=(input_features, decoder_input_ids))
Upvotes: 0
Views: 105