ROBOT AI
ROBOT AI

Reputation: 1227

pyTorch can backward twice without setting retain_graph=True

As indicated in pyTorch tutorial,

if you even want to do the backward on some part of the graph twice, you need to pass in retain_graph = True during the first pass.

However, I found the following codes snippet actually worked without doing so. I'm using pyTorch-0.4

x = torch.ones(2, 2, requires_grad=True)
y = x + 2
y.backward(torch.ones(2, 2)) # Note I do not set retain_graph=True
y.backward(torch.ones(2, 2)) # But it can still work!
print x.grad

output:

tensor([[ 2.,  2.], 
        [ 2.,  2.]]) 

Could anyone explain? Thanks in advance!

Upvotes: 4

Views: 5460

Answers (1)

trsvchn
trsvchn

Reputation: 8981

The reason why it works w/o retain_graph=True in your case is you have very simple graph that probably would have no internal intermediate buffers, in turn no buffers will be freed, so no need to use retain_graph=True.

But everything is changing when adding one more extra computation to your graph:

Code:

x = torch.ones(2, 2, requires_grad=True)
v = x.pow(3)
y = v + 2

y.backward(torch.ones(2, 2))

print('Backward 1st time w/o retain')
print('x.grad:', x.grad)

print('Backward 2nd time w/o retain')

try:
    y.backward(torch.ones(2, 2))
except RuntimeError as err:
    print(err)

print('x.grad:', x.grad)

Output:

Backward 1st time w/o retain
x.grad: tensor([[3., 3.],
                [3., 3.]])
Backward 2nd time w/o retain
Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
x.grad: tensor([[3., 3.],
                [3., 3.]]).

In this case additional internal v.grad will be computed, but torch doesn't store intermediate values (intermediate gradients etc), and with retain_graph=False v.grad will be freed after first backward.

So, if you want to backprop second time you need to specify retain_graph=True to "keep" the graph.

Code:

x = torch.ones(2, 2, requires_grad=True)
v = x.pow(3)
y = v + 2

y.backward(torch.ones(2, 2), retain_graph=True)

print('Backward 1st time w/ retain')
print('x.grad:', x.grad)

print('Backward 2nd time w/ retain')

try:
    y.backward(torch.ones(2, 2))
except RuntimeError as err:
    print(err)
print('x.grad:', x.grad)

Output:

Backward 1st time w/ retain
x.grad: tensor([[3., 3.],
                [3., 3.]])
Backward 2nd time w/ retain
x.grad: tensor([[6., 6.],
                [6., 6.]])

Upvotes: 11

Related Questions