Reputation: 103
Hoping someone can help me understand an issue I have been having using LSTMs with dynamic_rnn in Tensorflow. As per this MWE, when I have a batch size of 1 with sequences that are incomplete (I pad the short tensors with nan's as opposed to zeros to highlight) everything operates as normal, the nan's in the short sequences are ignored as expected...
import tensorflow as tf
import numpy as np
batch_1 = np.random.randn(1, 10, 8)
batch_2 = np.random.randn(1, 10, 8)
batch_1[6:] = np.nan # lets make a short batch in batch 1 second sample of length 6 by padding with nans
seq_lengths_batch_1 = [6]
seq_lengths_batch_2 = [10]
tf.reset_default_graph()
input_vals = tf.placeholder(shape=[1, 10, 8], dtype=tf.float32)
lengths = tf.placeholder(shape=[1], dtype=tf.int32)
cell = tf.nn.rnn_cell.LSTMCell(num_units=5)
outputs, states = tf.nn.dynamic_rnn(cell=cell, dtype=tf.float32, sequence_length=lengths, inputs=input_vals)
last_relevant_value = states.h
fake_loss = tf.reduce_mean(last_relevant_value)
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(fake_loss)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
_, fl, lrv = sess.run([optimizer, fake_loss, last_relevant_value], feed_dict={input_vals: batch_1, lengths: seq_lengths_batch_1})
print(fl, lrv)
_, fl, lrv = sess.run([optimizer, fake_loss, last_relevant_value], feed_dict={input_vals: batch_2, lengths: seq_lengths_batch_2})
print(fl, lrv)
sess.close()
which outputs properly populated values of the ilk....
0.00659429 [[ 0.11608966 0.08498846 -0.02892204 -0.01945034 -0.1197343 ]]
-0.080244 [[-0.03018401 -0.18946587 -0.19128899 -0.10388547 0.11360413]]
However then when I increase my batch size up to size 3 for example, the first batch executes correctly but then somehow the second batch causes nans to start to propogating
import tensorflow as tf
import numpy as np
batch_1 = np.random.randn(3, 10, 8)
batch_2 = np.random.randn(3, 10, 8)
batch_1[1, 6:] = np.nan
batch_2[0, 8:] = np.nan
seq_lengths_batch_1 = [10, 6, 10]
seq_lengths_batch_2 = [8, 10, 10]
tf.reset_default_graph()
input_vals = tf.placeholder(shape=[3, 10, 8], dtype=tf.float32)
lengths = tf.placeholder(shape=[3], dtype=tf.int32)
cell = tf.nn.rnn_cell.LSTMCell(num_units=5)
outputs, states = tf.nn.dynamic_rnn(cell=cell, dtype=tf.float32, sequence_length=lengths, inputs=input_vals)
last_relevant_value = states.h
fake_loss = tf.reduce_mean(last_relevant_value)
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(fake_loss)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
_, fl, lrv = sess.run([optimizer, fake_loss, last_relevant_value], feed_dict={input_vals: batch_1, lengths: seq_lengths_batch_1})
print(fl, lrv)
_, fl, lrv = sess.run([optimizer, fake_loss, last_relevant_value], feed_dict={input_vals: batch_2, lengths: seq_lengths_batch_2})
print(fl, lrv)
sess.close()
giving
0.0533635 [[ 0.33622459 -0.0284576 0.11914439 0.14402215 -0.20783389]
[ 0.20805927 0.17591488 -0.24977767 -0.03432769 0.2944448 ]
[-0.04508523 0.11878576 0.07287208 0.14114542 -0.24467923]]
nan [[ nan nan nan nan nan]
[ nan nan nan nan nan]
[ nan nan nan nan nan]]
I have found this behavior quite strange, as I expected all values after the sequence lengths to be ignored as happens with a batch size of 1 but doesn't work with a batch size of 2 or more.
Obviously, nans do not get propagated if I use 0 as my padding value, but this doesn't inspire me with any confidence that dynamic_rnn is functioning as I am expecting it to.
Also I should mention that if I remove the optimisation step the issue doesnt occur so now I'm properly confused and after a day of trying many different permutations, I cant see why batch size would make any difference here
Upvotes: 0
Views: 209
Reputation: 3633
I did not trace it down to the exact operation but here is what I believe to be the case.
Why aren't values beyond sequence_length
ignored? They are ignored in the sense that they are multiplied by 0
(they are masked out) when doing some operations. Mathematically, the result is always a zero, so they should have no effect. Unfortunately, nan * 0 = nan
. So, if you give nan
values in your examples, they propagate. You might wonder why TensorFlow does not ignore them completely, but only masks them. The reason is performance on modern hardware. It is much easier to do operations on a large regular shape with a bunch of zeros than on several small shapes (that you get from decomposing an irregular shape).
Why does it only happen on the second batch? In the first batch, the loss and last hidden state are computed using the original variable values. They are fine. Because you also do the optimizer update in the sess.run()
, variables get updated and become nan
in the first call. In the second call, the nan
s from variables spread to loss and hidden state.
How can I be confident that the values beyond sequence_length
are really masked out? I modified your example to reproduce the issue but also made it deterministic.
import tensorflow as tf
import numpy as np
batch_1 = np.ones((3, 10, 2))
batch_1[1, 7:] = np.nan
seq_lengths_batch_1 = [10, 7, 10]
tf.reset_default_graph()
input_vals = tf.placeholder(shape=[3, 10, 2], dtype=tf.float32)
lengths = tf.placeholder(shape=[3], dtype=tf.int32)
cell = tf.nn.rnn_cell.LSTMCell(num_units=3, initializer=tf.constant_initializer(1.0))
init_state = tf.nn.rnn_cell.LSTMStateTuple(*[tf.ones([3, c]) for c in cell.state_size])
outputs, states = tf.nn.dynamic_rnn(cell=cell, dtype=tf.float32, sequence_length=lengths, inputs=input_vals,
initial_state=init_state)
last_relevant_value = states.h
fake_loss = tf.reduce_mean(last_relevant_value)
optimizer = tf.train.AdamOptimizer(learning_rate=0.1).minimize(fake_loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in range(1):
_, fl, lrv = sess.run([optimizer, fake_loss, last_relevant_value],
feed_dict={input_vals: batch_1, lengths: seq_lengths_batch_1})
print "VARIABLES:", sess.run(tf.trainable_variables())
print "LOSS and LAST HIDDEN:", fl, lrv
If you replace the np.nan
in batch_1[1, 7:] = np.nan
with any number (e.g. try -1M, 1M, 0) , you will see that the values you get are the same. You can also run the loop for more iterations. As a further sanity check, if you set seq_lengths_batch_1
to something "wrong", e.g. [10, 8, 10], you can see that now the value you use in batch_1[1, 7:] = np.nan
effects the output.
Upvotes: 1