Dimitris Poulopoulos
Dimitris Poulopoulos

Reputation: 1159

What is the first parameter (gradients) of the backward method, in pytorch?

We have the following code from the pytorch documentation:

x = torch.randn(3)
x = Variable(x, requires_grad=True)

y = x * 2
while y.data.norm() < 1000:
    y = y * 2

gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradients)

What exactly is the gradients parameter that we pass into the backward method and based on what do we initialize it?

Upvotes: 4

Views: 1381

Answers (1)

cleros
cleros

Reputation: 4333

To fully answer your question, it'd require a somewhat longer explanation that evolves around the details of how Backprop or, more fundamentally, the chain rule works.

The short programmatic answer is that the backwards function of a Variable computes the gradient of all variables in the computation graph attached to that Variable. (To clarify: if you have a = b + c, then the computation graph (recursively) points first to b, then to c, then to how those are computed, etc.) and cumulatively stores (sums) these gradients in the .grad attribute of these Variables. When you then call opt.step(), i.e. a step of your optimizer, it adds a fraction of that gradient to the value of these Variables.

That said, there are two answers when you look at it conceptually: If you want to train a Machine Learning model, you normally want to have the gradient with respect to some loss function. In this case, the gradients computed will be such that the overall loss (a scalar value) will decrease when applying the step function. In this special case, we want to compute the gradient to a specific value, namely the unit length step (so that the learning rate will compute the fraction of the gradients that we want). This means that if you have a loss function, and you call loss.backward(), this will compute the same as loss.backward(torch.FloatTensor([1.])).

While this is the common use case for backprop in DNNs, it is only a special case of general differentiation of functions. More generally, the symbolic differentiation packages (autograd in this case, as part of pytorch) can be used to compute gradients of earlier parts of the computation graph with respect to any gradient at a root of whatever subgraph you choose. This is when the key-word argument gradient comes in useful, since you can provide this "root-level" gradient there, even for non-scalar functions!

To illustrate, here's a small example:

a = nn.Parameter(torch.FloatTensor([[1, 1], [2, 2]]))
b = nn.Parameter(torch.FloatTensor([[1, 2], [1, 2]]))
c = torch.sum(a - b)
c.backward(None)  # could be c.backward(torch.FloatTensor([1.])) for the same result
print(a.grad, b.grad) 

prints:

Variable containing:
 1  1
 1  1
[torch.FloatTensor of size 2x2]
 Variable containing:
-1 -1
-1 -1
[torch.FloatTensor of size 2x2]

While

a = nn.Parameter(torch.FloatTensor([[1, 1], [2, 2]]))
b = nn.Parameter(torch.FloatTensor([[1, 2], [1, 2]]))
c = torch.sum(a - b)
c.backward(torch.FloatTensor([[1, 2], [3, 4]]))
print(a.grad, b.grad)

prints:

Variable containing:
 1  2
 3  4
[torch.FloatTensor of size 2x2]
 Variable containing:
-1 -2
-3 -4
[torch.FloatTensor of size 2x2]

and

a = nn.Parameter(torch.FloatTensor([[0, 0], [2, 2]]))
b = nn.Parameter(torch.FloatTensor([[1, 2], [1, 2]]))
c = torch.matmul(a, b)
c.backward(torch.FloatTensor([[1, 1], [1, 1]]))  # we compute w.r.t. a non-scalar variable, so the gradient supplied cannot be scalar, either!
print(a.grad, b.grad)

prints

Variable containing:
 3  3
 3  3
[torch.FloatTensor of size 2x2]
 Variable containing:
 2  2
 2  2
[torch.FloatTensor of size 2x2]

and

a = nn.Parameter(torch.FloatTensor([[0, 0], [2, 2]]))
b = nn.Parameter(torch.FloatTensor([[1, 2], [1, 2]]))
c = torch.matmul(a, b)
c.backward(torch.FloatTensor([[1, 2], [3, 4]]))  # we compute w.r.t. a non-scalar variable, so the gradient supplied cannot be scalar, either!
print(a.grad, b.grad)

prints:

Variable containing:
  5   5
 11  11
[torch.FloatTensor of size 2x2]
 Variable containing:
 6  8
 6  8
[torch.FloatTensor of size 2x2]

Upvotes: 1

Related Questions