Mihai.Mehe
Mihai.Mehe

Reputation: 504

Keras Custom Layer gives errors when saving the full model

class ConstLayer(tf.keras.layers.Layer):
    def __init__(self, x, **kwargs):
        super(ConstLayer, self).__init__(**kwargs)
        self.x = tf.Variable(x, trainable=False)

    def call(self, input):
        return self.x

    def get_config(self):
        #Note: all original model has eager execution disabled
        config = super(ConstLayer, self).get_config()
        config['x'] = self.x
        return config
    


model_test_const_layer = keras.Sequential([
    keras.Input(shape=(784)),
    ConstLayer([[1.,1.]], name="anchors"),
    keras.layers.Dense(10),
])

model_test_const_layer.summary()
model_test_const_layer.save("../models/my_model_test_constlayer.h5")
del model_test_const_layer
model_test_const_layer = keras.models.load_model("../models/my_model_test_constlayer.h5",custom_objects={'ConstLayer': ConstLayer,})
model_test_const_layer.summary()

This code is a sandbox replication of an error given by a larger Keras model with a RESNet 101 backbone.

Errors: If the model includes the custom layer ConstLayer:

Any help and clues are greatly appreciated!

Upvotes: 1

Views: 585

Answers (1)

Plagon
Plagon

Reputation: 3137

As far as I understand it, TF has problems with copying variables. Just save the original value / config passed to the layer instead:

import tensorflow as tf
import tensorflow.keras as keras

tf.compat.v1.disable_eager_execution()

class ConstLayer(tf.keras.layers.Layer):
    def __init__(self, x, **kwargs):
        super(ConstLayer, self).__init__(**kwargs)
        self._config = {'x': x}
        self.x = tf.Variable(x, trainable=False)

    def call(self, input):
        return self.x

    def get_config(self):
        #Note: all original model has eager execution disabled
        config = {
            **super(ConstLayer, self).get_config(),
            **self._config
        }
        return config


model_test_const_layer = keras.Sequential([
    keras.Input(shape=(784)),
    ConstLayer([[1., 1.]], name="anchors"),
    keras.layers.Dense(10),
])

model_test_const_layer.summary()
model_test_const_layer.save("../models/my_model_test_constlayer.h5")
del model_test_const_layer
model_test_const_layer = keras.models.load_model(
    "../models/my_model_test_constlayer.h5", custom_objects={'ConstLayer': ConstLayer, })
model_test_const_layer.summary()

Upvotes: 1

Related Questions