vivek
vivek

Reputation: 369

MLP to initialize LSTM cell state in Keras

Can we use output of MLP as cell state in LSTM network and train the MLP too with back propagation?

This is similar to image captioning with CNN & LSTM where the output of CNN is flattened and used as initial hidden/cell state and train the stacked network where even the CNN part is updated through back-propagation.

I tried an architecture in keras to achieve the same. Please find the code here.

But the weights of the MLP are not being updated. I understand this is more straightforward in tensorflow where we can explicitly mention which parameters to update with the loss, but can anyone help me with keras API?

Upvotes: 1

Views: 1482

Answers (1)

modesitt
modesitt

Reputation: 7210

Yes, we can. Simply pass the output as the initial hidden state. Remember that an LSTM has two hidden states, h and c. You can read more about this here. Note that you also do not have to create multiple keras models, but can simple connect all the layers.:

# define mlp 
mlp_inp = Input(batch_shape=(batch_size, hidden_state_dim))
mlp_dense = Dense(hidden_state_dim, activation='relu')(mlp_inp)

## Define LSTM model
lstm_inp = Input(batch_shape=(batch_size, seq_len, inp_size))
lstm_layer = LSTM(lstm_dim)(lstm_inp, initial_state=[mlp_dense,mlp_dense])
lstm_out = Dense(10,activation='softmax')(lstm_layer)

model = Model([mlp_inp, lstm_inp], lstm_out)
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=  ['accuracy'])

However, because of the above fact about having two states, you may want to consider two MLP layers for each initial state separately.

# define mlp 
mlp_inp = Input(batch_shape=(batch_size, hidden_state_dim))
mlp_dense_h = Dense(hidden_state_dim, activation='relu')(mlp_inp)
mlp_dense_c = Dense(hidden_state_dim, activation='relu')(mlp_inp)

## Define LSTM model
lstm_inp = Input(batch_shape=(batch_size, seq_len, inp_size))
lstm_layer = LSTM(lstm_dim)(lstm_inp, initial_state=[mlp_dense_h,mlp_dense_c])
lstm_out = Dense(10,activation='softmax')(lstm_layer)

model = Model([mlp_inp, lstm_inp], lstm_out)
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=  ['accuracy'])

Also, note that when you go about saving this model, use save_weights instead of save because save_model can not handle the initial state passing. Also, as a slight note.

Upvotes: 2

Related Questions