Reputation: 38285
Can I change a custom resnet 18 architecture and still use it in pre-trained = true mode? I am doing a subtle change in the architecture of a custom resnet18 and when i run it, i get the following error:
This is how the custom resnet18 is called:
model = Resnet_18.resnet18(pretrained=True, embedding_size=args.dim_embed)
The new change in the custom resnet18:
self.layer_attend1 = nn.Sequential(nn.Conv2d(layers[0], layers[0], stride=2, padding=1, kernel_size=3),
nn.AdaptiveAvgPool2d(1),
nn.Softmax(1))
I am loading the checkpoint using:
checkpoint = torch.load(args.resume, encoding='latin1')
args.start_epoch = checkpoint['epoch']
best_acc = checkpoint['best_prec1']
tnet.load_state_dict(checkpoint['state_dict'])
The output of running the model is:
/scratch3/venv/fashcomp/lib/python3.8/site-packages/torchvision/transforms/transforms.py:310: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
warnings.warn("The use of the transforms.Scale transform is deprecated, " +
=> loading checkpoint 'runs/nondisjoint_l2norm/model_best.pth.tar'
Traceback (most recent call last):
File "main.py", line 352, in <module>
main()
File "main.py", line 145, in main
tnet.load_state_dict(checkpoint['state_dict'])
File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1406, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Tripletnet:
Missing key(s) in state_dict: "embeddingnet.embeddingnet.layer_attend1.0.weight", "embeddingnet.embeddingnet.layer_attend1.0.bias". /scratch3/venv/fashcomp/lib/python3.8/site-packages/torchvision/transforms/transforms.py:310: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
warnings.warn("The use of the transforms.Scale transform is deprecated, " +
=> loading checkpoint 'runs/nondisjoint_l2norm/model_best.pth.tar'
Traceback (most recent call last):
File "main.py", line 352, in <module>
main()
File "main.py", line 145, in main
tnet.load_state_dict(checkpoint['state_dict'])
File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1406, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Tripletnet:
Missing key(s) in state_dict: "embeddingnet.embeddingnet.layer_attend1.0.weight", "embeddingnet.embeddingnet.layer_attend1.0.bias".
So, how can you implement small architectural changes without retraining from the scratch every time?
P.S.: Cross-posting here: https://discuss.pytorch.org/t/can-i-change-a-custom-resnet-18-architecture-subtly-and-still-use-it-in-pre-trained-true-mode/130783 Thanks a lot to Rodrigo Berriel teach me about https://meta.stackexchange.com/a/141824/913043
Upvotes: 0
Views: 734
Reputation: 5289
If you really want to do this, you should construct the model and then call load_state_dict
with the argument strict=False
(https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict).
Keep in mind that A) you should initialize any new layers you added explicitly, because they won't be initialized by the state dict, and B) the model will probably not work out of the box because of the uninitialized weights, but it should train faster than a randomly initialized model.
Upvotes: 1