Sijan Bhandari
Sijan Bhandari

Reputation: 3051

Saving and Loading Pytorch Model Checkpoint for inference not working

I have a trained model using LSTM. The model is trained on GPU (On Google COLABORATORY). I have to save the model for inference; which I will run on CPU. Once trained, I saved the model checkpoint as follows:

torch.save({'model_state_dict': model.state_dict()},'lstmmodelgpu.tar')

And, for inference, I loaded the model as :

# model definition
vocab_size = len(vocab_to_int)+1 
output_size = 1
embedding_dim = 300
hidden_dim = 256
n_layers = 2

model = SentimentLSTM(vocab_size, output_size, embedding_dim, hidden_dim, n_layers)

# loading model
device = torch.device('cpu')
checkpoint = torch.load('lstmmodelgpu.tar', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

But, it is raising the following error:

model.load_state_dict(checkpoint['model_state_dict'])
  File "workspace/envs/envdeeplearning/lib/python3.5/site-packages/torch/nn/modules/module.py", line 719, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for SentimentLSTM:
    Missing key(s) in state_dict: "embedding.weight". 
    Unexpected key(s) in state_dict: "encoder.weight".

Is there anything I missed while saving the checkpoint?

Upvotes: 1

Views: 3440

Answers (1)

Mukul Kumar Jha
Mukul Kumar Jha

Reputation: 1082

There are two things to be considered here.

  1. You mentioned that you're training your model on GPU and using it for inference on CPU, so u need to add a parameter map_location in load function passing torch.device('cpu').

  2. There is a mismatch of state_dict keys (indicated in your ouput message), which might be caused by some missing keys or having more keys in state_dict you are loading than the model u are using currently. And for it you have to add a parameter strict with value False in the load_state_dict function. This will make method to ignore the mismatch of keys.

Side note : Try to use extension of pt or pth for checkpoint files as it is a convention .

Upvotes: 3

Related Questions