Reputation: 109
I am trying to perform a local reduce with specified slices over a single axis on a 2D array.
I achieved this using numpy's numpy.ufunc.reduceat
or numpy.add.reduceat
but I would like do the same in tensorflow as the input to this reduce operation is an output from tensorflow convolution.
I came across tf.math.reduce_sum
but I am not sure how this can be used in my case.
It will be great if I can do the reduceat
operation in tensorflow as I can take advantage of a GPU.
Upvotes: 0
Views: 200
Reputation: 59731
You can do almost the same using tf.math.segment_sum
:
import tensorflow as tf
import numpy as np
def add_reduceat_tf(a, indices, axis=0):
a = tf.convert_to_tensor(a)
indices = tf.convert_to_tensor(indices)
# Transpose if necessary
transpose = not (isinstance(axis, int) and axis == 0)
if transpose:
axis = tf.convert_to_tensor(axis)
ndims = tf.cast(tf.rank(a), axis.dtype)
a = tf.transpose(a, tf.concat([[axis], tf.range(axis),
tf.range(axis + 1, ndims)], axis=0))
# Make segment ids
r = tf.range(tf.shape(a, out_type=indices.dtype)[0])
segments = tf.searchsorted(indices, r, side='right')
# Compute segmented sum and discard first unused segment
out = tf.math.segment_sum(a, segments)[1:]
# Transpose back if necessary
if transpose:
out = tf.transpose(out, tf.concat([tf.range(1, axis + 1), [0],
tf.range(axis + 1, ndims)], axis=0))
return out
# Test
np.random.seed(0)
a = np.random.rand(5, 10).astype(np.float32)
indices = [2, 4, 7]
axis = 1
# NumPy computation
out_np = np.add.reduceat(a, indices, axis=axis)
# TF computation
with tf.Graph().as_default(), tf.Session() as sess:
out = add_reduceat_tf(a, indices, axis=axis)
out_tf = sess.run(out)
# Check result
print(np.allclose(out_np, out_tf))
# True
You can replace tf.math.segment_sum
above with the reduction function you want to use. The only difference between this and the actual np.ufunc.reduceat
is the special case where indices[i] >= indices[i + 1]
. The posted function requires indices
to be sorted, and if there were a case where indices[i] == indices[i + 1]
the corresponding i
position in the output would be zero, not a[indices[i]]
.
Upvotes: 1