GT GT
GT GT

Reputation: 73

Behaviour of multiple MultiRNNCell?

i'm trying to understand some RNN implementation with TensorFlow. Does MultiRNNCell creates new object each time, or adds layers to the previous object?

cells_fwd_list = []

for num_cells in [256, 128, 64]:
    cells_fwd_list.append(tf.nn.rnn_cell.LSTMCell(num_units=num_cells, activation=tf.nn.tanh))
    cells_fwd = tf.nn.rnn_cell.MultiRNNCell(cells_fwd_list, state_is_tuple=True)

Does it mean cells_fwd has 3 layers of [256,128,64] units each? OR cells_fwd has 6 layers of [256,256,128,256,128,64] units each?

Upvotes: 2

Views: 224

Answers (1)

Prasad
Prasad

Reputation: 6034

MultiRNNCell abstracts a sequence of RNN cells into one layer. In your code, you are creating a MultiRNNCell after each time you are creating a LSTMCell. So basically your code should be:

cells_fwd_list = []

for num_cells in [256, 128, 64]:
    cells_fwd_list.append(tf.nn.rnn_cell.LSTMCell(num_units=num_cells, activation=tf.nn.tanh))
cells_fwd = tf.nn.rnn_cell.MultiRNNCell(cells_fwd_list, state_is_tuple=True)

Now, your cells_fwd will be holding three LSTM layers internally comprising of 256, 128 and 64 units.

Also, as Tensorflow 1.x is soon going to be deprecated, your TensorFlow 2.0 equivalent code should be something like this:

import tensorflow as tf

timesteps, input_dim = 60, 100
inputs = tf.keras.Input((timesteps, input_dim))
lstm_cells = [tf.keras.layers.LSTMCell(units) for units in [256, 128, 64]]
stacked_rnn_layer = tf.keras.layers.RNN(lstm_cells)(inputs)

Upvotes: 2

Related Questions