Reputation: 93
I would like to "sort" a multidimensional numpy array based on a multidimensional index I have.
So let's start with an example:
# The array I would like to sort
A = np.ones([3, 10, 2])
# The index array
i = np.ones([10, 2], dtype=int)
# Bring some life into sample data
samples = np.arange(10, dtype=int)
A[2, :, 0] = A[2, :, 0] * samples
np.random.shuffle(samples)
A[2, :, 1] = A[2, :, 1] * samples
i[:, 0] = i[:, 0] * samples
np.random.shuffle(samples)
i[:, 1] = i[:, 1] * samples
So my array A
contains 2 slices of 10 sets of 3 values. What I want to do is to sort each slice individually while keeping each set together.
Having the index array i
, my solution is:
A = A[:, i]
shape = A.shape
A = A.reshape([shape[0], shape[1], shape[2] + shape[3]])
A = A[:, :, [0, 3]]
where I first use i
to index A
. This creates a new dimension where each column of i
is applied to A
ending up with an array of shape (4, 10, 2, 2)
. Since I only need two out of the 4 results I reshape the array and drop the information I don't need.
This approach works fine, but I wonder if there is a more efficient or elegant solution for this.
Julz
Upvotes: 2
Views: 71
Reputation: 88226
You can use advanced indexing
here:
A[:,i, np.arange(A.shape[2])]
Comparing with the current approach -
out = A[:, i]
shape = out .shape
out = out .reshape([shape[0], shape[1], shape[2] + shape[3]])
out = out [:, :, [0, 3]]
np.allclose(A[:,i, np.arange(A.shape[2])], out)
#True
Upvotes: 1