Roua Rouatbi
Roua Rouatbi

Reputation: 5

issue while exporting torch model to onnx format

I'm trying to export my PyTorch model to an ONNX format but I keep getting this error:

TypeError: forward() missing 1 required positional argument: 'text'

This is my code:

model = Model(opt)
dummy_input = torch.randn(1, 3, 224, 224)
file_path='/content/drive/MyDrive/VitSTR/vitstr_tiny_patch16_224_aug.pth'
torch.save(model.state_dict(), file_path)
model.load_state_dict(torch.load(file_path))
#model = torch.nn.DataParallel(model).to(device)
#print(model)
torch.onnx.export(model, dummy_input, "vitstr.onnx", verbose=True)

Upvotes: 0

Views: 1515

Answers (1)

Berriel
Berriel

Reputation: 13601

ViTSTR forward requires two positional arguments, input and text:

def forward(self, input, text, is_train=True, seqlen=25):
    # ...

Therefore, you need to pass an additional argument:

# ...
dummy_text = # create a dummy_text as well, with the appropriate shape
torch.onnx.export(model, (dummy_input, dummy_text), "vitstr.onnx", verbose=True)

Upvotes: 1

Related Questions