David Kettler
David Kettler

Reputation: 31

Is there a way to keep a Tensorflow record file in memory?

Here is the situation: I am working with a large Tensorflow record file. It's about 50 GB. However the machine I'm doing this training on has 128 GB of RAM. 50 is less than 128, so even though this is a large file you would think that it would be possible to keep it in memory and save on slow I/O operators. But I'm using the TFRecordDataset class and it seems like the whole TFRecord system is designed specifically to not do that, and I don't see any way to force it to keep the records in memory. And since it reloads them every epoch I am wasting an inordinate amount of time on slow I/O operations reading from that 50 GB file.

I suppose I could load the records into memory in python and then load them into my model one by one with a feed_dict, bypassing the whole Dataset class. But that seems like a less elegant way to handle things and would require some redesign. Everything would be much simpler if I could just force the TFRecordDataset to load everything into memory and keep it there between epochs...

Upvotes: 3

Views: 2868

Answers (1)

Vlad-HC
Vlad-HC

Reputation: 4757

You need tf.data.Dataset.cache() operation. To achieve the desired effect (keeping the file in memory), put it right after the TFRecordDataset and don't provide any arguments to it:

  dataset = tf.data.TFRecordDataset(filenames)
  dataset = dataset.cache()

When the cache() operation is invoked without arguments, than caching is done in memory.

Also if you have some postprocessing of these records, like with dataset.map(...), then it could be even more beneficial to put the cache() operation in the end of the input pipeline.

More information can be found in the "Input Pipeline Performance Guide" Map and Cache section.

Upvotes: 3

Related Questions