exAres
exAres

Reputation: 4926

How to pass parmeters to functions inside tf.cond in Tensorflow?

I have following simple placeholders:

x = tf.placeholder(tf.float32, shape=[1])
y = tf.placeholder(tf.float32, shape=[1])
z = tf.placeholder(tf.float32, shape=[1])

There are two functions fn1 and fn2 defined as:

def fn1(a, b): 
    return tf.mul(a, b)
def fn2(a, b): 
    return tf.add(a, b)

Now I want to calculate result based on pred condition:

pred = tf.placeholder(tf.bool, shape=[1])
result = tf.cond(pred, fn1(x,y), fn2(y,z))

But it gives me an error saying fn1 and fn2 must be callable.

How can I write fn1 and fn2 so that they can receive parameters at runtime? I want to call the following:

sess.run(result, feed_dict={x:1,y:2,z:3,pred:True})

Upvotes: 12

Views: 13897

Answers (2)

Kongsea
Kongsea

Reputation: 927

You can pass parameters to the functions using lambda and the code is as bellows.

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
z = tf.placeholder(tf.float32)

def fn1(a, b):
  return tf.mul(a, b)

def fn2(a, b):
  return tf.add(a, b)

pred = tf.placeholder(tf.bool)
result = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))

Then you can call it as bellowing:

with tf.Session() as sess:
  print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: True})
  # The result is 2.0

Upvotes: 23

Phillip Bock
Phillip Bock

Reputation: 1889

The easiest would be to define your functions in the call:

result = tf.cond(pred, lambda: tf.mul(a, b), lambda: tf.add(a, b))

Upvotes: 4

Related Questions