Reputation: 785
I can see what this code below from this video is trying to do. But the sum
from y=torch.sum(x**2)
confuses me. With sum
operation, y
becomes a tensor with one single value. As I understand .backward()
as calculating derivatives, why would we want to use sum
and reduce y
to one value?
import pytorch
import matplotlib.pyplot as plt
x = torch.linspace(-10.0,10.0,10, requires_grad=True)
Y = x**2
y = torch.sum(x**2)
y.backward()
plt.plot(x.detach().numpy(), Y.detach().numpy(), label="Y")
plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label="derivatives")
plt.legend()
Upvotes: 10
Views: 6663
Reputation: 913
You have a tensor Y
, which has been computed directly or indirectly from tensor X
.
Y.backward()
would calculate the derivative of each element of Y w.r.t. each element of X. This gives us N_out (the number of elements in Y) masks with shape X.shape
.
However, torch.backward()
enforces by default that the gradient that will be stored in X.grad
shall be of the same shape as X
. If N_out=1, there is no problem as we have only one mask. That is why you want to reduce Y to a single value.
If N_out>1, Pytorch wants to take a weighted sum over the N_out gradient masks. But you need to supply the weights for this weighted sum! You can do this with the gradient argument:
Y.backward(gradient=weights_shaped_like_Y)
If you give every element of Y weight 1, you will get the same behaviour as using torch.sum(Y).backward()
.
Hence, the following two programs are equivalent:
x = torch.linspace(-10.0,10.0,10, requires_grad=True)
Y = x**2
y = torch.sum(x**2)
y.backward()
and
x = torch.linspace(-10.0,10.0,10, requires_grad=True)
Y = x**2
y.backward(gradient=torch.ones_like(Y))
Upvotes: 3
Reputation: 114786
You can only compute partial derivatives for a scalar function. What backwards()
gives you is d loss/d parameter
and you expect a single gradient value per parameter/variable.
Had your loss function been a vector function, i.e., mapping from multiple inputs to multiple outputs, you would have ended up with multiple gradients per parameter/variable.
Please see this answer for more information.
Upvotes: 12