Reputation: 63
I have a tensor probs
which has shape (None, None, 110)
representing (batch_size, sequence_length, 110)
in an LSTM.
I have another tensor indices
which has shape (None, None)
, which contains the indices of the elements to select from the third dimension of probs
.
I want to use indices
to index the tensor probs
.
Numpy equivalent:
k, j = np.meshgrid(np.arange(probs.shape[1]), np.arange(probs.shape[0]))
indexed_probs = probs[j, k, indices]
Since shape[0]
and shape[1]
of probs
is not known, tf.meshgrid()
is not an option.
I found tf.gather
, tf.gather_nd
and tf.batch_gather
, but they all don't seem to do what I want.
Does anybody know how to do this?
Upvotes: 2
Views: 392
Reputation: 59731
You can do that with tf.gather_nd
like this:
indexed_probs = tf.gather_nd(probs, tf.expand_dims(indices, axis=-1), batch_dims=2)
By the way, in NumPy you can use np.take_along_axis
to do the same:
indexed_probs = np.take_along_axis(probs, np.expand_dims(indices, axis=-1), axis=-1)[..., 0]
Upvotes: 2