ben1806
ben1806

Reputation: 174

Reduce sum with condition in tensorflow

I am given a 2D Tensor with stochastic rows. After applying tf.math.greater() and tf.cast(tf.int32) I am left with a Tensor with 0's and 1's. I now want to apply reduce sum onto that matrix but with a condition: If there was at least one 1 summed and a 0 follows I want to remove all following 1 aswell, meaning 1 0 1 should result in 1 instead of 2.

I have tried to solve the Problem with tf.scan(), but I was not able to come up with a function yet that is able to handle starting 0's, because the row might look like: 0 0 0 1 0 1 One idea was to set the lower part of the matrix to one (bc I know everything left from the diagonal will always be 0) and then have a function like tf.scan() run to filter out the spots (see code and error message below).

Let z be the matrix after tf.cast.

helper = tf.matrix_band_part(tf.ones_like(z), -1, 0)
z = tf.math.logical_or(tf.cast(z, tf.bool), tf.cast(helper,tf.bool))
z = tf.cast(z, tf.int32)
z = tf.scan(lambda a, x: x if a == 1 else 0 ,z)

Resulting in:

ValueError: Incompatible shape for value ([]), expected ([5])

Upvotes: 4

Views: 1528

Answers (1)

javidcf
javidcf

Reputation: 59731

IIUC, this is one way to do what you want without scanning or looping. It may be a bit convoluted, and is actually iterating the columns twice (one cumsum and one cumprod), but being vectorized operations I think it is probably faster. Code is TF 2.x but runs the same in TF 1.x (except for the last line obviously).

import tensorflow as tf

# Example data
a = tf.constant([[0, 0, 0, 0],
                 [1, 0, 0, 0],
                 [0, 1, 1, 0],
                 [0, 1, 0, 1],
                 [1, 1, 1, 0],
                 [1, 1, 0, 1],
                 [0, 1, 1, 1],
                 [1, 1, 1, 1]])
# Cumsum columns
c = tf.math.cumsum(a, axis=1)
# Column-wise differences
diffs = tf.concat([tf.ones([tf.shape(c)[0], 1], c.dtype), c[:, 1:] - c[:, :-1]], axis=1)
# Find point where we should not sum anymore (cumsum is not zero and difference is zero)
cutoff = tf.equal(a, 0) & tf.not_equal(c, 0)
# Make mask
mask = tf.math.cumprod(tf.dtypes.cast(~cutoff, tf.uint8), axis=1)
# Compute result
result = tf.reduce_max(c * tf.dtypes.cast(mask, c.dtype), axis=1)
print(result.numpy())
# [0 1 2 1 3 2 3 4]

Upvotes: 2

Related Questions