Reputation: 720
I'm facing a classification problem between 2 classes. Currently I augment the dataset using this code:
aug_train_data_gen = ImageDataGenerator(rotation_range=0,
height_shift_range=40,
width_shift_range=40,
zoom_range=0,
horizontal_flip=True,
vertical_flip=True,
fill_mode='reflect',
rescale=1/255.)
aug_train_gen = aug_train_data_gen.flow_from_directory(directory=training_dir,
target_size=(96,96),
color_mode='rgb',
classes=None, # can be set to labels
class_mode='categorical',
batch_size=64,
shuffle= False #set to false if need to compare images
)
But I think that increasing the data of class1 using augmentation would improve my performance, since at the moment class2 images are 6x more than class1, resulting in a CNN that tends to classify images to class2. How can I do so?
Upvotes: 0
Views: 56
Reputation: 21
In order to get a balanced batch you can use the attached class.
On init you supply a list with multiple datasets. A single dataset per a class. The number of the multiple datasets is equal to the number of classes.
On runtime, the __ get_item __() chooses randomly among the classes and inside the class a random sample.
Best
from torch.utils.data import Dataset
class MultipleDataset(Dataset):
"""
Choose randomly from which dataset return an item on each call to __get_item__()
"""
def __init__(self, datasets: Iterable[Dataset]) -> None:
super(MultipleDataset, self).__init__()
self.datasets = list(datasets)
assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
for d in self.datasets:
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
self.dataset_sizes = [len(d) for d in self.datasets]
self.num_datasets = len(self.datasets)
def __len__(self):
return max(self.dataset_sizes)
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = randint(self.num_datasets)
sample_idx = idx % self.dataset_sizes[dataset_idx]
return self.datasets[dataset_idx][sample_idx]
Upvotes: 1