dzdws
dzdws

Reputation: 89

Input dimension of Pytorch CNN model

I have input data for my 2D CNN model, say; X_train with shape (torch.Size([716, 50, 50])

my model is:

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=4,stride=1,padding = 1)
        self.mp1 = nn.MaxPool2d(kernel_size=4,stride=2)
        self.conv2 = nn.Conv2d(32,64, kernel_size=4,stride =1)
        self.mp2 = nn.MaxPool2d(kernel_size=4,stride=2)
        self.fc1= nn.Linear(2304,256)
        self.dp1 = nn.Dropout(p=0.2)
        self.fc2 = nn.Linear(256,10)

    def forward(self, x):
        in_size = x.size(0)
        x = F.relu(self.mp1(self.conv1(x)))
        x = F.relu(self.mp2(self.conv2(x)))
        x = x.view(in_size,-1)
        x = F.relu(self.fc1(x))
        x = self.dp1(x)
        x = self.fc2(x)
        
        return F.log_softmax(x, dim=1)

but when I run the model, I always get this error: ---> x = F.relu(self.mp1(self.conv1(x)))

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [32, 1, 4, 4], but got 3-dimensional input of size [64, 50, 50] instead

I understand my input for the model is of size 64 (batch size), 50*50 (size of each input, in this case is signal picture).

But I don't understand why it still requires 4-dimensional input where I had set my in_channels for nn.Conv2d to be 1.

How to solve this input dimension problem or to change the dimension requirement of model input?

Upvotes: 1

Views: 1790

Answers (2)

Reza
Reza

Reputation: 31

That's the problem... You've entered the in_channels=1, That doesn't mean that It doesn't exists... Expanding the Dimension of Your Data to [64, 1, 50, 50] should solve your problem

use .view() on input tensor

Upvotes: 0

Bram Vanroy
Bram Vanroy

Reputation: 28437

Whether in_channels is 1 or 42 does not matter: it is still an added dimension. It is useful to read the documentation in this respect.

In- and output are of the form N, C, H, W

  • N: batch size
  • C: channels
  • H: height in pixels
  • W: width in pixels

So you need to add the dimension in your case:

# Add a dimension at index 1
x = x.unsqueeze(1)

Upvotes: 2

Related Questions