XIN LIU
XIN LIU

Reputation: 87

Predict batches using Tensorflow Data API and Keras Model

Suppose I have a dataset and a Keras Model. The dataset has been divided into batches using batch() in tf Dataset API. Now I am seeking an efficient and clean way to do batch predictions for all testing samples.

I have tried the following code and it works.

batch_size = 32
dataset = dataset.batch(batch_size)
predictions = keras_model.predict(dataset, steps=math.ceil(num_testing_samples / batch_size))

I wonder is there any more efficient and elegant approach to implement this?

Upvotes: 1

Views: 1515

Answers (1)

Olivier Dehaene
Olivier Dehaene

Reputation: 1680

TF >= 1.14.0

You can just set steps=None. From the official documentation of tf.keras.Model.predict():

If x is a tf.data dataset and steps is None, predict will run until the input dataset is exhausted.

Just make sure that your dataset object is not in repeat mode and you are good to go :).

TF 1.12.0 & 1.13.0

The support for tf.data.Dataset with tf.keras is very poor in these versions. The tf.data.Dataset object is transformed into an iterator here, which then triggers an error here if you didn't set the steps argument. This is patched in 1.14.0.

Upvotes: 0

Related Questions