safetyduck
safetyduck

Reputation: 6874

Tensorflow 2.0 Keras BatchNorm: how to update the online params in custom training?

How to train the batch norm layer without using any keras.compile methods? Typically layers have losses that are accessible. Here the losses method is empty.

UPDATE:

It seems like there is a lot of confusion about this and even the way the BatchNorm is implemented is pretty confused.

First, there is only on way to train the online parameters (use in training=False mode) to scale and shift the features: call the layer in training=True mode. And if you NEVER want to use the "batch" part of the batch normalization (i.e. you just want an online normalizer that trains itself with a Normal log-prob loss, you basically can't do this in a single call AFAIK.

Calling the layer with training=False does not update the params. Calling it with training=True udpates the params but then you get the batch normed layer (does not use the online loc and scale).

import tensorflow as tf

class Model(tf.keras.models.Model):
    def __init__(self):
        super().__init__()
        self.dense = tf.keras.layers.Dense(4)
        self.bn = tf.keras.layers.BatchNormalization()
    def call(self, x, training=False):
        x = self.dense(x)
        x = self.bn(x, training=training)
        return x

model = Model()    
x = 10 * np.random.randn(30, 4).astype(np.float32)

print(tf.math.reduce_std(model(x)))
tf.keras.backend.set_learning_phase(1)
print(tf.math.reduce_std(model(x)))
print(tf.math.reduce_std(model(x)))
tf.keras.backend.set_learning_phase(0)
print(tf.math.reduce_std(model(x)))
print(tf.math.reduce_std(model(x)))


tf.Tensor(9.504262, shape=(), dtype=float32)
tf.Tensor(0.99999136, shape=(), dtype=float32)
tf.Tensor(0.99999136, shape=(), dtype=float32)
tf.Tensor(5.4472375, shape=(), dtype=float32)
tf.Tensor(5.4472375, shape=(), dtype=float32)

UPDATE:

Showing keras layers have losses sometimes (when subtasks exist like regulatization):

In [335]: l = tf.keras.layers.Dense(8, kernel_regularizer=tf.keras.regularizers.L1L2())

In [336]: l(np.random.randn(2, 4))

Out[336]:
<tf.Tensor: id=2521999, shape=(2, 8), dtype=float32, numpy=
array([[ 1.1332406 ,  0.32000083,  0.8104123 ,  0.5066328 ,  0.35904446, -1.4265257 ,  1.3057183 ,  0.34458983],
       [-0.23246719, -0.46841025,  0.9706465 ,  0.42356712,  1.705613  , -0.08619405, -0.5261058 , -1.1696107 ]], dtype=float32)>

In [337]: l.losses
Out[337]: [<tf.Tensor: id=2522000, shape=(), dtype=float32, numpy=0.0>]

In [338]: l = tf.keras.layers.Dense(8)

In [339]: l(np.random.randn(2, 4))

Out[339]:
<tf.Tensor: id=2522028, shape=(2, 8), dtype=float32, numpy=
array([[ 1.0674231 , -0.13423748,  0.01775402,  2.5400681 , -0.53589094,  1.4460006 , -1.7197075 ,  0.3285858 ],
       [ 2.2171447 , -1.7448915 ,  0.4758569 ,  0.58695656,  0.32054698,  0.7813705 , -2.3022552 ,  0.44061095]], dtype=float32)>

In [340]: l.losses
Out[340]: []

Upvotes: 1

Views: 1370

Answers (1)

ben
ben

Reputation: 1390

BatchNorm does train, but does not have a loss. It just tracks the mean and std of consecutive Batches in a weighted, moving Average. There is no loss/Gradient involved.

Upvotes: 0

Related Questions