manv
manv

Reputation: 138

How can I modify ModelCheckPoint in keras to monitor both val_acc and val_loss and save accordingly the best model?

ModelCheckPoint gives options to save both for val_Acc and val_loss separately. I want to modify this in a way so that if val_acc is improving -> save model. if val_acc is equal to previous best val_acc then check for val_loss, if val_loss is less than previous best val_loss then save the model.

    if val_acc(epoch i)> best_val_acc:
        save model
    else if val_acc(epoch i) == best_val_acc:
        if val_loss(epoch i) < best_val_loss:
           save model
        else
           do not save model

Upvotes: 4

Views: 9157

Answers (3)

Daniel M&#246;ller
Daniel M&#246;ller

Reputation: 86600

You can just add two callbacks:

callbacks = [ModelCheckpoint(filepathAcc, monitor='val_acc', ...),
             ModelCheckpoint(filepathLoss, monitor='val_loss', ...)]

model.fit(......., callbacks=callbacks)

Using custom callbacks

You can do anything you want in a LambdaCallback(on_epoch_end=saveModel).

best_val_acc = 0
best_val_loss = sys.float_info.max 

def saveModel(epoch,logs):
    val_acc = logs['val_acc']
    val_loss = logs['val_loss']

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        model.save(...)
    elif val_acc == best_val_acc:
        if val_loss < best_val_loss:
            best_val_loss=val_loss
            model.save(...)
    
callbacks = [LambdaCallback(on_epoch_end=saveModel)]

But this is nothing different from a single ModelCheckpoint with val_acc. You won't really be getting identical accuracies unless you're using very few samples, or you have a custom accuracy that doesn't vary much.

Upvotes: 11

prosti
prosti

Reputation: 46331

Check out ModelCheckPoint in here. model.fit() method takes as a parameter the callback list. Make sure you have something like:

model.fit(..., callbacks=[mcp] ) where mcp = ModelCheckPoint() as defined.

Note: You may have multiple callbacks in the callback list.

For clarity I am adding some details but effectively this will do the same as model.save() function:

class ModelCheckpoint(Callback):
    """Save the model after every epoch.
    `filepath` can contain named formatting options,
    which will be filled the value of `epoch` and
    keys in `logs` (passed in `on_epoch_end`).
    For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
    then the model checkpoints will be saved with the epoch number and
    the validation loss in the filename.
    # Arguments
        filepath: string, path to save the model file.
        monitor: quantity to monitor.
        verbose: verbosity mode, 0 or 1.
        save_best_only: if `save_best_only=True`,
            the latest best model according to
            the quantity monitored will not be overwritten.
        mode: one of {auto, min, max}.
            If `save_best_only=True`, the decision
            to overwrite the current save file is made
            based on either the maximization or the
            minimization of the monitored quantity. For `val_acc`,
            this should be `max`, for `val_loss` this should
            be `min`, etc. In `auto` mode, the direction is
            automatically inferred from the name of the monitored quantity.
        save_weights_only: if True, then only the model's weights will be
            saved (`model.save_weights(filepath)`), else the full model
            is saved (`model.save(filepath)`).
        period: Interval (number of epochs) between checkpoints.
    """

Upvotes: -1

Vincent Pakson
Vincent Pakson

Reputation: 1949

You can actually check in their documentation!

to save you some time though, the callback, ModelCheckpoint accepts an argument called save_best_only which does what you want to happen, just set it to True. here's the link of the documentation

I misunderstood you're question. I guess if you want a more complex type of callback you could always use the base Callback function, which gives you more power since you could access both parmas and model. Check the docu out. You can start by testing it out and printing the params and determine which one you'd want to take note of.

Upvotes: -1

Related Questions