Reputation: 75
I have a class which has this forward operator:
def forward(self, x):
r = torch.sqrt(x[:, 0:1]**2 + x[:, 1:2]**2)
value = torch.where(r <= 1, torch.sqrt(1.-r**2), -1)
return value
This causes a problem with NaNs when it comes to taking gradients.
To circumvent this, one can use:
def forward(self, x):
i = x[:, 0:1]**2 + x[:, 1:2]**2
r = torch.sqrt(torch.relu(i))
alternative=torch.sqrt(torch.relu(1.-r**2))
value = torch.where(r <= 1, alternative, -1)
return value
This gets rid of the NaN problems. Is there a better way to do this (with minimal changes to the original code if possible)?
Upvotes: 0
Views: 288
Reputation: 1219
You can use clamp to make the Nan values equal to zero.
def forward2(x):
r = torch.sqrt(x[:, 0:1] ** 2 + x[:, 1:2] ** 2)
r = torch.clamp(r, 0.0, 1.0) # Only one additional line
value = torch.where(r <= 1, torch.sqrt(1.0 - r**2), -1)
return value
Test with batch size of 5 and normal input:
Without clamp:
tensor([[-0.0069, -0.0278],
[ 0.0203, -0.0501],
[ 0.0126, -0.0309],
[ nan, nan],
[ nan, nan]])
With clamp:
tensor([[-0.0069, -0.0278],
[ 0.0203, -0.0501],
[ 0.0126, -0.0309],
[ 0.0000, 0.0000],
[ 0.0000, 0.0000]])
Upvotes: 2