Simon B
Simon B

Reputation: 329

Jax - Debugging NaN-values

Nice evening everyone,

i spent the last 6 hours trying to debug seemingly randomly occuring NaN-values in Jax. I have narrowed down that the NaNs initially stem from either the loss-function or its gradient.

A minimal-notebook that reproduces the error is available here https://colab.research.google.com/drive/1uXa-igMm9QBOOl8ZNdK1OkwxRFlLqvZD?usp=sharing

This might also be interesting as a use-case for Jax. I use Jax to solve an orientation estimation task when only a limited amount of gyro-/acc-measurements is available. Here an efficient implementation of quaternion-operations is nice.

The training-loop starts off fine but eventually diverges

Step 0| Loss: 4.550444602966309 | Time: 13.910547971725464s
Step 1| Loss: 4.110116481781006 | Time: 5.478027105331421s
Step 2| Loss: 3.7159230709075928 | Time: 5.476970911026001s
Step 3| Loss: 3.491917371749878 | Time: 5.474078416824341s
Step 4| Loss: 3.232130765914917 | Time: 5.433410406112671s
Step 5| Loss: 3.095140218734741 | Time: 5.433837413787842s
Step 6| Loss: 2.9580295085906982 | Time: 5.429029941558838s
Step 7| Loss: nan | Time: 5.427825689315796s
Step 8| Loss: nan | Time: 5.463077545166016s
Step 9| Loss: nan | Time: 5.479652643203735s

This can be traced back by diverging gradients as can be seen from the following snippet

(loss, _), grads = loss_fn(params, X[0], y[0], rnn.reset_carry(bs=2))

grads["params"]["Dense_0"]["bias"] # shape=(bs, out_features)
DeviceArray([[-0.38666773,         nan, -1.0433975 ,         nan],
             [ 0.623061  , -0.20950513,  0.8459796 , -0.42356613]],            dtype=float32)

My question is: How to debug this?

Enabling NaN-debugging

Enabling nan-debugging did not really help as it only ended up leading to huge stacktraces with many hidden traces ..

from jax.config import config
config.update("jax_debug_nans", True)

Any help would be much appreciated! Thanks :)

Upvotes: 7

Views: 6715

Answers (1)

fr_andres
fr_andres

Reputation: 6687

A few approaches (decently documented in the main docs) may work:

  1. As a hotfix, switching to float64 can do the trick. More info here: jax.config.update("jax_enable_x64", True)
  2. Gradient Clipping is All You Need (docs)
  3. You can sometimes implement your own backprop, this can help when e.g. you combine 2 functions that saturate into one that doesn't, or to enforce values at singularities.
  4. Diagnose your backprop by inspecting the computational graph. Usually look for divisions, signaled with the div token:
from jax import make_jaxpr

# If grad_fn(x) gives you trouble, you can inspect the computation as follows:
grad_fn = jit(value_and_grad(my_forward_prop, argnums=0))
make_jaxpr(grad_fn)(x)

Note that the community is quite active and some support has been and is being added to diagnose NaNs:

Hope this helps!
Andres

Upvotes: 6

Related Questions