user1322801
user1322801

Reputation: 859

Problem loading Pytourch 3.0 model unexpected key "module.features.0.weight" in state_dict

I am trying to load a model I have trained using Pytorch, but I keep getting the following error:

File "convert.py", line 12, in model.load_state_dict(torch.load('model/model_vgg2d_2.pth')) File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 490, in load_state_dict .format(name)) KeyError: 'unexpected key "module.features.0.weight" in state_dict'

Below is my code:

import torch.onnx
import torch.nn as nn

class TempModel(nn.Module):
    def __init__(self):
        super(TempModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 5, (3, 3))
    def forward(self, inp):
        return self.conv1(inp)

model = nn.DataParallel(TempModel())
model.load_state_dict(torch.load('model/model_vgg2d_2.pth'))
dummy_input = Variable(torch.randn(1, 3, 224, 224))
torch.onnx.export(model, dummy_input, "model_onnx/model_vgg2d_0.onnx")

I am working on the same machine that I have used to train the model(which has multiple GPUs). Any ideas what am I doing wrong?

Upvotes: 0

Views: 933

Answers (1)

Shai
Shai

Reputation: 114926

When loading state_dict you need it to be a state_dict of the same model: you cannot load a state_dict of a VGG model into a completely different BasicModel.


old answer
You saved the model without nn.DataParallel applied to the model and now you are trying to load after adding this. Try

model = TempModel()
model.load_state_dict(torch.load('model/model_vgg2d_2.pth'))
model = nn.DataParallel(model)  # parallel AFTER load

Upvotes: 0

Related Questions