Reputation: 1343
I am trying to load an Images Dataset using the PyTorch dataloader, but the resulting transformations are tiled, and don't have the original images cropped to the center as I am expecting them.
transform = transforms.Compose([transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor()])
dataset = datasets.ImageFolder('ml-models/downloads/', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
images, labels = next(iter(dataloader))
import matplotlib.pyplot as plt
plt.imshow(images[6].reshape(224, 224, 3))
The resulting image is tiled, and not center cropped.[![as shown in the Jupyter snapshot here][1]][1]
Is there something wrong in the provided transformation? (Image shown below on link: ) [1]: https://i.sstatic.net/HtrIa.png
Upvotes: 0
Views: 953
Reputation: 22284
Pytorch stores tensors in channel-first format, so a 3 channel image is a tensor of shape (3, H, W). Matplotlib expects data to be in channel-last format i.e. (H, W, 3). Reshaping does not rearrange the dimensions, for that you need Tensor.permute.
plt.imshow(images[6].permute(1, 2, 0))
Upvotes: 1