secret
secret

Reputation: 515

pytorch load _IncompatibleKeys

I trained a model of Efficentnet-b6(Architechture is as follows):

https://github.com/lukemelas/EfficientNet-PyTorch

Now, I tried to load a model I trained with it:

checkpoint  = torch.load('model.pth', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint, strict=False)

but then I got the following error:

_IncompatibleKeys
missing_keys=['_conv_stem.weight', '_bn0.weight', '_bn0.bias', ...]
unexpected_keys=['module._conv_stem.weight', 'module._bn0.weight', 'module._bn0.bias', ...]

Please let me know how can I fix that, what am I missing? Thank you!

Upvotes: 5

Views: 5348

Answers (1)

Wasi Ahmad
Wasi Ahmad

Reputation: 37681

If you compare the missing_keys and unexpected_keys, you may realize what is happening.

missing_keys=['_conv_stem.weight', '_bn0.weight', '_bn0.bias', ...]
unexpected_keys=['module._conv_stem.weight', 'module._bn0.weight', 'module._bn0.bias', ...]

As you can see, the model weights are saved with a module. prefix. And this is because you have trained the model with DataParallel.

Now, to load the model weights without using DataParallel, you can do the following.

# original saved file with DataParallel
checkpoint = torch.load(path, map_location=torch.device('cpu'))

# create new OrderedDict that does not contain `module.`
from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    name = key.replace("module.", "") # remove `module.`
    new_state_dict[name] = v

# load params
model.load_state_dict(new_state_dict, strict=False)

OR, if you wrap the model using DataParallel, then you do not need the above approach.

checkpoint  = torch.load('model.pth', map_location=torch.device('cpu'))
model = torch.nn.DataParallel(model)
model.load_state_dict(checkpoint, strict=False)

Although the second approach is not encouraged (since you may not need DataParallel in many cases).

Upvotes: 11

Related Questions