Hao
Hao

Reputation: 43

How to use sequence/generator on tf.data.Dataset object to fit partial data into memory?

I am doing image classification with Keras on Google Colab. I load images with the tf.keras.preprocessing.image_dataset_from_directory() function (https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory) which returns a tf.data.Dataset object:

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=1234,
  image_size=(img_height, img_width),
  batch_size=batch_size,
  label_mode="categorical")

I found that when the data contains thousands of images, model.fit() will use all memory after training a number of batches (I am using Google Colab and can see RAM usage grow during the first epoch). Then I try to use Keras Sequence, which is a suggested solution of loading partial data into RAM (https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence):

  class DatasetGenerator(tf.keras.utils.Sequence):
      def __init__(self, dataset):
          self.dataset = dataset

      def __len__(self):
          return tf.data.experimental.cardinality(self.dataset).numpy()

      def __getitem__(self, idx):
          return list(self.dataset.as_numpy_iterator())[idx]

And I train the model with:

history = model.fit(DatasetGenerator(train_ds), ...)

The problem is that getitem() must return a batch of data with index. However, the list() function I use has to put the whole dataset into RAM and thus hit memory limit when a DatasetGenerator object instantiates (tf.data.Dataset object does not support indexing with []).

My questions:

  1. Is there any way to implement getitem() (get a specific batch from the dataset object) without putting the whole object into memory?
  2. If item 1 is not possible, is there any workaround?

Thanks in advance!

Upvotes: 1

Views: 2781

Answers (1)

pratsbhatt
pratsbhatt

Reputation: 1538

I understand that you are concerned about having your complete dataset in the memory.

Do not worry, the tf.data.Dataset API is very efficient and it does not load your complete dataset in the memory.

Internally it just creates a sequence of functions and when called with model.fit() it will load only the batch in the memory and not the complete dataset.

You can read more in this link, I am pasting the important part from the documentation.

The tf.data.Dataset API supports writing descriptive and efficient input pipelines. Dataset usage follows a common pattern:

Create a source dataset from your input data. Apply dataset transformations to preprocess the data. Iterate over the dataset and process the elements. Iteration happens in a streaming fashion, so the full dataset does not need to fit into memory.

From the last line you can understand that the tf.data.Dataset API does not load the complete dataset in the memory but one batch at a time.

You will have to do the following to create batches of your dataset.

train_ds.batch(32)

This will create the batch of size 32. Also you can use prefetch to prepare one batch berore it heads for training. This removes the bottleneck where the model is idle after training one batch and waiting for another batch.

train_ds.batch(32).prefetch(1)

You can also use the cache API to make your data pipeline even faster. It will cache your dataset and make the training much faster.

train_ds.batch(32).prefetch(1).cache()

So to answer in short, you do not need the generator if you are concerned about loading the whole dataset into memory, the tf.data.Dataset API takes care of it.

I hope my answer finds you well.

Upvotes: 5

Related Questions