daniel451
daniel451

Reputation: 11002

Save complete tf.keras model without custom objects?

Let's suppose we have some model including custom losses & metrics that are important during training. Is it possible to save the complete model, so weights + graphdef / pb-file, without the custom objects?

During inference the custom losses & metrics are not needed, thus...

tf.keras.models.load_model("some_model", custom_objects={...})

...would just render the inference code more complicated since custom object code needs to be included for inferencing (although it is not used).

However, tf.keras.callbacks.ModelCheckpoint (even with include_optimizer=False) as well as calling model.save() always save the model definition including the custom objects.

Hence, simply loading the model with...

tf.keras.models.load_model("some_model")

...will always fail and complain about the missing custom objects.

Is it possible to somehow save the whole model without custom losses/metrics? To get an "inference" version of the network that is easy to load?

Or is the only solution to this to freeze everything to a TFLite model?

Of course, one could simply use model.save_weights(), but then the actual code needs to be included for inference later, which is not desired.

Upvotes: 6

Views: 2870

Answers (1)

Dr. Snoopy
Dr. Snoopy

Reputation: 56377

If the purpose is to prevent the loading of loss and metrics, you can use the parameter compile in load_model:

model = tf.keras.models.load_model("some_model", compile=False)

This should skip the requirement of loss and metrics/optimizers since the model is not compiled. Of course you cannot train the model now, but it should work fine for inference using model.predict()

Upvotes: 10

Related Questions