doraemon
doraemon

Reputation: 2512

tensorflow model.predict uses more and more memory

I am using tensorflow 2.14. I have a dataset created using from_generator. I then batched the dataset to a fixed batch_size 1024. The code used is close to d1024 = from_generator(...).batch(1024).cache('cache_file').prefetch(AUTOTUNE).

Now if I do model.predict(d1024.take(n)) for a big n, I see GPU memory usage gradually increases when more batches are processed. This eventually results in an out-of-memory error if n is big.

In my understanding, tensorflow will do the prediction batch by batch. Given that the batch size is fixed at 1024, its memory usage should be determined by the batch size instead of n. Is this correct?

Why does bigger n needs more memory and how to mitigate this problem if I have a big n.

I googled and found this issue in tensorflow github. But I don't know whether it is related to this problem.

Upvotes: 1

Views: 41

Answers (1)

Ruben Bento
Ruben Bento

Reputation: 11

I also had similar issues with GPU memory when loading the dataset into memory. I see you're using a generator which is fine in your example, but why would you use model.predict(d1024.take(n)) instead of simply model.predict(d1024)?

When you use dataset.take(n) it will create a new dataset with n batches, so it will not process the entire dataset. Furthermore, it will try to load at once the n batches of your dataset into the GPU which explains why you get memory problems.

I found the best approach for me is using a custom generator that yield batches of data so only one batch at the time is loaded into memory.

Something like:

def gen():

   while True:
      ...

      yield X, Y

You'll be sure of having no memory problems and using X and Y as numpy arrays instead of using tf.Dataset you also have more flexibility.

Upvotes: 0

Related Questions