Reputation: 1745
I am training a CNN using the Keras ImageDataGenerator class. My code looks something like:
from keras.callbacks import LearningRateScheduler
from keras.callbacks import EarlyStopping
from keras.preprocessing.image import ImageDataGenerator
data_generator = ImageDataGenerator(
rotation_range = 15,
width_shift_range = 0.1,
height_shift_range = 0.1,
horizontal_flip = True
)
data_generator.fit(xtrain)
es = EarlyStopping(monitor = 'val_loss', mode = 'min', verbose = 1, patience = 100)
history = CNN_model.fit(data_generator.flow(xtrain, ytrain, batch_size = batch_size), \
epochs = num_epochs, \
validation_data = (xval, yval), \
callbacks = [LearningRateScheduler(learning_rate_schedule), es])
Now I have looked at the documentation and I don't entirely understand if it does this image augmentation to every image in each batch, or it is randomly selects a percentage of them? I am assuming that it does augmentation to each image but I am trying to understand.
If it is doing this to every image in my dataset, this means that I am not at all using an un-augmented image for any of my training epochs, correct?
Thanks!
Upvotes: 3
Views: 1866
Reputation: 36674
Well I did the test and it's applied randomly per image, even though these pictures come from the same batch.
Based on my reading of the documentation, I have no reason to think that only some of the images are transformed and that others are left intact.
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from skimage.data import chelsea
import matplotlib.pyplot as plt
import numpy as np
imgs = np.stack([chelsea() for i in range(4*4)], axis=0)
data_gen = ImageDataGenerator(
rotation_range = 90,
width_shift_range = 0.1,
height_shift_range = 0.1,
horizontal_flip = True,
preprocessing_function=lambda x: x[..., np.random.permutation([0, 1, 2])]
)
fig = plt.figure()
plt.subplots_adjust(wspace=.2, hspace=.2)
for index, image in enumerate(next(data_gen.flow(imgs)).astype(int)):
ax = plt.subplot(4, 4, index + 1)
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(image)
plt.show()
Upvotes: 4