Reputation: 1053
I have an LSTM whose output is the same across the board. How could I go about fixing this? The following are the parameters. I'd love to have a general answer as this would help me learn the solution in case I see it again.
batch_size = 32
X_train.shape, Y_train.shape, X_test.shape, Y_test.shape
>((1920, 30, 5), (1920, 6), (1696, 30, 5), (1696, 6))
data_dim = X_train.shape[2]
timesteps = X_train.shape[1]
# Expected input batch shape: (batch_size, timesteps, data_dim)
# Note that we have to provide the full batch_input_shape since the network is stateful.
# the sample of index i in batch k is the follow-up for the sample i in batch k-1.
model = Sequential()
model.add(LSTM(32,
return_sequences=True,
stateful=True,
kernel_regularizer=regularizers.l2(0.0001),
batch_input_shape=(batch_size, timesteps, data_dim)))
model.add(Dropout(0.4))
model.add(LSTM(32, return_sequences=True,
kernel_regularizer=regularizers.l2(0.0001),
stateful=True))
model.add(Dropout(0.4))
model.add(LSTM(32, stateful=True))
model.add(Dropout(0.4))
model.add(Dense(6, activation='softmax', use_bias=True))
rms = RMSprop(lr=0.001)
model.compile(loss='categorical_crossentropy',
optimizer=rms,
metrics=['accuracy'])
history = model.fit(X_train, Y_train,
batch_size=batch_size,
epochs=5,
shuffle=False,
validation_data=(X_test, Y_test))
After training, I get the following output:
0b 1b 2b 3b 4b 5b
2017-06-30 0.077203 0.180573 0.314528 0.287455 0.110213 0.030026
2017-07-03 0.077225 0.180570 0.314542 0.287430 0.110204 0.030029
2017-07-04 0.077220 0.180586 0.314541 0.287423 0.110207 0.030023
2017-07-05 0.077193 0.180622 0.314523 0.287426 0.110221 0.030015
2017-07-06 0.077125 0.180695 0.314496 0.287435 0.110257 0.029992
They're all very very similar :(
EDIT: Forgot to mention I used the sklearn MinMaxScaler and scaled the data to (-7,7) as that seemed to have worked in the past. Is this a right approach?
Upvotes: 4
Views: 3985
Reputation: 4537
Don't worry; it's a common problem, to solve it you have to find optimum parameters for your network.
Unfortunately, I can't tell you how to fix your aNN but here are some ideas, which you can try:
Upvotes: 3