Julz
Julz

Reputation: 93

Efficient multidimensional indexing in numpy

I would like to "sort" a multidimensional numpy array based on a multidimensional index I have.

So let's start with an example:

# The array I would like to sort
A = np.ones([3, 10, 2])

# The index array
i = np.ones([10, 2], dtype=int)

# Bring some life into sample data
samples = np.arange(10, dtype=int)
A[2, :, 0] = A[2, :, 0] * samples
np.random.shuffle(samples)
A[2, :, 1] = A[2, :, 1] * samples  
i[:, 0] = i[:, 0] * samples
np.random.shuffle(samples)
i[:, 1] = i[:, 1] * samples

So my array A contains 2 slices of 10 sets of 3 values. What I want to do is to sort each slice individually while keeping each set together. Having the index array i, my solution is:

A = A[:, i]
shape = A.shape
A = A.reshape([shape[0], shape[1], shape[2] + shape[3]])
A = A[:, :, [0, 3]]

where I first use i to index A. This creates a new dimension where each column of i is applied to A ending up with an array of shape (4, 10, 2, 2). Since I only need two out of the 4 results I reshape the array and drop the information I don't need.

This approach works fine, but I wonder if there is a more efficient or elegant solution for this.

Julz

Upvotes: 2

Views: 71

Answers (1)

yatu
yatu

Reputation: 88226

You can use advanced indexing here:

A[:,i, np.arange(A.shape[2])]

Comparing with the current approach -

out = A[:, i]
shape = out .shape
out  = out .reshape([shape[0], shape[1], shape[2] + shape[3]])
out  = out [:, :, [0, 3]]

np.allclose(A[:,i, np.arange(A.shape[2])], out)
#True

Upvotes: 1

Related Questions