anon
anon

Reputation: 417

Keras: Serializing a Masking Layer for save/load

So I have a custom layer in Keras that uses a Mask in it.

To get it to work with save/load I need to serialize the Mask correctly. So this standard code doesn't work:

def get_config(self):
    config =  {'mask': self.mask}
    base_config = super(Mixing, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

where mask is a reference to the Masking Layer.

I'm not sure how to serialize Masking (or Keras Layers in general). Can anyone help?

Upvotes: 4

Views: 573

Answers (1)

Yu-Yang
Yu-Yang

Reputation: 14619

You can implement the same serializing methods as the built-in Wrapper class.

def get_config(self):
    config = {'layer': {'class_name': self.layer.__class__.__name__,
                        'config': self.layer.get_config()}}
    base_config = super(Wrapper, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

@classmethod
def from_config(cls, config, custom_objects=None):
    from . import deserialize as deserialize_layer
    layer = deserialize_layer(config.pop('layer'),
                              custom_objects=custom_objects)
    return cls(layer, **config)

During serialization, in get_config, the inner layer's class name and config are saved in config['layer'].

In from_config, the inner layer is deserialized with deserialize_layer using config['layer'].

Upvotes: 2

Related Questions