Reputation: 3472
I am using tf.cond
for controlling the flow of the Tensorflow graph. I went through the documentation and was able to implement tf.cond
based branching successfully. But my concern is that while the graph is being loaded the value of the bool
variable is checked and the branching decision is made at the initialization step itself. Any further changes in the bool
is not tracked. Following is the MWE that better describes the problem:
def funa():
return tf.constant(32)
def funb():
return tf.constant(25)
foo = True
x = tf.cond(tf.convert_to_tensor(foo), lambda: funa(), lambda: funb())
for i in range(20):
global foo
if i > 10:
foo = False
print(sess.run(x))
This prints only 32
s.
I tried with eager_execution
too with the following code:
tf.enable_eager_execution()
def funa():
return tf.constant(32)
def funb():
return tf.constant(21)
foo = True
x = tf.cond(tf.convert_to_tensor(foo), lambda: funa(), lambda: funb())
for i in range(20):
if i > 10:
foo = False
print(x)
Still the same result.
So my question is how can I write code such that one part of the graph is chosen dynamically, based on the updates to the bool
variable (if possible)? Thanks. I am using Tensorflow v1.14.
Upvotes: 1
Views: 54
Reputation: 1455
You can make a placeholder for foo
and feed it's value while running the session. Modified code:
import tensorflow as tf
def funa():
return tf.constant(32)
def funb():
return tf.constant(25)
foo = True
foo_p = tf.placeholder(tf.bool)
sess = tf.Session()
x = tf.cond(foo_p, lambda: funa(), lambda: funb())
for i in range(20):
if i > 10:
foo = False
print(sess.run(x, {foo_p:foo}))
Upvotes: 1