Belkacem Thiziri
Belkacem Thiziri

Reputation: 665

save model weights at the end of every N epochs

I'm training a NN and would like to save the model weights every N epochs for a prediction phase. I propose this draft code, it's inspired by @grovina 's response here. Could you, please, make suggestions? Thanks in advance.

from keras.callbacks import Callback

class WeightsSaver(Callback):
    def __init__(self, model, N):
        self.model = model
        self.N = N
        self.epoch = 0

    def on_batch_end(self, epoch, logs={}):
        if self.epoch % self.N == 0:
            name = 'weights%08d.h5' % self.epoch
            self.model.save_weights(name)
        self.epoch += 1

Then add it to the fit call: to save weights every 5 epochs:

model.fit(X_train, Y_train, callbacks=[WeightsSaver(model, 5)])

Upvotes: 13

Views: 17886

Answers (2)

umutto
umutto

Reputation: 7690

You shouldn't need to pass a model for the callback. It already has access to the model via it's super. So remove __init__(..., model, ...) argument and self.model = model. You should be able to access the current model via self.model regardless. You are also saving it on every batch end, which is not what you want, you probably want it to be on_epoch_end.

But in any case, what you are doing can be done via naive modelcheckpoint callback. You don't need to write a custom one. You can use that as follows;

mc = keras.callbacks.ModelCheckpoint('weights{epoch:08d}.h5', 
                                     save_weights_only=True, period=5)
model.fit(X_train, Y_train, callbacks=[mc])

Upvotes: 29

Mitiku
Mitiku

Reputation: 5412

You should implement on on_epoch_end rather implementing on_batch_end. And also passing model as argument for __init__ is redundant.

from keras.callbacks import Callback
class WeightsSaver(Callback):
  def __init__(self, N):
    self.N = N
    self.epoch = 0

  def on_epoch_end(self, epoch, logs={}):
    if self.epoch % self.N == 0:
      name = 'weights%08d.h5' % self.epoch
      self.model.save_weights(name)
    self.epoch += 1

Upvotes: 4

Related Questions