Coldchain9
Coldchain9

Reputation: 1745

Does Keras' ImageDataGenerator randomly apply transformations to every image?

I am training a CNN using the Keras ImageDataGenerator class. My code looks something like:

from keras.callbacks import LearningRateScheduler
from keras.callbacks import EarlyStopping
from keras.preprocessing.image import ImageDataGenerator

    data_generator = ImageDataGenerator(
        rotation_range = 15,
        width_shift_range = 0.1,
        height_shift_range = 0.1,
        horizontal_flip = True
    )
    data_generator.fit(xtrain)

    es = EarlyStopping(monitor = 'val_loss', mode = 'min', verbose = 1, patience = 100)
    history = CNN_model.fit(data_generator.flow(xtrain, ytrain, batch_size = batch_size), \
                                                                          epochs = num_epochs, \
                                                                          validation_data = (xval, yval), \
                                                                          callbacks = [LearningRateScheduler(learning_rate_schedule), es])

Now I have looked at the documentation and I don't entirely understand if it does this image augmentation to every image in each batch, or it is randomly selects a percentage of them? I am assuming that it does augmentation to each image but I am trying to understand.

If it is doing this to every image in my dataset, this means that I am not at all using an un-augmented image for any of my training epochs, correct?

Thanks!

Upvotes: 3

Views: 1866

Answers (1)

Nicolas Gervais
Nicolas Gervais

Reputation: 36674

Well I did the test and it's applied randomly per image, even though these pictures come from the same batch.

Based on my reading of the documentation, I have no reason to think that only some of the images are transformed and that others are left intact.

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from skimage.data import chelsea
import matplotlib.pyplot as plt
import numpy as np

imgs = np.stack([chelsea() for i in range(4*4)], axis=0)

data_gen = ImageDataGenerator(
    rotation_range = 90,
    width_shift_range = 0.1,
    height_shift_range = 0.1,
    horizontal_flip = True,
    preprocessing_function=lambda x: x[..., np.random.permutation([0, 1, 2])]
)

fig = plt.figure()
plt.subplots_adjust(wspace=.2, hspace=.2)
for index, image in enumerate(next(data_gen.flow(imgs)).astype(int)):
    ax = plt.subplot(4, 4, index + 1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(image)
plt.show()

enter image description here

Upvotes: 4

Related Questions