Reputation: 1736
I am trying to understand how pytorch autograd works. If I have functions y = 2x and z = y**2, if I do normal differentiation, I get dz/dx at x = 1 as 8 (dz/dx = dz/dy * dy/dx = 2y*2 = 2(2x)*2 = 8x). Or, z = (2x)**2 = 4x^2 and dz/dx = 8x, so at x = 1, it is 8.
If I do the same with pytorch autograd, I get 4
x = torch.ones(1,requires_grad=True)
y = 2*x
z = y**2
x.backward(z)
print(x.grad)
which prints
tensor([4.])
where am I going wrong?
Upvotes: 3
Views: 718
Reputation: 6618
If you still have some confusion on autograd in pytorch, Please refer this: This will be basic xor gate representation
import numpy as np
import torch.nn.functional as F
inputs = torch.tensor(
[
[0, 0],
[0, 1],
[1, 0],
[1, 1]
]
)
outputs = torch.tensor(
[
0,
1,
1,
0
],
)
weights = torch.randn(1, 2)
weights.requires_grad = True #set it as true for gradient computation
bias = torch.randn(1, requires_grad=True) #set it as true for gradient computation
preds = F.linear(inputs, weights, bias) #create a basic linear model
loss = (outputs - preds).mean()
loss.backward()
print(weights.grad) # this will print your weights
Upvotes: 0
Reputation: 22184
You're using Tensor.backward
wrong. To get the result you asked for you should use
x = torch.ones(1,requires_grad=True)
y = 2*x
z = y**2
z.backward() # <-- fixed
print(x.grad)
The call to z.backward()
invokes the back-propagation algorithm, starting at z
and working back to each leaf node in the computation graph. In this case x
is the only leaf node. After calling z.backward()
the computation graph is reset and the .grad
member of each leaf node is updated with the gradient of z
with respect to the leaf node (in this case dz/dx).
What's actually happening in your original code? Well, what you've done is apply back-propagation starting at x
. With no arguments x.backward()
would simply result in x.grad
being set to 1
since dx/dx = 1. The additional argument (gradient
) is effectively a scale to apply to the resulting gradient. In this case z=4
so you get x.grad = z * dx/dx = 4 * 1 = 4
. If interested, you can check out this for more information on what the gradient
argument does.
Upvotes: 6