ksivakumar
ksivakumar

Reputation: 481

Calling .backward() function for two different neural networks but getting retain_graph=True error

I have an Actor Critic neural network where the Actor is its own class and the Critic is its own class with its own neural network and .forward() function. I then am creating an object of each of these classes in a larger Model class. My setup is as follows:

self.actor = Actor().to(device)
self.actor_opt = optim.Adam(self.actor.parameters(), lr=lr)
self.critic = Critic().to(device)
self.critic_opt = optim.Adam(self.critic.parameters(), lr=lr)

I then calculate two different loss functions and want to update each neural network separately. For the critic:

loss_critic = F.smooth_l1_loss(value, expected)
self.critic_opt.zero_grad()
loss_critic.backward()
self.critic_opt.step()

and for the actor:

loss_actor = -self.critic(state, action)
self.actor_opt.zero_grad()
loss_actor.backward()
self.actor_opt.step()

However, when doing this, I get the following error:

RuntimeError: Trying to backward through the graph a second time, but the saved 
intermediate results have already been freed. Specify retain_graph=True when
calling backward the first time.

When reading up on this, I understood that I only need to retain_graph=True when calling backward twice on the same network, and in most cases this is not good to set to True as I will run out of GPU. Moreover, when I comment out one of the .backward() functions, the error goes away, leading me to believe that for some reason the code is thinking that both backward() functions are being called on the same neural network, even though I think I am doing it separately. What could be the reason for this? Is there a way to specify for which neural network I am calling the backward function on?

Edit: For reference, the optimize() function in this code here https://github.com/wudongming97/PyTorch-DDPG/blob/master/train.py uses backward() twice with no issue (I've cloned the repo and tested it). I'd like my code to operate similarly where I backprop through critic and actor separately.

Upvotes: 1

Views: 1209

Answers (1)

Szymon Maszke
Szymon Maszke

Reputation: 24691

Yes, you shouldn't do it like that. What you should do instead, is propagating through parts of the graph.

What the graph contains

Now, graph contains both actor and critic. If the computations pass through the same part of graph (say, twice through actor), it will raise this error.

  • And they will, as you clearly use actor and critic joined with loss value (this line: loss_actor = -self.critic(state, action))

  • Different optimizers do not change anything here, as it's backward problem (optimizers simply apply calculated gradients onto models)

Trying to fix it

  • This is how to fix it in GANs, but not in this case, see Actual fix paragraph below, read on if you are curious about the topic

If part of a neural network (critic in this case) does not take part in the current optimization step, it should be treated as a constant (and vice versa).

To do that, you could disable gradient using torch.no_grad context manager (documentation) and set critic to eval mode (documentation), something along those lines:

self.critic.eval()
with torch.no_grad():
    loss_actor = -self.critic(state, action)
...

But, here is a problem:

We are turning off gradient (tape recording) for action and breaking the graph!

hence this is not a viable solution.

Actual solution

It is much simpler than you think, one can see it in PyTorch's repository also:

  • Do not backpropagate after critic/actor loss
  • Calculate all losses (for both critic and actor)
  • sum them together
  • zero_grad for both optimizers
  • backpropagate with this summed value
  • critic_optimizer.step() and actor_optimizer.step() at this point

Something along those lines:

self.critic_opt.zero_grad()
self.actor_opt.zero_grad()

loss_critic = F.smooth_l1_loss(value, expected)
loss_actor = -self.critic(state, action)

total_loss = loss_actor + loss_critic
total_loss.backward()

self.critic_opt.step()
self.actor_opt.step()

Upvotes: 2

Related Questions