Reputation: 137
I'm trying to write my own LSTM Variational Autoencoder for text, and have gotten an OK understanding of how the encoding step works and how I perform sampling of the latent vector Z
. The problem is now how I should pass on the Z
to the decoder. For the input to the decoder I have a start token <s>
, which leaves the hidden state h
, and the cell state c
for the LSTM cell in the decoder.
Should I make both the initial states h
and c
equal to Z
, just one of them, or something else?
Upvotes: 0
Views: 810
Reputation: 16607
Using RepeatVector you can repeat the latent output n
times. Then, feed it into the LSTM. Here is a minimal example:
# latent_dim: int, latent z-layer shape.
decoder_input = Input(shape=(latent_dim,))
_h_decoded = RepeatVector(timesteps)(decoder_input)
decoder_h = LSTM(intermediate_dim, return_sequences=True)
_h_decoded = decoder_h(_h_decoded)
decoder_mean = LSTM(input_dim, return_sequences=True)
_x_decoded_mean = decoder_mean(_h_decoded)
decoder = Model(decoder_input, _x_decoded_mean)
It is clearly explained in Keras documentation.
Upvotes: 1