NicoJ
NicoJ

Reputation: 98

Tensorflow: using argmax to slice a tensor

I have a tensor with shape tf.shape(t1) = [1, 1000, 400] and I obtain the indices of the maxima on the 3rd dimension using max_ind = tf.argmax(t1, axis=-1) which has shape [1, 1000]. Now I have a second tensor that has the same shape as t1: tf.shape(t2) = [1, 1000, 400].

I want to use the maxima indices from t1 to slice t2 so the output has the form

[1, 1000]

A more visual description: The resulting tensor should be like the result of tf.reduce_max(t2, axis=-1) but with the location of the maxima in t1

Upvotes: 1

Views: 979

Answers (2)

lovetl2002
lovetl2002

Reputation: 1068

You can use tf.reshape to avoid generating index grid. For me this is more easy to understand (especially if you are familiar with reshape logic). The basic idea is to unfold the tensor to a 2D matrix.

#generate a 3D tensor as example
a  = tf.random.uniform(shape=(8,9,10), minval=0, maxval=100, dtype=tf.int32)

ix = tf.argmax(a, axis=-1, output_type=tf.int64)
#indices for unfolded a: 
ix2 = tf.stack((tf.range(8*9, dtype=tf.int64), tf.reshape(ix, (8*9,))), axis=1)

#get values based on ix2
b = tf.reshape(tf.gather_nd(tf.reshape(a, (8*9, -1)), ix2), (8,9))

#verify
tf.reduce_all(tf.reduce_max(a, -1) == b)
#<tf.Tensor: shape=(), dtype=bool, numpy=True>

Upvotes: 0

P-Gn
P-Gn

Reputation: 24581

You can achieve this through tf.gather_nd, although it is not really straightforward. For example,

shape = t1.shape.as_list()
xy_ind = np.stack(np.mgrid[:shape[0], :shape[1]], axis=-1)
gather_ind = tf.concat([xy_ind, max_ind[..., None]], axis=-1)
sliced_t2 = tf.gather_nd(t2, gather_ind)

If on the other hand the shape of your input is unknown as graph construction time, you could use

shape = tf.shape(t1)
xy_ind = tf.stack(tf.meshgrid(tf.range(shape[0]), tf.range(shape[1]),
                              indexing='ij'), axis=-1)

and the remainder is the same as above.

Upvotes: 2

Related Questions