Brans Ds
Brans Ds

Reputation: 4117

How to use tf.cond for batch processing

I want to use tf.cond(pred, fn1, fn2, name=None) for conditional branching. Let say I have two tensors: x, y. Each tensor is a batch of 0/1 and I want to use this tensors compression x < y as the source for tf.cond pred argument:

pred: A scalar determining whether to return the result of fn1 or fn2.

But if I am working with batches then it looks like I need to iterate over the source tensor inside the graph and make slices for every item in batch and apply tf.cond for every item. Looks suspiciously as for me. Why tf.cond not accept batch and only scalar? Can you advise what is the right way to use it with batch?

Upvotes: 13

Views: 5040

Answers (1)

Allen Lavoie
Allen Lavoie

Reputation: 5808

tf.where sounds like what you want: a vectorized selection between Tensors.

tf.cond is a control flow modifier: it determines which ops are executed, and so it's difficult to think of useful batch semantics.

We can also put together a mixture of these operations: an operation which slices based on a condition and passes those slices to two branches.

import tensorflow as tf
from tensorflow.python.util import nest

def slicing_where(condition, full_input, true_branch, false_branch):
  """Split `full_input` between `true_branch` and `false_branch` on `condition`.

  Args:
    condition: A boolean Tensor with shape [B_1, ..., B_N].
    full_input: A Tensor or nested tuple of Tensors of any dtype, each with
      shape [B_1, ..., B_N, ...], to be split between `true_branch` and
      `false_branch` based on `condition`.
    true_branch: A function taking a single argument, that argument having the
      same structure and number of batch dimensions as `full_input`. Receives
      slices of `full_input` corresponding to the True entries of
      `condition`. Returns a Tensor or nested tuple of Tensors, each with batch
      dimensions matching its inputs.
    false_branch: Like `true_branch`, but receives inputs corresponding to the
      false elements of `condition`. Returns a Tensor or nested tuple of Tensors
      (with the same structure as the return value of `true_branch`), but with
      batch dimensions matching its inputs.
  Returns:
    Interleaved outputs from `true_branch` and `false_branch`, each Tensor
    having shape [B_1, ..., B_N, ...].
  """
  full_input_flat = nest.flatten(full_input)
  true_indices = tf.where(condition)
  false_indices = tf.where(tf.logical_not(condition))
  true_branch_inputs = nest.pack_sequence_as(
      structure=full_input,
      flat_sequence=[tf.gather_nd(params=input_tensor, indices=true_indices)
                     for input_tensor in full_input_flat])
  false_branch_inputs = nest.pack_sequence_as(
      structure=full_input,
      flat_sequence=[tf.gather_nd(params=input_tensor, indices=false_indices)
                     for input_tensor in full_input_flat])
  true_outputs = true_branch(true_branch_inputs)
  false_outputs = false_branch(false_branch_inputs)
  nest.assert_same_structure(true_outputs, false_outputs)
  def scatter_outputs(true_output, false_output):
    batch_shape = tf.shape(condition)
    scattered_shape = tf.concat(
        [batch_shape, tf.shape(true_output)[tf.rank(batch_shape):]],
        0)
    true_scatter = tf.scatter_nd(
        indices=tf.cast(true_indices, tf.int32),
        updates=true_output,
        shape=scattered_shape)
    false_scatter = tf.scatter_nd(
        indices=tf.cast(false_indices, tf.int32),
        updates=false_output,
        shape=scattered_shape)
    return true_scatter + false_scatter
  result = nest.pack_sequence_as(
      structure=true_outputs,
      flat_sequence=[
          scatter_outputs(true_single_output, false_single_output)
          for true_single_output, false_single_output
          in zip(nest.flatten(true_outputs), nest.flatten(false_outputs))])
  return result

Some examples:

vector_test = slicing_where(
    condition=tf.equal(tf.range(10) % 2, 0),
    full_input=tf.range(10, dtype=tf.float32),
    true_branch=lambda x: 0.2 + x,
    false_branch=lambda x: 0.1 + x)

cross_range = (tf.range(10, dtype=tf.float32)[:, None]
               * tf.range(10, dtype=tf.float32)[None, :])
matrix_test = slicing_where(
    condition=tf.equal(tf.range(10) % 3, 0),
    full_input=cross_range,
    true_branch=lambda x: -x,
    false_branch=lambda x: x + 0.1)

with tf.Session():
  print(vector_test.eval())
  print(matrix_test.eval())

Prints:

[ 0.2         1.10000002  2.20000005  3.0999999   4.19999981  5.0999999
  6.19999981  7.0999999   8.19999981  9.10000038]
[[  0.           0.           0.           0.           0.           0.
    0.           0.           0.           0.        ]
 [  0.1          1.10000002   2.0999999    3.0999999    4.0999999
    5.0999999    6.0999999    7.0999999    8.10000038   9.10000038]
 [  0.1          2.0999999    4.0999999    6.0999999    8.10000038
   10.10000038  12.10000038  14.10000038  16.10000038  18.10000038]
 [  0.          -3.          -6.          -9.         -12.         -15.
  -18.         -21.         -24.         -27.        ]
 [  0.1          4.0999999    8.10000038  12.10000038  16.10000038
   20.10000038  24.10000038  28.10000038  32.09999847  36.09999847]
 [  0.1          5.0999999   10.10000038  15.10000038  20.10000038
   25.10000038  30.10000038  35.09999847  40.09999847  45.09999847]
 [  0.          -6.         -12.         -18.         -24.         -30.
  -36.         -42.         -48.         -54.        ]
 [  0.1          7.0999999   14.10000038  21.10000038  28.10000038
   35.09999847  42.09999847  49.09999847  56.09999847  63.09999847]
 [  0.1          8.10000038  16.10000038  24.10000038  32.09999847
   40.09999847  48.09999847  56.09999847  64.09999847  72.09999847]
 [  0.          -9.         -18.         -27.         -36.         -45.
  -54.         -63.         -72.         -81.        ]]

Upvotes: 10

Related Questions