ponir
ponir

Reputation: 477

How to use the tf.keras.layers.BatchNormalization() in custom training loop?

I went back to tensorflow after quite a while and it seems the landscape is completely changed.

However, previously I used to use tf.contrib....batch_normalization with the following in the training loop:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(cnn.loss, global_step=global_step)

But it seems, contrib is nowhere to be found and tf.keras.layers.BatchNormalization does not work the same way. Also, I couldn't find any training instruction in their documentation.

So, any information of help is appreciated.

Upvotes: 2

Views: 538

Answers (2)

As batch normalization behaves differently during training and inference, the training variable you pass into call() needs to be fed into the BatchNormalization layer.

import tensorflow as tf

def call(self, inputs, training=False):
    x = tf.keras.layers.BatchNormalization()(inputs, training=training)
    return x

Upvotes: 0

ponir
ponir

Reputation: 477

I started using pyTorch. It solved the problem.

Upvotes: 0

Related Questions