Neal Carter
Neal Carter

Reputation: 1

PyTorch mat1 and mat2 shapes cannot be multiplied

I am attempting to train an image classifier using PyTorch. I followed the tutorial at https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.htm, and It worked perfectly fine.

I am now trying to use a custom dataset instead of the one provided by the tutorial, and I am encountering some issues.

Here is my code for preparing the images:

transform = transforms.Compose(
[transforms.ToTensor(),
 transforms.Resize((224,224)),
 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

Here is my code for creating the data loaders:

trainset = datasets.ImageFolder('./Dataset/train', transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)

testset = datasets.ImageFolder('./Dataset/test', transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=0)

The issue appears in this class:

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) 
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

When I run my code, I get the error:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x44944 and 400x120)

This error did not appear when I used the dataset provided by the tutorial, even though the Net() class is exactly the same. The only differences between the code that did work and the code that does not work are the data loaders.

Here is the code for the original data loaders in the tutorial:

trainset = torchvision.datasets.CIFAR10(root='./data/train', train=True,
                                    download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                      shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data/test', train=False,
                                   download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                     shuffle=False, num_workers=0)

I have already tried many suggestions from answers to similar questions, but none of them worked. Any suggestions about what I should do?

Upvotes: 0

Views: 2542

Answers (2)

Prajot Kuvalekar
Prajot Kuvalekar

Reputation: 6658

Plz remove resize from transforms since cifar10 are already 32x32x3 and no need to resize . keep it this way, and the code will work

transform = transforms.Compose(
[transforms.ToTensor(),
 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

Upvotes: 0

Sudhanshu
Sudhanshu

Reputation: 732

You getting this error because you are using the model structure of one shown in the tutorial but resizing the image by 224x224. after getting flatter it became 179776 or 4x44944,so you have two choice either do not use Resize operation in transform here because original size of cifar10 is 32x32 and or if you do wanna Resize the image then just change your fc1 input dim accordingly.

Upvotes: 0

Related Questions