Valentin Mercier
Valentin Mercier

Reputation: 385

Get grads of parameters w.r.t a loss term in pytorch

I my Pytorch training i use a composite loss function defined as : enter image description here. In order to update the weights alpha and beta, i need to compute three values : enter image description here which are the the means of the gradients of the loss terms w.r.t to all the weights in the network.

Is there an efficient way to write it in pytorch ?

My training code look like :

for epoc in range(1, nb_epochs+1):
  #init
  optimizer_fo.zero_grad()
  #get the current loss
  loss_total = mynet_fo.loss(tensor_xy_dirichlet,g_boundaries_d,tensor_xy_inside,tensor_f_inter,tensor_xy_neuman,g_boundaries_n)
  #compute gradients
  loss_total.backward(retain_graph=True)
  #optimize
  optimizer_fo.step()

Where my .loss() function directly return the sum of the terms. I've thinking of make a second forward pass and call backward on each Loss term independently but it would be very expensive.

Upvotes: 2

Views: 3691

Answers (2)

Ivan
Ivan

Reputation: 40648

1- Using torch.autograd.grad

You can get the different terms of your gradient only by back-propagating multiple times on your network. In order to avoid having to perform multiple inferences on your input, you can use the torch.autograd.grad utility function instead of doing a conventional backward pass backward. This means you won't pollute the gradients coming from the different terms.

Here is a minimal example that shows the basic idea:

>>> x = torch.rand(1, 10, requires_grad=True)
>>> lossA = x.pow(2).sum()
>>> lossB = x.mean()

Then perform one backward pass on each term out of place. You have to retain the graph on all calls but the last:

>>> gradA = torch.autograd.grad(lossA, x, retain_graph=True)
(tensor([[1.5810, 0.6684, 0.1467, 0.6618, 0.5067, 0.2368, 0.0971, 0.4533, 0.3511,
          1.9858]]),)

>>> gradB = torch.autograd.grad(lossB, x)
(tensor([[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
          0.1000]]),)

This method has some limitations since you are receiving your parameters' gradient as a tuple which is not that convenient.


2- Caching the results of backward

An alternative solution consists in caching the gradient after each successive backward call:

>>> lossA = x.pow(2).sum()
>>> lossB = x.mean()

>>> lossA.backward(retain_graph=True)

Store the gradient and clear the .grad attributes (don't forget to do so otherwise the gradient of lossA will pollute gradB. You will have to adapt this to the general case when handling multiple tensor parameters:

>>> x.gradA = x.grad
>>> x.grad = None

Backward pass on the next loss term:

>>> lossB.backward()
>>> x.gradB = x.grad

Then you can interact with each gradient term locally (i.e. on each parameter separately):

>>> x.gradA, x.gradB
(tensor([[1.5810, 0.6684, 0.1467, 0.6618, 0.5067, 0.2368, 0.0971, 0.4533, 0.3511,
          1.9858]]),
 tensor([[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
          0.1000]]))

The latter method seems more practical.


This essentially comes down to torch.autograd.grad vs torch.autograd.backward, i.e. out-of-place vs in-place... and will ultimately depends on your needs. You can read more about these two functions here.

Upvotes: 5

Edoardo Guerriero
Edoardo Guerriero

Reputation: 1250

Best way to get each loss component separately is to just define the loss outside your model (I assume now the loss is a method of your model since you're calling it as a method).

So you should change your code to look something like

ModelClass:
    def__init__(self):

    def forward(self):
       return output_of_the_model

# you can backpropagate inside the loss class directly 
LossClass(nn.Module):
    def__init__(self):

    def forward(self, model_output, target)
        loss_score_e = compute first component 
        loss_score_e.backward(retain_graph=True)
        # same for b and i components 
        return loss_score_e, loss_score_b, loss_score_i
 

Then the training script is basically the same

loss = LossClass()
for epoc in range(1, nb_epochs+1):
      #init
      optimizer_fo.zero_grad()
      #get the current loss
      loss_e, loss_b, loss_i = loss(tensor_xy_dirichlet,g_boundaries_d,tensor_xy_inside,tensor_f_inter,tensor_xy_neuman,g_boundaries_n)
      #optimize
      optimizer_fo.step()

Upvotes: 0

Related Questions