Paul_0
Paul_0

Reputation: 378

Torch backward PowBackward0 causes nan gradient where it shouldn't

I have a pytorch tensor with NaN inside, when I calculate the loss function using a simple MSE Loss the gradient becomes NaN even if I mask out the NaN values.

Weirdly this happens only when the mask is applyied after calculating the loss and only when the loss has a pow operation inside. The various cases follow

import torch
torch.autograd.set_detect_anomaly(True)

x = torch.rand(10, 10) 
y = torch.rand(10, 10)
w = torch.rand(10, 10, requires_grad=True)
y[y > 0.5] = torch.nan


o = w @ x
l = (y - o)**2
l = l[~y.isnan()]

try:
    l.mean().backward(retain_graph=True)
except RuntimeError:
    print('(y-o)**2 caused nan gradient')

l = (y - o)
l = l[~y.isnan()]

try:
    l.mean().backward(retain_graph=True)
except RuntimeError():
    pass
else:
    print('y-o does not cause nan gradient')

l = (y[~y.isnan()] - o[~y.isnan()])**2
l.mean().backward()
print('masking before pow does not propagate nan gradient')

What makes NaN gradients propagate when passing through the backward pass of the pow function?

Upvotes: 0

Views: 121

Answers (1)

Karl
Karl

Reputation: 5473

The nans don't come from the gradient, the nans come from the forward pass. These are multiplied by gradient values in the backward pass (chain rule).

Take a simpler example. Set exactly one value in y to nan:

x = torch.rand(10, 10) 
y = torch.rand(10, 10)
w = torch.rand(10, 10, requires_grad=True)
y[0,0] = torch.nan

Now compute your intermediates and retain gradients

o = w@x
o.retain_grad()

l = (y - o).pow(2)
l.retain_grad()

l_nonnan = l[~y.isnan()]
l_nonnan.retain_grad()

l_nonnan.mean().backward()

Inspect the gradients

  • l_nonnan has full gradients
  • l has full gradients except for l.grad[0,0] which is 0
  • o has a nan gradient at o.grad[0,0]
  • w has nan gradients for the entire first row

This is due to how the computation propagates. We set y[0,0] = torch.nan. We compute l = (y - o).pow(2) this means o[0,0] is nan because it directly interacts with the nan from y.

o is created via o = w@x. This means the value at o[0,0] = (w[0] * x[:,0]).sum(). When we run the computation in reverse in backprop, the gradient of o[0,0] (which we know to be nan) propagates back to all ements of w[0]. This is why the entire row has nan gradients.

When you set a bunch of nans randomly, you get the same effect on more elements.

You can avoid this via l = (y[~y.isnan()] - o[~y.isnan()])**2 because when you do that you prevent the nans in y from entering the computation in the first place.

Upvotes: 0

Related Questions