DragonKnight
DragonKnight

Reputation: 1870

How to balance data with keras.ImageDataGenerator()

Having unbalanced data, how can I use ImageDataGenerator() to generate enough augmented data for shorter sample to balance all categories?

Upvotes: 2

Views: 1389

Answers (2)

Taisa
Taisa

Reputation: 123

You need to create a dictionary based on the weights of each class and then feed the model.fit_generator with it:

from sklearn.utils import class_weight import numpy as np

class_weights = class_weight.compute_class_weight(
           'balanced',
            np.unique(train_generator.classes), 
            train_generator.classes)

train_class_weights = dict(enumerate(class_weights))
model.fit_generator(..., class_weight=train_class_weights)

Upvotes: 0

Harshit Ruwali
Harshit Ruwali

Reputation: 1088

You can use the following code,

datagen = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True)

This will not affect your dataset at all. It formats the image while feeding into the model.
You may refer the documentation, Image Preprocessing
Hope this helps.

Upvotes: 1

Related Questions