vinsent paramanantham
vinsent paramanantham

Reputation: 951

pytorch model loading and prediction, AttributeError: 'dict' object has no attribute 'predict'

model = torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')
results, labels = predict_function(model, dev_data, version)

> /home/ofsdms/san_mrc/my_utils/data_utils.py(34)predict_squad()
-> phrase, spans, scores = model.predict(batch)
(Pdb) n
AttributeError: 'dict' object has no attribute 'predict'

How do I load a saved checkpoint of pytorch model, and use the same for prediction. I have the model saved in .pt extension

Upvotes: 1

Views: 5047

Answers (1)

Shai
Shai

Reputation: 114926

the checkpoint you save is usually a state_dict: a dictionary containing the values of the trained weights - but not the actual architecture of the net. The actual computational graph/architecture of the net is described as a python class (derived from nn.Module).
To use a trained model you need:

  1. Instantiate a model from the class implementing the computational graph.
  2. Load the saved state_dict to that instance:

    model.load_state_dict(torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')
    

Upvotes: 1

Related Questions