Reputation: 11628
I'm trying to use the MNIST dataset from torchvision.datasets
.It seems to be provided as an N x H x W (uint8)
(batch dimension, height, width) tensor. All the pytorch classes for working on images (for instance Conv2d
) however require a N x C x H x W (float32)
tensor where C
is the number of colour channels. I've tried to add add the ToTensor
transform but that didn't add a color channel.
Is there a way using torchvision.transforms
to add this additional dimension? For a raw tensor
we could just do .unsqueeze(1)
but that doesn't look like a very elegant solution. I'm just trying to do it the "proper" way.
Here is the failed conversion.
import torchvision
dataset = torchvision.datasets.MNIST("~/PyTorchDatasets/MNIST/", train=True, transform=torchvision.transforms.ToTensor(), download=True)
print(dataset.train_data[0])
Upvotes: 2
Views: 1377
Reputation: 11628
I had a misconception: dataset.train_data
is not affected by the specified transform
, only the output of a DataLoader(dataset,...)
will be. After checking data
from
for data, _ in DataLoader(dataset):
break
we can see that ToTensor
actually does exactly what is desired.
Upvotes: 1