mikal94305
mikal94305

Reputation: 5093

keras model.fit_generator() several times slower than model.fit()

Even as of Keras 1.2.2, referencing merge, it does have multiprocessing included, but model.fit_generator() is still about 4-5x slower than model.fit() due to disk reading speed limitations. How can this be sped up, say through additional multiprocessing?

Upvotes: 9

Views: 8155

Answers (2)

Mach_Zero
Mach_Zero

Reputation: 524

You may want to check out the workers and max_queue_size parameters of fit_generator() in the documentation. Essentially, more workers creates more threads for loading the data into the queue that feeds data to your network. There is a chance that filling the queue might cause memory problems, though, so you might want to decrease max_queue_size to avoid this.

Upvotes: 3

nafizh
nafizh

Reputation: 185

I had a similar problem where I switched to dask to load the data into memory rather than using a generator where I was using pandas. So, depending on your data size, if possible, load the data into memory and use the fit function.

Upvotes: 1

Related Questions