Ausrada404
Ausrada404

Reputation: 599

load and freeze one model and train others in PyTorch

I have a model A that including three submodels model1, model2, model3.

the model flow: model1 --> model2 --> model3

I have trained model1 in an independent project.

The question is how to use the pre-trained model1 when training the model A?

Now, I try to implement this as follow:

I load the checkpoint of model1 by `model1.load_state_dict(torch.load(model1.pth)) and then set the requires_grad of the model1’s parameters as False?

Is it right?

Upvotes: 1

Views: 4205

Answers (1)

yanarp
yanarp

Reputation: 175

Yes, that is correct.

When you structure your model the way you explained, what you are doing is correct.

ModelA consists of three submodels - model1, models, model3

Then you load the weights of each individual model with model*.load_state_dict(torch.load(model*.pth))

Then make requires_grad=False for the model you want to freeze.

for param in model*.parameters():
    param.requires_grad = False

You can also freeze weights of particular layers by accessing the submodules, for example, if you have a layer named fc in model1, then you can freeze its weights by making model1.fc.weight.requres_grad = False.

Upvotes: 3

Related Questions