tenticon
tenticon

Reputation: 2933

keras - keyword initial_state not understood

I am trying to implement a Seq-2-Seq model in keras but I am having trouble resolving an error from SimpleRNN:

TypeError: ('Keyword argument not understood:', 'initial_state')

Here is a small example:

from keras.models import Model
from keras.layers import Input, SimpleRNN, Embedding

encoder_input = Input(shape=(16,))
encoder_embedding = Embedding(input_dim=40, output_dim=12)(encoder_input)
encoder_rnn_out, encoder_rnn_state = SimpleRNN(32, activation='relu', return_sequences=False, return_state=True)(encoder_embedding)

decoder_input = Input(shape=(11,)) 
decoder_embedding = Embedding(input_dim=12, output_dim=12)(decoder_input) 
decoder_rnn = SimpleRNN(32, activation='relu', initial_state=encoder_rnn_state, return_sequences=True)(decoder_embedding) 
decoder_predictions = Dense(12, activation='softmax')(decoder_rnn)

model = Model(encoder_input, decoder_predictions)

These are my tensorflow and keras versions (I have already uninstalled and reinstalled them using pip)

$ conda list -n py36 | grep tensorflow
tensorflow                1.13.1                    <pip>
tensorflow-estimator      1.13.0                    <pip>
$ conda list -n py36 | grep Keras
Keras                     2.2.4                     <pip>
Keras-Applications        1.0.7                     <pip>
Keras-Preprocessing       1.0.9                     <pip>

My ~/.keras/keras.json

{
    "epsilon": 1e-07,
    "floatx": "float32",
    "backend": "tensorflow"
}

Upvotes: 1

Views: 1184

Answers (2)

LegenDUST
LegenDUST

Reputation: 142

I had same problem, and found answer.

Change this

decoder_rnn = SimpleRNN(32, activation='relu', initial_state=encoder_rnn_state, return_sequences=True)(decoder_embedding)

to this.

decoder_rnn = SimpleRNN(32, activation='relu', return_sequences=True)(decoder_embedding, initial_state=encoder_rnn_state)

Upvotes: 1

Peter
Peter

Reputation: 658

The SimpleRNN constructor does not take initial_state as input argument. You probably meant to use the kernel_initializer or the recurrent_initializer argument instead.

See https://keras.io/layers/recurrent/.

Upvotes: 0

Related Questions