Siddeshwar Raghavan
Siddeshwar Raghavan

Reputation: 87

Bias terms in Pre-trained ResNet models are not available?

I am using a ResNet for neural network classification and wish to try out a comparison between pre-trained and non-pre-trained networks. However, I do want to use the Bias term which is not the default setting in Pytorch's ResNet modules.

Is there a way to include a pre-trained model and use bias terms on top of that?

A very brief snippet of my current code, I redefine ResNet architecture from here - https://pytorch.org/vision/0.8/_modules/torchvision/models/resnet.html and set Bias = True

net = resnet18(pretrained=True)
net.fc = nn.Linear(512, num_classes)

The obvious error right now is

Error(s) in loading state_dict for ResNet: Missing key(s) in state_dict: "conv1.bias", "layer1.0.conv1.bias", "layer1.0.conv2.bias", "layer1.1.conv1.bias", "layer1.1.conv2.bias", "layer2.0.conv1.bias", "layer2.0.conv2.bias", "layer2.0.downsample.0.bias", "layer2.1.conv1.bias", "layer2.1.conv2.bias", "layer3.0.conv1.bias", "layer3.0.conv2.bias", "layer3.0.downsample.0.bias", "layer3.1.conv1.bias", "layer3.1.conv2.bias", "layer4.0.conv1.bias", "layer4.0.conv2.bias", "layer4.0.downsample.0.bias", "layer4.1.conv1.bias", "layer4.1.conv2.bias".

Upvotes: 0

Views: 308

Answers (2)

Siddeshwar Raghavan
Siddeshwar Raghavan

Reputation: 87

The default _resnet function should be changed as follows:

def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict, strict=False)
    return model

This would let us load the pre-trained weights from state_dict and ignore the non-matching keys.

Upvotes: 0

Paul_0
Paul_0

Reputation: 358

You should change the _resnet function given in the snippet

def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress,
                                              strict=False)
        model.load_state_dict(state_dict)
    return model

By adding strict=False it should ignore non matching keys and avoid crashing

Upvotes: 0

Related Questions