Reputation: 1247
See the code snippet:
import torch
x = torch.tensor([-1.], requires_grad=True)
y = torch.where(x > 0., x, torch.tensor([2.], requires_grad=True))
y.backward()
print(x.grad)
The output is tensor([0.])
, but
import torch
x = torch.tensor([-1.], requires_grad=True)
if x > 0.:
y = x
else:
y = torch.tensor([2.], requires_grad=True)
y.backward()
print(x.grad)
The output is None
.
I'm confused that why the output of torch.where
is tensor([0.])
?
import torch
a = torch.tensor([[1,2.], [3., 4]])
b = torch.tensor([-1., -1], requires_grad=True)
a[:,0] = b
(a[0, 0] * a[0, 1]).backward()
print(b.grad)
The output is tensor([2., 0.])
. (a[0, 0] * a[0, 1])
is not in any way related to b[1]
, but the gradient of b[1]
is 0
not None
.
Upvotes: 5
Views: 2983
Reputation: 20980
Tracking based AD, like pytorch, works by tracking. You can't track through things that are not function calls intercepted by the library. By using an if
statement like this, there's no connection between x
and y
, whereas with where
, x
and y
are linked in the expression tree.
Now, for the differences:
0
is the correct derivative of the function x ↦ x > 0 ? x : 2
at the point -1
(since the negative side is constant).x
is not in any way related to y
(in the else
branch). Therefore, the derivative of y
given x
is undefined, which is represented as None
.(You can do such things even in Python, but that requires more sophisticated technology like source transformation. I don't thing it is possible with pytorch.)
Upvotes: 5