Ripleys
Ripleys

Reputation: 31

Avoiding cpu/gpu synchronization due to python control flow with a constant as alternative leads to incorrect gradients

I have the following if-and-else query, which is called multiple times in my training loop:

if torch.logical_or(d0 < 1e-5, d1 < 1e-5):
    h0 = torch.tensor(1e-6, dtype=dtype, device=device)
else:
    h0 = (0.01 * d0 / d1) 

where d0, d1 and are both torch.float32 tensors on cuda. I would like to avoid the cpu synchronization due to the if statement.

I've tried:

ifelse = torch.logical_or(d0 < 1e-5, d1 < 1e-5)
h0 = torch.tensor(1e-6, dtype=dtype, device=device)*ifelse + (0.01 * d0 / d1)*~ifelse

Unfortunately, as soon as ifelse is true, this leads to nan gradients and my training fails. I assume, that the problem is, that the gradient will be 0 in that case and not None.

Is there a possibility, to avoid the cpu sync with a custom autograd function?

Upvotes: 0

Views: 87

Answers (0)

Related Questions