Reputation: 98
I have a tensor with shape tf.shape(t1) = [1, 1000, 400]
and I obtain the indices of the maxima on the 3rd dimension using max_ind = tf.argmax(t1, axis=-1)
which has shape [1, 1000]
. Now I have a second tensor that has the same shape as t1
: tf.shape(t2) = [1, 1000, 400]
.
I want to use the maxima indices from t1
to slice t2
so the output has the form
[1, 1000]
A more visual description: The resulting tensor should be like the result of tf.reduce_max(t2, axis=-1)
but with the location of the maxima in t1
Upvotes: 1
Views: 979
Reputation: 1068
You can use tf.reshape
to avoid generating index grid. For me this is more easy to understand (especially if you are familiar with reshape
logic). The basic idea is to unfold the tensor to a 2D matrix.
#generate a 3D tensor as example
a = tf.random.uniform(shape=(8,9,10), minval=0, maxval=100, dtype=tf.int32)
ix = tf.argmax(a, axis=-1, output_type=tf.int64)
#indices for unfolded a:
ix2 = tf.stack((tf.range(8*9, dtype=tf.int64), tf.reshape(ix, (8*9,))), axis=1)
#get values based on ix2
b = tf.reshape(tf.gather_nd(tf.reshape(a, (8*9, -1)), ix2), (8,9))
#verify
tf.reduce_all(tf.reduce_max(a, -1) == b)
#<tf.Tensor: shape=(), dtype=bool, numpy=True>
Upvotes: 0
Reputation: 24581
You can achieve this through tf.gather_nd
, although it is not really straightforward. For example,
shape = t1.shape.as_list()
xy_ind = np.stack(np.mgrid[:shape[0], :shape[1]], axis=-1)
gather_ind = tf.concat([xy_ind, max_ind[..., None]], axis=-1)
sliced_t2 = tf.gather_nd(t2, gather_ind)
If on the other hand the shape of your input is unknown as graph construction time, you could use
shape = tf.shape(t1)
xy_ind = tf.stack(tf.meshgrid(tf.range(shape[0]), tf.range(shape[1]),
indexing='ij'), axis=-1)
and the remainder is the same as above.
Upvotes: 2