Reputation: 43491
This was asked a few times (here and here for instance).
My model is:
model = Sequential()
model.add(LSTM(128, input_shape=(10, VECTOR_SIZE), return_sequences=True))
model.add(TimeDistributed(Dense(VECTOR_SIZE, activation='linear')))
model.compile(loss='mean_squared_error', optimizer='rmsprop')
Which works well.
When I try to stack it:
model = Sequential()
model.add(LSTM(128, input_shape=(10, VECTOR_SIZE), return_sequences=True))
model.add(LSTM(32))
model.add(TimeDistributed(Dense(VECTOR_SIZE, activation='linear')))
model.compile(loss='mean_squared_error', optimizer='rmsprop')
I get an error:
Traceback (most recent call last):
File "train_tf.py", line 112, in <module>
main()
File "train_tf.py", line 89, in main
model.add(TimeDistributed(Dense(VECTOR_SIZE, activation='linear')))
File "/Users/shamoon/.local/share/virtualenvs/pytorch-lstm-audio-Pq4zK81J/lib/python3.6/site-packages/keras/engine/sequential.py", line 182, in add
output_tensor = layer(self.outputs[0])
File "/Users/shamoon/.local/share/virtualenvs/pytorch-lstm-audio-Pq4zK81J/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 75, in symbolic_fn_wrapper
return func(*args, **kwargs)
File "/Users/shamoon/.local/share/virtualenvs/pytorch-lstm-audio-Pq4zK81J/lib/python3.6/site-packages/keras/engine/base_layer.py", line 463, in __call__
self.build(unpack_singleton(input_shapes))
File "/Users/shamoon/.local/share/virtualenvs/pytorch-lstm-audio-Pq4zK81J/lib/python3.6/site-packages/keras/layers/wrappers.py", line 197, in build
assert len(input_shape) >= 3
AssertionError
The previous answers seem to make it seem so easy, but the reality is different.
Upvotes: 0
Views: 64
Reputation: 56347
You need to set return_sequences=True
for all recurrent layers in the stack, except the last one, so in your example you should do:
model = Sequential()
model.add(LSTM(128, input_shape=(10, VECTOR_SIZE), return_sequences=True))
model.add(LSTM(32, return_sequences=True))
model.add(TimeDistributed(Dense(VECTOR_SIZE, activation='linear')))
Upvotes: 1