Reputation: 1227
As indicated in pyTorch tutorial,
if you even want to do the backward on some part of the graph twice, you need to pass in retain_graph = True during the first pass.
However, I found the following codes snippet actually worked without doing so. I'm using pyTorch-0.4
x = torch.ones(2, 2, requires_grad=True)
y = x + 2
y.backward(torch.ones(2, 2)) # Note I do not set retain_graph=True
y.backward(torch.ones(2, 2)) # But it can still work!
print x.grad
output:
tensor([[ 2., 2.],
[ 2., 2.]])
Could anyone explain? Thanks in advance!
Upvotes: 4
Views: 5460
Reputation: 8981
The reason why it works w/o retain_graph=True
in your case is you have very simple graph that probably would have no internal intermediate buffers, in turn no buffers will be freed, so no need to use retain_graph=True
.
But everything is changing when adding one more extra computation to your graph:
Code:
x = torch.ones(2, 2, requires_grad=True)
v = x.pow(3)
y = v + 2
y.backward(torch.ones(2, 2))
print('Backward 1st time w/o retain')
print('x.grad:', x.grad)
print('Backward 2nd time w/o retain')
try:
y.backward(torch.ones(2, 2))
except RuntimeError as err:
print(err)
print('x.grad:', x.grad)
Output:
Backward 1st time w/o retain
x.grad: tensor([[3., 3.],
[3., 3.]])
Backward 2nd time w/o retain
Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
x.grad: tensor([[3., 3.],
[3., 3.]]).
In this case additional internal v.grad
will be computed, but torch
doesn't store intermediate values (intermediate gradients etc), and with retain_graph=False
v.grad
will be freed after first backward
.
So, if you want to backprop second time you need to specify retain_graph=True
to "keep" the graph.
Code:
x = torch.ones(2, 2, requires_grad=True)
v = x.pow(3)
y = v + 2
y.backward(torch.ones(2, 2), retain_graph=True)
print('Backward 1st time w/ retain')
print('x.grad:', x.grad)
print('Backward 2nd time w/ retain')
try:
y.backward(torch.ones(2, 2))
except RuntimeError as err:
print(err)
print('x.grad:', x.grad)
Output:
Backward 1st time w/ retain
x.grad: tensor([[3., 3.],
[3., 3.]])
Backward 2nd time w/ retain
x.grad: tensor([[6., 6.],
[6., 6.]])
Upvotes: 11