Reputation: 43
I am doing image classification with Keras on Google Colab. I load images with the tf.keras.preprocessing.image_dataset_from_directory() function (https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory) which returns a tf.data.Dataset object:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=1234,
image_size=(img_height, img_width),
batch_size=batch_size,
label_mode="categorical")
I found that when the data contains thousands of images, model.fit() will use all memory after training a number of batches (I am using Google Colab and can see RAM usage grow during the first epoch). Then I try to use Keras Sequence, which is a suggested solution of loading partial data into RAM (https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence):
class DatasetGenerator(tf.keras.utils.Sequence):
def __init__(self, dataset):
self.dataset = dataset
def __len__(self):
return tf.data.experimental.cardinality(self.dataset).numpy()
def __getitem__(self, idx):
return list(self.dataset.as_numpy_iterator())[idx]
And I train the model with:
history = model.fit(DatasetGenerator(train_ds), ...)
The problem is that getitem() must return a batch of data with index. However, the list() function I use has to put the whole dataset into RAM and thus hit memory limit when a DatasetGenerator object instantiates (tf.data.Dataset object does not support indexing with []).
My questions:
Thanks in advance!
Upvotes: 1
Views: 2781
Reputation: 1538
I understand that you are concerned about having your complete dataset in the memory.
Do not worry, the tf.data.Dataset
API is very efficient and it does not load your complete dataset in the memory.
Internally it just creates a sequence of functions and when called with model.fit()
it will load only the batch in the memory and not the complete dataset.
You can read more in this link, I am pasting the important part from the documentation.
The tf.data.Dataset API supports writing descriptive and efficient input pipelines. Dataset usage follows a common pattern:
Create a source dataset from your input data. Apply dataset transformations to preprocess the data. Iterate over the dataset and process the elements. Iteration happens in a streaming fashion, so the full dataset does not need to fit into memory.
From the last line you can understand that the tf.data.Dataset
API does not load the complete dataset in the memory but one batch at a time.
You will have to do the following to create batches of your dataset.
train_ds.batch(32)
This will create the batch of size 32
. Also you can use prefetch to prepare one batch berore it heads for training. This removes the bottleneck where the model is idle after training one batch and waiting for another batch.
train_ds.batch(32).prefetch(1)
You can also use the cache
API to make your data pipeline even faster. It will cache your dataset and make the training much faster.
train_ds.batch(32).prefetch(1).cache()
So to answer in short, you do not need the generator
if you are concerned about loading the whole dataset into memory, the tf.data.Dataset
API takes care of it.
I hope my answer finds you well.
Upvotes: 5