Reputation: 329
Nice evening everyone,
i spent the last 6 hours trying to debug seemingly randomly occuring NaN-values in Jax. I have narrowed down that the NaNs initially stem from either the loss-function or its gradient.
A minimal-notebook that reproduces the error is available here https://colab.research.google.com/drive/1uXa-igMm9QBOOl8ZNdK1OkwxRFlLqvZD?usp=sharing
This might also be interesting as a use-case for Jax. I use Jax to solve an orientation estimation task when only a limited amount of gyro-/acc-measurements is available. Here an efficient implementation of quaternion-operations is nice.
The training-loop starts off fine but eventually diverges
Step 0| Loss: 4.550444602966309 | Time: 13.910547971725464s
Step 1| Loss: 4.110116481781006 | Time: 5.478027105331421s
Step 2| Loss: 3.7159230709075928 | Time: 5.476970911026001s
Step 3| Loss: 3.491917371749878 | Time: 5.474078416824341s
Step 4| Loss: 3.232130765914917 | Time: 5.433410406112671s
Step 5| Loss: 3.095140218734741 | Time: 5.433837413787842s
Step 6| Loss: 2.9580295085906982 | Time: 5.429029941558838s
Step 7| Loss: nan | Time: 5.427825689315796s
Step 8| Loss: nan | Time: 5.463077545166016s
Step 9| Loss: nan | Time: 5.479652643203735s
This can be traced back by diverging gradients as can be seen from the following snippet
(loss, _), grads = loss_fn(params, X[0], y[0], rnn.reset_carry(bs=2))
grads["params"]["Dense_0"]["bias"] # shape=(bs, out_features)
DeviceArray([[-0.38666773, nan, -1.0433975 , nan],
[ 0.623061 , -0.20950513, 0.8459796 , -0.42356613]], dtype=float32)
Enabling nan-debugging did not really help as it only ended up leading to huge stacktraces with many hidden traces ..
from jax.config import config
config.update("jax_debug_nans", True)
Any help would be much appreciated! Thanks :)
Upvotes: 7
Views: 6715
Reputation: 6687
A few approaches (decently documented in the main docs) may work:
float64
can do the trick. More info here: jax.config.update("jax_enable_x64", True)
div
token:from jax import make_jaxpr
# If grad_fn(x) gives you trouble, you can inspect the computation as follows:
grad_fn = jit(value_and_grad(my_forward_prop, argnums=0))
make_jaxpr(grad_fn)(x)
Note that the community is quite active and some support has been and is being added to diagnose NaNs
:
Hope this helps!
Andres
Upvotes: 6