justtrying
justtrying

Reputation: 21

How can I properly save and load this custom tensorflow model?

A tensorflow layer in a model is simply not loading correctly after saving no matter what I do with it, the model trains and saves well but doesn't load when I need it at a later time. I save the model with model.save and load it with tf.keras.models.load_model. This is the layer in question:

@keras.utils.register_keras_serializable(package='Custom', name='TokenAndPositionEmbedding')
class TokenAndPositionEmbedding(keras.layers.Layer):
    def __init__(self, max_len, vocab_size, embed_dim, **kwargs):
        super(TokenAndPositionEmbedding, self).__init__(**kwargs)
        self.max_len = max_len
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.pos_emb = layers.Embedding(input_dim=max_len, output_dim=embed_dim)

    def call(self, x):
        maxlen = tf.shape(x)[-1]
        positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return x + positions

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "max_len": self.max_len,
                "vocab_size": self.vocab_size,
                "embed_dim": self.embed_dim,
            }
        )
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

The error message is:

*TypeError: <class 'keras.src.models.functional.Functional'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by get_config() are explicitly deserialized in the model's from_config() method.

config={'module': 'keras.src.models.functional', 'class_name': 'Functional', 'config': {}, 'registered_name': 'Functional', 'build_config': {'input_shape': None}, 'compile_config': {'optimizer': {'module': 'keras.optimizers', 'class_name': 'Adam', 'config': {'name': 'adam', 'learning_rate': 0.0010000000474974513, 'weight_decay': None, 'clipnorm': None, 'global_clipnorm': None, 'clipvalue': None, 'use_ema': False, 'ema_momentum': 0.99, 'ema_overwrite_frequency': None, 'loss_scale_factor': None, 'gradient_accumulation_steps': None, 'beta_1': 0.9, 'beta_2': 0.999, 'epsilon': 1e-07, 'amsgrad': False}, 'registered_name': None}, 'loss': 'categorical_crossentropy', 'loss_weights': None, 'metrics': ['accuracy'], 'weighted_metrics': None, 'run_eagerly': False, 'steps_per_execution': 1, 'jit_compile': False}}.

Exception encountered: Could not locate class 'TokenAndPositionEmbedding'. Make sure custom classes are decorated with @keras.saving.register_keras_serializable(). Full object config: {'module': None, 'class_name': 'TokenAndPositionEmbedding', 'config': {'name': 'token_and_position_embedding', 'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}, 'max_len': 20, 'vocab_size': 256, 'embed_dim': 512}, 'registered_name': 'Custom>TokenAndPositionEmbedding', 'build_config': {'input_shape': [None, 20]}, 'name': 'token_and_position_embedding', 'inbound_nodes': [{'args': [{'class_name': 'keras_tensor', 'config': {'shape': [None, 20], 'dtype': 'float32', 'keras_history': ['input_layer', 0, 0]}}], 'kwargs': {}}]}*

I tried passing the custom layers in get_config and from_config, using different decorators and more. Only suspicious thing I noticed is that, while the decorator @keras.utils.register_keras_serializable allows the code to run, if I use @keras.saving.register_keras_serializable it doesn't find the saving class, but I believe it should still work with @keras.utils.register_keras_serializable.

Upvotes: 2

Views: 48

Answers (0)

Related Questions