Roger
Roger

Reputation: 161

Reuse MultiRNNCell for two different inputs in Tensorflow

I want to have a multi-layer LSTM model, in each mini-batch, it should compute the outputs for two different inputs, as later they will be used differently.

I tried to implement this by myself as follows:

with tf.name_scope('placeholders'):
    X = tf.placeholder(tf.float64, shape=[batch_size, max_length, dim])
    Y = tf.placeholder(tf.float64, shape=[batch_size, max_length, dim])
    seq_length1 = tf.placeholder(tf.int32, [batch_size], name="len1")
    seq_length2 = tf.placeholder(tf.int32, [batch_size], name="len2")

with tf.variable_scope("model") as scope:
    layers = [
        tf.contrib.rnn.BasicLSTMCell(num_units=num, activation=tf.nn.relu, name="e_lstm")
        for num in neurons
    ]
    if training:    # apply dropout during training
        layers_e = [
            tf.contrib.rnn.DropoutWrapper(layer, input_keep_prob=keep_prob)
            for layer in layers
        ]
    multi_layer_cell = tf.contrib.rnn.MultiRNNCell(layers)
    _, states_s = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float64, sequence_length=seq_length1)  

    _, states_o = tf.nn.dynamic_rnn(multi_layer_cell, Y, dtype=tf.float64, sequence_length=seq_length2)

But in the visualization graph from TensorBoard, it actually builds two different RNNs in the model scope, and the output of one RNN becomes the input of the other and vice versa, which is not the desired behavior.

Can anyone tell me how should I modify the code to get the desired behavior?

Thank you.

Upvotes: 0

Views: 392

Answers (1)

Ink
Ink

Reputation: 963

Add two lines:

with tf.variable_scope('rnn'):
    _, states_s = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float64, sequence_length=seq_length1)  
with tf.variable_scope('rnn', reuse=True):
    _, states_o = tf.nn.dynamic_rnn(multi_layer_cell, Y, dtype=tf.float64, sequence_length=seq_length2)

and I think the code below is a better way, but not sure, advice are welcomed!

with tf.variable_scope('rnn', reues=tf.AUTO_REUSE):
    _, states_s = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float64, sequence_length=seq_length1)  
    _, states_o = tf.nn.dynamic_rnn(multi_layer_cell, Y, dtype=tf.float64, sequence_length=seq_length2)

Upvotes: 1

Related Questions