Yuna
Yuna

Reputation: 798

Pytorch, `backward` RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed

I'm implementing DDPG with PyTorch (0.4) and got stuck backproping the loss. So, first my Code performing the update:

def update_nets(self, transitions):
    """
    Performs one update step
    :param transitions: list of sampled transitions
    """
    # get batches
    batch = transition(*zip(*transitions))
    states = torch.stack(batch.state)
    actions = torch.stack(batch.action)
    next_states = torch.stack(batch.next_state)
    rewards = torch.stack(batch.reward)

    # zero gradients
    self._critic.zero_grad()

    # compute critic's loss
    y = rewards.view(-1, 1) + self._gamma * \
        self.critic_target(next_states, self.actor_target(next_states))

    loss_critic = F.mse_loss(y, self._critic(states, actions),
                             size_average=True)

    # backpropagte it
    loss_critic.backward()
    self._optim_critic.step()

    # zero gradients
    self._actor.zero_grad()

    # compute actor's loss
    loss_actor = ((-1.) * self._critic(states, self._actor(states))).mean()

    # backpropagate it
    loss_actor.backward()
    self._optim_actor.step()

    # do soft updates
    self.perform_soft_update(self.actor_target, self._actor)
    self.perform_soft_update(self.critic_target, self._critic)

Where self._actor, self._crtic, self.actor_target and self.critic_target are Nets.

If I run this, I get the following error in the second iteration:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

at 

line 221, in update_nets
    loss_critic.backward()
line 93, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
line 89, in backward
    allow_unreachable=True)  # allow_unreachable flag

and I don't know what is causing it.

What I know till now is, that the loss_critic.backward() call causes the error. I already debugged loss_critic - it got a valid value. If I replace the loss computation with a simple

loss_critic = torch.tensor(1., device=self._device, dtype=torch.float, requires_grad=True)

Tensor containing the value 1 it works fine. Also, I already checked that I'm not saving some results which could cause the error. Additionally updating the actor with the loss_actor doesn't cause any problems.

Does anyone know what is going wrong here?

Thanks!

Update

I replaced

    # zero gradients
    self._critic.zero_grad()

and

    # zero gradients
    self._actor.zero_grad()

with

    # zero gradients
    self._critic.zero_grad()
    self._actor.zero_grad()
    self.critic_target.zero_grad()
    self.actor_target.zero_grad()

(both calls) but it is still failing with the same error. Additionally, the code performing the update at the end of one iteration

def perform_soft_update(self, target, trained):
    """
    Preforms the soft update
    :param target: Net to be updated
    :param trained: Trained net - used for update
    """
    for param_target, param_trained in \
            zip(target.parameters(), trained.parameters()):
        param_target.data.copy_(
            param_target.data * (
                    1.0 - self._tau) + param_trained * self._tau
        )

Upvotes: 1

Views: 4566

Answers (2)

Yuna
Yuna

Reputation: 798

I found the solution. I saved tensors in my replay_buffer, for training purposes, which I used in every iteration resulting in the code snippet:

    # get batches
    batch = transition(*zip(*transitions))
    states = torch.stack(batch.state)
    actions = torch.stack(batch.action)
    next_states = torch.stack(batch.next_state)
    rewards = torch.stack(batch.reward)

This "saving" of tensors is the cause of the problem. So I changed my code to save only the data (tensor.data.numpy().tolist()) and only put it into a tensor when I need it.

More detailed: In DDPG I evaluate the policy every iteration and do one learning step with a batch. Now I'm saving the evaluation in the replay buffer via:

action = self.action(state)
...
self.replay_buffer.push(state.data.numpy().tolist(), action.data.numpy().tolist(), ...)

And used it like:

batch = transition(*zip(*transitions))
states = self.to_tensor(batch.state)
actions = self.to_tensor(batch.action)
next_states = self.to_tensor(batch.next_state)
rewards = self.to_tensor(batch.reward)

Upvotes: 4

Escaton
Escaton

Reputation: 16

Didn't call zero_grad() on self.actor_target and self.critic_target? Or is it called in self.perform_soft_update()?

Upvotes: 0

Related Questions