Reputation: 982
I am able to run this file vit_jax.ipynb on colab and perform training and run my experiments but when I try to replicate it on my cluster, I am getting an error during training given below. However, the forward pass to calculate accuracy works fine on my cluster.
I have 4 GTX 1080 with CUDA10.1 version on my cluster and using tensorflow==2.4.0 and jax[cuda101]==0.2.18. I am running this as jupyter notebook from inside a docker container.
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-57-176d6124ae02> in <module>()
11 opt_repl, loss_repl, update_rng_repl = update_fn_repl(
---> 12 opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)
13 losses.append(loss_repl[0])
/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
182 try:
--> 183 return fun(*args, **kwargs)
184 except Exception as e:
/usr/local/lib/python3.7/dist-packages/jax/_src/api.py in f_pmapped(*args, **kwargs)
1638 name=flat_fun.__name__, donated_invars=tuple(donated_invars),
-> 1639 global_arg_shapes=tuple(global_arg_shapes_flat))
1640 return tree_unflatten(out_tree(), out)
/usr/local/lib/python3.7/dist-packages/jax/core.py in bind(self, fun, *args, **params)
1620 assert len(params['in_axes']) == len(args)
-> 1621 return call_bind(self, fun, *args, **params)
1622
/usr/local/lib/python3.7/dist-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1551 tracers = map(top_trace.full_raise, args)
-> 1552 outs = primitive.process(top_trace, fun, tracers, params)
1553 return map(full_lower, apply_todos(env_trace_todo(), outs))
/usr/local/lib/python3.7/dist-packages/jax/core.py in process(self, trace, fun, tracers, params)
1623 def process(self, trace, fun, tracers, params):
-> 1624 return trace.process_map(self, fun, tracers, params)
1625
/usr/local/lib/python3.7/dist-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
606 def process_call(self, primitive, f, tracers, params):
--> 607 return primitive.impl(f, *tracers, **params)
608 process_map = process_call
/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py in xla_pmap_impl(fun, backend, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, global_arg_shapes, *args)
636 ("fingerprint", fingerprint))
--> 637 return compiled_fun(*args)
638
/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py in execute_replicated(compiled, backend, in_handler, out_handler, *args)
1159 input_bufs = in_handler(args)
-> 1160 out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
1161 if xla.needs_check_special():
UnfilteredStackTrace: RuntimeError: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:203: NCCL operation ncclGroupEnd() failed: unhandled system error: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
<ipython-input-57-176d6124ae02> in <module>()
10
11 opt_repl, loss_repl, update_rng_repl = update_fn_repl(
---> 12 opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)
13 losses.append(loss_repl[0])
14 lrs.append(lr_fn(step))
/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py in execute_replicated(compiled, backend, in_handler, out_handler, *args)
1158 def execute_replicated(compiled, backend, in_handler, out_handler, *args):
1159 input_bufs = in_handler(args)
-> 1160 out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
1161 if xla.needs_check_special():
1162 for bufs in out_bufs:
RuntimeError: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:203: NCCL operation ncclGroupEnd() failed: unhandled system error: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
Please let me know if anyone has faced this issue before? Or any way to resolve this?
Upvotes: 0
Views: 551
Reputation: 86328
It is hard to know for sure without more information, but this error can be caused by running out of GPU memory. Depending on your local settings, you may be able to remedy it by upping the proportion of the GPU memory reserved by XLA, e.g. by setting the XLA_PYTHON_CLIENT_MEM_FRACTION
system variable to 0.9
or something similarly high.
Alternatively, you could try running your code on a smaller problem that fits into memory on your local hardware.
Upvotes: 1