srakrn
srakrn

Reputation: 362

Transforming every training points without using dataloaders

I just found out that even though torchvision.dataset.MNIST accepts the transformer parameter, ...

transform = transforms.compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
mnist_trainset = datasets.mnist(
    root="mnist", train=True, download=True, transform=transform
)

...the value obtained from the mnist_trainset.data variable is still not transformed (please observe that the data in the range of (0, 255) should be normalised to (-1, 1) regarding the transformer's behaviour).

[102] mnist_testset.data[0].min()
tensor(0, dtype=torch.uint8)

[103] mnist_testset.data[0].max()
tensor(255, dtype=torch.uint8)

I tried calling mnist_trainset.transform over mnist_trainset.data, but the output shape is not what I intended

[104] mnist_testset.data.shape
torch.Size([10000, 28, 28])

[105] transform(mnist_testset.data).shape
torch.Size([3, 28, 28])

# Should be [10000, 28, 28] as identical to the original data.

I can use the DataLoader to load the entire training set and set the shuffling to False, but I think it's too overkilling. What is the best way to transform the entire mnist_testset using the defined transformer object, in order to obtain the intended transformed image, without having to manually transform it one-by-one?

Upvotes: 3

Views: 196

Answers (1)

jodag
jodag

Reputation: 22184

Transforms are invoked when you sample the dataset using its __getitem__ method. So you could do something like the following to get all the transformed data.

imgs_transformed = []
for img, label in mnist_testset:
    imgs_transformed.append(img[0,:,:])

or using list comprehension

imgs_transformed = [img[0,:,:] for img, label in mnist_testset]

If you want to turn this into one big tensor you can use torch.stack

data_transformed = torch.stack(imgs_transformed, dim=0)

Upvotes: 1

Related Questions