Reputation: 5
I'm trying to export my PyTorch model to an ONNX format but I keep getting this error:
TypeError: forward() missing 1 required positional argument: 'text'
This is my code:
model = Model(opt)
dummy_input = torch.randn(1, 3, 224, 224)
file_path='/content/drive/MyDrive/VitSTR/vitstr_tiny_patch16_224_aug.pth'
torch.save(model.state_dict(), file_path)
model.load_state_dict(torch.load(file_path))
#model = torch.nn.DataParallel(model).to(device)
#print(model)
torch.onnx.export(model, dummy_input, "vitstr.onnx", verbose=True)
Upvotes: 0
Views: 1515
Reputation: 13601
ViTSTR forward requires two positional arguments, input
and text
:
def forward(self, input, text, is_train=True, seqlen=25):
# ...
Therefore, you need to pass an additional argument:
# ...
dummy_text = # create a dummy_text as well, with the appropriate shape
torch.onnx.export(model, (dummy_input, dummy_text), "vitstr.onnx", verbose=True)
Upvotes: 1