Reputation: 484
May I ask if the following understanding of batch normalization in convolutional neural network is correct?
As shown in the diagram below, the mean and variance are calculated using all the cells on the same feature maps generated from respective examples in the current mini-batch, i.e. they are calculated across h, w and m axis.
Upvotes: 5
Views: 4795
Reputation: 40648
It seems you are correct. The empirical mean and variance are measured on all dimension except the feature dimension. The z-score is then calculated to standardize the mini-batch to mean=0
and std=1
. Additionally, it is then scaled-shifted with two learnable parameters gamma
and beta
.
Here is a description of a batch normalization layer:
And the calculation details:
Here is a quick implementation to show you the normalization process without the scale-shift:
>>> a = torch.eye(2,4).reshape(2,2,2)
>>> b = torch.arange(8).reshape(2,2,2)
>>> x = torch.stack([a, b])
tensor([[[[1., 0.],
[0., 0.]],
[[0., 1.],
[0., 0.]]],
[[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]]])
We are looking to measure the mean and variance on all axes except the channel axis. So we start by permuting the batch axis with the channel axis, then flatten all axes but the first. Finally we take the average and variance.
>>> x_ = x.permute(1,0,2,3).flatten(start_dim=1)
>>> mean, var = x_.mean(dim=-1), x_.var(dim=-1)
(tensor([0.8750, 2.8750]), tensor([1.2679, 8.6964]))
>>> y = (x - mean)/(var + 1e-8).sqrt()
tensor([[[[ 0.1110, -0.9749],
[-0.7771, -0.9749]],
[[-0.7771, -0.6358],
[-0.7771, -0.9749]]],
[[[-0.7771, -0.6358],
[ 0.9991, 0.0424]],
[[ 2.7753, 0.7206],
[ 4.5515, 1.3988]]]])
Notice the shapes of mean
and variance
: vectors whose length equals the number of input channels. The same could be said about the shapes of gamma
and beta
.
Upvotes: 4
Reputation: 785
The picture depicts BatchNorm
correctly.
In BatchNorm
we compute the mean and variance using the spatial feature maps of the same channel in the whole batch. If you look at the picture that you've attached It may sound confusing because, in that picture, the data is single-channel, which means each grid/matrix represents 1 data sample, however, if you think of colored images, those will require 3 such grid/matrix to represent 1 data sample as they have 3 channels (RGB) per sample. So in your picture, you could think of taking the same element (index) from every m
grid/matrices and then calculate their mean and variance.
So your picture does show the computation of mean and variance for BatchNorm
correctly, however when you'll think of multi-channel data, you might get confused as the picture only good for understanding single-channel data. To make that case (multi-channel) a bit clear, you may think of a colored image dataset. So in every batch, there are a number of images, and each image has 3 channels, RED, GREEN, and BLUE (to visualize, think of RED as a matrix, GREEN as a matrix, and BLUE as a matrix, so 3 matrices per image). So in BatchNorm, what you would do now is (assume batch size is 32) take all the 32 matrices of RED channel and calculate their mean and variance, similarly, you'll repeat the process for GREEN and BLUE channels, so that's how you'd do for multi-channeled data.
Upvotes: 4