Gunnarsi
Gunnarsi

Reputation: 354

When to use model.predict(x) vs model(x) in tensorflow

I've got a keras.models.Model that I load with tf.keras.models.load_model.

Now there are two options to use this model. I can call model.predict(x) or I can call model(x).numpy(). Both options give me the same result, but model.predict(x) takes over 10x longer to run.

The comments in the source code state:

Computation is done in batches. This method is designed for performance in large scale inputs. For small amount of inputs that fit in one batch, directly using __call__ is recommended for faster execution, e.g., model(x), or model(x, training=False)

I've tested with x containing 1; 1,000,000; and 10,000,000 rows and model(x) still performs better.

How large does the input need to be to be classified as a large scale input, and for the model.predict(x) to perform better?

Upvotes: 5

Views: 4912

Answers (1)

jkr
jkr

Reputation: 19260

There is an existing stack overflow answer that you might find useful: https://stackoverflow.com/a/58385156/5666087. I found it on tensorflow/tensorflow#33340. That answer suggests passing experimental_run_tf_function=False into the model.compile call to revert to the TF 1.x version of model execution. You can also omit the model.compile call entirely (it is not necessary for prediction).

How large does the input need to be to be classified as a large scale input, and for the model.predict(x) to perform better?

This is something you can test. As the documentation states, model(x) will likely be faster than model.predict(x) if your data fit in one batch. One thing that model.predict(x) provides over model(x) is the ability to predict on multiple batches. If you want to predict on multiple batches with model(x), you have to write the loop yourself. model.predict also provides other features, like callbacks.

FYI the documentation in the source code was added in commit 42f469be0f3e8c36624f0b01c571e7ed15f75faf, as a result of tensorflow/tensorflow#33340.

The main behavior of model.predict(x) is implemented here. It contains more than just the forward pass of the model. This could account for some of the speed differences.

I've tested with x containing 1; 1,000,000; and 10,000,000 rows and model(x) still performs better.

Do these 10,000,000 rows fit into one batch...?

Upvotes: 3

Related Questions