Reputation: 1643
I have a list of indices and a tensor with shape:
shape = [batch_size, d_0, d_1, ..., d_k]
idx = [i_0, i_1, ..., i_k]
Is there a way to efficiently index the tensor on each dim d_0, ..., d_k
with the indices i_0, ..., i_k
? (k
is available only at run time)
The result should be:
tensor[:, i_0, i_1, ..., i_k] #tensor.shape = [batch_size]
At the moment I'm creating a tuple of slices, one for each dimension:
idx = (slice(tensor.shape[0]),) + tuple(slice(i, i+1) for i in idx)
tensor[idx]
but I would prefer something like:
tensor[:, *idx]
Example:
a = torch.randint(0,10,[3,3,3,3])
indexes = torch.LongTensor([1,1,1])
I would like to index only the last len(indexes) dimensions like:
a[:, indexes[0], indexes[1], indexes[2]]
but in the general case where I don't know how long indexes
is.
Note: this answer does not help since it indexes all the dimensions, and does not work for a proper subset!
Upvotes: 3
Views: 1111
Reputation: 24171
Unfortunately you can't provide1 a mix of slices and iterators to an indexing (e.g. a[:,*idx]
). However, you can achieve almost the same thing by wrapping it in brackets to cast to an iterator:
a[(slice(None), *idx)]
In Python,
x[(exp1, exp2, ..., expN)]
is equivalent tox[exp1, exp2, ..., expN]
; the latter is just syntactic sugar for the former.
Upvotes: 2