Reputation: 79
Since the data dimension is big for my task, 32 samples would consume nearly 9% of memory in server, of which total free memory is about 105G. So I have to do consecutive calls to fit() in the loop. And I also want to do early-stopping with the consecutive calls to fit().
However, since the callback methods introduced in Keras documents only apply to one single fit() call.
How can I do early-stopping in this case?
Following is my code snippet:
for sen_batch, cls_batch in train_data_gen:
sen_batch = np.array(sen_batch).reshape(-1, WORD_LENGTH, 50, 1)
cls_batch = np.array(cls_batch)
model.fit(x = sen_batch,y = cls_batch)
num_iterations += 1
Upvotes: 1
Views: 1279
Reputation: 40516
Use fit_generator
: as you have generator - you could use generator traning instead of classical fit
. This method supports Callbacks
so you could use keras.callbacks.EarlyStopping
.
When you cannot use fit_generator
:
So - first of all - you need to use train_on_batch
method - as fit
call resets many model states (e.g. optimizer states).
train_on_batch
method returns a loss value, but it doesn't accept callbacks. So you need to implement early stopping
on your own. You can do it e.g. like this:
from six import next
patience = 4
best_loss = 1e6
rounds_without_improvement = 0
for epoch_nb in range(nb_of_epochs):
losses_list = list()
for batch in range(nb_of_batches):
x, y = next(train_data_gen)
losses_list.append(model.train_on_batch(x, y))
mean_loss = sum(losses_list) / len(losses_list)
if mean_loss < best_loss:
best_loss = mean_loss
rounds_witout_improvement = 0
else:
rounds_without_improvement +=1
if rounds_without_improvement == patience:
break
Upvotes: 4