Reputation: 161
I want to have a multi-layer LSTM model, in each mini-batch, it should compute the outputs for two different inputs, as later they will be used differently.
I tried to implement this by myself as follows:
with tf.name_scope('placeholders'):
X = tf.placeholder(tf.float64, shape=[batch_size, max_length, dim])
Y = tf.placeholder(tf.float64, shape=[batch_size, max_length, dim])
seq_length1 = tf.placeholder(tf.int32, [batch_size], name="len1")
seq_length2 = tf.placeholder(tf.int32, [batch_size], name="len2")
with tf.variable_scope("model") as scope:
layers = [
tf.contrib.rnn.BasicLSTMCell(num_units=num, activation=tf.nn.relu, name="e_lstm")
for num in neurons
]
if training: # apply dropout during training
layers_e = [
tf.contrib.rnn.DropoutWrapper(layer, input_keep_prob=keep_prob)
for layer in layers
]
multi_layer_cell = tf.contrib.rnn.MultiRNNCell(layers)
_, states_s = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float64, sequence_length=seq_length1)
_, states_o = tf.nn.dynamic_rnn(multi_layer_cell, Y, dtype=tf.float64, sequence_length=seq_length2)
But in the visualization graph from TensorBoard, it actually builds two different RNNs in the model scope, and the output of one RNN becomes the input of the other and vice versa, which is not the desired behavior.
Can anyone tell me how should I modify the code to get the desired behavior?
Thank you.
Upvotes: 0
Views: 392
Reputation: 963
Add two lines:
with tf.variable_scope('rnn'):
_, states_s = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float64, sequence_length=seq_length1)
with tf.variable_scope('rnn', reuse=True):
_, states_o = tf.nn.dynamic_rnn(multi_layer_cell, Y, dtype=tf.float64, sequence_length=seq_length2)
and I think the code below is a better way, but not sure, advice are welcomed!
with tf.variable_scope('rnn', reues=tf.AUTO_REUSE):
_, states_s = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float64, sequence_length=seq_length1)
_, states_o = tf.nn.dynamic_rnn(multi_layer_cell, Y, dtype=tf.float64, sequence_length=seq_length2)
Upvotes: 1