Soerendip
Soerendip

Reputation: 9148

How to do early stopping with tensorflow.models.Sequential()?

Using a sequential model generated like this:

def generate_model():
    model = Sequential()
    model.add(Conv1D(64, kernel_size=10, strides=1,
                     activation='relu', padding='same',
                     input_shape=(MAXLENGTH, NAMESPACELENGTH)))
    model.add(MaxPooling1D(pool_size=4, strides=2))
    model.add(Conv1D(32, 3, activation='relu', padding='same'))
    model.add(MaxPooling1D(pool_size=4))
    model.add(Flatten())
    model.add(Dense(10, activation='relu'))
    model.add(Dense(1, activation='linear'))
    model.compile(loss='mean_squared_error', 
                  optimizer='adam', metrics=['mean_squared_error'])
    return model

I want to do Kfold cross-validated modeling. So, I train K models in a loop:

models = []
for ndx_train, ndx_val in kfold.split(X, y):
    model = generate_model()
    N_train = len(ndx_train)
    X_batch = X[ndx_train]
    y_batch = y[ndx_train]
    model.fit(X_batch, y_batch, epochs=100, verbose=1, steps_per_epoch=10,
             validation_data=(X[ndx_val], y[ndx_val]), validation_steps=100)

    models.append(model)

Now, I can see when I want each model to stop by looking at the output. I.e. when the validation error increases again. Is it possible to do that easily with pure tf and with this higher level api setup? There is some suggestions using along the lines using tflearn here.

Upvotes: 0

Views: 887

Answers (1)

Dmytro Prylipko
Dmytro Prylipko

Reputation: 5064

By using EarlyStopping callback:

from tensorflow.keras.callbacks import EarlyStopping
callbacks = [
    EarlyStopping(monitor='val_mean_squared_error', patience=2, verbose=1),
]
model.fit(..., callbacks=callbacks)

Upvotes: 2

Related Questions