user3358117
user3358117

Reputation: 73

Batch Normalization in tensorflow

I noticed there are batch normalization functions already in the api for tensorflow. One thing I don't understand though, is how to to change the procedure between training and test?

Batch normalization acts differently during test than during training. Specifically one uses a fixed mean and variance during training.

Is there some good example code somewhere? I saw some, but with scope variables it got confusing

Upvotes: 5

Views: 2248

Answers (1)

keveman
keveman

Reputation: 8487

You are right, the tf.nn.batch_normalization provides just the basic functionality for implementing batch normalization. You have to add the extra logic to keep track of moving means and variances during training, and use the trained means and variances during inference. You can look at this example for a very general implementation, but a quick version that doesn't use gamma is here :

  beta = tf.Variable(tf.zeros(shape), name='beta')
  moving_mean = tf.Variable(tf.zeros(shape), name='moving_mean',
                                 trainable=False)
  moving_variance = tf.Variable(tf.ones(shape),
                                     name='moving_variance',
                                     trainable=False)
  control_inputs = []
  if is_training:
    mean, variance = tf.nn.moments(image, [0, 1, 2])
    update_moving_mean = moving_averages.assign_moving_average(
        moving_mean, mean, self.decay)
    update_moving_variance = moving_averages.assign_moving_average(
        moving_variance, variance, self.decay)
    control_inputs = [update_moving_mean, update_moving_variance]
  else:
    mean = moving_mean
    variance = moving_variance
  with tf.control_dependencies(control_inputs):
    return tf.nn.batch_normalization(
        image, mean=mean, variance=variance, offset=beta,
        scale=None, variance_epsilon=0.001)

Upvotes: 9

Related Questions