Jakub Bielan
Jakub Bielan

Reputation: 605

Problem with missing and unexpected keys while loading my model in Pytorch

I'm trying to load the model using this tutorial: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference . Unfortunately I'm very beginner and I face some problems.

I have created checkpoint:

checkpoint = {'epoch': epochs, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),'loss': loss}
torch.save(checkpoint, 'checkpoint.pth')

Then I wrote class for my network and I wanted to load the file:

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(9216, 4096)
        self.fc2 = nn.Linear(4096, 1000)
        self.fc3 = nn.Linear(1000, 102)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = log(F.softmax(x, dim=1))
        return x

Like that:

def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = Network()
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

model = load_checkpoint('checkpoint.pth')

I got this error (edited to show whole communicate):

RuntimeError: Error(s) in loading state_dict for Network:
    Missing key(s) in state_dict: "fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias", "fc3.weight", "fc3.bias". 
    Unexpected key(s) in state_dict: "features.0.weight", "features.0.bias", "features.3.weight", "features.3.bias", "features.6.weight", "features.6.bias", "features.8.weight", "features.8.bias", "features.10.weight", "features.10.bias", "classifier.fc1.weight", "classifier.fc1.bias", "classifier.fc2.weight", "classifier.fc2.bias", "classifier.fc3.weight", "classifier.fc3.bias". 

This is my model.state_dict().keys():

odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 
'features.3.bias', 'features.6.weight', 'features.6.bias', 
'features.8.weight', 'features.8.bias', 'features.10.weight', 
'features.10.bias', 'classifier.fc1.weight', 'classifier.fc1.bias', 
'classifier.fc2.weight', 'classifier.fc2.bias', 'classifier.fc3.weight', 
'classifier.fc3.bias'])

This is my model:

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)

((classifier): Sequential(
(fc1): Linear(in_features=9216, out_features=4096, bias=True)
(relu1): ReLU()
(fc2): Linear(in_features=4096, out_features=1000, bias=True)
(relu2): ReLU()
(fc3): Linear(in_features=1000, out_features=102, bias=True)
(output): LogSoftmax()
)
)

It's my first network ever and I'm blundering along. Thanks for steering me into right direction!

Upvotes: 16

Views: 27966

Answers (2)

Ali Waqas
Ali Waqas

Reputation: 335

in my case, i had to remove "module." prefix from the state dict to load.

    model= Model()
    state_dict = torch.load(model_path)
    remove_prefix = 'module.'
    state_dict = {k[len(remove_prefix):] if k.startswith(remove_prefix) else k: v for k, v in state_dict.items()}

After that,


    model.load_state_dict(state_dict)

Worked!

Upvotes: 5

Jatentaki
Jatentaki

Reputation: 13103

So your Network is essentially the classifier part of AlexNet and you're looking to load pretrained AlexNet weights into it. The problem is that the keys in state_dict are "fully qualified", which means that if you look at your network as a tree of nested modules, a key is just a list of modules in each branch, joined with dots like grandparent.parent.child. You want to

  1. Keep only the tensors with name starting with "classifier."
  2. Remove the "classifier." part of keys

so try

model = Network()
loaded_dict = checkpoint['model_state_dict']
prefix = 'classifier.'
n_clip = len(prefix)
adapted_dict = {k[n_clip:]: v for k, v in loaded_dict.items()
                if k.startswith(prefix)}
model.load_state_dict(adapted_dict)

Upvotes: 9

Related Questions