Sanaullah Ashfat
Sanaullah Ashfat

Reputation: 71

what is the meaning of (axis = 3) in the BatchNormalization?

inputs = Input((img_height, img_width, img_ch))
conv1 = Conv2D(n_filters, (k, k), padding=padding)(inputs)
conv1 = BatchNormalization(scale=False, axis=3)(conv1)
conv1 = Activation('relu')(conv1)    
conv1 = Conv2D(n_filters, (k, k),  padding=padding)(conv1)
conv1 = BatchNormalization(scale=False, axis=3)(conv1)
conv1 = Activation('relu')(conv1)    
pool1 = MaxPooling2D(pool_size=(s, s))(conv1)

What is the meaning of (axis =3) in the BatchNormalization I read keras documentation but I coudln't understand it, can any one explain what does axis means?

Upvotes: 6

Views: 6645

Answers (2)

Bincy
Bincy

Reputation: 73

A small correction is required in the above answer. If the dimension is [height, width, channel] then the axis is 3. The batch is not part of the input dimension.

Upvotes: 2

unlut
unlut

Reputation: 3775

It depends on how dimensions of your "conv1" variable is ordered. First, note that batch normalization should be performed over channels after a convolution, for example if your dimension order are [batch, height, width, channel], you want to use axis=3. Basically you choose the axis index which represents your channels.

Upvotes: 5

Related Questions