Joey
Joey

Reputation: 13

Which is the difference between "model_zoo.load_url" and "state_dict"

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

args.dataset == 'cifar100' or args.dataset == 'cifar10':
args.stride = [2, 2]
resnet = resnet18(args, pretrained=False, num_classes=args.num_classes)
initial_weight = model_zoo.load_url(model_urls['resnet18'])
local_model = resnet 
initial_weight_1 = local_model.state_dict() 

for key in initial_weight.keys():
    if key[0:3] == 'fc.' or key[0:5]=='conv1' or key[0:3]=='bn1':
        initial_weight[key] = initial_weight_1[key] 

local_model.load_state_dict(initial_weight)

I dont understand this line " initial_weight[key] = initial_weight_1[key]"

Could you please tell me why we need to do this?

thanks

Upvotes: 0

Views: 940

Answers (1)

Ivan
Ivan

Reputation: 40738

Function torch.utils.model_zoo.load_url will load the serialized torch object from the given URL. In this particular case the URL used hosts the model's weight dictionary for the ResNet18 network.

Therefore initial_weight is the dictionary containing the weights of a pretrained ResNet18, while initial_weight_1 is the dictionary of the weights on the current model resnet in memory initialized by resnet18.

The following lines will go through the layers of the resnet model and copy the weights loaded from that URL if the key[0:3] == 'fc.' or key[0:5]=='conv1' or key[0:3]=='bn1': condition is met.

Upvotes: 0

Related Questions