gaussclb
gaussclb

Reputation: 1247

what is the difference between if-else statement and torch.where in pytorch?

See the code snippet:

import torch
x = torch.tensor([-1.], requires_grad=True)
y = torch.where(x > 0., x, torch.tensor([2.], requires_grad=True))
y.backward()
print(x.grad)

The output is tensor([0.]), but

import torch
x = torch.tensor([-1.], requires_grad=True)
if x > 0.:
    y = x
else:
    y = torch.tensor([2.], requires_grad=True)
y.backward()
print(x.grad)

The output is None.

I'm confused that why the output of torch.where is tensor([0.])?

update

import torch
a = torch.tensor([[1,2.], [3., 4]])
b = torch.tensor([-1., -1], requires_grad=True)
a[:,0] = b

(a[0, 0] * a[0, 1]).backward()
print(b.grad)

The output is tensor([2., 0.]). (a[0, 0] * a[0, 1]) is not in any way related to b[1], but the gradient of b[1] is 0 not None.

Upvotes: 5

Views: 2983

Answers (1)

phipsgabler
phipsgabler

Reputation: 20980

Tracking based AD, like pytorch, works by tracking. You can't track through things that are not function calls intercepted by the library. By using an if statement like this, there's no connection between x and y, whereas with where, x and y are linked in the expression tree.

Now, for the differences:

  • In the first snippet, 0 is the correct derivative of the function x ↦ x > 0 ? x : 2 at the point -1 (since the negative side is constant).
  • In the second snippet, as I said, x is not in any way related to y (in the else branch). Therefore, the derivative of y given x is undefined, which is represented as None.

(You can do such things even in Python, but that requires more sophisticated technology like source transformation. I don't thing it is possible with pytorch.)

Upvotes: 5

Related Questions