mjspier
mjspier

Reputation: 6526

tensorflow model predict runs out of memory

We have a tensorflow keras model which we would like to evaluate after training but the predict call after the training runs into out of memory errors even though the fit call works just fine.

The dataset is loaded like this:

options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
dataset = (
            tf.data.Dataset.list_files(f'{dataset_location}/*.csv')
            .flat_map(tf.data.TextLineDataset)
            .skip(1)
            .map(decode_line) # custom function how to read the csv files
        )
dataset = dataset.apply(tf.data.experimental.ignore_errors())
dataset = dataset.batch(batch_size).with_options(options)

Using model.fit on the dataset just works fine with batch size of 128

model.fit(dataset, epochs=10)

But using predict after the model is trained gives me an out of memory error even when I used batch size 1.

model.predict(dataset) 

We are using google AI-platform with a custom image which is using the following base image: gcr.io/deeplearning-platform-release/tf-gpu.2-8

Why would the fit call work but the predict call for the same model with the same dataset runs out of memory? I would expect that the fit call is more memory intensive than the predict call.

This happens on tensorflow 2.8

Update:

I found a way how to solve the problem although I still think there is a memory leak in the predict function in tensorflow 2.8.

The solution which worked for me is not to load all csv files in the dataset but load one by one and do the prediction on the smaller datasets. Even though for training loading all csv files in one dataset works.

predictions = []
for file_name in tf.io.gfile.listdir(dataset_location):
    file_path = f'{dataset_location}/{file_name}'
    dataset = tf.data.experimental.make_csv_dataset(
        file_pattern=file_path,
        batch_size=10,
        shuffle=False,
        label_name="label",
        field_delim=';',
        column_defaults=[[""], [""], [0.0]],
        num_epochs=1
    )
    pred = model.predict(dataset)
    predictions.append(pred)

Upvotes: 2

Views: 1361

Answers (0)

Related Questions