Niamh
Niamh

Reputation: 93

Using nn.Linear() and nn.BatchNorm1d() together

I don't understand how BatchNorm1d works when the data is 3D, (batch size, H, W).

Example

If I then include a batch normalisation layer it requires num_features=50:

and I don't understand why it isn't 20:

Example 1)

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.bn11 = nn.BatchNorm1d(50)
        self.fc11 = nn.Linear(70,20)

    def forward(self, inputs):
        out = self.fc11(inputs)
        out = torch.relu(self.bn11(out))
        return out

model = Net()
inputs = torch.Tensor(2,50,70)
outputs = model(inputs)

Example 2)

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.bn11 = nn.BatchNorm1d(20)
        self.fc11 = nn.Linear(70,20)

    def forward(self, inputs):
        out = self.fc11(inputs)
        out = torch.relu(self.bn11(out))
        return out

model = Net()
inputs = torch.Tensor(2,50,70)
outputs = model(inputs)

2D example:

I thought the 20 in the BN layer was due to there being 20 nodes output by the linear layer and each one requires a running means/std for the incoming values.

Why in the 3D case, if the linear layer has 20 output nodes, the BN layer doesn't have 20 features?

Upvotes: 9

Views: 14250

Answers (1)

Szymon Maszke
Szymon Maszke

Reputation: 24894

One can find the answer inside torch.nn.Linear documentation.

It takes input of shape (N, *, I) and returns (N, *, O), where I stands for input dimension and O for output dim and * are any dimensions between.

If you pass torch.Tensor(2,50,70) into nn.Linear(70,20), you get output of shape (2, 50, 20) and when you use BatchNorm1d it calculates running mean for first non-batch dimension, so it would be 50. That's the reason behind your error.

Upvotes: 4

Related Questions