Reputation: 500
I'm trying to do a seq2seq model with and encoder LSTM and decoder LSTM, both with Bidirectional
layers.
I can pass the hidden state and memory cell forward to the decoder LSTM, but I can't see how I'd possibly pass the values back from the decoder to the encoder.
def sequence_model(total_words, emb_dimension, lstm_units):
# Encoder
encoder_input = Input(shape=(None,), name="Enc_Input")
x = Embedding(total_words, emb_dimension, input_length=max_sequence_length, name="Enc_Embedding")(encoder_input)
x, state_h, state_c, _, _ = Bidirectional(LSTM(lstm_units, return_state=True, name="Enc_LSTM1"), name="Enc_Bi1")(x) # pass hidden activation and memory cell states forward
encoder_states = [state_h, state_c] # package states to pass to decoder
# Decoder
decoder_input = Input(shape=(None,), name="Dec_Input")
x = Embedding(total_words, emb_dimension, name="Dec_Embedding")(decoder_input)
x = LSTM(lstm_units, return_sequences=True, name="Dec_LSTM1")(x, initial_state=encoder_states)
decoder_output = Dense(total_words, activation="softmax", name="Dec_Softmax")(x)
func_model = tf.keras.Model(inputs=[encoder_input,decoder_input], outputs=decoder_output)
return func_model
The forward states are passed to the initial_state
of the decoder LSTM layer. But if I wrap this Dec_LSTM1
layer with a Bidirectional
Layer, it doesn't like me passing the initial_state
value in and breaks.
Am I right in thinking I don't need the backwards states from the encoder LSTM layer?
Attached is an image of the architecture I'm trying to achieve.
Upvotes: 1
Views: 581
Reputation: 17613
Your code is breaking when you add Bidirectional
to your decoder because you have left out two elements from the encoder state.
x, state_h, state_c, _, _ = ...
# ^ ^
# -------------------|--|
An LSTM state has two tensors in it, of shape (batch, hidden)
; when you run your
LSTM in both directions, this will add two more states (the backward pass).
import tensorflow as tf
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import Bidirectional
from tensorflow.keras.layers import LSTM
from tensorflow.keras.layers import Dense
enc_in = Input(shape=(None,))
enc_x = Embedding(1024, 128, input_length=92)(enc_in)
# vanilla LSTM
s_enc_x, *s_enc_state = LSTM(256, return_state=True)(enc_x)
print(len(s_enc_state))
print(s_enc_state)
# 2
# [<KerasTensor: shape=(None, 256) dtype=float32 (created by layer 'lstm_7')>,
# <KerasTensor: shape=(None, 256) dtype=float32 (created by layer 'lstm_7')>]
# bi-directional LSTM
bi_enc_x, *bi_enc_state = Bidirectional(LSTM(256, return_state=True))(enc_x)
print(len(bi_enc_state))
print(bi_enc_state)
# 4
# [<KerasTensor: shape=(None, 256) dtype=float32 (created by layer 'bidirectional_6')>,
# <KerasTensor: shape=(None, 256) dtype=float32 (created by layer 'bidirectional_6')>,
# <KerasTensor: shape=(None, 256) dtype=float32 (created by layer 'bidirectional_6')>,
# <KerasTensor: shape=(None, 256) dtype=float32 (created by layer 'bidirectional_6')>]
# decoder
dec_in = Input(shape=(None,))
dec_x = Embedding(1024, 128, input_length=92)(dec_in)
dec_x = Bidirectional(LSTM(256, return_sequences=True))(
dec_x, initial_state=bi_enc_state) # <= use bidirectional state
output = Dense(1024, activation="softmax")(dec_x)
print(output.shape)
# TensorShape([None, None, 1024])
Upvotes: 2