DaveTheAl
DaveTheAl

Reputation: 2155

Pytorch nn.functional.batch_norm for 2D input

I am currently implementing a model on which I need to change the running mean and standard deviation during test time. As such, I assume the nn.functional.batch_norm would be a better choice than the nn.BatchNorm2d

However, I have batches of images as input, and am currently not sure how to take in the images. How would I apply nn.functional.batch_norm on batches of 2D images?

The current code I have is this, I post this even though this is not correct:

mu = torch.mean(inp[0])
stddev = torch.std(inp[0])
x = nn.functional.batch_norm(inp[0], mu, stddev, training=True, momentum=0.9)

Upvotes: 4

Views: 5158

Answers (1)

saetch_g
saetch_g

Reputation: 1505

The key is that 2D batchnorm performs the same normalization for each channel. i.e. if you have a batch of data with shape (N, C, H, W) then your mu and stddev should be shape (C,). If your images do not have a channel dimension, then add one using view.

Warning: if you set training=True then batch_norm computes and uses the appropriate normalization statistics for the argued batch (this means we don't need to calculate the mean and std ourselves). Your argued mu and stddev are supposed to be the running mean and running std for all training batches. These tensors are updated with the new batch statistics in the batch_norm function.

# inp is shape (N, C, H, W)
n_chans = inp.shape[1]
running_mu = torch.zeros(n_chans) # zeros are fine for first training iter
running_std = torch.ones(n_chans) # ones are fine for first training iter
x = nn.functional.batch_norm(inp, running_mu, running_std, training=True, momentum=0.9)
# running_mu and running_std now have new values

If you want to just use your own batch statistics, try this:

# inp is shape (N, C, H, W)
n_chans = inp.shape[1]
reshaped_inp = inp.permute(1,0,2,3).contiguous().view(n_chans, -1) # shape (C, N*W*H)
mu = reshaped_inp.mean(-1)
stddev = reshaped_inp.std(-1)
x = nn.functional.batch_norm(inp, mu, stddev, training=False)

Upvotes: 3

Related Questions