Reputation: 3051
I have a trained model using LSTM. The model is trained on GPU (On Google COLABORATORY). I have to save the model for inference; which I will run on CPU. Once trained, I saved the model checkpoint as follows:
torch.save({'model_state_dict': model.state_dict()},'lstmmodelgpu.tar')
And, for inference, I loaded the model as :
# model definition
vocab_size = len(vocab_to_int)+1
output_size = 1
embedding_dim = 300
hidden_dim = 256
n_layers = 2
model = SentimentLSTM(vocab_size, output_size, embedding_dim, hidden_dim, n_layers)
# loading model
device = torch.device('cpu')
checkpoint = torch.load('lstmmodelgpu.tar', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
But, it is raising the following error:
model.load_state_dict(checkpoint['model_state_dict'])
File "workspace/envs/envdeeplearning/lib/python3.5/site-packages/torch/nn/modules/module.py", line 719, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for SentimentLSTM:
Missing key(s) in state_dict: "embedding.weight".
Unexpected key(s) in state_dict: "encoder.weight".
Is there anything I missed while saving the checkpoint?
Upvotes: 1
Views: 3440
Reputation: 1082
There are two things to be considered here.
You mentioned that you're training your model on GPU and using it for inference on CPU, so u need to add a parameter map_location in load function passing torch.device('cpu').
There is a mismatch of state_dict keys (indicated in your ouput message), which might be caused by some missing keys or having more keys in state_dict you are loading than the model u are using currently. And for it you have to add a parameter strict with value False in the load_state_dict function. This will make method to ignore the mismatch of keys.
Side note : Try to use extension of pt or pth for checkpoint files as it is a convention .
Upvotes: 3