Reputation: 3275
I'm trying to train an LSTM for some a binary classification problem. When I plot loss
curve after the training, there are strange picks in it. Here are some examples:
Here is the basic code
model = Sequential()
model.add(recurrent.LSTM(128, input_shape = (columnCount,1), return_sequences=True))
model.add(Dropout(0.5))
model.add(recurrent.LSTM(128, return_sequences=False))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
new_train = X_train[..., newaxis]
history = model.fit(new_train, y_train, nb_epoch=500, batch_size=100,
callbacks = [EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=2, verbose=0, mode='auto'),
ModelCheckpoint(filepath="model.h5", verbose=0, save_best_only=True)],
validation_split=0.1)
# list all data in history
print(history.history.keys())
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
I don't understand why do that picks occur? Any ideas?
Upvotes: 5
Views: 3632
Reputation: 11
This question is old, but I've seen this happen before when re-starting training from a checkpoint. If the spike corresponded to a break in training, you may be inadvertently resetting some of the weights.
Upvotes: 0
Reputation: 40516
There are many possibilities why something like this occurs:
Your parameters trajectory changed its basin of attraction - this means that your system left a stable trajectory and switched to another one. This was probably due to randomization like e.g. batch sampling or dropout.
LSTM instability- LSTMs are believed to be extremely unstable in terms of training. It was also reported that very often it's really time consuming for them to stabilize.
Due to the latest research (e.g. from here) I would recommend you decreasing the batch size and leaving it for more epochs. I would also try to check if e.g. topology of a network is not to complexed (or plain) in terms of amount of patterns it need to learn. I would also try switch to either GRU
or SimpleRNN
.
Upvotes: 8