Reputation: 1509
I'm trying to implement encoder-decoder type network in Keras, with Bidirectional GRUs.
The following code seems to be working
src_input = Input(shape=(5,))
ref_input = Input(shape=(5,))
src_embedding = Embedding(output_dim=300, input_dim=vocab_size)(src_input)
ref_embedding = Embedding(output_dim=300, input_dim=vocab_size)(ref_input)
encoder = Bidirectional(
GRU(2, return_sequences=True, return_state=True)
)(src_embedding)
decoder = GRU(2, return_sequences=True)(ref_embedding, initial_state=encoder[1])
But when I change the decode to use Bidirectional
wrapper, it stops showing encoder
and src_input
layers in the model.summary()
. The new decoder looks like:
decoder = Bidirectional(
GRU(2, return_sequences=True)
)(ref_embedding, initial_state=encoder[1:])
The output of model.summary()
with the Bidirectional decoder.
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) (None, 5) 0
_________________________________________________________________
embedding_2 (Embedding) (None, 5, 300) 6610500
_________________________________________________________________
bidirectional_2 (Bidirection (None, 5, 4) 3636
=================================================================
Total params: 6,614,136
Trainable params: 6,614,136
Non-trainable params: 0
_________________________________________________________________
Question: Am I missing something when I pass initial_state
in Bidirectional
decoder? How can I fix this? Is there any other way to make this work?
Upvotes: 5
Views: 3917
Reputation: 14619
It's a bug. The RNN
layer implements __call__
so that tensors in initial_state
can be collected into a model instance. However, the Bidirectional
wrapper did not implement it. So topological information about the initial_state
tensors is missing and some strange bugs happen.
I wasn't aware of it when I was implementing initial_state
for Bidirectional
. It should be fixed now, after this PR. You can install the latest master branch on GitHub to fix it.
Upvotes: 1