Pierluigi
Pierluigi

Reputation: 1118

Keras Wrapping Layers and Save

I created a custom Keras Layer with the purpose of wrapping another Layer so that I can have some additional computation before and/or after it. Here is some pseudo code:

class Wrapper(Layer):

    def __init__(self, layer: Layer, **kwargs):
        super(Wrapper, self).__init__(**kwargs)
        self._layer = layer

    def build(self, input_shape):
        self._layer.build(input_shape)

        # i.e. following is just an example
        self._after_layer = Dense(10)

    def call(self, x, **kwargs):
        y = self._layer(x)
        y = self._after_layer(y)
        return y

    def compute_output_shape(self, input_shape):
        return self._after_layer.compute_output_shape(self._layer.compute_output_shape(input_shape))

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'layer': self._layer,
        })
        return config

now you can use the above wrapper the following way:

y = Wrapper(SomeOtherLayer(...))(x)

everything works like a charm, but having a Layer as input make impossible to save the model. Trying to save raise a "TypeError: Cannot convert ... to a TensorFlow DType" and it is triggered by the fact that I added 'layer': self._layer, in the Layer config.

Is there any workaraound or best way to achieve the same as above and also save/load the model?

Upvotes: 1

Views: 944

Answers (1)

M. Perier--Dulhoste
M. Perier--Dulhoste

Reputation: 1039

Since you are wrapping a TF layer, you must serialize it in the get_config.

You can also implement the classmethod from_config which should be able to recreate the wrapper from the output of get_config. For this, you will need to deserialize the layer wrapped. This will be useful in case you are saving the architecture and not only the weights.

Here is the full working code:

import tensorflow as tf

class Wrapper(tf.keras.layers.Layer):

    def __init__(self, layer: tf.keras.layers.Layer, **kwargs):
        super(Wrapper, self).__init__(**kwargs)
        self._layer = layer

    def build(self, input_shape):
        self._layer.build(input_shape)

        # i.e. following is just an example
        self._after_layer = tf.keras.layers.Dense(10)

    def call(self, x, **kwargs):
        y = self._layer(x)
        y = self._after_layer(y)
        return y

    def compute_output_shape(self, input_shape):
        return self._after_layer.compute_output_shape(self._layer.compute_output_shape(input_shape))

    def get_config(self):
        config = super().get_config().copy()
        config["layer"] = tf.keras.layers.serialize(self._layer)
        return config
    
    @classmethod
    def from_config(cls, config):
        layer = tf.keras.layers.deserialize(config.pop("layer"))
        return cls(layer, **config)

Upvotes: 1

Related Questions