hazrmard
hazrmard

Reputation: 3661

In PyTorch, how to I make certain module `Parameters` static during training?

Context:

In pytorch, any Parameter is a special kind of Tensor. A Parameter is automatically registered with a module's parameters() method when it is assigned as an attribute.

During training, I will pass m.parameters() to the Optimizer instance so they can be updated.


Question: For a built-in pytorch module, how to I prevent certain parameters from being modified by the optimizer?

s = Sequential(
        nn.Linear(2,2),
        nn.Linear(2,3),   # I want this one's .weight and .bias to be constant
        nn.Linear(3,1)
    )

Upvotes: 0

Views: 2835

Answers (1)

hazrmard
hazrmard

Reputation: 3661

Parameters can be made static by setting their attribute requires_grad=False.

In my example case:

params = list(s.parameters())  # .parameters() returns a generator
# Each linear layer has 2 parameters (.weight and .bias),
# Skipping first layer's parameters (indices 0, 1):
params[2].requires_grad = False
params[3].requires_grad = False

When a mix of requires_grad=True and requires_grad=False tensors are used to make a calculation, the result inherits requires_grad=True.

According to the PyTorch autograd mechanics documentation:

If there’s a single input to an operation that requires gradient, its output will also require gradient. Conversely, only if all inputs don’t require gradient, the output also won’t require it. Backward computation is never performed in the subgraphs, where all Tensors didn’t require gradients.


My concern was that if I disabled gradient tracking for the middle layer, the first layer wouldn't receive backpropagated gradients. This was faulty understanding.

Edge Case: If I disable gradients for all parameters in a module and try to train, the optimizer will raise an exception. Because there is not a single tensor to apply the backward() pass to.

This edge case is why I was getting errors. I was trying to test requires_grad=False on parameters for module with a single nn.Linear layer. That meant I disabled tracking for all parameters, which caused the optimizer to complain.

Upvotes: 2

Related Questions