Reputation: 357
I had problems when loading the weights of model. Here's some parts of the model
class InceptionV4(nn.Module):
def __init__(self, num_classes=1001):
super(InceptionV4, self).__init__()
# Special attributs
self.input_space = None
self.input_size = (299, 299, 3)
self.mean = None
self.std = None
# Modules
self.features = nn.Sequential(
BasicConv2d(3, 32, kernel_size=3, stride=2),
BasicConv2d(32, 32, kernel_size=3, stride=1),
BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1),
Mixed_3a(),
Mixed_4a(),
Mixed_5a(),
Inception_A(),
Inception_A(),
Inception_A(),
...
)
self.avg_pool = nn.AvgPool2d(8, count_include_pad=False)
self.last_linear = nn.Linear(1536, num_classes)
I have tried to save the weights, something like torch.save(model.state_dict(), weight_name)
and then reload again model.load_state_dict(torch.load(weight_name))
but got these errors:
Missing key(s) in state_dict: "features.0.conv.weight", "features.0.bn.weight", "features.0.bn.bias", "features.0.bn.running_mean", "features.0.bn.running_var", "features.1.conv.weight", "features.1.bn.weight", "features.1.bn.bias", "features.1.bn.running_mean", "features.1.bn.running_var", "features.2.conv.weight", "features.2.bn.weight
and also:
Unexpected key(s) in state_dict: "conv.0.conv1.0.weight", "conv.0.conv1.0.bias", "conv.0.conv1.2.weight", "conv.0.conv1.2.bias", "conv.0.conv1.2.running_mean", "conv.0.conv1.2.running_var", "conv.0.conv1.2.num_batches_tracked", "conv.0.conv2.0.weight", "conv.0.conv2.0.bias", "conv.0.conv2.2.weight", "conv.0.conv2.2.bias", "conv.0.conv2.2.running_mean", "conv.0.conv2.2.running_var", "conv.0.conv2.2.num_batches_tracked", "conv.1.conv1.0.weight", "conv.1.conv1.0.bias", "conv.1.conv1.2.weight", "conv.1.conv1.2.bias", "conv.1.conv1.2.running_mean", "conv.1.conv1.2.running_var", "conv.1.conv1.2.num_batches_tracked
Any hints on this? Thanks in advance.
Upvotes: 0
Views: 1482
Reputation: 1708
I faced this problem several times. The error indicates that your model state_dict
has different names from the pre-trained weights
that you load.
I don't see the pretrained model for Inception_v4
in torchvision
model zoo, so it would be a little difficult to tell exactly where your InceptionV4
class has a problem with mismatched dict.
Regardless of where you get your the pre-trained
file, but the key point is to define your model the same as the pre-trained
model code, and you can load the weight file smoothly.
Here are some indicators where your code is different from the model:
# change self.features -> self.conv: This helps in solving mismatched names.
self.conv = nn.Sequential(...)
# Google how to change the BatchNorm in your current pytorch version
# and the older pytorch version which the pretrained model was defined.
conv.1.conv1.2.num_batches_tracked # it is deprecated in pytorch version 0.4 or newer
The hint is:
# Define your model (or parts you want to reuse) the same as the original
Hope this helps :)
Upvotes: 1