hua wei
hua wei

Reputation: 870

How to do inference when using a bidirectianl decoder in seq2seq model?

I build a seq2seq model using keras like this:

def build_bidi_model():
    rnn = layers.LSTM

    # Encoder
    # encoder inputs
    encoder_inputs = Input(shape=(None,), name='encoder_inputs')
    # encoder embedding
    encoder_embedding = Embedding(num_encoder_tokens, encoder_embedding_dim,name='encoder_embedding')(encoder_inputs)
    # encoder lstm
    encoder_lstm = Bidirectional(rnn(latent_dim, return_state=True, dropout=0.2,recurrent_dropout=0.5),name='encoder_lstm')
    _, *encoder_states = bidi_encoder_lstm(encoder_embedding)

    # Decoder
    # decoder inputs
    decoder_inputs = Input(shape=(None,), name='decoder_inputs')
    # decoder embeddding
    decoder_embedding = Embedding(num_decoder_tokens, decoder_embedding_dim, name='decoder_embedding')(decoder_inputs)
    # decoder lstm,
    decoder_lstm = Bidirectional(rnn(latent_dim, return_state=True, 
                       return_sequences=True, dropout=0.2,
                       recurrent_dropout=0.5),name='decoder_lstm')
    # get outputs and decoder states
    rnn_outputs, *decoder_states = decoder_lstm(decoder_embedding, initial_state=encoder_states)
    # decoder dense
    decoder_dense = Dense(num_decoder_tokens, activation='softmax', name='decoder_dense')
    decoder_outputs = decoder_dense(rnn_outputs)

    bidi_model = Model([encoder_inputs,decoder_inputs], [decoder_outputs])
    bidi_model.compile(optimizer='adam', loss='categorical_crossentropy')

    return bidi_model

The training loss and validaiton loss is really low, but when I try to do inference with the trained model, it turns out a bad results. here is the inference code:

reverse_input_char_index = dict([(i, char) for char, i in input_token_index.items()])
reverse_target_word_index = dict([(i, char) for char, i in target_token_index.items()])

def decode_sequence(input_seq, encoder_model, decoder_model):
    # get encoder states
    states_value = encoder_model.predict(input_seq)

    # create a empty sequence
    target_seq = np.zeros((1, 1))
    target_seq[0, 0] = target_token_index['start']

    stop_condition = False
    decoded_sentence = ''

    while not stop_condition:
        output, *decoder_states = decoder_model.predict([target_seq] + states_value)

        sampled_token_index = np.argmax(output[0, -1, :])
        sampled_word = reverse_target_word_index[sampled_token_index]
        decoded_sentence += sampled_word

        if sampled_word == 'end' or len(decoded_sentence) > max_decoder_seq_length:
            stop_condition = True

        # update target_seq
        target_seq = np.zeros((1, 1))
        target_seq[0, 0] = sampled_token_index

        # update states
        states_value = decoder_states

    return decoded_sentence

It is weird that the inference result is pretty good when I build another similar model which remove the Bidirectional layer in decoder. So I am wondering if my code is wrong? Please help.

Upvotes: 0

Views: 864

Answers (1)

kesang
kesang

Reputation: 21

I think the issue is that: for decoder, it is not possible to do the inference using bidirectional. Since bidirectional is actually the original sequence and also the reverse of the sequence, when in inference mode, your model is only able to predict the starting character of the results, it does not know the ending character yet. Feeding only the starting character would be a bad input for the decoder. Garbage in and garbage out, so your model performance is way off.

Hope this makes sense.

Upvotes: 2

Related Questions