Luca Di Liello
Luca Di Liello

Reputation: 1643

Index multidimensional torch tensor with array of variable length

I have a list of indices and a tensor with shape:

shape = [batch_size, d_0, d_1, ..., d_k]
idx = [i_0, i_1, ..., i_k]

Is there a way to efficiently index the tensor on each dim d_0, ..., d_k with the indices i_0, ..., i_k? (k is available only at run time)

The result should be:

tensor[:, i_0, i_1, ..., i_k] #tensor.shape = [batch_size]

At the moment I'm creating a tuple of slices, one for each dimension:

idx = (slice(tensor.shape[0]),) + tuple(slice(i, i+1) for i in idx)
tensor[idx]

but I would prefer something like:

tensor[:, *idx]

Example:

a = torch.randint(0,10,[3,3,3,3])
indexes = torch.LongTensor([1,1,1])

I would like to index only the last len(indexes) dimensions like:

a[:, indexes[0], indexes[1], indexes[2]]

but in the general case where I don't know how long indexes is.


Note: this answer does not help since it indexes all the dimensions, and does not work for a proper subset!

Upvotes: 3

Views: 1111

Answers (1)

iacob
iacob

Reputation: 24171

Unfortunately you can't provide1 a mix of slices and iterators to an indexing (e.g. a[:,*idx]). However, you can achieve almost the same thing by wrapping it in brackets to cast to an iterator:

a[(slice(None), *idx)]

  1. In Python, x[(exp1, exp2, ..., expN)] is equivalent to x[exp1, exp2, ..., expN]; the latter is just syntactic sugar for the former.

    https://numpy.org/doc/stable/reference/arrays.indexing.html

Upvotes: 2

Related Questions