Reputation: 378
I have a pytorch tensor with NaN inside, when I calculate the loss function using a simple MSE Loss the gradient becomes NaN even if I mask out the NaN values.
Weirdly this happens only when the mask is applyied after calculating the loss and only when the loss has a pow operation inside. The various cases follow
import torch
torch.autograd.set_detect_anomaly(True)
x = torch.rand(10, 10)
y = torch.rand(10, 10)
w = torch.rand(10, 10, requires_grad=True)
y[y > 0.5] = torch.nan
o = w @ x
l = (y - o)**2
l = l[~y.isnan()]
try:
l.mean().backward(retain_graph=True)
except RuntimeError:
print('(y-o)**2 caused nan gradient')
l = (y - o)
l = l[~y.isnan()]
try:
l.mean().backward(retain_graph=True)
except RuntimeError():
pass
else:
print('y-o does not cause nan gradient')
l = (y[~y.isnan()] - o[~y.isnan()])**2
l.mean().backward()
print('masking before pow does not propagate nan gradient')
What makes NaN gradients propagate when passing through the backward pass of the pow function?
Upvotes: 0
Views: 121
Reputation: 5473
The nans don't come from the gradient, the nans come from the forward pass. These are multiplied by gradient values in the backward pass (chain rule).
Take a simpler example. Set exactly one value in y
to nan:
x = torch.rand(10, 10)
y = torch.rand(10, 10)
w = torch.rand(10, 10, requires_grad=True)
y[0,0] = torch.nan
Now compute your intermediates and retain gradients
o = w@x
o.retain_grad()
l = (y - o).pow(2)
l.retain_grad()
l_nonnan = l[~y.isnan()]
l_nonnan.retain_grad()
l_nonnan.mean().backward()
Inspect the gradients
l_nonnan
has full gradientsl
has full gradients except for l.grad[0,0]
which is 0
o
has a nan gradient at o.grad[0,0]
w
has nan gradients for the entire first rowThis is due to how the computation propagates. We set y[0,0] = torch.nan
. We compute l = (y - o).pow(2)
this means o[0,0]
is nan because it directly interacts with the nan from y
.
o
is created via o = w@x
. This means the value at o[0,0] = (w[0] * x[:,0]).sum()
. When we run the computation in reverse in backprop, the gradient of o[0,0]
(which we know to be nan) propagates back to all ements of w[0]
. This is why the entire row has nan gradients.
When you set a bunch of nans randomly, you get the same effect on more elements.
You can avoid this via l = (y[~y.isnan()] - o[~y.isnan()])**2
because when you do that you prevent the nans in y
from entering the computation in the first place.
Upvotes: 0