Reputation: 477
I went back to tensorflow after quite a while and it seems the landscape is completely changed.
However, previously I used to use tf.contrib....batch_normalization
with the following in the training loop:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(cnn.loss, global_step=global_step)
But it seems, contrib
is nowhere to be found and tf.keras.layers.BatchNormalization
does not work the same way. Also, I couldn't find any training instruction in their documentation.
So, any information of help is appreciated.
Upvotes: 2
Views: 538
Reputation: 78
As batch normalization behaves differently during training and inference, the training
variable you pass into call()
needs to be fed into the BatchNormalization
layer.
import tensorflow as tf
def call(self, inputs, training=False):
x = tf.keras.layers.BatchNormalization()(inputs, training=training)
return x
Upvotes: 0