Reputation: 951
model = torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')
results, labels = predict_function(model, dev_data, version)
> /home/ofsdms/san_mrc/my_utils/data_utils.py(34)predict_squad()
-> phrase, spans, scores = model.predict(batch)
(Pdb) n
AttributeError: 'dict' object has no attribute 'predict'
How do I load a saved checkpoint of pytorch model, and use the same for prediction. I have the model saved in .pt extension
Upvotes: 1
Views: 5047
Reputation: 114926
the checkpoint you save is usually a state_dict
: a dictionary containing the values of the trained weights - but not the actual architecture of the net. The actual computational graph/architecture of the net is described as a python class (derived from nn.Module
).
To use a trained model you need:
model
from the class implementing the computational graph. Load the saved state_dict
to that instance:
model.load_state_dict(torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')
Upvotes: 1