Ziyuan
Ziyuan

Reputation: 4558

Is there something similar to tf.cond but for vector predicates?

Say I have a 5D tensor x and a 1D boolean mask m where m.shape[0] == x.shape[0], and I want to decide which sub-network should be applied on each 4D sample inside x based on the corresponding boolean entry of m.

To my knowledge tf.cond accepts scalar predicator only. Though tf.boolean_mask may be helpful to split the samples inside a batch into two subsets according to m as desired, I am not sure how to re-pack the outputs back into one 5D tensor without messing up the original sample order. Any hints?

Upvotes: 1

Views: 228

Answers (1)

javidcf
javidcf

Reputation: 59711

The simplest thing would be to evaluate the data on both models and the use tf.where to select the final output.

import tensorflow as tf

def model1(x):
    return 2 * x

def model2(x):
    return -3 * x

with tf.Graph().as_default(), tf.Session() as sess:
    x = tf.placeholder(tf.float32, [None, None])  # 2D for simplicity
    m = tf.placeholder(tf.bool, [None])
    y1 = model1(x)
    y2 = model2(x)
    y = tf.where(m, y1, y2)
    print(sess.run(y, feed_dict={x: [[1, 2], [3, 4], [5, 6]], m: [True, False, True]}))
    # [[  2.   4.]
    #  [ -9. -12.]
    #  [ 10.  12.]]

If you really want to avoid that, you can use tf.boolean_mask to split the data and then recombine it with tf.scatter_nd. This is one possible way.

import tensorflow as tf

def model1(x):
    return 2 * x

def model2(x):
    return -3 * x

with tf.Graph().as_default(), tf.Session() as sess:
    x = tf.placeholder(tf.float32, [None, None])
    m = tf.placeholder(tf.bool, [None])
    n = tf.size(m)
    i = tf.range(n)
    x1 = tf.boolean_mask(x, m)
    i1 = tf.boolean_mask(i, m)
    y1 = model1(x1)
    m_neg = ~m
    x2 = tf.boolean_mask(x, m_neg)
    i2 = tf.boolean_mask(i, m_neg)
    y2 = model2(x2)
    y = tf.scatter_nd(tf.expand_dims(tf.concat([i1, i2], axis=0), 1),
                      tf.concat([y1, y2], axis=0),
                      tf.concat([[n], tf.shape(y1)[1:]], axis=0))
    print(sess.run(y, feed_dict={x: [[1, 2], [3, 4], [5, 6]], m: [True, False, True]}))
    # [[  2.   4.]
    #  [ -9. -12.]
    #  [ 10.  12.]]

Upvotes: 1

Related Questions