Reputation: 1374
I use the Pytorch library and i'm looking for a way to make the weights and biases in my model to freeze.
I saw these 2 options:
model.train(False)
for param in model.parameters(): param.requires_grad = False
What is the difference (if there is any) and which one should i use to freeze the current state of my model?
Upvotes: 5
Views: 3604
Reputation: 46401
There are two ways to freeze in PyTorch when training:
requires_grad
to False
While model.train(False)
is a way not to train. ;)
Upvotes: 0
Reputation: 3563
They are very different.
Independently of the backprop process, some layers have different behaviors when you are training or evaluating a model. In pytorch, there are only 2 of them : BatchNorm (which I think stops updating its running mean and deviation when you are evaluating) and Dropout (which only drops values in training mode). So model.train()
and model.eval()
(equivalently model.train(false)
) just set a boolean flag to tell these 2 layers "freeze yourselves". Note that these two layers do not have any parameters that are affected by backward operation (batchnorm buffer tensors are changed during the forward pass I think)
On the other hand, setting all your parameters to "requires_grad=false" just tells pytorch to stop recording gradients for backprop. That will not affect the BatchNorm and the Dropout layers
How to freeze your model kinda depends on your use-case, but I'd say the easiest way is to use torch.jit.trace. This will create a frozen copy your model, exactly in the state it was when you called trace
. Your model remained unaffected.
Usually, you would call
model.eval()
traced_model = torch.jit.trace(model, input)
Upvotes: 4