flawr
flawr

Reputation: 11628

Add channel to MNIST via transform?

I'm trying to use the MNIST dataset from torchvision.datasets.It seems to be provided as an N x H x W (uint8) (batch dimension, height, width) tensor. All the pytorch classes for working on images (for instance Conv2d) however require a N x C x H x W (float32) tensor where C is the number of colour channels. I've tried to add add the ToTensor transform but that didn't add a color channel.

Is there a way using torchvision.transforms to add this additional dimension? For a raw tensor we could just do .unsqueeze(1) but that doesn't look like a very elegant solution. I'm just trying to do it the "proper" way.

Here is the failed conversion.

import torchvision
dataset = torchvision.datasets.MNIST("~/PyTorchDatasets/MNIST/", train=True, transform=torchvision.transforms.ToTensor(), download=True)
print(dataset.train_data[0])

Upvotes: 2

Views: 1377

Answers (1)

flawr
flawr

Reputation: 11628

I had a misconception: dataset.train_data is not affected by the specified transform, only the output of a DataLoader(dataset,...) will be. After checking data from

for data, _ in DataLoader(dataset):
    break

we can see that ToTensor actually does exactly what is desired.

Upvotes: 1

Related Questions