oat
oat

Reputation: 484

calculation of mean and variance in batch normalization in convolutional neural network

May I ask if the following understanding of batch normalization in convolutional neural network is correct?

As shown in the diagram below, the mean and variance are calculated using all the cells on the same feature maps generated from respective examples in the current mini-batch, i.e. they are calculated across h, w and m axis.

enter image description here

Upvotes: 5

Views: 4795

Answers (2)

Ivan
Ivan

Reputation: 40648

It seems you are correct. The empirical mean and variance are measured on all dimension except the feature dimension. The z-score is then calculated to standardize the mini-batch to mean=0 and std=1. Additionally, it is then scaled-shifted with two learnable parameters gamma and beta.

Here is a description of a batch normalization layer:

Description
Input
Parameters
Output

And the calculation details:

Name Intermediate operations
Mini-batch mean
Mini-batch variance
Normalize
Scale & shift

Here is a quick implementation to show you the normalization process without the scale-shift:

>>> a = torch.eye(2,4).reshape(2,2,2)
>>> b = torch.arange(8).reshape(2,2,2)
>>> x = torch.stack([a, b])
tensor([[[[1., 0.],
          [0., 0.]],

         [[0., 1.],
          [0., 0.]]],


        [[[0., 1.],
          [2., 3.]],

         [[4., 5.],
          [6., 7.]]]])

We are looking to measure the mean and variance on all axes except the channel axis. So we start by permuting the batch axis with the channel axis, then flatten all axes but the first. Finally we take the average and variance.

>>> x_ = x.permute(1,0,2,3).flatten(start_dim=1)
>>> mean, var = x_.mean(dim=-1), x_.var(dim=-1)
(tensor([0.8750, 2.8750]), tensor([1.2679, 8.6964]))

>>> y = (x - mean)/(var + 1e-8).sqrt()
tensor([[[[ 0.1110, -0.9749],
          [-0.7771, -0.9749]],

         [[-0.7771, -0.6358],
          [-0.7771, -0.9749]]],


        [[[-0.7771, -0.6358],
          [ 0.9991,  0.0424]],

         [[ 2.7753,  0.7206],
          [ 4.5515,  1.3988]]]])

Notice the shapes of mean and variance: vectors whose length equals the number of input channels. The same could be said about the shapes of gamma and beta.

Upvotes: 4

Khalid Saifullah
Khalid Saifullah

Reputation: 785

The picture depicts BatchNorm correctly.

In BatchNorm we compute the mean and variance using the spatial feature maps of the same channel in the whole batch. If you look at the picture that you've attached It may sound confusing because, in that picture, the data is single-channel, which means each grid/matrix represents 1 data sample, however, if you think of colored images, those will require 3 such grid/matrix to represent 1 data sample as they have 3 channels (RGB) per sample. So in your picture, you could think of taking the same element (index) from every m grid/matrices and then calculate their mean and variance.

So your picture does show the computation of mean and variance for BatchNorm correctly, however when you'll think of multi-channel data, you might get confused as the picture only good for understanding single-channel data. To make that case (multi-channel) a bit clear, you may think of a colored image dataset. So in every batch, there are a number of images, and each image has 3 channels, RED, GREEN, and BLUE (to visualize, think of RED as a matrix, GREEN as a matrix, and BLUE as a matrix, so 3 matrices per image). So in BatchNorm, what you would do now is (assume batch size is 32) take all the 32 matrices of RED channel and calculate their mean and variance, similarly, you'll repeat the process for GREEN and BLUE channels, so that's how you'd do for multi-channeled data.

Upvotes: 4

Related Questions