Mona Jalal
Mona Jalal

Reputation: 38285

Changing a custom resnet 18 architecture subtly and still use it in pre-trained mode

Can I change a custom resnet 18 architecture and still use it in pre-trained = true mode? I am doing a subtle change in the architecture of a custom resnet18 and when i run it, i get the following error: This is how the custom resnet18 is called: model = Resnet_18.resnet18(pretrained=True, embedding_size=args.dim_embed)

The new change in the custom resnet18:

self.layer_attend1 =  nn.Sequential(nn.Conv2d(layers[0], layers[0], stride=2, padding=1, kernel_size=3),
                                    nn.AdaptiveAvgPool2d(1),
                                    nn.Softmax(1))

I am loading the checkpoint using:

checkpoint = torch.load(args.resume, encoding='latin1')
args.start_epoch = checkpoint['epoch']
best_acc = checkpoint['best_prec1']
tnet.load_state_dict(checkpoint['state_dict'])

The output of running the model is:

/scratch3/venv/fashcomp/lib/python3.8/site-packages/torchvision/transforms/transforms.py:310: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
  warnings.warn("The use of the transforms.Scale transform is deprecated, " +
=> loading checkpoint 'runs/nondisjoint_l2norm/model_best.pth.tar'
Traceback (most recent call last):
  File "main.py", line 352, in <module>
    main()    
  File "main.py", line 145, in main
    tnet.load_state_dict(checkpoint['state_dict'])
  File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1406, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Tripletnet:
        Missing key(s) in state_dict: "embeddingnet.embeddingnet.layer_attend1.0.weight", "embeddingnet.embeddingnet.layer_attend1.0.bias". /scratch3/venv/fashcomp/lib/python3.8/site-packages/torchvision/transforms/transforms.py:310: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
  warnings.warn("The use of the transforms.Scale transform is deprecated, " +
=> loading checkpoint 'runs/nondisjoint_l2norm/model_best.pth.tar'
Traceback (most recent call last):
  File "main.py", line 352, in <module>
    main()    
  File "main.py", line 145, in main
    tnet.load_state_dict(checkpoint['state_dict'])
  File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1406, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Tripletnet:
        Missing key(s) in state_dict: "embeddingnet.embeddingnet.layer_attend1.0.weight", "embeddingnet.embeddingnet.layer_attend1.0.bias".

So, how can you implement small architectural changes without retraining from the scratch every time?

P.S.: Cross-posting here: https://discuss.pytorch.org/t/can-i-change-a-custom-resnet-18-architecture-subtly-and-still-use-it-in-pre-trained-true-mode/130783 Thanks a lot to Rodrigo Berriel teach me about https://meta.stackexchange.com/a/141824/913043

Upvotes: 0

Views: 734

Answers (1)

mlucy
mlucy

Reputation: 5289

If you really want to do this, you should construct the model and then call load_state_dict with the argument strict=False (https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict).

Keep in mind that A) you should initialize any new layers you added explicitly, because they won't be initialized by the state dict, and B) the model will probably not work out of the box because of the uninitialized weights, but it should train faster than a randomly initialized model.

Upvotes: 1

Related Questions