Paul
Paul

Reputation: 382

Cannot Iterate through PyTorch MNIST dataset

I am trying load the MNIST dataset in Pytorch and use the built-in dataloader to iterate through the training examples. However I get an error when calling next() on the iterator. I don't have this problem with CIFAR10.

import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 128

dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
dataiter = iter(dataloader)
dataiter.next() # ERROR
# RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

I am using Python 3.7.3 with PyTorch 1.1.0

Upvotes: 2

Views: 1278

Answers (2)

Anubhav Singh
Anubhav Singh

Reputation: 8699

MNIST dataset consists of grayscaled images, i.e., each image has just 1 channel, while CIFAR10 dataset consists of color images, i.e., each image has 3 channels.

So, incase of MNIST dataset, replace to transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) to transforms.Normalize([0.5], [0.5]).

Upvotes: 7

thedch
thedch

Reputation: 177

You are trying to normalize a 1 channel image using

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

This will not work, and is causing the error you mention. You should reconsider what transforms are necessary for your task.

Upvotes: 0

Related Questions