zafirzarya
zafirzarya

Reputation: 153

How does loss.backward() relate to the appropriate parameters of the model?

I'm new in PyTorch and I'm having trouble understanding how loss knows to compute the gradients through loss.backward()?

Sure, I understand that the parameters need to have requires_grad=True and I understand that it sets x.grad to the appropriate gradient only for the optimizer later to perform the gradient update.

The optimizer is linked to the model parameters when it's instantiated, but the loss is never linked to the model.

I've been going through this thread, but I don't think anyone answered it clearly and the person that started the thread seems to have the same issue as I do.

What happens when I have two different networks with two different loss functions and two different optimizers. I will easily link the optimizers to each of the networks, but how will the loss functions know how to compute the gradients for each of their appropriate network if I never link them together?

Upvotes: 5

Views: 3258

Answers (3)

Sina Atalay
Sina Atalay

Reputation: 33

torch.Tensor objects (loss functions return torch.Tensor objects as well) store some kind of a history of computations (computational graph).

.backward() is a method of torch.Tensor objects. The backward() method performs the following steps:

  1. Look at the computational graph and find all other tensors that have played a role in this computation.
  2. If the tensors that have played a role in this computation have required_grad=True, then write that component of the gradient in that tensor's grad attribute.

So .backward() updates all the related tensor's grad attribute. The optimizer looks at model parameters' (torch.Tensor objects) grad attributes and does its job.

Play with the code below to understand more:

import torch

# Almost everything in PyTorch is a torch.Tensor, and every computation is between
# torch.Tensor s. Each computation results in a new torch.Tensor object and this object
# stores something called a computational graph. This graph contains all the operations
# that were performed on the tensor.

# When you call .backward() on a tensor, it looks at the whole computational graph and
# computes the gradients of this tensor with respect to all the tensors that have
# requires_grad=True and that were used to compute this tensor. It then stores these
# gradients in the .grad attribute of the respective tensors.

x1 = torch.tensor(2, requires_grad=True, dtype=torch.float32)
x2 = torch.tensor(3, requires_grad=True, dtype=torch.float32)

f = x1 * 2 + x2 * 3 + 4
f.backward()
print(f"gradient of x1 = {x1.grad}")  # 2
print(f"gradient of x2 = {x2.grad}")  # 3

Upvotes: 0

jodag
jodag

Reputation: 22214

Loss is itself a tensor which is derived from the parameters of the network. A graph is implicitly constructed where each new tensor, including loss, points back to the tensors which were involved with it's construction. When you apply loss.backward() pytorch follows the graph backwards and populates the .grad member of each tensor with the partial dervative of loss with respect to that tensor using the chain rule (i.e. backpropagation)

Upvotes: 4

Vimal Thilak
Vimal Thilak

Reputation: 2680

The question appears to be very general so I can only provide suggestions to get you started (I hope):

  • Plot the graph to understand how data flows in your computation graph
  • Look at torch.autograd documentation to see how the framework records all operations that it will use for computing gradients ("backward") https://pytorch.org/docs/stable/autograd.html
  • You can use hooks (available with Python 3 + PyTorch) to figure out the gradient values. It should also give you a sense for how gradients are flowing in your graph Could you please post graphs of the graph(s) so that you can get concrete answers?

If the above does not answer your question, I ask that you clarify your question with example code

Upvotes: 0

Related Questions