vlyubin
vlyubin

Reputation: 648

Is there a way to use an external loss function in pytorch?

A typical skeleton of pytorch neural network has a forward() method, then we compute loss based on outputs of forward pass, and call backward() on that loss to update the gradients. What if my loss is determined externally (e.g. by running simulation in some RL environment). Can I still leverage this typical structure this way?

Thank you!

Upvotes: 1

Views: 847

Answers (1)

mbpaulus
mbpaulus

Reputation: 7691

In this case it appears easiest to me abstract the forward pass (your policy?) from the loss computation. This is because (as you note) in most scenarios, you will need to obtain a state (from your environment), then compute an action (essentially the forward pass), then feed that action back to the environment to obtain a reward/ loss from your environment.

Of course, you could probably call your environment within the forward pass once you computed an action to then calculate the resultant loss. But why bother? It will get even more complicated (though possible) once you are are taking several steps in your environment until you obtain a reward/ loss.

I would suggest you take a look at the following RL example for an application of policy gradients within openAI gym: https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py#L43

The essential ideas are:

  • Create a policy (as an nn.module) that takes in a state and returns a stochastic policy
  • Wrap the computation of a policy and the sampling of an action from the policy into one function.
  • Call this function repeatedly to take steps through your environment, record actions and rewards.
  • Once an episode is finished, register rewards and perform only now the back-propagation and gradient updates.

While this example is specific to REINFORCE, the general idea of structuring your code is applicable to other RL algorithms. Besides, you'll find two other examples in the same repo.

Hope this helps.

Upvotes: 1

Related Questions