Reputation: 51
I have built a complex-valued CNN using ComplexPyTorch, where the layers are wrapped in a torch.ModuleList
. When running the network I get through the validation sanity check and 1 batch of the training, then my loss outputs NaNs. Logging gradients in on_after_backward
shows NaNs immediately. Does anyone have any suggestions for how I can troubleshoot this?
I have a real-valued version of the network where I'm not using ComplexPyTorch and everything works fine so I can't help but feeling that during the network's backward pass there is a problem with my layers being in a torch.ModuleList
. Also, I hard-coded the network without a torch.ModuleList
and didn't get this issue either.
Upvotes: 1
Views: 1071
Reputation: 51
For anyone interested, I set detect_anomaly=True
in Trainer
, then was able to trace the torch function outputting NaNs during backpropagation. In my case it was torch.atan2
so I added a tiny epsilon to its denominator and fixed it, but as a general point I've always found denominator epsilons to be really helpful in preventing NaNs from dividing functions!
Upvotes: 2