Reputation: 11
this is my main code,but I don't know how to fix the problem?
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load('./checkpoints/fcn_model_5.pth') # 加载模型
model = model.to(device)
Upvotes: 1
Views: 7352
Reputation: 105
The source of your problem is simply you are loading your model as a dict, instead of nn.Module
. Here is an another approach you can employ without converting to nn.Module
bloat adopted from here:
for k, v in model.items():
model[k] = v.to(device)
Now, you have an ordered dict with the items at correct place.
Please note that you will still have an ordered dict instead of nn.Module
. You will not be able to forward pass anything from an ordered dict.
Upvotes: 0
Reputation: 188
You are loading the checkpoint as a state dict, it is not a nn.module object.
checkpoint = './checkpoints/fcn_model_5.pth'
model = your_model() # a torch.nn.Module object
model.load_state_dict(torch.load(checkpoint ))
model = model.to(device)
Upvotes: 4