Reputation: 9378
TensorFlow has a long-standing limitation of 2 Gb on a single tensor. It means that you can't train your model on more than 2 Gb of data at one time without jumping through hoops. See Initializing tensorflow Variable with an array larger than 2GB ; Use large dataset in Tensorflow
The standard solution referenced in those posts is to use a placeholder and to pass it to the "session" through feed_dict:
my_graph = tf.Graph()
sess = tf.Session(graph=my_graph)
X_init = tf.placeholder(tf.float32, shape=(m_input, n_input))
X = tf.Variable(X_init)
sess.run(tf.global_variables_initializer(), feed_dict={X_init: data_for_X})
However, this only works when I use the "old" API (tf.Session(), etc.) The recommended approach nowadays is to use Keras (all the tutorials on tensorflow.org use it). And, with Keras, there's no tf.Graph(), no tf.Session(), and no run() (at least none that are readily visible to the user.)
How do I adapt the above code to work with Keras?
Upvotes: 1
Views: 2573
Reputation: 86600
In Keras, you'd not load your entire dataset in a tensor. You load it in numpy arrays.
Thanks to @sebrockm's comment.
The most trivial usage of Keras is simply loading your dataset in a numpy array (not a tf tensor) and call model.fit(arrayWithInputs, arrayWithoutputs, ...)
You'd create a generator
or a keras.utils.Sequence
to load batches one by one and then train the model with model.fit_generator(generatorOrSequence, ...)
The limitation becomes the batch size, but you'd hardly ever hit 2GB in a single batch. So, go for it:
Upvotes: 6
Reputation: 56357
Keras doesn't have a 2GB limitation for datasets, I've trained much larger datasets with Keras with no issues.
The limitation could come from TensorFlow constants, which do have a 2GB limit, but in any case you should NOT store datasets as constants, as these are saved as part of the graph, and that is not the idea of storing a model.
Keras has the model.fit_generator
function that you can use to pass a generator function which loads data on the fly, and makes batches. This allows you to load a large dataset on the fly, and you usually adjust the batch size so you maximize performance with acceptable RAM usage. TensorFlow doesn't have a similar API, you have to implement it manually as you say with feed_dict
.
Upvotes: 3