Clara Tump
Clara Tump

Reputation: 63

Indexing k-th dimension of tensor with another tensor in Tensorflow 2.0

I have a tensor probs which has shape (None, None, 110) representing (batch_size, sequence_length, 110) in an LSTM. I have another tensor indices which has shape (None, None), which contains the indices of the elements to select from the third dimension of probs.

I want to use indices to index the tensor probs.

Numpy equivalent:

k, j = np.meshgrid(np.arange(probs.shape[1]), np.arange(probs.shape[0]))
indexed_probs = probs[j, k, indices]

Since shape[0] and shape[1] of probs is not known, tf.meshgrid() is not an option. I found tf.gather, tf.gather_nd and tf.batch_gather, but they all don't seem to do what I want.

Does anybody know how to do this?

Upvotes: 2

Views: 392

Answers (1)

javidcf
javidcf

Reputation: 59731

You can do that with tf.gather_nd like this:

indexed_probs = tf.gather_nd(probs, tf.expand_dims(indices, axis=-1), batch_dims=2)

By the way, in NumPy you can use np.take_along_axis to do the same:

indexed_probs = np.take_along_axis(probs, np.expand_dims(indices, axis=-1), axis=-1)[..., 0]

Upvotes: 2

Related Questions