HenrySky
HenrySky

Reputation: 11

Is there an alternative method to apply tf.case, tf.cond, etc. on a tensor?

I'm trying to use tf.case to use a index value in a tensor to direct to different network structure part, get different losses and then sum them up as the final loss for training. Take a simple example, I judge the value in a list and output a different value. For example [0,1,2,3] -> [0,7,10,13] where case 0: output 0 case 1: output 7 case 2: output 10 case 3: output 13. However, the tf.cond, tf.case seem only to be used on a scalar. How to fulfill the goal?

Upvotes: 1

Views: 1806

Answers (2)

FelixHo
FelixHo

Reputation: 1304

try this

import tensorflow as tf
value = [0, 1, 2, 3]
ones = tf.ones_like(value)
out = tf.where(tf.equal(value, 0), ones * 0,
               tf.where(tf.equal(value, 1), ones * 7,
                        tf.where(tf.equal(value, 2), ones * 10,
                                 tf.where(tf.equal(value, 3), ones * 13, ones * -1
                                          )
                                 )
                        )
               )

with tf.Session() as sess:
    print(sess.run(out)) # [ 0  7 10 13]

Upvotes: 0

jwayne
jwayne

Reputation: 756

The only operation I'm aware of that evaluates a condition separately on each element of a vector is tf.where. You would leave x=None, y=None:

t_orig = tf.constant([0, 1, 2, 3, 1])
t_filt = tf.where(tf.equal(t_orig, 1))
with tf.Session() as sess:
    print sess.run(t_filt)

Output:

[[1]
 [4]]

However, this only evaluates the truth of a single condition. If you want to evaluate the truth of multiple conditions, over each element of a vector, I think you'll have to use tf.map_fn combined with tf.case. AFAIK, tf.case is the only operation that evaluates the truth of many conditions on a given value:

t_orig = tf.constant([0, 1, 2, 3])
t_new = tf.map_fn(
        lambda x: tf.case(
            pred_fn_pairs=[
                (tf.equal(x, 0), lambda: tf.constant(0)),
                (tf.equal(x, 1), lambda: tf.constant(7)),
                (tf.equal(x, 2), lambda: tf.constant(10)),
                (tf.equal(x, 3), lambda: tf.constant(13))],
            default=lambda: tf.constant(-1)),
        t_orig)
with tf.Session() as sess:
    print sess.run(t_new)

Output:

[ 0  7 10 13]

Upvotes: 2

Related Questions