Reputation: 665
I'm training a NN and would like to save the model weights every N epochs for a prediction phase. I propose this draft code, it's inspired by @grovina 's response here. Could you, please, make suggestions? Thanks in advance.
from keras.callbacks import Callback
class WeightsSaver(Callback):
def __init__(self, model, N):
self.model = model
self.N = N
self.epoch = 0
def on_batch_end(self, epoch, logs={}):
if self.epoch % self.N == 0:
name = 'weights%08d.h5' % self.epoch
self.model.save_weights(name)
self.epoch += 1
Then add it to the fit call: to save weights every 5 epochs:
model.fit(X_train, Y_train, callbacks=[WeightsSaver(model, 5)])
Upvotes: 13
Views: 17886
Reputation: 7690
You shouldn't need to pass a model for the callback. It already has access to the model via it's super. So remove __init__(..., model, ...)
argument and self.model = model
. You should be able to access the current model via self.model
regardless. You are also saving it on every batch end, which is not what you want, you probably want it to be on_epoch_end
.
But in any case, what you are doing can be done via naive modelcheckpoint callback. You don't need to write a custom one. You can use that as follows;
mc = keras.callbacks.ModelCheckpoint('weights{epoch:08d}.h5',
save_weights_only=True, period=5)
model.fit(X_train, Y_train, callbacks=[mc])
Upvotes: 29
Reputation: 5412
You should implement on on_epoch_end rather implementing on_batch_end. And also passing model as argument for __init__
is redundant.
from keras.callbacks import Callback
class WeightsSaver(Callback):
def __init__(self, N):
self.N = N
self.epoch = 0
def on_epoch_end(self, epoch, logs={}):
if self.epoch % self.N == 0:
name = 'weights%08d.h5' % self.epoch
self.model.save_weights(name)
self.epoch += 1
Upvotes: 4