rodrigo-silveira
rodrigo-silveira

Reputation: 13088

Updating batch_normalization mean & variance using Estimator API

The documentation isn't 100% clear on this:

Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. For example:

(see https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization)

Does that mean that all that is needed to save the moving_mean and moving_variance is the following?

def model_fn(features, labels, mode, params):
   training = mode == tf.estimator.ModeKeys.TRAIN
   extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

   x = tf.reshape(features, [-1, 64, 64, 3])
   x = tf.layers.batch_normalization(x, training=training)

   # ...

  with tf.control_dependencies(extra_update_ops):
     train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())

In other words, does simply using

with tf.control_dependencies(extra_update_ops):

take care of saving the moving_mean and moving_variance?

Upvotes: 2

Views: 861

Answers (2)

rodrigo-silveira
rodrigo-silveira

Reputation: 13088

As it turns out, those values can get saved automatically. The edge case is that if you get the update ops collection before adding the batch normalization op to the graph, the update collection will be empty. This had not been documented before, but is now.

The caveat when using batch_norm is to call tf.get_collection(tf.GraphKeys.UPDATE_OPS) after you've called tf.layers.batch_normalization.

Upvotes: 1

Alexandre Passos
Alexandre Passos

Reputation: 5206

Yes, adding those control dependencies will save the mean and variance.

Upvotes: 1

Related Questions