sparkonhdfs
sparkonhdfs

Reputation: 1343

PyTorch is tiling images when loaded with Dataloader

I am trying to load an Images Dataset using the PyTorch dataloader, but the resulting transformations are tiled, and don't have the original images cropped to the center as I am expecting them.

transform = transforms.Compose([transforms.Resize(224),
                             transforms.CenterCrop(224),
                             transforms.ToTensor()])

dataset = datasets.ImageFolder('ml-models/downloads/', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)


images, labels = next(iter(dataloader))
import matplotlib.pyplot as plt
plt.imshow(images[6].reshape(224, 224, 3))

The resulting image is tiled, and not center cropped.[![as shown in the Jupyter snapshot here][1]][1]

Is there something wrong in the provided transformation? (Image shown below on link: ) [1]: https://i.sstatic.net/HtrIa.png

Upvotes: 0

Views: 953

Answers (1)

jodag
jodag

Reputation: 22284

Pytorch stores tensors in channel-first format, so a 3 channel image is a tensor of shape (3, H, W). Matplotlib expects data to be in channel-last format i.e. (H, W, 3). Reshaping does not rearrange the dimensions, for that you need Tensor.permute.

plt.imshow(images[6].permute(1, 2, 0))

Upvotes: 1

Related Questions