Reputation: 859
I am trying to load a model I have trained using Pytorch, but I keep getting the following error:
File "convert.py", line 12, in model.load_state_dict(torch.load('model/model_vgg2d_2.pth')) File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 490, in load_state_dict .format(name)) KeyError: 'unexpected key "module.features.0.weight" in state_dict'
Below is my code:
import torch.onnx
import torch.nn as nn
class TempModel(nn.Module):
def __init__(self):
super(TempModel, self).__init__()
self.conv1 = nn.Conv2d(3, 5, (3, 3))
def forward(self, inp):
return self.conv1(inp)
model = nn.DataParallel(TempModel())
model.load_state_dict(torch.load('model/model_vgg2d_2.pth'))
dummy_input = Variable(torch.randn(1, 3, 224, 224))
torch.onnx.export(model, dummy_input, "model_onnx/model_vgg2d_0.onnx")
I am working on the same machine that I have used to train the model(which has multiple GPUs). Any ideas what am I doing wrong?
Upvotes: 0
Views: 933
Reputation: 114926
When loading state_dict
you need it to be a state_dict
of the same model: you cannot load a state_dict
of a VGG model into a completely different BasicModel
.
old answer
You saved the model without nn.DataParallel
applied to the model and now you are trying to load after adding this. Try
model = TempModel()
model.load_state_dict(torch.load('model/model_vgg2d_2.pth'))
model = nn.DataParallel(model) # parallel AFTER load
Upvotes: 0