Reputation: 798
I'm implementing DDPG with PyTorch (0.4) and got stuck backproping the loss. So, first my Code performing the update:
def update_nets(self, transitions):
"""
Performs one update step
:param transitions: list of sampled transitions
"""
# get batches
batch = transition(*zip(*transitions))
states = torch.stack(batch.state)
actions = torch.stack(batch.action)
next_states = torch.stack(batch.next_state)
rewards = torch.stack(batch.reward)
# zero gradients
self._critic.zero_grad()
# compute critic's loss
y = rewards.view(-1, 1) + self._gamma * \
self.critic_target(next_states, self.actor_target(next_states))
loss_critic = F.mse_loss(y, self._critic(states, actions),
size_average=True)
# backpropagte it
loss_critic.backward()
self._optim_critic.step()
# zero gradients
self._actor.zero_grad()
# compute actor's loss
loss_actor = ((-1.) * self._critic(states, self._actor(states))).mean()
# backpropagate it
loss_actor.backward()
self._optim_actor.step()
# do soft updates
self.perform_soft_update(self.actor_target, self._actor)
self.perform_soft_update(self.critic_target, self._critic)
Where self._actor
, self._crtic
, self.actor_target
and self.critic_target
are Nets.
If I run this, I get the following error in the second iteration:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
at
line 221, in update_nets
loss_critic.backward()
line 93, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
line 89, in backward
allow_unreachable=True) # allow_unreachable flag
and I don't know what is causing it.
What I know till now is, that the loss_critic.backward()
call causes the error.
I already debugged loss_critic
- it got a valid value.
If I replace the loss computation with a simple
loss_critic = torch.tensor(1., device=self._device, dtype=torch.float, requires_grad=True)
Tensor containing the value 1
it works fine.
Also, I already checked that I'm not saving some results which could cause the error.
Additionally updating the actor with the loss_actor
doesn't cause any problems.
Does anyone know what is going wrong here?
Thanks!
Update
I replaced
# zero gradients
self._critic.zero_grad()
and
# zero gradients
self._actor.zero_grad()
with
# zero gradients
self._critic.zero_grad()
self._actor.zero_grad()
self.critic_target.zero_grad()
self.actor_target.zero_grad()
(both calls) but it is still failing with the same error. Additionally, the code performing the update at the end of one iteration
def perform_soft_update(self, target, trained):
"""
Preforms the soft update
:param target: Net to be updated
:param trained: Trained net - used for update
"""
for param_target, param_trained in \
zip(target.parameters(), trained.parameters()):
param_target.data.copy_(
param_target.data * (
1.0 - self._tau) + param_trained * self._tau
)
Upvotes: 1
Views: 4566
Reputation: 798
I found the solution.
I saved tensors in my replay_buffer
, for training purposes, which I used in every iteration resulting in the code snippet:
# get batches
batch = transition(*zip(*transitions))
states = torch.stack(batch.state)
actions = torch.stack(batch.action)
next_states = torch.stack(batch.next_state)
rewards = torch.stack(batch.reward)
This "saving" of tensors is the cause of the problem. So I changed my code to save only the data (tensor.data.numpy().tolist()
) and only put it into a tensor when I need it.
More detailed: In DDPG I evaluate the policy every iteration and do one learning step with a batch. Now I'm saving the evaluation in the replay buffer via:
action = self.action(state)
...
self.replay_buffer.push(state.data.numpy().tolist(), action.data.numpy().tolist(), ...)
And used it like:
batch = transition(*zip(*transitions))
states = self.to_tensor(batch.state)
actions = self.to_tensor(batch.action)
next_states = self.to_tensor(batch.next_state)
rewards = self.to_tensor(batch.reward)
Upvotes: 4
Reputation: 16
Didn't call zero_grad()
on self.actor_target
and self.critic_target
? Or is it called in self.perform_soft_update()
?
Upvotes: 0