Reputation: 153
I'm new in PyTorch and I'm having trouble understanding how loss
knows to compute the gradients through loss.backward()
?
Sure, I understand that the parameters need to have requires_grad=True
and I understand that it sets x.grad
to the appropriate gradient only for the optimizer later to perform the gradient update.
The optimizer is linked to the model parameters when it's instantiated, but the loss is never linked to the model.
I've been going through this thread, but I don't think anyone answered it clearly and the person that started the thread seems to have the same issue as I do.
What happens when I have two different networks with two different loss functions and two different optimizers. I will easily link the optimizers to each of the networks, but how will the loss functions know how to compute the gradients for each of their appropriate network if I never link them together?
Upvotes: 5
Views: 3258
Reputation: 33
torch.Tensor
objects (loss functions return torch.Tensor
objects as well) store some kind of a history of computations (computational graph).
.backward()
is a method of torch.Tensor
objects. The backward()
method performs the following steps:
required_grad=True,
then write that component of the gradient in that tensor's grad
attribute.So .backward()
updates all the related tensor's grad
attribute. The optimizer looks at model parameters' (torch.Tensor
objects) grad
attributes and does its job.
Play with the code below to understand more:
import torch
# Almost everything in PyTorch is a torch.Tensor, and every computation is between
# torch.Tensor s. Each computation results in a new torch.Tensor object and this object
# stores something called a computational graph. This graph contains all the operations
# that were performed on the tensor.
# When you call .backward() on a tensor, it looks at the whole computational graph and
# computes the gradients of this tensor with respect to all the tensors that have
# requires_grad=True and that were used to compute this tensor. It then stores these
# gradients in the .grad attribute of the respective tensors.
x1 = torch.tensor(2, requires_grad=True, dtype=torch.float32)
x2 = torch.tensor(3, requires_grad=True, dtype=torch.float32)
f = x1 * 2 + x2 * 3 + 4
f.backward()
print(f"gradient of x1 = {x1.grad}") # 2
print(f"gradient of x2 = {x2.grad}") # 3
Upvotes: 0
Reputation: 22214
Loss is itself a tensor which is derived from the parameters of the network. A graph is implicitly constructed where each new tensor, including loss, points back to the tensors which were involved with it's construction. When you apply loss.backward()
pytorch follows the graph backwards and populates the .grad
member of each tensor with the partial dervative of loss with respect to that tensor using the chain rule (i.e. backpropagation)
Upvotes: 4
Reputation: 2680
The question appears to be very general so I can only provide suggestions to get you started (I hope):
If the above does not answer your question, I ask that you clarify your question with example code
Upvotes: 0