Reputation: 13
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
args.dataset == 'cifar100' or args.dataset == 'cifar10':
args.stride = [2, 2]
resnet = resnet18(args, pretrained=False, num_classes=args.num_classes)
initial_weight = model_zoo.load_url(model_urls['resnet18'])
local_model = resnet
initial_weight_1 = local_model.state_dict()
for key in initial_weight.keys():
if key[0:3] == 'fc.' or key[0:5]=='conv1' or key[0:3]=='bn1':
initial_weight[key] = initial_weight_1[key]
local_model.load_state_dict(initial_weight)
I dont understand this line " initial_weight[key] = initial_weight_1[key]"
Could you please tell me why we need to do this?
thanks
Upvotes: 0
Views: 940
Reputation: 40738
Function torch.utils.model_zoo.load_url
will load the serialized torch object from the given URL. In this particular case the URL used hosts the model's weight dictionary for the ResNet18 network.
Therefore initial_weight
is the dictionary containing the weights of a pretrained ResNet18, while initial_weight_1
is the dictionary of the weights on the current model resnet
in memory initialized by resnet18
.
The following lines will go through the layers of the resnet
model and copy the weights loaded from that URL if the key[0:3] == 'fc.' or key[0:5]=='conv1' or key[0:3]=='bn1':
condition is met.
Upvotes: 0