Reputation: 13
I get the error like the title....and I found some answers so I try retain_graph=True
, but it doesn't work. Maybe my code got another problems (it occurs in loss_actor.backward(retain_grah....)
)
q = torch.zeros(len(reward))
q_target = torch.zeros(len(reward))
for j, r in enumerate(reward):
q_target[j] = self.critic_network(torch.transpose(next_state[j], 0, 1), self.actor_network(torch.transpose(next_state[j], 0, 1)).view(1, 1))
q_target[j] = r + (done[j] * gamma * q_target[j]).detach()
q[j] = self.critic_network(torch.transpose(state[j], 0, 1), action[j].view(1, 1))
loss_critic = F.mse_loss(q, q_target)
self.critic_optimizer.zero_grad()
loss_critic.backward()
self.critic_optimizer.step()
b = torch.zeros(len(reward))
for j, r in enumerate(reward):
b[j] = self.critic_network(torch.transpose(state[j], 0, 1), self.actor_network(torch.transpose(state[j], 0, 1)).view(1, 1))
loss_actor = -torch.mean(b)
self.actor_optimizer.zero_grad()
loss_actor.backward(retain_graph=True)
self.actor_optimizer.step()
Upvotes: 0
Views: 770
Reputation: 8981
Based on the provided info about part of your computational graph, I assume, that loss_actor
and loss_critic
share some parts of it, I think its state
(not sure)
state -> q --> loss_critic <-- backward 1
|
-------> b --> loss_actor <--- backward 2
to reproduce your example:
# Some computations that produce state
state = torch.ones((2, 2), requires_grad=True) ** 2
# Compute the first loss
q = torch.zeros((1))
q[0] = state[0, 0]
l1 = torch.sum(2 * q)
l1.backward()
# Compute the second loss
b = torch.zeros((1))
b[0] = state[1, 1]
l2 = torch.mean(2 * b)
l2.backward()
RuntimeError Traceback (most recent call last)
<ipython-input-28-2ab509bedf7a> in <module>
10 b[0] = state[1, 1]
11 l2 = torch.mean(2 * b)
---> 12 l2.backward()
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.
Trying
...
l2.backward(retain_graph=True)
doesn't help, because you have to
Specify retain_graph=True when calling backward the first time.
here, on the first backward call (for l1
)
l1.backward(retain_graph=True)
Upvotes: 1