Reputation: 2104
I want to run an LSTM over a few different sequences on every batch and then join the last outputs. Here is what I've been trying:
from keras.layers import Dense, Input, LSTM, Embedding, TimeDistributed
num_sentences = 4
num_features = 3
num_time_steps = 5
inputs = Input([num_sentences, num_time_steps])
emb_layer = Embedding(10, num_features)
embedded = emb_layer(inputs)
lstm_layer = LSTM(4)
shape = [num_sentences, num_time_steps, num_features]
lstm_outputs = TimeDistributed(lstm_layer, input_shape=shape)(embedded)
This is giving me the following error:
Traceback (most recent call last):
File "test.py", line 12, in <module>
lstm_outputs = TimeDistributed(lstm_layer, input_shape=shape)(embedded)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/keras/engine/topology.py", line 546, in __call__
self.build(input_shapes[0])
File "/Users/erick/anaconda2/lib/python2.7/site-packages/keras/layers/wrappers.py", line 94, in build
self.layer.build(child_input_shape)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/keras/layers/recurrent.py", line 702, in build
self.input_dim = input_shape[2]
IndexError: tuple index out of range
I tried omitting the input_shape
argument in TimeDistributed
, but it didn't change anything.
Upvotes: 0
Views: 833
Reputation: 2104
After trying michetonu's answer and having the same error, I realized my version of keras might be outdated. Indeed, was running keras 1.2, and the code ran fine on 2.0.
Upvotes: 0
Reputation: 4348
The input_shape
needs to be an argument of the LSTM layer, not TimeDistributed (which is a wrapper). By omitting it everything works fine for me:
from keras.layers import Dense, Input, LSTM, Embedding, TimeDistributed
num_sentences = 4
num_features = 3
num_time_steps = 5
inputs = Input([num_sentences, num_time_steps])
emb_layer = Embedding(10, num_features)
embedded = emb_layer(inputs)
lstm_layer = LSTM(4)
shape = [num_sentences, num_time_steps, num_features]
lstm_outputs = TimeDistributed(lstm_layer)(embedded)
#OUTPUT:
Using TensorFlow backend.
[Finished in 1.5s]
Upvotes: 1