Reputation: 115
I am testing the MNIST dataset in Pytorch, and after I apply a transformation to the X data, it seems the DataLoader puts all values out of the original order, potentially messing up the training step.
My transformation is to divide all values by 255. One should notice that the transformation itself does not change positions, as shown by the first scatterplots. But after the data is passed to the DataLoader and I retrieve it back, they are out of order. If I make no transformation, everything is fine (not shown). The distribution of the values is the same among before, after1 (divided by 255/before DataLoader) and after2 (divided by 255/after DataLoader) (also not shown), only the order seems to be affected.
import torch
from torchvision import datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
transform = transforms.ToTensor()
train = datasets.MNIST(root = '.', train = True, download = True, transform = transform)
test = datasets.MNIST(root = '.', train = False, download = True, transform = transform)
before = train.data[0]
train.data = train.data.float()/255
after1 = train.data[0]
train_loader = torch.utils.data.DataLoader(train, batch_size = 128)
test_loader = torch.utils.data.DataLoader(test, batch_size = 128)
fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after1.view(-1))), after1.view(-1))
ax[1].set_title('After1')
after2 = next(iter(train_loader))[0][0]
fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after2.view(-1))), after2.view(-1))
ax[1].set_title('After2')
fig, ax = plt.subplots(1, 3)
ax[0].imshow(before, cmap = 'gray')
ax[1].imshow(after1, cmap = 'gray')
ax[2].imshow(after2.view(28, 28), cmap = 'gray')
I know that this might not be the best way to deal with this data (transforms.Normalize
should solve it), but I would really like to understand what is happening.
Thank you!
Upvotes: 1
Views: 776
Reputation: 115
So... I posted this same question at Pytorch's GitHub page, and they answered the following:
It's unrelated to data loader. You are messing with an attribute of the particular dataset object, however, the actual
__getitem__
of that object does much more: https://github.com/pytorch/vision/blob/6de158c473b83cf43344a0651d7c01128c7850e6/torchvision/datasets/mnist.py#L92In particular this line (
mode='L'
) assumesuint8
input. Since you replaced it with float, it is wrong.
Then I guess the preferred approach would be to apply a transform when preparing the dataset at the very beginning of my code.
Upvotes: 1
Reputation: 46291
Originally I haven't tested the code you wrote. Rewrote the original:
import torch
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset
import matplotlib.pyplot as plt
transform = transforms.ToTensor()
train = datasets.MNIST(root = '.', train = True, download = True, transform = transform)
test = datasets.MNIST(root = '.', train = False, download = True, transform = transform)
dl = DataLoader(train)
images = dl.dataset.data.float()/255
labels = dl.dataset.targets
train_ds = TensorDataset(images, labels)
train_loader = DataLoader(train_ds, batch_size=128)
# img, target = next(iter(train_loader))
before = train.data[0]
train.data = train.data.float()/255
after1 = train.data[0]
# train_loader = torch.utils.data.DataLoader(train, batch_size = 128)
test_loader = torch.utils.data.DataLoader(test, batch_size = 128)
fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after1.view(-1))), after1.view(-1))
ax[1].set_title('After1')
after2 = next(iter(train_loader))[0][0]
fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after2.view(-1))), after2.view(-1))
ax[1].set_title('After2')
fig, ax = plt.subplots(1, 3)
ax[0].imshow(before, cmap = 'gray')
ax[1].imshow(after1, cmap = 'gray')
ax[2].imshow(after2.view(28, 28), cmap = 'gray')
Upvotes: 0