Reputation: 1035
I want to call a callback after n epochs, but always in the last epoch of training. Here explains how I can call the callback after n epochs.
At the moment I am using the following approach:
class MyCallBack(keras.callbacks.Callback):
def on_epoch_end(self, epoch, log=None)
if epoch % 10 == 0: # <- add additional condition here
self._do_the_stuff()
def _do_the_stuff(self):
print('Do the stuff')
def on_training_end(self, logs=None):
self._do_the_stuff()
Is there a simpler way where I add an additional condition to the if statement inside on_epoch_end
and don't need on_training_end
?
Upvotes: 0
Views: 793
Reputation: 1035
As suggested by @Ewran in the comments above, it is possible to access the total number of epochs by `self.params['epochs'].
class MyCallBack(keras.callbacks.Callback):
def on_epoch_end(self, epoch, log=None)
if epoch % self.epoch_freq == 0 or epoch == self.params.get('epochs', -1):
self._do_the_stuff()
def _do_the_stuff(self):
print('Do the stuff')
def on_training_end(self, logs=None):
self._do_the_stuff()
If other callbacks such as tf.keras.callbacks.EarlyStopping
are used, I would continue to use the approach with on_train_end
. Otherwise it is not guaranteed that the callback is called after the last epoch.
Upvotes: 1