Ayush Gupta
Ayush Gupta

Reputation: 15

PyTorch- call to backward() on two different functions in sequence gives unexpected output second time

I am trying to differentiate two functions z and y which are based on same input x. While doing so, I run y.backward() and z.backward(). As per my understanding, I have created two functions before doing any operation on the input so both y and z should be independent and give independent results of differentiation. However, the output of second call to backward() gives incorrect output. It can be called as y.backward() then z.backward(). In this case, z.backward() gives 14 as output instead of 12. If I run z.backward() and then y.backward(), y.backward() gives 14 as output. Only first output is correct in both cases. I cannot understand how is it giving 14 in second time.

import torch
x = torch.tensor(2.0, requires_grad=True)
y = 2 * x + 3
z = x**3 + 1

y.backward()
print('grad attribute of the tensor::',x.grad)

z.backward()
print('grad attribute of the tensor::',x.grad)

Output:

grad attribute of the tensor:: tensor(2.)
grad attribute of the tensor:: tensor(14.)

Upvotes: 0

Views: 1112

Answers (1)

Alexander Guyer
Alexander Guyer

Reputation: 2203

backward() does not overwrite tensor grad attributes; it accumulates them. If you don't zero out the gradients in between backpropagations, the resulting gradients will be the sum of the gradients from each backpropagation. The reason that it's implemented like this is to better support recurrent neural networks.

The most conventional way to zero out the gradients is by calling torch.optim.Optimizer.zero_grad(); this will zero out the gradients of all of the parameters passed to the optimizer on construction. This works well when you're just using the gradients for an optimizer step. There's also torch.nn.Module.zero_grad(), which zeroes out the module's parameters (I believe it's recursive, so it should also zero out the sub-modules' parameters, and so on).

Also note that if you need to reuse intermediate results (gradients of non-leaf tensors), then you need to pass retain_graph=True when calling backward(); this is usually the case when working with recurrent neural networks. Otherwise, PyTorch will free them to conserve memory.

Upvotes: 1

Related Questions