nabroyan
nabroyan

Reputation: 3275

Strange loss curve while training LSTM with Keras

I'm trying to train an LSTM for some a binary classification problem. When I plot loss curve after the training, there are strange picks in it. Here are some examples:

enter image description here

enter image description here

Here is the basic code

model = Sequential()
model.add(recurrent.LSTM(128, input_shape = (columnCount,1), return_sequences=True))
model.add(Dropout(0.5))
model.add(recurrent.LSTM(128, return_sequences=False))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.compile(optimizer='adam', 
             loss='binary_crossentropy', 
             metrics=['accuracy'])

new_train = X_train[..., newaxis]

history = model.fit(new_train, y_train, nb_epoch=500, batch_size=100, 
                    callbacks = [EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=2, verbose=0, mode='auto'), 
                                 ModelCheckpoint(filepath="model.h5", verbose=0, save_best_only=True)],
                    validation_split=0.1)

# list all data in history
print(history.history.keys())
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

I don't understand why do that picks occur? Any ideas?

Upvotes: 5

Views: 3632

Answers (2)

stu
stu

Reputation: 11

This question is old, but I've seen this happen before when re-starting training from a checkpoint. If the spike corresponded to a break in training, you may be inadvertently resetting some of the weights.

Upvotes: 0

Marcin Możejko
Marcin Możejko

Reputation: 40516

There are many possibilities why something like this occurs:

  1. Your parameters trajectory changed its basin of attraction - this means that your system left a stable trajectory and switched to another one. This was probably due to randomization like e.g. batch sampling or dropout.

  2. LSTM instability- LSTMs are believed to be extremely unstable in terms of training. It was also reported that very often it's really time consuming for them to stabilize.

Due to the latest research (e.g. from here) I would recommend you decreasing the batch size and leaving it for more epochs. I would also try to check if e.g. topology of a network is not to complexed (or plain) in terms of amount of patterns it need to learn. I would also try switch to either GRU or SimpleRNN.

Upvotes: 8

Related Questions