Reputation: 21
A tensorflow layer in a model is simply not loading correctly after saving no matter what I do with it, the model trains and saves well but doesn't load when I need it at a later time. I save the model with model.save and load it with tf.keras.models.load_model. This is the layer in question:
@keras.utils.register_keras_serializable(package='Custom', name='TokenAndPositionEmbedding')
class TokenAndPositionEmbedding(keras.layers.Layer):
def __init__(self, max_len, vocab_size, embed_dim, **kwargs):
super(TokenAndPositionEmbedding, self).__init__(**kwargs)
self.max_len = max_len
self.vocab_size = vocab_size
self.embed_dim = embed_dim
self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
self.pos_emb = layers.Embedding(input_dim=max_len, output_dim=embed_dim)
def call(self, x):
maxlen = tf.shape(x)[-1]
positions = tf.range(start=0, limit=maxlen, delta=1)
positions = self.pos_emb(positions)
x = self.token_emb(x)
return x + positions
def get_config(self):
config = super().get_config()
config.update(
{
"max_len": self.max_len,
"vocab_size": self.vocab_size,
"embed_dim": self.embed_dim,
}
)
return config
@classmethod
def from_config(cls, config):
return cls(**config)
The error message is:
*TypeError: <class 'keras.src.models.functional.Functional'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by get_config()
are explicitly deserialized in the model's from_config()
method.
config={'module': 'keras.src.models.functional', 'class_name': 'Functional', 'config': {}, 'registered_name': 'Functional', 'build_config': {'input_shape': None}, 'compile_config': {'optimizer': {'module': 'keras.optimizers', 'class_name': 'Adam', 'config': {'name': 'adam', 'learning_rate': 0.0010000000474974513, 'weight_decay': None, 'clipnorm': None, 'global_clipnorm': None, 'clipvalue': None, 'use_ema': False, 'ema_momentum': 0.99, 'ema_overwrite_frequency': None, 'loss_scale_factor': None, 'gradient_accumulation_steps': None, 'beta_1': 0.9, 'beta_2': 0.999, 'epsilon': 1e-07, 'amsgrad': False}, 'registered_name': None}, 'loss': 'categorical_crossentropy', 'loss_weights': None, 'metrics': ['accuracy'], 'weighted_metrics': None, 'run_eagerly': False, 'steps_per_execution': 1, 'jit_compile': False}}.
Exception encountered: Could not locate class 'TokenAndPositionEmbedding'. Make sure custom classes are decorated with @keras.saving.register_keras_serializable()
. Full object config: {'module': None, 'class_name': 'TokenAndPositionEmbedding', 'config': {'name': 'token_and_position_embedding', 'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}, 'max_len': 20, 'vocab_size': 256, 'embed_dim': 512}, 'registered_name': 'Custom>TokenAndPositionEmbedding', 'build_config': {'input_shape': [None, 20]}, 'name': 'token_and_position_embedding', 'inbound_nodes': [{'args': [{'class_name': 'keras_tensor', 'config': {'shape': [None, 20], 'dtype': 'float32', 'keras_history': ['input_layer', 0, 0]}}], 'kwargs': {}}]}*
I tried passing the custom layers in get_config and from_config, using different decorators and more. Only suspicious thing I noticed is that, while the decorator @keras.utils.register_keras_serializable allows the code to run, if I use @keras.saving.register_keras_serializable it doesn't find the saving class, but I believe it should still work with @keras.utils.register_keras_serializable.
Upvotes: 2
Views: 48