Reputation: 71
inputs = Input((img_height, img_width, img_ch))
conv1 = Conv2D(n_filters, (k, k), padding=padding)(inputs)
conv1 = BatchNormalization(scale=False, axis=3)(conv1)
conv1 = Activation('relu')(conv1)
conv1 = Conv2D(n_filters, (k, k), padding=padding)(conv1)
conv1 = BatchNormalization(scale=False, axis=3)(conv1)
conv1 = Activation('relu')(conv1)
pool1 = MaxPooling2D(pool_size=(s, s))(conv1)
What is the meaning of (axis =3) in the BatchNormalization
I read keras documentation but I coudln't understand it, can any one explain what does axis means?
Upvotes: 6
Views: 6645
Reputation: 73
A small correction is required in the above answer. If the dimension is [height, width, channel] then the axis is 3. The batch is not part of the input dimension.
Upvotes: 2
Reputation: 3775
It depends on how dimensions of your "conv1" variable is ordered. First, note that batch normalization should be performed over channels after a convolution, for example if your dimension order are [batch, height, width, channel], you want to use axis=3. Basically you choose the axis index which represents your channels.
Upvotes: 5