Reputation: 47
I would like to create a pair of recurrent neural networks, say NN1 and NN2, where NN2 reproduces its output from the previous time step and does not update its weights at the current time step whenever NN1 outputs a different value from the previous time step.
To do this, I was planning to use tf.cond()
together with tf.stop_gradients()
. However, in all toy examples I have run, I cannot get tf.gradients()
to pass through tf.cond()
: tf.gradients()
simply returns [None]
.
Here is a simple toy example:
import tensorflow as tf
x = tf.constant(5)
y = tf.constant(3)
mult = tf.multiply(x, y)
cond = tf.cond(pred = tf.constant(True),
true_fn = lambda: mult,
false_fn = lambda: mult)
grad = tf.gradients(cond, x) # Returns [None]
Here is another simple toy example where I define true_fn
and false_fn
in tf.cond()
(still no dice):
import tensorflow as tf
x = tf.constant(5)
y = tf.constant(3)
z = tf.constant(8)
cond = tf.cond(pred = x < y,
true_fn = lambda: tf.add(x, z),
false_fn = lambda: tf.square(y))
tf.gradients(cond, z) # Returns [None]
I originally thought that the gradient should flow through both true_fn
and and false_fn
, but clearly no gradient is flowing at all. Is this the expected behavior of gradients computed through tf.cond()
? Might there be a way around this issue?
Upvotes: 4
Views: 903
Reputation: 8585
Yes, the gradients will pass through tf.cond()
. You just need to use floats instead of integers and (preferably) use variables instead of constants:
import tensorflow as tf
x = tf.Variable(5.0, dtype=tf.float32)
y = tf.Variable(6.0, dtype=tf.float32)
z = tf.Variable(8.0, dtype=tf.float32)
cond = tf.cond(pred = x < y,
true_fn = lambda: tf.add(x, z),
false_fn = lambda: tf.square(y))
op = tf.gradients(cond, z)
# Returns [<tf.Tensor 'gradients_1/cond_1/Add/Switch_1_grad/cond_grad:0' shape=() dtype=float32>]
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(op)) # [1.0]
Upvotes: 1