Reputation: 369
Can we use output of MLP as cell state in LSTM network and train the MLP too with back propagation?
This is similar to image captioning with CNN & LSTM where the output of CNN is flattened and used as initial hidden/cell state and train the stacked network where even the CNN part is updated through back-propagation.
I tried an architecture in keras to achieve the same. Please find the code here.
But the weights of the MLP are not being updated. I understand this is more straightforward in tensorflow where we can explicitly mention which parameters to update with the loss, but can anyone help me with keras API?
Upvotes: 1
Views: 1482
Reputation: 7210
Yes, we can. Simply pass the output as the initial hidden state
. Remember that an LSTM has two hidden states, h
and c
. You can read more about this here. Note that you also do not have to create multiple keras models, but can simple connect all the layers.:
# define mlp
mlp_inp = Input(batch_shape=(batch_size, hidden_state_dim))
mlp_dense = Dense(hidden_state_dim, activation='relu')(mlp_inp)
## Define LSTM model
lstm_inp = Input(batch_shape=(batch_size, seq_len, inp_size))
lstm_layer = LSTM(lstm_dim)(lstm_inp, initial_state=[mlp_dense,mlp_dense])
lstm_out = Dense(10,activation='softmax')(lstm_layer)
model = Model([mlp_inp, lstm_inp], lstm_out)
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics= ['accuracy'])
However, because of the above fact about having two states, you may want to consider two MLP layers for each initial state separately.
# define mlp
mlp_inp = Input(batch_shape=(batch_size, hidden_state_dim))
mlp_dense_h = Dense(hidden_state_dim, activation='relu')(mlp_inp)
mlp_dense_c = Dense(hidden_state_dim, activation='relu')(mlp_inp)
## Define LSTM model
lstm_inp = Input(batch_shape=(batch_size, seq_len, inp_size))
lstm_layer = LSTM(lstm_dim)(lstm_inp, initial_state=[mlp_dense_h,mlp_dense_c])
lstm_out = Dense(10,activation='softmax')(lstm_layer)
model = Model([mlp_inp, lstm_inp], lstm_out)
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics= ['accuracy'])
Also, note that when you go about saving this model, use save_weights
instead of save
because save_model
can not handle the initial state passing. Also, as a slight note.
Upvotes: 2