Raha Moosavi
Raha Moosavi

Reputation: 547

Plot the transformed (augmented) images in pytorch

I want to use one of the image augmentation techniques (for example rotation or horizontal flip) and apply it to some images of the CIFAR-10 dataset and plot them in PyTorch.

I know that we can use the following code to augmented images:

from torchvision import models, datasets, transforms
from torchvision.datasets import CIFAR10

data_transforms = transforms.Compose([
        # add augmentations
        transforms.RandomHorizontalFlip(p=0.5),
        # The output of torchvision datasets are PILImage images of range [0, 1].
        # We transform them to Tensors of normalized range [-1, 1]
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

and then I used the transforms above when I want to load the Cifar10 dataset:

train_set = CIFAR10(
    root='./data/',
    train=True,
    download=True,
    transform=data_transforms['train'])

As far as I know, when this code is used, all CIFAR10 datasets are transformed.

Question

My question is how can I use data transform or augmentation techniques for some images in data sets and plot them? for example 10 images and their augmented images.

Upvotes: 0

Views: 1312

Answers (1)

Ivan
Ivan

Reputation: 40768

when this code is used, all CIFAR10 datasets are transformed

Actually, the transform pipeline will only be called when images in the dataset are fetched via the __getitem__ function by the user or through a data loader. So at this point in time, train_set doesn't contain augmented images, they are transformed on the fly.


You will need to construct another dataset without augmentations.

>>> non_augmented = CIFAR10(
...     root='./data/',
...     train=True,
...     download=True)

>>> train_set = CIFAR10(
...     root='./data/',
...     train=True,
...     download=True,
...     transform=data_transforms)

Stack some images together:

>>> imgs = torch.stack((*[non_augmented[i][0] for i in range(10)],
                        *[train_set[i][0] for i in range(10)]))

>>> imgs.shape
torch.Size([20, 3, 32, 32])

Then torchvision.utils.make_grid can be useful to create the desired layout:

>>> grid = torchvision.utils.make_grid(imgs, nrow=10)

There you have it!

>>> transforms.ToPILImage()(grid)

enter image description here

Upvotes: 2

Related Questions