Reputation: 417
So I have a custom layer in Keras that uses a Mask in it.
To get it to work with save/load I need to serialize the Mask correctly. So this standard code doesn't work:
def get_config(self):
config = {'mask': self.mask}
base_config = super(Mixing, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
where mask is a reference to the Masking Layer.
I'm not sure how to serialize Masking (or Keras Layers in general). Can anyone help?
Upvotes: 4
Views: 573
Reputation: 14619
You can implement the same serializing methods as the built-in Wrapper
class.
def get_config(self):
config = {'layer': {'class_name': self.layer.__class__.__name__,
'config': self.layer.get_config()}}
base_config = super(Wrapper, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
from . import deserialize as deserialize_layer
layer = deserialize_layer(config.pop('layer'),
custom_objects=custom_objects)
return cls(layer, **config)
During serialization, in get_config
, the inner layer's class name and config are saved in config['layer']
.
In from_config
, the inner layer is deserialized with deserialize_layer
using config['layer']
.
Upvotes: 2