Eugene Smith
Eugene Smith

Reputation: 9378

Using datasets larger than 2 Gb with Keras

TensorFlow has a long-standing limitation of 2 Gb on a single tensor. It means that you can't train your model on more than 2 Gb of data at one time without jumping through hoops. See Initializing tensorflow Variable with an array larger than 2GB ; Use large dataset in Tensorflow

The standard solution referenced in those posts is to use a placeholder and to pass it to the "session" through feed_dict:

my_graph = tf.Graph()
sess = tf.Session(graph=my_graph)   
X_init = tf.placeholder(tf.float32, shape=(m_input, n_input))
X = tf.Variable(X_init)
sess.run(tf.global_variables_initializer(), feed_dict={X_init: data_for_X})

However, this only works when I use the "old" API (tf.Session(), etc.) The recommended approach nowadays is to use Keras (all the tutorials on tensorflow.org use it). And, with Keras, there's no tf.Graph(), no tf.Session(), and no run() (at least none that are readily visible to the user.)

How do I adapt the above code to work with Keras?

Upvotes: 1

Views: 2573

Answers (2)

Daniel Möller
Daniel Möller

Reputation: 86600

In Keras, you'd not load your entire dataset in a tensor. You load it in numpy arrays.

If the entire data can be in a single numpy array:

Thanks to @sebrockm's comment.

The most trivial usage of Keras is simply loading your dataset in a numpy array (not a tf tensor) and call model.fit(arrayWithInputs, arrayWithoutputs, ...)

If the entire data doesn't fit a numpy array:

You'd create a generator or a keras.utils.Sequence to load batches one by one and then train the model with model.fit_generator(generatorOrSequence, ...)

The limitation becomes the batch size, but you'd hardly ever hit 2GB in a single batch. So, go for it:

Upvotes: 6

Dr. Snoopy
Dr. Snoopy

Reputation: 56357

Keras doesn't have a 2GB limitation for datasets, I've trained much larger datasets with Keras with no issues.

The limitation could come from TensorFlow constants, which do have a 2GB limit, but in any case you should NOT store datasets as constants, as these are saved as part of the graph, and that is not the idea of storing a model.

Keras has the model.fit_generator function that you can use to pass a generator function which loads data on the fly, and makes batches. This allows you to load a large dataset on the fly, and you usually adjust the batch size so you maximize performance with acceptable RAM usage. TensorFlow doesn't have a similar API, you have to implement it manually as you say with feed_dict.

Upvotes: 3

Related Questions