Reputation: 2512
I am using tensorflow 2.14. I have a dataset created using from_generator
. I then batched the dataset to a fixed batch_size 1024. The code used is close to d1024 = from_generator(...).batch(1024).cache('cache_file').prefetch(AUTOTUNE)
.
Now if I do model.predict(d1024.take(n))
for a big n, I see GPU memory usage gradually increases when more batches are processed. This eventually results in an out-of-memory error if n
is big.
In my understanding, tensorflow will do the prediction batch by batch. Given that the batch size is fixed at 1024, its memory usage should be determined by the batch size instead of n
. Is this correct?
Why does bigger n
needs more memory and how to mitigate this problem if I have a big n
.
I googled and found this issue in tensorflow github. But I don't know whether it is related to this problem.
Upvotes: 1
Views: 41
Reputation: 11
I also had similar issues with GPU memory when loading the dataset into memory.
I see you're using a generator which is fine in your example, but why would you use model.predict(d1024.take(n))
instead of simply model.predict(d1024)
?
When you use dataset.take(n)
it will create a new dataset with n
batches, so it will not process the entire dataset. Furthermore, it will try to load at once the n
batches of your dataset into the GPU which explains why you get memory problems.
I found the best approach for me is using a custom generator that yield
batches of data so only one batch at the time is loaded into memory.
Something like:
def gen():
while True:
...
yield X, Y
You'll be sure of having no memory problems and using X and Y as numpy arrays instead of using tf.Dataset you also have more flexibility.
Upvotes: 0