Hoeze
Hoeze

Reputation: 716

NumPy multi-dimensional array indexing

Assume the following example:

>>> a = np.random.randint(0, 10, (3, 10, 200))
>>> print(a.shape)
(3, 10, 200)
>>> 
>>> idx = np.random.randint(0, 3, 10)
>>> print(idx)
[2, 0, 0, 0, 1, 2, 1, 2, 0, 0]

ais an array of shape (K=3, J=10, I=200).

idx is an array of the same length as a.shape[1], i.e. contains J = 10 elements. Each index denotes which element of K should be chosen.

Now I'd like to select from the first axis (K) by the indices idx to get an array of shape (J=10, I=200) back.

How can I accomplish this?

Upvotes: 1

Views: 53

Answers (1)

Divakar
Divakar

Reputation: 221684

We are using idx to index along the first axis, while selecting per element along the second axis and all along the last one. Thus, we can use advanced-indexing, like so -

a[idx, np.arange(len(idx)),:]

Skipping the trailing : gives us a shorter version -

a[idx, np.arange(len(idx))]

Upvotes: 2

Related Questions