user11764168
user11764168

Reputation:

Dimension error in implementation of a convolutional network

I am trying to understand why my classifier has a dimension issue. Here is my code:

class convnet(nn.Module):

    def __init__(self, num_classes=1000):
        super(convnet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=2, stride = 2),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=2, stride = 2), #stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride = 2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(576, 128),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.Linear(64,num_classes),
            nn.Softmax(),
       )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x,1) #x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x


def neuralnet(num_classes,**kwargs):
    model = convnet(**kwargs)
    return model

So here my issue is: expected 4D input (got 2D input)

I'm quite sure that the error arises from the flatten command, however I don't really understand why as the classifier has fully dense connections. If someone knows where I'm going wrong, that would be very helpful!

Thank you

Upvotes: 0

Views: 41

Answers (1)

Michael Jungo
Michael Jungo

Reputation: 32972

After flattening, the input to the classifier has 2 dimensions (size: [batch_size, 576]), therefore the output of the first linear layer will also have 2 dimensions (size: [batch_size, 128]). That output is then passed to nn.BatchNorm2d, which requires its input to have 4 dimensions (size: [batch_size, channels, height, width]).

If you want to use batch norm on a 2D input, you need to use nn.BatchNorm1d, which accepts either a 3D input (size: [batch_size, channels, length]) or a 2D input (size: [batch_size, length]).

self.classifier = nn.Sequential(
    nn.Linear(576, 128),
    nn.BatchNorm1d(128),
    nn.ReLU(inplace=True),
    nn.Linear(128, 64),
    nn.ReLU(inplace=True),
    nn.BatchNorm1d(64),
    nn.Linear(64,num_classes),
    nn.Softmax(),
)

Upvotes: 1

Related Questions