Reputation: 87
Suppose I have a dataset and a Keras Model. The dataset has been divided into batches using batch()
in tf Dataset API. Now I am seeking an efficient and clean way to do batch predictions for all testing samples.
I have tried the following code and it works.
batch_size = 32
dataset = dataset.batch(batch_size)
predictions = keras_model.predict(dataset, steps=math.ceil(num_testing_samples / batch_size))
I wonder is there any more efficient and elegant approach to implement this?
Upvotes: 1
Views: 1515
Reputation: 1680
You can just set steps=None
. From the official documentation of tf.keras.Model.predict()
:
If x is a tf.data dataset and steps is None, predict will run until the input dataset is exhausted.
Just make sure that your dataset
object is not in repeat mode and you are good to go :).
The support for tf.data.Dataset
with tf.keras
is very poor in these versions. The tf.data.Dataset
object is transformed into an iterator here, which then triggers an error here if you didn't set the steps
argument. This is patched in 1.14.0.
Upvotes: 0