gazm2k5
gazm2k5

Reputation: 500

How to pass Bidirectional LSTM state to earlier LSTM layer?

I'm trying to do a seq2seq model with and encoder LSTM and decoder LSTM, both with Bidirectional layers.

I can pass the hidden state and memory cell forward to the decoder LSTM, but I can't see how I'd possibly pass the values back from the decoder to the encoder.

def sequence_model(total_words, emb_dimension, lstm_units):
    # Encoder
    encoder_input = Input(shape=(None,), name="Enc_Input")
    x = Embedding(total_words, emb_dimension, input_length=max_sequence_length, name="Enc_Embedding")(encoder_input)
    x, state_h, state_c, _, _ = Bidirectional(LSTM(lstm_units, return_state=True, name="Enc_LSTM1"), name="Enc_Bi1")(x) # pass hidden activation and memory cell states forward
    encoder_states = [state_h, state_c] # package states to pass to decoder
    
    # Decoder
    decoder_input = Input(shape=(None,), name="Dec_Input")
    x = Embedding(total_words, emb_dimension, name="Dec_Embedding")(decoder_input)
    x = LSTM(lstm_units, return_sequences=True, name="Dec_LSTM1")(x, initial_state=encoder_states)
    decoder_output = Dense(total_words, activation="softmax", name="Dec_Softmax")(x)

    func_model = tf.keras.Model(inputs=[encoder_input,decoder_input], outputs=decoder_output)
    return func_model

The forward states are passed to the initial_state of the decoder LSTM layer. But if I wrap this Dec_LSTM1 layer with a Bidirectional Layer, it doesn't like me passing the initial_state value in and breaks.

Am I right in thinking I don't need the backwards states from the encoder LSTM layer?

Attached is an image of the architecture I'm trying to achieve.

enter image description here

Upvotes: 1

Views: 581

Answers (1)

o-90
o-90

Reputation: 17613

Your code is breaking when you add Bidirectional to your decoder because you have left out two elements from the encoder state.

x, state_h, state_c, _, _ = ...
#                    ^  ^
# -------------------|--|

An LSTM state has two tensors in it, of shape (batch, hidden); when you run your LSTM in both directions, this will add two more states (the backward pass).

import tensorflow as tf

from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import Bidirectional
from tensorflow.keras.layers import LSTM
from tensorflow.keras.layers import Dense


enc_in = Input(shape=(None,))
enc_x = Embedding(1024, 128, input_length=92)(enc_in)

# vanilla LSTM
s_enc_x, *s_enc_state = LSTM(256, return_state=True)(enc_x)

print(len(s_enc_state))
print(s_enc_state)
# 2
# [<KerasTensor: shape=(None, 256) dtype=float32 (created by layer 'lstm_7')>,
#  <KerasTensor: shape=(None, 256) dtype=float32 (created by layer 'lstm_7')>]

# bi-directional LSTM
bi_enc_x, *bi_enc_state = Bidirectional(LSTM(256, return_state=True))(enc_x)

print(len(bi_enc_state))
print(bi_enc_state)
# 4
# [<KerasTensor: shape=(None, 256) dtype=float32 (created by layer 'bidirectional_6')>,
#  <KerasTensor: shape=(None, 256) dtype=float32 (created by layer 'bidirectional_6')>,
#  <KerasTensor: shape=(None, 256) dtype=float32 (created by layer 'bidirectional_6')>,
#  <KerasTensor: shape=(None, 256) dtype=float32 (created by layer 'bidirectional_6')>]

# decoder
dec_in = Input(shape=(None,))
dec_x = Embedding(1024, 128, input_length=92)(dec_in)
dec_x = Bidirectional(LSTM(256, return_sequences=True))(
    dec_x, initial_state=bi_enc_state)  # <= use bidirectional state
output = Dense(1024, activation="softmax")(dec_x)

print(output.shape)
# TensorShape([None, None, 1024])

Upvotes: 2

Related Questions