Reputation: 2155
I am currently implementing a model on which I need to change the running mean and standard deviation during test time. As such, I assume the nn.functional.batch_norm
would be a better choice than the nn.BatchNorm2d
However, I have batches of images as input, and am currently not sure how to take in the images. How would I apply nn.functional.batch_norm
on batches of 2D images?
The current code I have is this, I post this even though this is not correct:
mu = torch.mean(inp[0])
stddev = torch.std(inp[0])
x = nn.functional.batch_norm(inp[0], mu, stddev, training=True, momentum=0.9)
Upvotes: 4
Views: 5158
Reputation: 1505
The key is that 2D batchnorm performs the same normalization for each channel. i.e. if you have a batch of data with shape (N, C, H, W) then your mu and stddev should be shape (C,). If your images do not have a channel dimension, then add one using view
.
Warning: if you set training=True
then batch_norm
computes and uses the appropriate normalization statistics for the argued batch (this means we don't need to calculate the mean and std ourselves). Your argued mu and stddev are supposed to be the running mean and running std for all training batches. These tensors are updated with the new batch statistics in the batch_norm
function.
# inp is shape (N, C, H, W)
n_chans = inp.shape[1]
running_mu = torch.zeros(n_chans) # zeros are fine for first training iter
running_std = torch.ones(n_chans) # ones are fine for first training iter
x = nn.functional.batch_norm(inp, running_mu, running_std, training=True, momentum=0.9)
# running_mu and running_std now have new values
If you want to just use your own batch statistics, try this:
# inp is shape (N, C, H, W)
n_chans = inp.shape[1]
reshaped_inp = inp.permute(1,0,2,3).contiguous().view(n_chans, -1) # shape (C, N*W*H)
mu = reshaped_inp.mean(-1)
stddev = reshaped_inp.std(-1)
x = nn.functional.batch_norm(inp, mu, stddev, training=False)
Upvotes: 3