Kyle
Kyle

Reputation: 31

How to set up loss function in PyTorch for Soft-Actor-Critic

I'm trying to implement a custom loss function for a soft Q-learning, actor-critic policy gradient algorithm in PyTorch. This comes from the following paper Learning from Imperfect Demonstrations. The structure of the algorithm is similar to deep q-learning, in that we are using a network to estimate Q-values, and we use a target network to stabilize results. Unlike DQN, however, we calculate V(s) from Q(s) by:

enter image description here

This is simple enough to calculate with PyTorch. My main question has to do with how to set up the loss function. Part of the update equation is expressed as:

enter image description here

Note that Q_hat comes from the target network. How can I go about putting something like this into a loss function? I can compute values for V and Q, but how can I handle the gradients in this case? If anyone can point me towards a similar example that would be much appreciated.

Upvotes: 0

Views: 463

Answers (1)

Kyle
Kyle

Reputation: 31

This turns out to be fairly simple assuming you can calculate V, Q, and Q^. After discussing this with some people offline I was able to get pytorch to calculate this loss by setting it up as:

loss = (Q-V)*(Q-Q_hat).detach()
optimizer.zero_grad()
loss.backward()
optimizer.step()

Upvotes: 1

Related Questions