Lorenzo Cutrupi
Lorenzo Cutrupi

Reputation: 720

Augmenting data proportionally

I'm facing a classification problem between 2 classes. Currently I augment the dataset using this code:

aug_train_data_gen = ImageDataGenerator(rotation_range=0,
                                    height_shift_range=40,
                                    width_shift_range=40,
                                    zoom_range=0,
                                    horizontal_flip=True,
                                    vertical_flip=True, 
                                    fill_mode='reflect',
                                    rescale=1/255.)

aug_train_gen = aug_train_data_gen.flow_from_directory(directory=training_dir,
                                                   target_size=(96,96),
                                                   color_mode='rgb',
                                                   classes=None, # can be set to labels
                                                   class_mode='categorical',
                                                   batch_size=64,
                                                   shuffle= False #set to false if need to compare images
                                                   )

But I think that increasing the data of class1 using augmentation would improve my performance, since at the moment class2 images are 6x more than class1, resulting in a CNN that tends to classify images to class2. How can I do so?

Upvotes: 0

Views: 56

Answers (1)

Chen Brestel
Chen Brestel

Reputation: 21

In order to get a balanced batch you can use the attached class.

On init you supply a list with multiple datasets. A single dataset per a class. The number of the multiple datasets is equal to the number of classes.

On runtime, the __ get_item __() chooses randomly among the classes and inside the class a random sample.

Best

from torch.utils.data import Dataset

class MultipleDataset(Dataset):
"""
Choose randomly from which dataset return an item on each call to __get_item__()
"""

def __init__(self, datasets: Iterable[Dataset]) -> None:
    super(MultipleDataset, self).__init__()
    self.datasets = list(datasets)
    assert len(self.datasets) > 0, 'datasets should not be an empty iterable'  # type: ignore[arg-type]
    for d in self.datasets:
        assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
    self.dataset_sizes = [len(d) for d in self.datasets]
    self.num_datasets = len(self.datasets)

def __len__(self):
    return max(self.dataset_sizes)

def __getitem__(self, idx):
    if idx < 0:
        if -idx > len(self):
            raise ValueError("absolute value of index should not exceed dataset length")
        idx = len(self) + idx
    dataset_idx = randint(self.num_datasets)
    sample_idx = idx % self.dataset_sizes[dataset_idx]
    return self.datasets[dataset_idx][sample_idx]

Upvotes: 1

Related Questions