Ximin Lin
Ximin Lin

Reputation: 79

How to have Keras model do early-stopping in different fit calls

Since the data dimension is big for my task, 32 samples would consume nearly 9% of memory in server, of which total free memory is about 105G. So I have to do consecutive calls to fit() in the loop. And I also want to do early-stopping with the consecutive calls to fit().

However, since the callback methods introduced in Keras documents only apply to one single fit() call.

How can I do early-stopping in this case?

Following is my code snippet:

for sen_batch, cls_batch in train_data_gen:

    sen_batch = np.array(sen_batch).reshape(-1, WORD_LENGTH, 50, 1)
    cls_batch = np.array(cls_batch)

    model.fit(x = sen_batch,y = cls_batch)

    num_iterations += 1

Upvotes: 1

Views: 1279

Answers (1)

Marcin Możejko
Marcin Możejko

Reputation: 40516

  1. Use fit_generator: as you have generator - you could use generator traning instead of classical fit. This method supports Callbacks so you could use keras.callbacks.EarlyStopping.

  2. When you cannot use fit_generator: So - first of all - you need to use train_on_batch method - as fit call resets many model states (e.g. optimizer states).

    train_on_batch method returns a loss value, but it doesn't accept callbacks. So you need to implement early stopping on your own. You can do it e.g. like this:

    from six import next
    
    patience = 4
    best_loss = 1e6
    rounds_without_improvement = 0
    
    for epoch_nb in range(nb_of_epochs):
        losses_list = list()
        for batch in range(nb_of_batches):
            x, y = next(train_data_gen)
            losses_list.append(model.train_on_batch(x, y))
        mean_loss = sum(losses_list) / len(losses_list)
    
        if mean_loss < best_loss:
            best_loss = mean_loss
            rounds_witout_improvement = 0
        else:
            rounds_without_improvement +=1
    
        if rounds_without_improvement == patience:
            break
    

Upvotes: 4

Related Questions