
Reputation: 796

How to test distributed layers on Tensorflow?

I am trying to test a layer that I will add later in a distributed model however I want to be sure that it works before.

This is the layer in question:

class BNShuffler(tf.Module):
    def __init__(
        global_batch_size: int=64
        super(BNShuffler, self).__init__()
        self.global_batch_size = global_batch_size
        self.idx = tf.Variable(tf.range(global_batch_size), trainable=False)

    def __call__(self, x, shuffle=True):
        batch_size = tf.shape(x)[0]
        replica_context = tf.distribute.get_replica_context()
        if replica_context is not None:
            replica_id   = replica_context.replica_id_in_sync_group
            num_replicas = replica_context.num_replicas_in_sync
            x_target = _cross_replica_concat(x, replica_id, replica_context, num_replicas)
            x_target = x
            num_replicas = 1
            replica_id = 0

        if shuffle:
            x_shuffled = tf.gather(x_target, self.idx)
            return x_shuffled[replica_id * batch_size: (replica_id + 1) * batch_size]
            unshuffled_idx = tf.math.invert_permutation(self.idx)
            x_unshuffled   = tf.gather(x_target, unshuffled_idx)
            return x_unshuffled[replica_id * batch_size: (replica_id + 1) * batch_size]

def _cross_replica_concat(x, replica_id, replica_context, num_replicas):
    x_shape = tf.shape(x)
    result_tensor = tf.scatter_nd(
        shape=tf.concat([[num_replicas], x_shape], axis=0),
    result_tensor = replica_context.all_reduce(
        tf.distribute.ReduceOp.SUM, result_tensor

    return tf.reshape(result_tensor, x_shape)

The goal of this layer is to shuffle data across all gpu when shuffle=True and put them back when shuffle=False so it will be applied twice.

In order to test it I tried to generate a simple distributed dataset and apply my shuffler based on this tutorial but it throws me an error.


global_batch_size = 6
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
train_dataset =, drop_remainder=False)
dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
with strategy.scope():

def train_step(x):
    x = tf.reshape(x, [-1, 1, 1])
    x = shuffler(x, True)
    x = shuffler(x, False)
    return strategy.reduce(tf.distribute.ReduceOp.SUM, x,

for epoch in range(1):
    total_loss = 0.0
    num_batches = 0
    for x in dist_dataset:, args=(x,))


tf.Tensor([0 1 2], shape=(3,), dtype=int64)
tf.Tensor([3 4 5], shape=(3,), dtype=int64)
INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:Error reported to Coordinator: 2 root error(s) found.
  (0) Internal:  NCCL: invalid usage. Set NCCL_DEBUG=WARN for detail.
     [[node NcclAllReduce_1 (defined at <ipython-input-22-0888591f3d27>:17) ]]
  (1) Internal:  NCCL: invalid usage. Set NCCL_DEBUG=WARN for detail.
     [[node NcclAllReduce_1 (defined at <ipython-input-22-0888591f3d27>:17) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference__all_reduce_1284]

Function call stack:
_all_reduce -> _all_reduce
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/", line 297, in stop_on_exception
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/", line 228, in _call_for_each_replica
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/", line 572, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/", line 3080, in batch_all_reduce
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/", line 2374, in batch_reduce_to
    return self._batch_reduce_to(reduce_op, value_destination_pairs, options)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/", line 697, in _batch_reduce_to
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/", line 426, in batch_reduce
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/", line 819, in batch_reduce_implementation
    [v[0] for v in value_destination_pairs])
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/", line 831, in _batch_all_reduce
    dense_results = self._do_batch_all_reduce(reduce_op, dense_values)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/", line 860, in _do_batch_all_reduce
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/", line 45, in aggregate_gradients_using_nccl
    agg_grads = nccl_ops.all_sum(single_grads)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/", line 47, in all_sum
    return _apply_all_reduce('sum', tensors)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/", line 234, in _apply_all_reduce
    return def_function.function(_all_reduce)()
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/", line 828, in __call__
    result = self._call(*args, **kwds)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/", line 895, in _call
    filtered_flat_args, self._concrete_stateful_fn.captured_inputs)  # pylint: disable=protected-access
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/", line 1919, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/", line 560, in call
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InternalError: 2 root error(s) found.
  (0) Internal:  NCCL: invalid usage. Set NCCL_DEBUG=WARN for detail.
     [[node NcclAllReduce_1 (defined at <ipython-input-22-0888591f3d27>:17) ]]
  (1) Internal:  NCCL: invalid usage. Set NCCL_DEBUG=WARN for detail.
     [[node NcclAllReduce_1 (defined at <ipython-input-22-0888591f3d27>:17) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference__all_reduce_1284]

Function call stack:
_all_reduce -> _all_reduce

InternalError                             Traceback (most recent call last)
<ipython-input-22-0888591f3d27> in <module>
     15     num_batches = 0
     16     for x in dist_dataset:
---> 17, args=(x,))

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/ in run(***failed resolving arguments***)
   1257       fn = autograph.tf_convert(
   1258           fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
-> 1259       return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
   1261   def reduce(self, reduce_op, value, axis):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/ in call_for_each_replica(self, fn, args, kwargs)
   2728       kwargs = {}
   2729     with self._container_strategy().scope():
-> 2730       return self._call_for_each_replica(fn, args, kwargs)
   2732   def _call_for_each_replica(self, fn, args, kwargs):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/ in _call_for_each_replica(self, fn, args, kwargs)
    627   def _call_for_each_replica(self, fn, args, kwargs):
    628     return mirrored_run.call_for_each_replica(
--> 629         self._container_strategy(), fn, args, kwargs)
    631   def _configure(self,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/ in call_for_each_replica(strategy, fn, args, kwargs)
     91     fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
---> 93   return _call_for_each_replica(strategy, fn, args, kwargs)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/ in _call_for_each_replica(distribution, fn, args, kwargs)
    232     for t in threads:
    233       t.should_run.set()
--> 234     coord.join(threads)
    236   return distribute_utils.regroup(tuple(t.main_result for t in threads))

/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/ in join(self, threads, stop_grace_period_secs, ignore_live_threads)
    387       self._registered_threads = set()
    388       if self._exc_info_to_raise:
--> 389         six.reraise(*self._exc_info_to_raise)
    390       elif stragglers:
    391         if ignore_live_threads:

/usr/local/lib/python3.6/dist-packages/ in reraise(tp, value, tb)
    701             if value.__traceback__ is not tb:
    702                 raise value.with_traceback(tb)
--> 703             raise value
    704         finally:
    705             value = None

/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/ in stop_on_exception(self)
    295     """
    296     try:
--> 297       yield
    298     except:  # pylint: disable=bare-except
    299       self.request_stop(ex=sys.exc_info())

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/ in _call_for_each_replica(distribution, fn, args, kwargs)
    226               variable_scope.variable_scope(mtt_captured_var_scope):
    227             merge_result = threads[0].merge_fn(distribution, *merge_args,
--> 228                                                **merge_kwargs)
    229           for r, t in enumerate(threads):
    230             t.merge_result = distribute_utils.select_replica(r, merge_result)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/ in wrapper(*args, **kwargs)
    570   def wrapper(*args, **kwargs):
    571     with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED):
--> 572       return func(*args, **kwargs)
    574   if inspect.isfunction(func) or inspect.ismethod(func):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/ in batch_all_reduce(strategy, *value_flat)
   3078       return strategy.extended.batch_reduce_to(
   3079           reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat],
-> 3080           options)
   3082     if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/ in batch_reduce_to(self, reduce_op, value_destination_pairs, options)
   2372     if isinstance(reduce_op, six.string_types):
   2373       reduce_op = reduce_util.ReduceOp(reduce_op.upper())
-> 2374     return self._batch_reduce_to(reduce_op, value_destination_pairs, options)
   2376   def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/ in _batch_reduce_to(self, reduce_op, value_destination_pairs, options)
    695         reduce_op,
    696         value_destination_pairs,
--> 697         options=self._communication_options.merge(options))
    699   def _update(self, var, fn, args, kwargs, group):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/ in batch_reduce(self, reduce_op, value_destination_pairs, options)
    424       options = collective_util.Options()
    425     return self.batch_reduce_implementation(reduce_op, value_destination_pairs,
--> 426                                             options)
    428   def broadcast(self, tensor, destinations):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/ in batch_reduce_implementation(self, reduce_op, value_destination_pairs, options)
    817     if _all_devices_match(value_destination_pairs):
    818       return self._batch_all_reduce(reduce_op,
--> 819                                     [v[0] for v in value_destination_pairs])
    820     else:
    821       return [

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/ in _batch_all_reduce(self, reduce_op, per_replica_values)
    829         cross_device_utils.split_by_sparsity(per_replica_values))
    830     if dense_values:
--> 831       dense_results = self._do_batch_all_reduce(reduce_op, dense_values)
    832     else:
    833       dense_results = []

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/ in _do_batch_all_reduce(self, reduce_op, dense_values)
    858       # TODO(yuefengz): merge this into the all-reduce library.
    859       reduced = cross_device_utils.aggregate_gradients_using_nccl(
--> 860           device_grad_packs)
    861     else:
    862       # TODO(yuefengz): check that gpu ids in `destinations` are in ascending

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/ in aggregate_gradients_using_nccl(replica_grads)
     43   for single_g_and_v in zip(*replica_grads):
     44     single_grads = [g for g, _ in single_g_and_v]
---> 45     agg_grads = nccl_ops.all_sum(single_grads)
     46     agg_all_g_and_v.append(
     47         [(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)])

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/ in all_sum(tensors)
     45     the same device as `tensors[i]`.
     46   """
---> 47   return _apply_all_reduce('sum', tensors)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/ in _apply_all_reduce(reduction, tensors)
    232     # Nccl ops will block unless they are executed concurrently such as in a
    233     # graph or a defun.
--> 234     return def_function.function(_all_reduce)()
    235   else:
    236     return _all_reduce()

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/ in __call__(self, *args, **kwds)
    826     tracing_count = self.experimental_get_tracing_count()
    827     with trace.Trace(self._name) as tm:
--> 828       result = self._call(*args, **kwds)
    829       compiler = "xla" if self._experimental_compile else "nonXla"
    830       new_tracing_count = self.experimental_get_tracing_count()

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/ in _call(self, *args, **kwds)
    893       # If we did not create any variables the trace we have is good enough.
    894       return self._concrete_stateful_fn._call_flat(
--> 895           filtered_flat_args, self._concrete_stateful_fn.captured_inputs)  # pylint: disable=protected-access
    897     def fn_with_cond(inner_args, inner_kwds, inner_filtered_flat_args):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/ in _call_flat(self, args, captured_inputs, cancellation_manager)
   1917       # No tape is watching; skip to running the function.
   1918       return self._build_call_outputs(
-> 1919           ctx, args, cancellation_manager=cancellation_manager))
   1920     forward_backward = self._select_forward_and_backward_functions(
   1921         args,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/ in call(self, ctx, args, cancellation_manager)
    558               inputs=args,
    559               attrs=attrs,
--> 560               ctx=ctx)
    561         else:
    562           outputs = execute.execute_with_cancellation(

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/ in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InternalError: 2 root error(s) found.
  (0) Internal:  NCCL: invalid usage. Set NCCL_DEBUG=WARN for detail.
     [[node NcclAllReduce_1 (defined at <ipython-input-22-0888591f3d27>:17) ]]
  (1) Internal:  NCCL: invalid usage. Set NCCL_DEBUG=WARN for detail.
     [[node NcclAllReduce_1 (defined at <ipython-input-22-0888591f3d27>:17) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference__all_reduce_1284]

Function call stack:
_all_reduce -> _all_reduce

How should I use the strategy to make my test ?

Upvotes: 1

Views: 368

Answers (1)

Laplace Ricky
Laplace Ricky

Reputation: 1687

The major reason why you got the error messages may be because tf.distribute.get_replica_context().all_reduce() does not always work in eager mode. It will work properly in graph mode.(See example codes below)

There are also some other potential problems in your codes.

  1. pass aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA to tf.Variable to make sure it is synchronized across replicas.
  2. strategy.reduce() shouldn't be called inside train_step

Example codes:

strategy = tf.distribute.MirroredStrategy()
print(f'using distribution strategy\nnumber of gpus:{strategy.num_replicas_in_sync}')

global_batch_size = 6
with strategy.scope():
train_dataset =, drop_remainder=True)
options =
options.experimental_distribute.auto_shard_policy =
train_dataset = train_dataset.with_options(options)
ds = strategy.experimental_distribute_dataset(train_dataset)

def shuffler(x,idx,global_batch_size,shuffle=False):
    ctx = tf.distribute.get_replica_context()
    num_replicas = ctx.num_replicas_in_sync
    replica_id = ctx.replica_id_in_sync_group

    x_shape = global_batch_size//num_replicas
    result_tensor = tf.scatter_nd([[replica_id]],[x],[num_replicas, x_shape])
    result_tensor = tf.reshape(result_tensor, [global_batch_size])

    all_x = ctx.all_reduce(
        tf.distribute.ReduceOp.SUM, result_tensor

    if shuffle:
        x_shuffled = tf.gather(all_x, idx)
        return x_shuffled[replica_id * x_shape: (replica_id + 1) * x_shape]

        unshuffled_idx = tf.math.invert_permutation(idx)
        x_unshuffled   = tf.gather(all_x, unshuffled_idx)
        return x_unshuffled[replica_id * x_shape: (replica_id + 1) * x_shape]

def train_step(x):
    replica_id = tf.distribute.get_replica_context().replica_id_in_sync_group
    tf.print('before shuffle',replica_id, x, output_stream=sys.stdout)
    x = shuffler(x,idx,global_batch_size,shuffle = True)
    tf.print('shuffle = True',replica_id, x, output_stream=sys.stdout)
    x = shuffler(x,idx,global_batch_size,shuffle = False)
    tf.print('shuffle = False',replica_id, x, output_stream=sys.stdout)

#add @tf.function to run in graph mode
def distributed_train_step(x):, args=(x,))


Expected outputs:

using distribution strategy
number of gpus:3
before shuffle 0 [0 1]
before shuffle 1 [2 3]
before shuffle 2 [4 5]
shuffle = True 0 [4 1]
shuffle = True 1 [0 2]
shuffle = True 2 [3 5]
shuffle = False 0 [0 1]
shuffle = False 1 [2 3]
shuffle = False 2 [4 5]
before shuffle 0 [6 7]
before shuffle 1 [8 9]
before shuffle 2 [10 11]
shuffle = True 0 [9 8]
shuffle = True 1 [7 11]
shuffle = True 2 [10 6]
shuffle = False 0 [6 7]
shuffle = False 1 [8 9]
shuffle = False 2 [10 11]

Finally, please also note that tf.distribute.get_replica_context().all_gather() is made for exactly what you want to do instead of all_reduce() although that all_reduce() can do the same thing.

Upvotes: 1

Related Questions