Reputation:
I am trying to understand why my classifier has a dimension issue. Here is my code:
class convnet(nn.Module):
def __init__(self, num_classes=1000):
super(convnet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(32),
nn.MaxPool2d(kernel_size=2, stride = 2),
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(32),
nn.MaxPool2d(kernel_size=2, stride = 2), #stride=2),
nn.Conv2d(32, 64, kernel_size=3, stride=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(64),
nn.MaxPool2d(kernel_size=2, stride = 2),
)
self.classifier = nn.Sequential(
nn.Linear(576, 128),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Linear(128, 64),
nn.ReLU(inplace=True),
nn.BatchNorm2d(64),
nn.Linear(64,num_classes),
nn.Softmax(),
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x,1) #x.view(x.size(0), 256 * 6 * 6)
x = self.classifier(x)
return x
def neuralnet(num_classes,**kwargs):
model = convnet(**kwargs)
return model
So here my issue is: expected 4D input (got 2D input)
I'm quite sure that the error arises from the flatten command, however I don't really understand why as the classifier has fully dense connections. If someone knows where I'm going wrong, that would be very helpful!
Thank you
Upvotes: 0
Views: 41
Reputation: 32972
After flattening, the input to the classifier has 2 dimensions (size: [batch_size, 576]), therefore the output of the first linear layer will also have 2 dimensions (size: [batch_size, 128]). That output is then passed to nn.BatchNorm2d
, which requires its input to have 4 dimensions (size: [batch_size, channels, height, width]).
If you want to use batch norm on a 2D input, you need to use nn.BatchNorm1d
, which accepts either a 3D input (size: [batch_size, channels, length]) or a 2D input (size: [batch_size, length]).
self.classifier = nn.Sequential(
nn.Linear(576, 128),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Linear(128, 64),
nn.ReLU(inplace=True),
nn.BatchNorm1d(64),
nn.Linear(64,num_classes),
nn.Softmax(),
)
Upvotes: 1