Reputation: 679
Setting
As already mentioned in the title, I got a problem with my custom loss function, when trying to load the saved model. My loss looks as follows:
def weighted_cross_entropy(weights):
weights = K.variable(weights)
def loss(y_true, y_pred):
y_pred = K.clip(y_pred, K.epsilon(), 1-K.epsilon())
loss = y_true * K.log(y_pred) * weights
loss = -K.sum(loss, -1)
return loss
return loss
weighted_loss = weighted_cross_entropy([0.1,0.9])
So during training, I used the weighted_loss
function as loss function and everything worked well. When training is finished I save the model as .h5
file with the standard model.save
function from keras API.
Problem
When I am trying to load the model via
model = load_model(path,custom_objects={"weighted_loss":weighted_loss})
I am getting a ValueError
telling me that the loss is unknown.
Error
The error message looks as follows:
File "...\predict.py", line 29, in my_script
"weighted_loss": weighted_loss})
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\saving.py", line 419, in load_model
model = _deserialize_model(f, custom_objects, compile)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\saving.py", line 312, in _deserialize_model
sample_weight_mode=sample_weight_mode)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\training.py", line 139, in compile
loss_function = losses.get(loss)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\losses.py", line 133, in get
return deserialize(identifier)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\losses.py", line 114, in deserialize
printable_module_name='loss function')
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\utils\generic_utils.py", line 165, in deserialize_keras_object
':' + function_name)
ValueError: Unknown loss function:loss
Questions
How can I fix this problem? May it be possible that the reason for that is my wrapped loss definition? So keras
doesn't know, how to handle the weights
variable?
Upvotes: 13
Views: 22950
Reputation: 476
For full examples demonstrating saving and loading Keras models with custom loss functions or models, please have a look at the following GitHub gist files:
Custom loss function defined using a wrapper: https://gist.github.com/ashkan-abbasi66/a81fe4c4d588e2c187180d5bae734fde
Custom loss function defined by subclassing: https://gist.github.com/ashkan-abbasi66/327efe2dffcf9788847d26de934ef7bd
Custom model: https://gist.github.com/ashkan-abbasi66/d5a525d33600b220fa7b095f7762cb5b
Note: I tested the above examples on Python 3.8 with Tensorflow 2.5.
Upvotes: -1
Reputation: 33410
Your loss function's name is loss
(i.e. def loss(y_true, y_pred):
). Therefore, when loading back the model you need to specify 'loss'
as its name:
model = load_model(path, custom_objects={'loss': weighted_loss})
Upvotes: 17