Hamza Yerlikaya
Hamza Yerlikaya

Reputation: 49329

Keras LSTM - Input shape for time series prediction

I am trying to predict the output of a function. (Eventually it will be multi input multi output) but for now just to get the mechanics right I am trying to predict the output of sin function. My dataset is as follows,

    t0          t1
0   0.000000    0.125333
1   0.125333    0.248690
2   0.248690    0.368125
3   0.368125    0.481754
4   0.481754    0.587785
5   0.587785    0.684547
6   0.684547    0.770513
7   0.770513    0.844328
8   0.844328    0.904827
9   0.904827    0.951057
.....

Total of 100 values. t0 is the current input t1 is the next output I want to predict. Then data is split into train/test via scikit,

x_train, x_test, y_train, y_test = train_test_split(wave["t0"].values, wave["t1"].values, test_size=0.20)

Problem happens in fit, I get an error that says input wrong dimensions.

model = Sequential()
model.add(LSTM(128, input_shape=??? ,stateful=True))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')

model.fit(x_train, y_train, 
          batch_size=10, epochs=100,
          validation_data=(x_test, y_test))

I've tried other questions on the site to fix the problem but no matter what i try i can not get keras to recognize correct input.

Upvotes: 1

Views: 3605

Answers (1)

Manoj Mohan
Manoj Mohan

Reputation: 6044

The LSTM expects the input data to be of shape (batch_size, time_steps, num_features). In sine-wave prediction, the num_features is 1, the time_steps is how many previous time-points the LSTM should use for prediction. In the example below, batch size is 1, time_steps is 2 and num_features is 1.

x_train = np.ones((1,2,1)) 
y_train = np.ones((1,1))

x_test = np.ones((1,2,1))
y_test = np.ones((1,1))

model = Sequential()
model.add(LSTM(128, input_shape=(2,1)))
#for stateful
#model.add(LSTM(128, batch_input_shape=(1,2,1), stateful=True))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')

model.fit(x_train, y_train,
          batch_size=1, epochs=100,
          validation_data=(x_test, y_test))

Upvotes: 3

Related Questions