Covey
Covey

Reputation: 137

How to pass latent vector to decoder in LSTM Variational Autoencoder

I'm trying to write my own LSTM Variational Autoencoder for text, and have gotten an OK understanding of how the encoding step works and how I perform sampling of the latent vector Z. The problem is now how I should pass on the Z to the decoder. For the input to the decoder I have a start token <s>, which leaves the hidden state h, and the cell state c for the LSTM cell in the decoder.

Should I make both the initial states h and c equal to Z, just one of them, or something else?

Upvotes: 0

Views: 810

Answers (1)

Amir
Amir

Reputation: 16607

Using RepeatVector you can repeat the latent output n times. Then, feed it into the LSTM. Here is a minimal example:

 # latent_dim: int, latent z-layer shape. 
 decoder_input = Input(shape=(latent_dim,)) 

 _h_decoded = RepeatVector(timesteps)(decoder_input)
 decoder_h = LSTM(intermediate_dim, return_sequences=True)
 _h_decoded = decoder_h(_h_decoded)

 decoder_mean = LSTM(input_dim, return_sequences=True)
 _x_decoded_mean = decoder_mean(_h_decoded)

 decoder = Model(decoder_input, _x_decoded_mean)

It is clearly explained in Keras documentation.

Upvotes: 1

Related Questions