Jerry
Jerry

Reputation: 97

PyTorch backward() on a tensor element affected by nan in other elements

Consider the following two examples:

x = torch.tensor(1., requires_grad=True)
y = torch.tensor(0., requires_grad=True)
z = torch.full((2, ), float("nan"))
z0 = x * y / y
z1 = x + y
print(z0, z1) # tensor(nan, grad_fn=<DivBackward0>) tensor(1., grad_fn=<AddBackward0>)
z1.backward()
print(x.grad) # tensor(1.)


x = torch.tensor(1., requires_grad=True)
y = torch.tensor(0., requires_grad=True)
z = torch.full((2, ), float("nan"))
z[0] = x * y / y
z[1] = x + y
print(z) # tensor([nan, 1.], grad_fn=<CopySlices>)
z[1].backward()
print(x.grad) # tensor(nan)

In example 1, z0 does not affect z1, and the backward() of z1 executes as expected and x.grad is not nan. However, in example 2, the backward() of z[1] seems to be affected by z[0], and x.grad is nan.

How do I prevent this (example 1 is desired behaviour)? Specifically I need to retain the nan in z[0] so adding epsilon to division does not help.

Upvotes: 4

Views: 2443

Answers (1)

iacob
iacob

Reputation: 24201

When indexing the tensor in the assignment, PyTorch accesses all elements of the tensor (it uses binary multiplicative masking under the hood to maintain differentiability) and this is where it is picking up the nan of the other element (since 0*nan -> nan).

We can see this in the computational graph:

torchviz.make_dot(z1, params={'x':x,'y':y}) torchviz.make_dot(z[1], params={'x':x,'y':y})
enter image description here enter image description here

If you wish to avoid this behaviour, either mask the nan's, or do as you have done in the first example - separate these into two different objects.

Upvotes: 3

Related Questions