Reputation: 369
I was wondering if the fit_generator()
in keras has any advantage in respect to memory usage over using the usual fit()
method with the same batch_size
as the generator yields. I've seen some examples similar to this:
def generator():
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# some data prep
...
while 1:
for i in range(1875): # 1875 * 32 = 60000 -> # of training samples
yield X_train[i*32:(i+1)*32], y_train[i*32:(i+1)*32]
If I pass this into the fit_generator()
method or just pass all the data directly into the fit()
method and define a batch_size
of 32, would it make any difference regarding (GPU?)-memory whatsoever?
Upvotes: 3
Views: 699
Reputation: 811
Yes the difference actually comes in when you need augmented data for better model accuracy.
For efficiency it allows realtime data augmentation on images with CPU. That means it can use the GPU for your model training and it updates, while delegating to the CPU the load of augmenting images and providing the batches to train.
Upvotes: 3