Reputation: 87
I am using a ResNet for neural network classification and wish to try out a comparison between pre-trained and non-pre-trained networks. However, I do want to use the Bias term which is not the default setting in Pytorch's ResNet modules.
Is there a way to include a pre-trained model and use bias terms on top of that?
A very brief snippet of my current code, I redefine ResNet architecture from here - https://pytorch.org/vision/0.8/_modules/torchvision/models/resnet.html and set Bias = True
net = resnet18(pretrained=True)
net.fc = nn.Linear(512, num_classes)
The obvious error right now is
Error(s) in loading state_dict for ResNet: Missing key(s) in state_dict: "conv1.bias", "layer1.0.conv1.bias", "layer1.0.conv2.bias", "layer1.1.conv1.bias", "layer1.1.conv2.bias", "layer2.0.conv1.bias", "layer2.0.conv2.bias", "layer2.0.downsample.0.bias", "layer2.1.conv1.bias", "layer2.1.conv2.bias", "layer3.0.conv1.bias", "layer3.0.conv2.bias", "layer3.0.downsample.0.bias", "layer3.1.conv1.bias", "layer3.1.conv2.bias", "layer4.0.conv1.bias", "layer4.0.conv2.bias", "layer4.0.downsample.0.bias", "layer4.1.conv1.bias", "layer4.1.conv2.bias".
Upvotes: 0
Views: 308
Reputation: 87
The default _resnet function should be changed as follows:
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict, strict=False)
return model
This would let us load the pre-trained weights from state_dict and ignore the non-matching keys.
Upvotes: 0
Reputation: 358
You should change the _resnet function given in the snippet
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress,
strict=False)
model.load_state_dict(state_dict)
return model
By adding strict=False it should ignore non matching keys and avoid crashing
Upvotes: 0