Reputation: 23
I am trying to feed_dict value bool to a function
def sum(a, b, flag = True, msg1= "Sum", msg2= "Multiply "):
if (flag is True):
print(msg1)
vtotal = tf.add(a,b)
else:
print(msg2)
vtotal = tf.multiply(a,b)
return vtotal
when i call the function as sum(a,b), the default value of flag = True is used for processing
but when i call the function as
sum(a, b, flag):
and i feed the value of flag from feed_dict like
output = sess.run(total,feed_dict = {a: a_arr, b: b_arr, flag: True})
it does not take value as True, rather executes the else part of the function
the full code is below: Please help why is this happening.
def initialize_placeholders():
a = tf.placeholder(tf.float32,[3,None],name="a")
b = tf.placeholder(tf.float32,[3,None],name ="b")
flag = tf.placeholder(tf.bool, name="flag")
return a, b, flag
def sum(a, b, flag = True, msg1= "Sum", msg2= "Multiply "):
if (flag is True):
print(msg1)
vtotal = tf.add(a,b)
else:
print(msg2)
vtotal = tf.multiply(a,b)
return vtotal
def model(a_arr,b_arr):
#print(a_arr)
#print(b_arr)
tf.reset_default_graph()
a, b ,flag= initialize_placeholders()
total = sum(a,b,flag)
init = tf.global_variables_initializer()
print(flag)
with tf.Session() as sess:
sess.run(init)
output = sess.run(total,feed_dict = {a: a_arr, b: b_arr, flag: True})
print(flag)
unv = sess.run(tf.report_uninitialized_variables())
sess.close()
return output, unv
a_arr = np.arange(6)
a_arr = a_arr.reshape(3,2)
b_arr = np.array([2,4,6,8,10,12])
b_arr = b_arr.reshape(3,2)
output , unv = model(a_arr,b_arr)
print(output)
print(unv)
Upvotes: 1
Views: 33
Reputation: 59731
You cannot use TensorFlow values in regular conditional Python statements (unless you are using something like AutoGraph). You can do what you want with tf.cond
like this:
def sum(a, b, flag=True):
flag = tf.convert_to_tensor(flag)
return tf.cond(flag, lambda: tf.add(a, b), lambda: tf.multiply(a, b))
You could also make it a bit more complicated in order to save the tf.cond
operation when the value of flag
is fixed in advance. For example, you could have something like this:
def sum(a, b, flag = True, msg1= "Sum", msg2= "Multiply "):
true_fn = lambda: tf.add(a, b)
false_fn = lambda: tf.multiply(a, b)
if flag is True:
return true_fn()
elif flag is False:
return false_fn()
else: # Use TensorFlow conditional
flag = tf.convert_to_tensor(flag)
return tf.cond(flag, true_fn, false_fn)
I removed the print
instructions because they do cannot be directly used in TensorFlow conditionals, but you can still have tf.print
operations if you want to see the printed messages when the graph is executed.
Upvotes: 1