afsara_ben
afsara_ben

Reputation: 682

calculate flops in Whisper model

I am trying to calculate the flop count of a single pass through the whisper forward() function. I am using the thop library and getting 0 flop counts. I believe its because in its register_hooks, Sequential is set to zero_ops. Here I am passing the input_features along with the initial decoder_token_ids. What other way is there to get the flop count? I also found ptflops library but can't seem to call it properly. Any solution is appreciated.

import torch
import torch.nn as nn
import torch.nn.functional as F
from ptflops import get_model_complexity_info
import whisper_timestamped as whisper
import os
from thop import profile
import torch
from transformers import AutoFeatureExtractor, WhisperModel, WhisperForConditionalGeneration
from datasets import load_dataset

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
input_features = inputs.input_features
decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
# model_path = <local_path_to_saved_model>
model = whisper.load_model(os.path.join(model_path, 'tiny.pt'), device='cpu').float()  # vanilla model
macs, params = profile(model, inputs=(input_features, decoder_input_ids))

Upvotes: 0

Views: 105

Answers (0)

Related Questions