Reputation: 8623
I was under the (evidently wrong) impression from the documentation that torch.no_grad()
, as a context manager, was supposed to make everything requires_grad=False
. Indeed that's what I intended to use torch.no_grad()
for, as just a convenient context manager for instantiating a bunch of things that I want to stay constant (through training). but that's only the case for torch.Tensor
s it seems; it doesn't seem to affect torch.nn.Module
s, as the following example code shows:
with torch.no_grad():
linear = torch.nn.Linear(2, 3)
for p in linear.parameters():
print(p.requires_grad)
This will output:
True
True
That's a bit counterintuitive in my opinion. Is this the intended behaviour? If so, why? And is there a similarly convenient context manager in which I can be assured that anything I instantiate under it will not require gradient?
Upvotes: 2
Views: 854
Reputation: 16480
This is expected behavior, but I agree it is somewhat unclear from the documentation. Note that the documentation says :
In this mode, the result of every computation will have requires_grad=False, even when the inputs have requires_grad=True.
This context disables the gradient on the output of any computation done within the context. Technically, declaring/creating a layer is not computation, so the parameter's requires_grad
is True
. However, for any calculation you'd do inside this context, you won't be able to compute gradients. The requires_grad
for the output of calculation would be False
. This is probably best explained by extending your code snippet as below:
with torch.no_grad():
linear = torch.nn.Linear(2, 3)
for p in linear.parameters():
print(p.requires_grad)
out = linear(torch.rand(10,2))
print(out.requires_grad)
out = linear(torch.rand(10,2))
print(out.requires_grad)
True
True
False
True
Even if the requires_grad
for layer parameters is True
, you won't be able to compute the gradient as the output has requires_grad
False
.
Upvotes: 2