jkschin
jkschin

Reputation: 5834

How do I use Batch Normalization in a multi-gpu setting in TensorFlow?

Referencing this post on How could I use Batch Normalization in TensorFlow?.

I have a multi-gpu setup similar to the CIFAR10 example. When I insert tf.contrib.layers.batch_norm to my network definition, I get a NoneType object in average_gradients. Specifically, the variable g is the NoneType.

def average_gradients(tower_grads):
    average_grads = []
    for grad_and_vars in zip(*tower_grads):
        grads = []
        for g, _ in grad_and_vars:
            expanded_g = tf.expand_dims(g, 0)
            grads.append(expanded_g)
        grad = tf.concat(0, grads)
        grad = tf.reduce_mean(grad, 0)
        v = grad_and_vars[0][1]
        grad_and_var = (grad, v)
        average_grads.append(grad_and_var)
    return average_grads

Some sample code on how to run Batch Normalization in a multi-gpu environment would help.

EDIT:

Simply removing the "batch_norm" variables solves this bug. However, the pressing question here is that each Batch Normalization has a beta and gamma on each GPU, with their own moving averages. How are all these moving averages over all the GPUs resolved at inference?

Upvotes: 3

Views: 3391

Answers (1)

jayant
jayant

Reputation: 493

Just use BN independently across GPUs, while using one of the tower means to update the moving mean.

with tf.device('..'):
  x,y = iterator.get_next()

  // NN with variables copied over to each of the GPUs
  loss = tower_loss(..)

  // use last tower statistics to update the moving mean/variance 
  batchnorm_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=scope)

apply_gradient_op = average_gradients(*grads)
batchnorm_updates_op = tf.group(*batchnorm_updates)
train_op = tf.group(apply_gradient_op, batchnorm_updates_op)

As gleaned from multiple comments here, this simple asynchronous approach works well in practice for most domains, with the exception of problems like semantic segmentation, action video recognition etc. where the batch size is extremely small and async BN isn't able to afford the speed boost that it normally does.

Upvotes: 2

Related Questions