Reputation: 43
I've been working with a recurrent neural network implementation with the Keras framework and, when building the model i've had some problems.
Keras 2.2.4
Tensorflow 1.14.0
My model consists in only three layers: Embeddings, Recurrent and a Dense layer. It currently looks like this:
model = Sequential()
model.add(Embedding(input_dim=vocab_size, output_dim= EMBEDDING_DIM, input_length= W_SIZE))
if MODEL == 'GRU':
model.add(CuDNNGRU(NUM_UNITS))
if MODEL == 'RNN':
model.add(SimpleRNN(NUM_UNITS))
if MODEL == 'LSTM':
model.add(CuDNNLSTM(NUM_UNITS))
model.add(Dense(vocab_size, activation='softmax'))
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['acc'])
What I'm trying to do this is add the return_state=True
to the recurrent layers in order to get those states when I use the model.predict()
function but, when i add it, I get the following error:
TypeError: All layers in a Sequential model should have a single output tensor. For multi-output layers, use the functional API.
I've tried using the TimeDistributed wrapper layer around the Dense layer, but it didn't change anything.
Thanks in advance!
Upvotes: 0
Views: 611
Reputation: 11333
Sequential API is designed for straight-forward models that goes like a chain. That is, output of one layer connected to the next and so on.
So if you want multiple outputs out, you need the Keras Functional API.
from tensorflow.keras import layers, models
inp = layers.Input(shape=(n_timesteps,))
out = layers.Embedding(input_dim=vocab_size, output_dim= EMBEDDING_DIM, input_length= n_timesteps)(inp)
if MODEL == 'GRU':
out, state = layers.CuDNNGRU(NUM_UNITS, return_state=True)(out)
if MODEL == 'RNN':
out, state = layers.SimpleRNN(NUM_UNITS, return_state=True)(out)
if MODEL == 'LSTM':
out, state = layers.CuDNNLSTM(NUM_UNITS, return_state=True)(out)
out = layers.Dense(vocab_size, activation='softmax')(out)
model = models.Model(inputs=inp, outputs=[out, state])
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['acc'])
model.summary()
Upvotes: 2