Reputation: 1
I am attempting to train an image classifier using PyTorch. I followed the tutorial at https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.htm, and It worked perfectly fine.
I am now trying to use a custom dataset instead of the one provided by the tutorial, and I am encountering some issues.
Here is my code for preparing the images:
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Resize((224,224)),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
Here is my code for creating the data loaders:
trainset = datasets.ImageFolder('./Dataset/train', transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)
testset = datasets.ImageFolder('./Dataset/test', transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=0)
The issue appears in this class:
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
When I run my code, I get the error:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x44944 and 400x120)
This error did not appear when I used the dataset provided by the tutorial, even though the Net()
class is exactly the same. The only differences between the code that did work and the code that does not work are the data loaders.
Here is the code for the original data loaders in the tutorial:
trainset = torchvision.datasets.CIFAR10(root='./data/train', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR10(root='./data/test', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=0)
I have already tried many suggestions from answers to similar questions, but none of them worked. Any suggestions about what I should do?
Upvotes: 0
Views: 2542
Reputation: 6658
Plz remove resize
from transforms
since cifar10 are already 32x32x3 and no need to resize .
keep it this way, and the code will work
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
Upvotes: 0
Reputation: 732
You getting this error because you are using the model structure of one shown in the tutorial but resizing the image by 224x224
. after getting flatter it became 179776
or 4x44944
,so you have two choice either do not use Resize
operation in transform
here because original size of cifar10
is 32x32
and or if you do wanna Resize
the image then just change your fc1
input dim
accordingly.
Upvotes: 0