Reputation: 515
I trained a model of Efficentnet-b6(Architechture is as follows):
https://github.com/lukemelas/EfficientNet-PyTorch
Now, I tried to load a model I trained with it:
checkpoint = torch.load('model.pth', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint, strict=False)
but then I got the following error:
_IncompatibleKeys
missing_keys=['_conv_stem.weight', '_bn0.weight', '_bn0.bias', ...]
unexpected_keys=['module._conv_stem.weight', 'module._bn0.weight', 'module._bn0.bias', ...]
Please let me know how can I fix that, what am I missing? Thank you!
Upvotes: 5
Views: 5348
Reputation: 37681
If you compare the missing_keys
and unexpected_keys
, you may realize what is happening.
missing_keys=['_conv_stem.weight', '_bn0.weight', '_bn0.bias', ...]
unexpected_keys=['module._conv_stem.weight', 'module._bn0.weight', 'module._bn0.bias', ...]
As you can see, the model weights are saved with a module.
prefix.
And this is because you have trained the model with DataParallel
.
Now, to load the model weights without using DataParallel
, you can do the following.
# original saved file with DataParallel
checkpoint = torch.load(path, map_location=torch.device('cpu'))
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
name = key.replace("module.", "") # remove `module.`
new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict, strict=False)
OR, if you wrap the model using DataParallel
, then you do not need the above approach.
checkpoint = torch.load('model.pth', map_location=torch.device('cpu'))
model = torch.nn.DataParallel(model)
model.load_state_dict(checkpoint, strict=False)
Although the second approach is not encouraged (since you may not need DataParallel
in many cases).
Upvotes: 11