user3180
user3180

Reputation: 1487

When to use layernorm/batch norm?

Where should you splice the normalization when designing a network? E.g. if you have a stacked Transformer or Attention network, does it make sense to normalize any time after you have a dense layer?

Upvotes: 3

Views: 7898

Answers (1)

prosti
prosti

Reputation: 46401

What the original paper tries to explain is to reduce overfitting use Batch Normalization.

Where should you splice the normalization when designing a network?

Set the normalization early on inputs. Unbalanced input extreme values can cause instability.

While if you normalize on outputs this will not prevent the inputs to cause the instability all over again.

Here is the little code that explains what the BN do:

import torch
import torch.nn as nn

m = nn.BatchNorm1d(100, affine=False)
input = 1000*torch.randn(3, 100)
print(input)
output = m(input)
print(output)
print(output.mean()) # should be ~ 0
print(output.std()) # should be ~ 1

Does it make sense to normalize any time after you have a dense layer

Yes, you may do so as matrix multiplication may lead to producing the extremes. Also, after convolution layers, because these are also matrix multiplication, similar but less intense comparing to dense (nn.Linear) layer. If you for instance print the resent model, you will see that batch norms are set every time after the conv layer like this:

(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

To print the full resnet you may use this:

import torchvision.models as models
r = models.resnet18()
print(r)

Upvotes: 1

Related Questions