Lemon
Lemon

Reputation: 1394

tensorflow - how to use variational recurrent dropout correctly

The tensorflow config dropout wrapper has three different dropout probabilities that can be set: input_keep_prob, output_keep_prob, state_keep_prob.

I want to use variational dropout for my LSTM units, by setting the variational_recurrent argument to true. However, I don't know which of the three dropout probabilities I have to use for variational dropout to function correctly.

Can someone provide help?

Upvotes: 2

Views: 2784

Answers (1)

Nipun Wijerathne
Nipun Wijerathne

Reputation: 1829

According to this paper https://arxiv.org/abs/1512.05287 that is used for implementation of the variational_recurrent dropouts, you can think about as follows,

  • input_keep_prob - probability that dropping out input connections.

  • output_keep_prob - probability that dropping out output connections.

  • state_keep_prob - Probability that droping out recurrent connections.

See the diagram below,

enter image description here

If you set the variational_recurrent to be true you will get an RNN that's similar to the model in right and otherwise in left.

The basic differences in above two models are,

  • Variational RNN repeats the same dropout mask at each time step for both inputs, outputs, and recurrent layers (drop the same network units at each time step).

  • Native RNN uses different dropout masks at each time step for the inputs and outputs alone (no dropout is used with the recurrent connections since the use of different masks with these connections leads to deteriorated performance).

In the above diagram, coloured connections represent the dropped-out connections, with different colours corresponding to different dropout masks. Dashed lines correspond to standard connections with no dropout.

Therefore, if you use a variational RNN you can set all three probability parameters according to your requirement.

Hope this helps.

Upvotes: 5

Related Questions