omatai
omatai

Reputation: 3728

Cannot load saved Keras model due to use of "lambda"

I have a simple Keras network that makes use of a custom activation function defined as a lambda:

from tensorflow.keras.activations import relu
lrelu = lambda x: relu( x, alpha=0.01 )
model = Sequential
model.add(Dense( 10, activation=lrelu, input_dim=12 ))
...

It compiles, trains, tests fine (code omitted) and I can save it ok using model.save( 'model.h5' ). But when I try to load it using loaded = tf.keras.models.load_model( 'model.h5', custom_objects={'lrelu' : lrelu}), and despite defining lrelu exactly as shown above, it complains:

ValueError: Unknown activation function:<lambda>

Wait a minute: isn't lambda a python keyword? I'm not about to re-define python so I can load a model - where would it end? How do I overcome this? What do I need to specify as my custom_objects?

According to the TF Keras guide to saving and loading with custom objects and functions...

Custom-defined functions (e.g. activation loss or initialization) do not need a get_config method. The function name is sufficient for loading as long as it is registered as a custom object.

It seems to me that is exactly what I've done. Could it be that this only applies to functions defined with def and not to lambda functions?

Upvotes: 2

Views: 1157

Answers (2)

Marco Cerliani
Marco Cerliani

Reputation: 22031

this is another approach to wrap your activation function

model = Sequential()
model.add(Dense( 10, input_dim=12 ))
model.add(Lambda( lambda x: tf.keras.activations.relu( x, alpha=0.01 ) ))

this is the same concept of doing model.add(Activation('...')) but with a custom modified activation

for saving and loading:

model.save( 'model.h5' )
loaded = tf.keras.models.load_model( 'model.h5' )

I have no problem saving and loading the model using it https://colab.research.google.com/drive/1K-4_nt66AH5PQDv9Fn-l69-eu5S6Y5EU?usp=sharing

Upvotes: 1

AKX
AKX

Reputation: 169032

Lambdas don't have a valid name attribute Keras could introspect, so it gets confused during serialization. Use a named function instead.

from tensorflow.keras.activations import relu

def lrelu(x):
   return relu(x, alpha=0.01)

model = Sequential()
model.add(Dense( 10, activation=lrelu, input_dim=12 ))

To wit:

>>> lrelu1 = lambda x: 0
>>> def lrelu2(x):
...   return 0
...
>>> lrelu1.__name__
'<lambda>'
>>> lrelu2.__name__
'lrelu2'
>>>

Upvotes: 4

Related Questions