Estel
Estel

Reputation: 3

Data augmentation in Pytorch for CNN

I want to do data augmentation to my set of images in order to have more data to train a convolutional neural network in Pytorch.

Example of transnformations:

 train_transforms = Compose([LoadImage(image_only=True),EnsureChannelFirst(),ScaleIntensity(),RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),RandFlip(spatial_axis=0, prob=0.5)]

The transforms in Pytorch, as I understand, make a transformation of the image but then the transformed image is the only one used, and no the original one. I want to do transformations to my data and then use the original one and the transformed one, as my objective is to augment the data...But then, how can we actually increment the number of input data by applying these transformations? If I want to do data augmentation with flip (for example), I want to use my original data and the transformed one (in order to train the model with more data).

I tried to add transformations to my data but it seems like the transformed data is the only one used, obtaining changes on the data but not an increase of it.

Upvotes: 0

Views: 1160

Answers (2)

bezirganyan
bezirganyan

Reputation: 424

In your torch dataset class you can check if the index is bigger than the length of your dataset, then you can return an augmented image.

class ExampleDataset(Dataset):
    def __init__(self):
        self.data = ...
        self.real_length = len(self.data)
        self.length = self.real_length * 2
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        if idx < self.real_length:
            return self.data[idx]
        else:
            return augment(self.data[idx - self.real_length])

You can extend your data more times (3, 4) depending on the augmentations you want to do.

Upvotes: 0

TheEngineerProgrammer
TheEngineerProgrammer

Reputation: 1461

If you want your original data and augmented data at same time, you can just concatenate them and then create a dataloader to use them. So the steps are these:

  1. Create a dataset with data augmentations.
  2. Create a dataset without data augmentations.
  3. Create a dataset by concatenating both.
  4. Create a dataloader with the concatenated dataset.

I guess you already know how to create datasets with data augmentation. To concatenate several datasets you can use:

from torch.utils.data import ConcatDataset
concat_dataset = ConcatDataset([dataset1, dataset2])

Here you have more information:
https://discuss.pytorch.org/t/how-does-concatdataset-work/60083/2

Upvotes: 0

Related Questions