Dr. Prof. Patrick
Dr. Prof. Patrick

Reputation: 1374

Difference between model.train(False) and required_grad = False

I use the Pytorch library and i'm looking for a way to make the weights and biases in my model to freeze.

I saw these 2 options:

  1. model.train(False)

  2. for param in model.parameters(): param.requires_grad = False

What is the difference (if there is any) and which one should i use to freeze the current state of my model?

Upvotes: 5

Views: 3604

Answers (2)

prosti
prosti

Reputation: 46401

There are two ways to freeze in PyTorch when training:

  • setting requires_grad to False
  • setting the learning rate lr to zero

While model.train(False) is a way not to train. ;)

Upvotes: 0

trialNerror
trialNerror

Reputation: 3563

They are very different.

Independently of the backprop process, some layers have different behaviors when you are training or evaluating a model. In pytorch, there are only 2 of them : BatchNorm (which I think stops updating its running mean and deviation when you are evaluating) and Dropout (which only drops values in training mode). So model.train()and model.eval()(equivalently model.train(false)) just set a boolean flag to tell these 2 layers "freeze yourselves". Note that these two layers do not have any parameters that are affected by backward operation (batchnorm buffer tensors are changed during the forward pass I think)

On the other hand, setting all your parameters to "requires_grad=false" just tells pytorch to stop recording gradients for backprop. That will not affect the BatchNorm and the Dropout layers

How to freeze your model kinda depends on your use-case, but I'd say the easiest way is to use torch.jit.trace. This will create a frozen copy your model, exactly in the state it was when you called trace. Your model remained unaffected.

Usually, you would call

model.eval()
traced_model = torch.jit.trace(model, input)

Upvotes: 4

Related Questions