Reputation: 547
I want to use one of the image augmentation techniques (for example rotation or horizontal flip) and apply it to some images of the CIFAR-10 dataset and plot them in PyTorch.
I know that we can use the following code to augmented images:
from torchvision import models, datasets, transforms
from torchvision.datasets import CIFAR10
data_transforms = transforms.Compose([
# add augmentations
transforms.RandomHorizontalFlip(p=0.5),
# The output of torchvision datasets are PILImage images of range [0, 1].
# We transform them to Tensors of normalized range [-1, 1]
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
and then I used the transforms above when I want to load the Cifar10 dataset:
train_set = CIFAR10(
root='./data/',
train=True,
download=True,
transform=data_transforms['train'])
As far as I know, when this code is used, all CIFAR10 datasets are transformed.
Question
My question is how can I use data transform or augmentation techniques for some images in data sets and plot them? for example 10 images and their augmented images.
Upvotes: 0
Views: 1312
Reputation: 40768
when this code is used, all CIFAR10 datasets are transformed
Actually, the transform pipeline will only be called when images in the dataset are fetched via the __getitem__
function by the user or through a data loader. So at this point in time, train_set
doesn't contain augmented images, they are transformed on the fly.
You will need to construct another dataset without augmentations.
>>> non_augmented = CIFAR10(
... root='./data/',
... train=True,
... download=True)
>>> train_set = CIFAR10(
... root='./data/',
... train=True,
... download=True,
... transform=data_transforms)
Stack some images together:
>>> imgs = torch.stack((*[non_augmented[i][0] for i in range(10)],
*[train_set[i][0] for i in range(10)]))
>>> imgs.shape
torch.Size([20, 3, 32, 32])
Then torchvision.utils.make_grid
can be useful to create the desired layout:
>>> grid = torchvision.utils.make_grid(imgs, nrow=10)
There you have it!
>>> transforms.ToPILImage()(grid)
Upvotes: 2