PuffedRiceCrackers
PuffedRiceCrackers

Reputation: 785

Why torch.sum() before doing .backward()?

I can see what this code below from this video is trying to do. But the sum from y=torch.sum(x**2) confuses me. With sum operation, y becomes a tensor with one single value. As I understand .backward() as calculating derivatives, why would we want to use sum and reduce y to one value?

import pytorch
import matplotlib.pyplot as plt 
x = torch.linspace(-10.0,10.0,10, requires_grad=True)
Y = x**2
y = torch.sum(x**2)     
y.backward()

plt.plot(x.detach().numpy(), Y.detach().numpy(), label="Y")
plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label="derivatives")
plt.legend()

Upvotes: 10

Views: 6663

Answers (2)

Jonas De Schouwer
Jonas De Schouwer

Reputation: 913

You have a tensor Y, which has been computed directly or indirectly from tensor X.

Y.backward() would calculate the derivative of each element of Y w.r.t. each element of X. This gives us N_out (the number of elements in Y) masks with shape X.shape.

However, torch.backward() enforces by default that the gradient that will be stored in X.grad shall be of the same shape as X. If N_out=1, there is no problem as we have only one mask. That is why you want to reduce Y to a single value.

If N_out>1, Pytorch wants to take a weighted sum over the N_out gradient masks. But you need to supply the weights for this weighted sum! You can do this with the gradient argument:
Y.backward(gradient=weights_shaped_like_Y)

If you give every element of Y weight 1, you will get the same behaviour as using torch.sum(Y).backward().
Hence, the following two programs are equivalent:

x = torch.linspace(-10.0,10.0,10, requires_grad=True)
Y = x**2
y = torch.sum(x**2)     
y.backward()

and

x = torch.linspace(-10.0,10.0,10, requires_grad=True)
Y = x**2   
y.backward(gradient=torch.ones_like(Y))

Upvotes: 3

Shai
Shai

Reputation: 114786

You can only compute partial derivatives for a scalar function. What backwards() gives you is d loss/d parameter and you expect a single gradient value per parameter/variable.
Had your loss function been a vector function, i.e., mapping from multiple inputs to multiple outputs, you would have ended up with multiple gradients per parameter/variable.

Please see this answer for more information.

Upvotes: 12

Related Questions