Reputation: 138
ModelCheckPoint
gives options to save both for val_Acc
and val_loss
separately.
I want to modify this in a way so that if val_acc
is improving -> save model. if val_acc
is equal to previous best val_acc
then check for val_loss
, if val_loss
is less than previous best val_loss
then save the model.
if val_acc(epoch i)> best_val_acc:
save model
else if val_acc(epoch i) == best_val_acc:
if val_loss(epoch i) < best_val_loss:
save model
else
do not save model
Upvotes: 4
Views: 9157
Reputation: 86600
You can just add two callbacks:
callbacks = [ModelCheckpoint(filepathAcc, monitor='val_acc', ...),
ModelCheckpoint(filepathLoss, monitor='val_loss', ...)]
model.fit(......., callbacks=callbacks)
You can do anything you want in a LambdaCallback(on_epoch_end=saveModel)
.
best_val_acc = 0
best_val_loss = sys.float_info.max
def saveModel(epoch,logs):
val_acc = logs['val_acc']
val_loss = logs['val_loss']
if val_acc > best_val_acc:
best_val_acc = val_acc
model.save(...)
elif val_acc == best_val_acc:
if val_loss < best_val_loss:
best_val_loss=val_loss
model.save(...)
callbacks = [LambdaCallback(on_epoch_end=saveModel)]
But this is nothing different from a single ModelCheckpoint
with val_acc
. You won't really be getting identical accuracies unless you're using very few samples, or you have a custom accuracy that doesn't vary much.
Upvotes: 11
Reputation: 46331
Check out ModelCheckPoint in here.
model.fit()
method takes as a parameter the callback list. Make sure you have something like:
model.fit(..., callbacks=[mcp] )
where mcp = ModelCheckPoint()
as defined.
Note: You may have multiple callbacks in the callback list.
For clarity I am adding some details but effectively this will do the same as model.save() function:
class ModelCheckpoint(Callback):
"""Save the model after every epoch.
`filepath` can contain named formatting options,
which will be filled the value of `epoch` and
keys in `logs` (passed in `on_epoch_end`).
For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
then the model checkpoints will be saved with the epoch number and
the validation loss in the filename.
# Arguments
filepath: string, path to save the model file.
monitor: quantity to monitor.
verbose: verbosity mode, 0 or 1.
save_best_only: if `save_best_only=True`,
the latest best model according to
the quantity monitored will not be overwritten.
mode: one of {auto, min, max}.
If `save_best_only=True`, the decision
to overwrite the current save file is made
based on either the maximization or the
minimization of the monitored quantity. For `val_acc`,
this should be `max`, for `val_loss` this should
be `min`, etc. In `auto` mode, the direction is
automatically inferred from the name of the monitored quantity.
save_weights_only: if True, then only the model's weights will be
saved (`model.save_weights(filepath)`), else the full model
is saved (`model.save(filepath)`).
period: Interval (number of epochs) between checkpoints.
"""
Upvotes: -1
Reputation: 1949
You can actually check in their documentation!
to save you some time though, the callback, ModelCheckpoint
accepts an argument called save_best_only
which does what you want to happen, just set it to True
. here's the link of the documentation
I misunderstood you're question. I guess if you want a more complex type of callback you could always use the base Callback
function, which gives you more power since you could access both parmas
and model
. Check the docu out. You can start by testing it out and printing the params and determine which one you'd want to take note of.
Upvotes: -1