sonu patel
sonu patel

Reputation: 23

Passing bool to function from feed_dict not work

I am trying to feed_dict value bool to a function

def sum(a, b, flag = True, msg1= "Sum", msg2= "Multiply "):

    if (flag is True):
        print(msg1)
        vtotal = tf.add(a,b)
    else:
        print(msg2)
        vtotal = tf.multiply(a,b)

    return vtotal

when i call the function as sum(a,b), the default value of flag = True is used for processing

but when i call the function as

sum(a, b, flag):

and i feed the value of flag from feed_dict like

output = sess.run(total,feed_dict = {a: a_arr, b: b_arr, flag: True})

it does not take value as True, rather executes the else part of the function

the full code is below: Please help why is this happening.

def initialize_placeholders():
    a = tf.placeholder(tf.float32,[3,None],name="a")
    b = tf.placeholder(tf.float32,[3,None],name ="b")
    flag = tf.placeholder(tf.bool, name="flag")

    return a, b, flag

def sum(a, b, flag = True, msg1= "Sum", msg2= "Multiply "):

    if (flag is True):
        print(msg1)
        vtotal = tf.add(a,b)
    else:
        print(msg2)
        vtotal = tf.multiply(a,b)

    return vtotal

def model(a_arr,b_arr):
    #print(a_arr)
    #print(b_arr)
    tf.reset_default_graph()
    a, b ,flag= initialize_placeholders()
    total = sum(a,b,flag)

    init = tf.global_variables_initializer()
    print(flag)

    with tf.Session() as sess:
        sess.run(init)
        output = sess.run(total,feed_dict = {a: a_arr, b: b_arr, flag: True})
        print(flag)
        unv = sess.run(tf.report_uninitialized_variables())
        sess.close()
    return output, unv

a_arr = np.arange(6)
a_arr = a_arr.reshape(3,2)
b_arr = np.array([2,4,6,8,10,12])
b_arr = b_arr.reshape(3,2)
output , unv = model(a_arr,b_arr)
print(output)
print(unv)

Upvotes: 1

Views: 33

Answers (1)

javidcf
javidcf

Reputation: 59731

You cannot use TensorFlow values in regular conditional Python statements (unless you are using something like AutoGraph). You can do what you want with tf.cond like this:

def sum(a, b, flag=True):
    flag = tf.convert_to_tensor(flag)
    return tf.cond(flag, lambda: tf.add(a, b), lambda: tf.multiply(a, b))

You could also make it a bit more complicated in order to save the tf.cond operation when the value of flag is fixed in advance. For example, you could have something like this:

def sum(a, b, flag = True, msg1= "Sum", msg2= "Multiply "):
    true_fn = lambda: tf.add(a, b)
    false_fn = lambda: tf.multiply(a, b)
    if flag is True:
        return true_fn()
    elif flag is False:
        return false_fn()
    else:  # Use TensorFlow conditional
        flag = tf.convert_to_tensor(flag)
        return tf.cond(flag, true_fn, false_fn)

I removed the print instructions because they do cannot be directly used in TensorFlow conditionals, but you can still have tf.print operations if you want to see the printed messages when the graph is executed.

Upvotes: 1

Related Questions