Reputation: 11
I'm trying to use tf.case to use a index value in a tensor to direct to different network structure part, get different losses and then sum them up as the final loss for training. Take a simple example, I judge the value in a list and output a different value. For example [0,1,2,3] -> [0,7,10,13] where case 0: output 0 case 1: output 7 case 2: output 10 case 3: output 13. However, the tf.cond, tf.case seem only to be used on a scalar. How to fulfill the goal?
Upvotes: 1
Views: 1806
Reputation: 1304
try this
import tensorflow as tf
value = [0, 1, 2, 3]
ones = tf.ones_like(value)
out = tf.where(tf.equal(value, 0), ones * 0,
tf.where(tf.equal(value, 1), ones * 7,
tf.where(tf.equal(value, 2), ones * 10,
tf.where(tf.equal(value, 3), ones * 13, ones * -1
)
)
)
)
with tf.Session() as sess:
print(sess.run(out)) # [ 0 7 10 13]
Upvotes: 0
Reputation: 756
The only operation I'm aware of that evaluates a condition separately on each element of a vector is tf.where. You would leave x=None, y=None
:
t_orig = tf.constant([0, 1, 2, 3, 1])
t_filt = tf.where(tf.equal(t_orig, 1))
with tf.Session() as sess:
print sess.run(t_filt)
Output:
[[1]
[4]]
However, this only evaluates the truth of a single condition. If you want to evaluate the truth of multiple conditions, over each element of a vector, I think you'll have to use tf.map_fn
combined with tf.case
. AFAIK, tf.case
is the only operation that evaluates the truth of many conditions on a given value:
t_orig = tf.constant([0, 1, 2, 3])
t_new = tf.map_fn(
lambda x: tf.case(
pred_fn_pairs=[
(tf.equal(x, 0), lambda: tf.constant(0)),
(tf.equal(x, 1), lambda: tf.constant(7)),
(tf.equal(x, 2), lambda: tf.constant(10)),
(tf.equal(x, 3), lambda: tf.constant(13))],
default=lambda: tf.constant(-1)),
t_orig)
with tf.Session() as sess:
print sess.run(t_new)
Output:
[ 0 7 10 13]
Upvotes: 2