BBB
BBB

Reputation: 75

How can I implement this function of torch.where and torch.sqrt in a better way so that autograd works?

I have a class which has this forward operator:

def forward(self, x):
    r = torch.sqrt(x[:, 0:1]**2 + x[:, 1:2]**2)    
    value = torch.where(r <= 1, torch.sqrt(1.-r**2), -1)
    return value

This causes a problem with NaNs when it comes to taking gradients.

To circumvent this, one can use:

def forward(self, x):
    i = x[:, 0:1]**2 + x[:, 1:2]**2
    r = torch.sqrt(torch.relu(i))
    alternative=torch.sqrt(torch.relu(1.-r**2))
    value = torch.where(r <= 1, alternative, -1)

    return value

This gets rid of the NaN problems. Is there a better way to do this (with minimal changes to the original code if possible)?

Upvotes: 0

Views: 288

Answers (1)

Valentin Goldit&#233;
Valentin Goldit&#233;

Reputation: 1219

You can use clamp to make the Nan values equal to zero.

def forward2(x):
    r = torch.sqrt(x[:, 0:1] ** 2 + x[:, 1:2] ** 2)
    r = torch.clamp(r, 0.0, 1.0)  # Only one additional line
    value = torch.where(r <= 1, torch.sqrt(1.0 - r**2), -1)
    return value

Test with batch size of 5 and normal input:

Without clamp:

tensor([[-0.0069, -0.0278],
        [ 0.0203, -0.0501],
        [ 0.0126, -0.0309],
        [    nan,     nan],
        [    nan,     nan]])

With clamp:

tensor([[-0.0069, -0.0278],
        [ 0.0203, -0.0501],
        [ 0.0126, -0.0309],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000]])

Upvotes: 2

Related Questions