Denys Fridman
Denys Fridman

Reputation: 23

Selecting exactly one element along the specified dimension in Tensorflow

I have 2 tensors, namely X of shape (?, 32, 500) and indices of shape (?,). For both tensors, the 0th dimension is a batch dimension. Each element of indices specifies the index of X along the 1st dimension to select. In the end, I'd like to get a tensor of shape (?, 500). In numpy I would do it this way:

X[np.arange(len(X)), indices]

Does anyone know how to achieve the same in tensorflow (version 1)? I already looked at some examples of tf.gather and tf.gather_nd, but couldn't get my head around it. Thanks!

Upvotes: 2

Views: 338

Answers (1)

Mustafa Aydın
Mustafa Aydın

Reputation: 18306

We can use tf.range, tf.stack and tf.gather_nd:

def fancy_index_arange(X, indices):
    arange = tf.range(len(X))
    fancy_index = tf.stack([arange, indices], axis=1)
    result = tf.gather_nd(X, fancy_index)
    return result

verify shape:

>>> X = tf.random.normal((10, 32, 500))
>>> indices = tf.random.uniform((10,), minval=0, maxval=32, dtype=tf.int32)

>>> fancy_index_arange(X, indices).shape
TensorShape([10, 500])

tested with tf.__version__ == "2.3.0"

Upvotes: 1

Related Questions