Yingqiang Gao
Yingqiang Gao

Reputation: 999

pytorch: can't load CNN model and do prediction TypeError: 'collections.OrderedDict' object is not callable

I trained a CNN model using MNIST dataset and now want to predict a classification of the image, which contains a number 3.

But when I tried to use this CNN to predict, pytorch gives me this error:

TypeError: 'collections.OrderedDict' object is not callable

And here's what I write:

cnn = torch.load("/usr/prakt/w153/Desktop/score_detector.pkl")
img = scipy.ndimage.imread("/usr/prakt/w153/Desktop/resize_num_three.png")
test_x = Variable(torch.unsqueeze(torch.FloatTensor(img), dim=1), volatile=True).type(torch.FloatTensor).cuda()
test_output, last_layer = cnn(test_x)
pred = torch.max(test_output, 1)[1].cuda().data.squeeze()
print(pred)

here's some explaination: img is the to be predicted image with size 28*28 score_detector.pkl is the trained CNN model

any help will be appreciated!

Upvotes: 7

Views: 21589

Answers (2)

Oleg O
Oleg O

Reputation: 441

Indeed, you are loading a state_dict rather than the model itself.

Saving the model is as follows:

torch.save(model.state_dict(), 'model_state.pth')

Whereas to load the model state you first need to init the model and then load the state

model = Model()
model.load_state_dict(torch.load('model_state.pth'))

If you trained your model on GPU but would like to load the model on a laptop which doesn't have CUDA, then you would need to add one more argument

model.load_state_dict(torch.load('model_state.pth', map_location='cpu'))

Upvotes: 13

Jens Petersen
Jens Petersen

Reputation: 349

I'm pretty sure score_detector.pkl is actually a state_dict and not the model itself. You will need to instantiate the model first and then load the state_dict, so your first line should be replaced by something like this:

cnn = MyModel()
cnn.load_state_dict("/usr/prakt/w153/Desktop/score_detector.pkl")

and then the rest should work. See this link for more information.

Upvotes: 2

Related Questions