Ray
Ray

Reputation: 8623

Pytorch `torch.no_grad()` doesn't affect modules

I was under the (evidently wrong) impression from the documentation that torch.no_grad(), as a context manager, was supposed to make everything requires_grad=False. Indeed that's what I intended to use torch.no_grad() for, as just a convenient context manager for instantiating a bunch of things that I want to stay constant (through training). but that's only the case for torch.Tensors it seems; it doesn't seem to affect torch.nn.Modules, as the following example code shows:

with torch.no_grad():
    linear = torch.nn.Linear(2, 3)
for p in linear.parameters():
    print(p.requires_grad)

This will output:

True
True

That's a bit counterintuitive in my opinion. Is this the intended behaviour? If so, why? And is there a similarly convenient context manager in which I can be assured that anything I instantiate under it will not require gradient?

Upvotes: 2

Views: 854

Answers (1)

Umang Gupta
Umang Gupta

Reputation: 16480

This is expected behavior, but I agree it is somewhat unclear from the documentation. Note that the documentation says :

In this mode, the result of every computation will have requires_grad=False, even when the inputs have requires_grad=True.

This context disables the gradient on the output of any computation done within the context. Technically, declaring/creating a layer is not computation, so the parameter's requires_grad is True. However, for any calculation you'd do inside this context, you won't be able to compute gradients. The requires_grad for the output of calculation would be False. This is probably best explained by extending your code snippet as below:

with torch.no_grad():
     linear = torch.nn.Linear(2, 3)
     for p in linear.parameters():
         print(p.requires_grad)
     out  = linear(torch.rand(10,2))
     print(out.requires_grad)
out = linear(torch.rand(10,2)) 
print(out.requires_grad)
True
True
False
True

Even if the requires_grad for layer parameters is True, you won't be able to compute the gradient as the output has requires_grad False.

Upvotes: 2

Related Questions