Jack Walters
Jack Walters

Reputation: 51

PyTorch Lightning complex-valued CNN training outputs NaN after 1 batch

I have built a complex-valued CNN using ComplexPyTorch, where the layers are wrapped in a torch.ModuleList. When running the network I get through the validation sanity check and 1 batch of the training, then my loss outputs NaNs. Logging gradients in on_after_backward shows NaNs immediately. Does anyone have any suggestions for how I can troubleshoot this?

I have a real-valued version of the network where I'm not using ComplexPyTorch and everything works fine so I can't help but feeling that during the network's backward pass there is a problem with my layers being in a torch.ModuleList. Also, I hard-coded the network without a torch.ModuleList and didn't get this issue either.

Upvotes: 1

Views: 1071

Answers (1)

Jack Walters
Jack Walters

Reputation: 51

For anyone interested, I set detect_anomaly=True in Trainer, then was able to trace the torch function outputting NaNs during backpropagation. In my case it was torch.atan2 so I added a tiny epsilon to its denominator and fixed it, but as a general point I've always found denominator epsilons to be really helpful in preventing NaNs from dividing functions!

Upvotes: 2

Related Questions