EcoView
EcoView

Reputation: 78

Feeding array (shape with rank 1) to TensorFlow tf.case

Following this example from the tf.case documentation:

def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2},
            default=f3, exclusive=True)

I want to do the same, but allow to use a feed_dict as input, illustrated by this snipped:

x = tf.placeholder(tf.float32, shape=[None])
y = tf.placeholder(tf.float32, shape=[None])
z = tf.placeholder(tf.float32, shape=[None])
def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2},
            default=f3, exclusive=True)
print(sess.run(r, feed_dict={x: [0, 1, 2, 3], y: [1, 1, 1, 1], z: [2, 2, 2, 2]}))
# result should be [17, -1, -1, 23]

So, basically I want to feed three int-arrays of equal length and receive an array of int-values containing either 17, 23, or -1. Unfortunately, there code above gives and error:

ValueError: Shape must be rank 0 but is rank 1 for 'case/cond/Switch' (op: 'Switch') with input shapes: [?], [?].

I understand, that tf.case requires boolean scalar tensor input values but is there any way to achieve what I want? I also tried tf.cond without success.

Upvotes: 2

Views: 201

Answers (1)

javidcf
javidcf

Reputation: 59731

Use tf.where for that, for example like this (broadcasting support for tf.where seems to be on its way, but not there yet as far as I can tell, so you have to make sure all arguments have the same size with a vector of ones, or tf.fill, tf.tile...).

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    x = tf.placeholder(tf.float32, shape=[None])
    y = tf.placeholder(tf.float32, shape=[None])
    z = tf.placeholder(tf.float32, shape=[None])
    ones = tf.ones_like(x)
    r = tf.where(x < y, 17 * ones, tf.where(x > z, 23 * ones, -ones))
    print(sess.run(r, feed_dict={x: [0, 1, 2, 3], y: [1, 1, 1, 1], z: [2, 2, 2, 2]}))
    # [17. -1. -1. 23.]

Upvotes: 1

Related Questions