Reputation: 4926
I have following simple placeholders:
x = tf.placeholder(tf.float32, shape=[1])
y = tf.placeholder(tf.float32, shape=[1])
z = tf.placeholder(tf.float32, shape=[1])
There are two functions fn1
and fn2
defined as:
def fn1(a, b):
return tf.mul(a, b)
def fn2(a, b):
return tf.add(a, b)
Now I want to calculate result based on pred condition:
pred = tf.placeholder(tf.bool, shape=[1])
result = tf.cond(pred, fn1(x,y), fn2(y,z))
But it gives me an error saying fn1 and fn2 must be callable
.
How can I write fn1
and fn2
so that they can receive parameters at runtime?
I want to call the following:
sess.run(result, feed_dict={x:1,y:2,z:3,pred:True})
Upvotes: 12
Views: 13897
Reputation: 927
You can pass parameters to the functions using lambda and the code is as bellows.
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
z = tf.placeholder(tf.float32)
def fn1(a, b):
return tf.mul(a, b)
def fn2(a, b):
return tf.add(a, b)
pred = tf.placeholder(tf.bool)
result = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))
Then you can call it as bellowing:
with tf.Session() as sess:
print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: True})
# The result is 2.0
Upvotes: 23
Reputation: 1889
The easiest would be to define your functions in the call:
result = tf.cond(pred, lambda: tf.mul(a, b), lambda: tf.add(a, b))
Upvotes: 4