user3656142
user3656142

Reputation: 467

Need help understanding the gradient function in pytorch

The following code


w = np.array([[2., 2.],[2., 2.]])
x = np.array([[3., 3.],[3., 3.]])
b = np.array([[4., 4.],[4., 4.]])
w = torch.tensor(w, requires_grad=True)
x = torch.tensor(x, requires_grad=True)
b = torch.tensor(b, requires_grad=True)


y = w*x + b 
print(y)
# tensor([[10., 10.],
#         [10., 10.]], dtype=torch.float64, grad_fn=<AddBackward0>)

y.backward(torch.FloatTensor([[1, 1],[ 1, 1]]))

print(w.grad)
# tensor([[3., 3.],
#         [3., 3.]], dtype=torch.float64)

print(x.grad)
# tensor([[2., 2.],
#         [2., 2.]], dtype=torch.float64)

print(b.grad)
# tensor([[1., 1.],
#         [1., 1.]], dtype=torch.float64)

As the tensor argument inside gradient function is an all ones tensor in the shape of the input tensor, my understanding says that

  1. w.grad means derivative of y w.r.t w, and produces b,

  2. x.grad means derivative of y w.r.t x, and produces b and

  3. b.grad means derivative of y w.r.t b, and produces all ones.

Out of these, only point 3 answer is matching my expected result. Can someone help me in understanding the first two answers. I think I understand the accumulation part, but don't think that is happening here.

Upvotes: 2

Views: 486

Answers (1)

Michael Jungo
Michael Jungo

Reputation: 33020

To find the correct derivatives in this example, we need to take the sum and product rule into consideration.

Sum rule:

Sum Rule

Product rule:

Product Rule

That means the derivatives of your equation are calculated as follows.

With respect to x:

Derivative with respect to x

With respect to w:

Derivative with respect to w

With respect to b:

Derivative with respect to b

The gradients reflect exactly that:

torch.equal(w.grad, x) # => True

torch.equal(x.grad, w) # => True

torch.equal(b.grad, torch.tensor([[1, 1], [1, 1]], dtype=torch.float64)) # => True

Upvotes: 5

Related Questions