Reputation: 31
I have the following if-and-else query, which is called multiple times in my training loop:
if torch.logical_or(d0 < 1e-5, d1 < 1e-5):
h0 = torch.tensor(1e-6, dtype=dtype, device=device)
else:
h0 = (0.01 * d0 / d1)
where d0, d1 and are both torch.float32 tensors on cuda. I would like to avoid the cpu synchronization due to the if statement.
I've tried:
ifelse = torch.logical_or(d0 < 1e-5, d1 < 1e-5)
h0 = torch.tensor(1e-6, dtype=dtype, device=device)*ifelse + (0.01 * d0 / d1)*~ifelse
Unfortunately, as soon as ifelse is true, this leads to nan gradients and my training fails. I assume, that the problem is, that the gradient will be 0 in that case and not None.
Is there a possibility, to avoid the cpu sync with a custom autograd function?
Upvotes: 0
Views: 87