Reputation: 721
Is there a way in Keras to cross-validate the early stopping metric being monitored EarlyStopping(monitor = 'val_acc', patience = 5)
? Before allowing training to proceed to the next epoch, could the model be cross-validated to get a more robust estimate of the test error? What I have found is that the early stopping metric, say the accuracy on a validation set, can suffer from high variance. Early-stopped models often do not perform nearly as well on unseen data, and I suspect this is because of the high variance associated with the validation set approach.
To minimize the variance in the early stopping metric, I would like to k-fold cross-validate the early stopping metric as the model trains from epoch i
to epoch i + 1
. I would like to take the model at epoch i
, divide the training data into 10 parts, learn on 9 parts, estimate the error on the remaining part, repeat so that all 10 parts have had a chance to be the validation set, and then proceed with training to epoch i + 1
with the full training data as usual. The average of the 10 error estimates will hopefully be a more robust metric that can be used for early stopping.
I have tried to write a custom metric function that includes k-fold cross-validation but I can't get it to work. Is there a way to cross-validate the early stopping metric being monitored, perhaps through a custom function inside the Keras model or a loop outside the Keras model?
Thanks!!
Upvotes: 6
Views: 1156
Reputation: 721
I imagine that using a callback as suggested by @VincentPakson would be cleaner and more efficient, but the level of programming required is beyond me. I was able to create a for loop to do what I wanted by:
Training a model for a single epoch and saving it using model.save()
.
Loading the saved model and training the model for a single epoch for each of the 10 folds (i.e. 10 models), then averaging the 10 validation set errors.
Loading the saved model and training for a single epoch using all of the training data, and the overwriting the saved model with this model.
Repeating steps 1-3 until the estimate from 2 stops improving for a given patience.
I'd love a better answer but this seems to work. Slowly.
Upvotes: 0
Reputation: 1949
Keras really does have a greatly customizable callback functionality as can be seen here.
If you are not satisfied with keras current EarlyStopping
function, which for me does the work of checking the validation loss during training, you could create a custom callback function. Also custom callback functions can be chained.
If your issue is accessing the model inside the callback, then self
is the variable you want to access as can be seen in this answer. I don't completely understand why you want to "redevelop" the model during test phase. BUT then with that, you could still use callbacks, after the EarlyStopping
callback, you could create another callback function that could "redevelop" the model.
If you want to access the models deeper variables you could use, Keras backend.
I hope I helped.
Upvotes: 3