Denny Ceccon
Denny Ceccon

Reputation: 115

DataLoader messing up transformed data

I am testing the MNIST dataset in Pytorch, and after I apply a transformation to the X data, it seems the DataLoader puts all values out of the original order, potentially messing up the training step.

My transformation is to divide all values by 255. One should notice that the transformation itself does not change positions, as shown by the first scatterplots. But after the data is passed to the DataLoader and I retrieve it back, they are out of order. If I make no transformation, everything is fine (not shown). The distribution of the values is the same among before, after1 (divided by 255/before DataLoader) and after2 (divided by 255/after DataLoader) (also not shown), only the order seems to be affected.

import torch
from torchvision import datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

transform = transforms.ToTensor()

train = datasets.MNIST(root = '.', train = True, download = True, transform = transform)
test = datasets.MNIST(root = '.', train = False, download = True, transform = transform)

before = train.data[0]

train.data = train.data.float()/255
after1 = train.data[0]

train_loader = torch.utils.data.DataLoader(train, batch_size = 128)
test_loader = torch.utils.data.DataLoader(test, batch_size = 128)

fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after1.view(-1))), after1.view(-1))
ax[1].set_title('After1')

after2 = next(iter(train_loader))[0][0]

fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after2.view(-1))), after2.view(-1))
ax[1].set_title('After2')

fig, ax = plt.subplots(1, 3)
ax[0].imshow(before, cmap = 'gray')
ax[1].imshow(after1, cmap = 'gray')
ax[2].imshow(after2.view(28, 28), cmap = 'gray')

I know that this might not be the best way to deal with this data (transforms.Normalize should solve it), but I would really like to understand what is happening.

Thank you!

Upvotes: 1

Views: 776

Answers (2)

Denny Ceccon
Denny Ceccon

Reputation: 115

So... I posted this same question at Pytorch's GitHub page, and they answered the following:

It's unrelated to data loader. You are messing with an attribute of the particular dataset object, however, the actual __getitem__ of that object does much more: https://github.com/pytorch/vision/blob/6de158c473b83cf43344a0651d7c01128c7850e6/torchvision/datasets/mnist.py#L92

In particular this line (mode='L') assumes uint8 input. Since you replaced it with float, it is wrong.

Then I guess the preferred approach would be to apply a transform when preparing the dataset at the very beginning of my code.

Upvotes: 1

prosti
prosti

Reputation: 46291

Originally I haven't tested the code you wrote. Rewrote the original:

import torch
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset
import matplotlib.pyplot as plt

transform = transforms.ToTensor()

train = datasets.MNIST(root = '.', train = True, download = True, transform = transform)
test = datasets.MNIST(root = '.', train = False, download = True, transform = transform)

dl = DataLoader(train)

images = dl.dataset.data.float()/255
labels = dl.dataset.targets

train_ds = TensorDataset(images, labels)
train_loader = DataLoader(train_ds, batch_size=128)
# img, target = next(iter(train_loader))

before = train.data[0]
train.data = train.data.float()/255
after1 = train.data[0]

# train_loader = torch.utils.data.DataLoader(train, batch_size = 128)
test_loader = torch.utils.data.DataLoader(test, batch_size = 128)

fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after1.view(-1))), after1.view(-1))
ax[1].set_title('After1')

after2 = next(iter(train_loader))[0][0]

fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after2.view(-1))), after2.view(-1))
ax[1].set_title('After2')

fig, ax = plt.subplots(1, 3)
ax[0].imshow(before, cmap = 'gray')
ax[1].imshow(after1, cmap = 'gray')
ax[2].imshow(after2.view(28, 28), cmap = 'gray')

Upvotes: 0

Related Questions