learner
learner

Reputation: 3472

Constantly update tf.cond based on bool value

I am using tf.cond for controlling the flow of the Tensorflow graph. I went through the documentation and was able to implement tf.cond based branching successfully. But my concern is that while the graph is being loaded the value of the bool variable is checked and the branching decision is made at the initialization step itself. Any further changes in the bool is not tracked. Following is the MWE that better describes the problem:

def funa():
    return tf.constant(32)

def funb():
    return tf.constant(25)

foo = True
x = tf.cond(tf.convert_to_tensor(foo), lambda: funa(), lambda: funb())
for i in range(20):
    global foo
    if i > 10:
        foo = False
    print(sess.run(x))    

This prints only 32s.

I tried with eager_execution too with the following code:

tf.enable_eager_execution()
def funa():
    return tf.constant(32)

def funb():
    return tf.constant(21)

foo = True
x = tf.cond(tf.convert_to_tensor(foo), lambda: funa(), lambda: funb())
for i in range(20):
    if i > 10:
        foo = False
    print(x)

Still the same result.

So my question is how can I write code such that one part of the graph is chosen dynamically, based on the updates to the bool variable (if possible)? Thanks. I am using Tensorflow v1.14.

Upvotes: 1

Views: 54

Answers (1)

Abhinav Goyal
Abhinav Goyal

Reputation: 1455

You can make a placeholder for foo and feed it's value while running the session. Modified code:

import tensorflow as tf

def funa():
    return tf.constant(32)

def funb():
    return tf.constant(25)

foo = True
foo_p = tf.placeholder(tf.bool)

sess = tf.Session()

x = tf.cond(foo_p, lambda: funa(), lambda: funb())
for i in range(20):
    if i > 10:
        foo = False
    print(sess.run(x, {foo_p:foo}))

Upvotes: 1

Related Questions