Mr_qiaozhi
Mr_qiaozhi

Reputation: 13

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed

I get the error like the title....and I found some answers so I try retain_graph=True, but it doesn't work. Maybe my code got another problems (it occurs in loss_actor.backward(retain_grah....))

q = torch.zeros(len(reward))
q_target = torch.zeros(len(reward))
for j, r in enumerate(reward):
    q_target[j] = self.critic_network(torch.transpose(next_state[j], 0, 1), self.actor_network(torch.transpose(next_state[j], 0, 1)).view(1, 1))
    q_target[j] = r + (done[j] * gamma * q_target[j]).detach()
    q[j] = self.critic_network(torch.transpose(state[j], 0, 1), action[j].view(1, 1))
loss_critic = F.mse_loss(q, q_target)
self.critic_optimizer.zero_grad()
loss_critic.backward()
self.critic_optimizer.step()

b = torch.zeros(len(reward))
for j, r in enumerate(reward):
    b[j] = self.critic_network(torch.transpose(state[j], 0, 1), self.actor_network(torch.transpose(state[j], 0, 1)).view(1, 1))
loss_actor = -torch.mean(b)
self.actor_optimizer.zero_grad()
loss_actor.backward(retain_graph=True)
self.actor_optimizer.step()

Upvotes: 0

Views: 770

Answers (1)

trsvchn
trsvchn

Reputation: 8981

Based on the provided info about part of your computational graph, I assume, that loss_actor and loss_critic share some parts of it, I think its state (not sure)

state -> q --> loss_critic <-- backward 1
|
-------> b --> loss_actor <--- backward 2

to reproduce your example:

# Some computations that produce state
state = torch.ones((2, 2), requires_grad=True) ** 2

# Compute the first loss
q = torch.zeros((1))
q[0] = state[0, 0]
l1 = torch.sum(2 * q)
l1.backward()

# Compute the second loss
b = torch.zeros((1))
b[0] = state[1, 1]
l2 = torch.mean(2 * b)
l2.backward()
RuntimeError                              Traceback (most recent call last)
<ipython-input-28-2ab509bedf7a> in <module>
     10 b[0] = state[1, 1]
     11 l2 = torch.mean(2 * b)
---> 12 l2.backward()
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

Trying

...
l2.backward(retain_graph=True)

doesn't help, because you have to

Specify retain_graph=True when calling backward the first time.

here, on the first backward call (for l1)

l1.backward(retain_graph=True)

Upvotes: 1

Related Questions