batuman
batuman

Reputation: 7314

Image augmentation in Pytorch

I like to augment image alternately. I have pytorch transform code as follows.

import torchvision.transforms as tt
from torchvision.datasets import ImageFolder
#Data transform (normalization & data augmentation)
stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_tfms = tt.Compose([tt.RandomCrop(32, padding = 4, padding_mode = 'reflect'),
                         tt.RandomHorizontalFlip(),
                         tt.RandomAffine(degrees=(10, 30),
                                         translate=(0.1, 0.3),
                                         scale=(0.7, 1.3),
                                         shear=0.1, 
                                         resample=Image.BICUBIC)
                         tt.ToTensor(),
                         tt.Normalize(*stats)])

When I create dataset as follow and do training, all images will be augmented.

train_ds = ImageFolder('content/train', train_tfms)

But I want alternately. First image, just train as original image. But the next image is augmented.

How can I do that?

Upvotes: 0

Views: 1784

Answers (1)

Guillem
Guillem

Reputation: 2647

From a single dataset you can create two datasets one with augmentation and the other without, and then concatenate them. The order is going to be kept since we are using the subdataset pytorch class which will handle this for us.

train_ds_no_aug = ImageFolder('content/train')
train_ds_aug = ImageFolder('content/train', train_tfms)

# Check that aug_idx and no_aug_idx are not overlapping
aug_idx = torch.arange(1, len(train_ds_no_aug), 2)
no_aug_idx = torch.arange(0, len(train_ds_no_aug), 2)

train_ds_no_aug = torch.utils.data.Subset(train_ds_no_aug, no_aug_idx)
train_ds_aug = torch.utils.data.Subset(train_ds_aug, aug_idx)

train_ds = torch.utils.data.ChainDataset([train_ds_no_aug, train_ds_aug])
# Done :=

Upvotes: 1

Related Questions