Reputation: 4117
I want to use tf.cond(pred, fn1, fn2, name=None) for conditional branching. Let say I have two tensors: x, y
. Each tensor is a batch of 0/1 and I want to use this tensors compression x < y
as the source for
tf.cond pred
argument:
pred: A scalar determining whether to return the result of fn1 or fn2.
But if I am working with batches then it looks like I need to iterate over the source tensor inside the graph and make slices for every item in batch and apply tf.cond for every item. Looks suspiciously as for me. Why tf.cond not accept batch and only scalar? Can you advise what is the right way to use it with batch?
Upvotes: 13
Views: 5040
Reputation: 5808
tf.where sounds like what you want: a vectorized selection between Tensors.
tf.cond
is a control flow modifier: it determines which ops are executed, and so it's difficult to think of useful batch semantics.
We can also put together a mixture of these operations: an operation which slices based on a condition and passes those slices to two branches.
import tensorflow as tf
from tensorflow.python.util import nest
def slicing_where(condition, full_input, true_branch, false_branch):
"""Split `full_input` between `true_branch` and `false_branch` on `condition`.
Args:
condition: A boolean Tensor with shape [B_1, ..., B_N].
full_input: A Tensor or nested tuple of Tensors of any dtype, each with
shape [B_1, ..., B_N, ...], to be split between `true_branch` and
`false_branch` based on `condition`.
true_branch: A function taking a single argument, that argument having the
same structure and number of batch dimensions as `full_input`. Receives
slices of `full_input` corresponding to the True entries of
`condition`. Returns a Tensor or nested tuple of Tensors, each with batch
dimensions matching its inputs.
false_branch: Like `true_branch`, but receives inputs corresponding to the
false elements of `condition`. Returns a Tensor or nested tuple of Tensors
(with the same structure as the return value of `true_branch`), but with
batch dimensions matching its inputs.
Returns:
Interleaved outputs from `true_branch` and `false_branch`, each Tensor
having shape [B_1, ..., B_N, ...].
"""
full_input_flat = nest.flatten(full_input)
true_indices = tf.where(condition)
false_indices = tf.where(tf.logical_not(condition))
true_branch_inputs = nest.pack_sequence_as(
structure=full_input,
flat_sequence=[tf.gather_nd(params=input_tensor, indices=true_indices)
for input_tensor in full_input_flat])
false_branch_inputs = nest.pack_sequence_as(
structure=full_input,
flat_sequence=[tf.gather_nd(params=input_tensor, indices=false_indices)
for input_tensor in full_input_flat])
true_outputs = true_branch(true_branch_inputs)
false_outputs = false_branch(false_branch_inputs)
nest.assert_same_structure(true_outputs, false_outputs)
def scatter_outputs(true_output, false_output):
batch_shape = tf.shape(condition)
scattered_shape = tf.concat(
[batch_shape, tf.shape(true_output)[tf.rank(batch_shape):]],
0)
true_scatter = tf.scatter_nd(
indices=tf.cast(true_indices, tf.int32),
updates=true_output,
shape=scattered_shape)
false_scatter = tf.scatter_nd(
indices=tf.cast(false_indices, tf.int32),
updates=false_output,
shape=scattered_shape)
return true_scatter + false_scatter
result = nest.pack_sequence_as(
structure=true_outputs,
flat_sequence=[
scatter_outputs(true_single_output, false_single_output)
for true_single_output, false_single_output
in zip(nest.flatten(true_outputs), nest.flatten(false_outputs))])
return result
Some examples:
vector_test = slicing_where(
condition=tf.equal(tf.range(10) % 2, 0),
full_input=tf.range(10, dtype=tf.float32),
true_branch=lambda x: 0.2 + x,
false_branch=lambda x: 0.1 + x)
cross_range = (tf.range(10, dtype=tf.float32)[:, None]
* tf.range(10, dtype=tf.float32)[None, :])
matrix_test = slicing_where(
condition=tf.equal(tf.range(10) % 3, 0),
full_input=cross_range,
true_branch=lambda x: -x,
false_branch=lambda x: x + 0.1)
with tf.Session():
print(vector_test.eval())
print(matrix_test.eval())
Prints:
[ 0.2 1.10000002 2.20000005 3.0999999 4.19999981 5.0999999
6.19999981 7.0999999 8.19999981 9.10000038]
[[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[ 0.1 1.10000002 2.0999999 3.0999999 4.0999999
5.0999999 6.0999999 7.0999999 8.10000038 9.10000038]
[ 0.1 2.0999999 4.0999999 6.0999999 8.10000038
10.10000038 12.10000038 14.10000038 16.10000038 18.10000038]
[ 0. -3. -6. -9. -12. -15.
-18. -21. -24. -27. ]
[ 0.1 4.0999999 8.10000038 12.10000038 16.10000038
20.10000038 24.10000038 28.10000038 32.09999847 36.09999847]
[ 0.1 5.0999999 10.10000038 15.10000038 20.10000038
25.10000038 30.10000038 35.09999847 40.09999847 45.09999847]
[ 0. -6. -12. -18. -24. -30.
-36. -42. -48. -54. ]
[ 0.1 7.0999999 14.10000038 21.10000038 28.10000038
35.09999847 42.09999847 49.09999847 56.09999847 63.09999847]
[ 0.1 8.10000038 16.10000038 24.10000038 32.09999847
40.09999847 48.09999847 56.09999847 64.09999847 72.09999847]
[ 0. -9. -18. -27. -36. -45.
-54. -63. -72. -81. ]]
Upvotes: 10