Kion
Kion

Reputation: 1

Taking a Different Subset of Indices per Row in Numpy Using Fancy Indexing

I have an array of pairwise differences of features:

diff.shape = (200, 200, 2)

of which I am trying to take only the columns corresponding to the 50 closest points. For each row, I have the indices of the closest 50 points stored as:

dist_idx.shape = (200, 50).

How can I index the 50 closest entries (different indices per row) using fancy indexing? I have tried:

diff[dist_idx].shape = (200, 50, 200, 2)
diff[np.arange(200), dist_idx] -> IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (200,) (200,50) 
diff[np.arange(200), dist_idx[np.arange(200)]] -> IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (200,) (200,50) 

An iterative solution that works is:

X_diff = np.zeros((200, 50, 2))
for i in range(200):
    X_diff[i] = diff[i, dist_idx[i]]

Upvotes: 0

Views: 189

Answers (1)

hpaulj
hpaulj

Reputation: 231665

Make a smaller example

In [159]: arr = np.arange(24).reshape(3, 4, 2)
In [160]: arr
Out[160]: 
array([[[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7]],

       [[ 8,  9],
        [10, 11],
        [12, 13],
        [14, 15]],

       [[16, 17],
        [18, 19],
        [20, 21],
        [22, 23]]])
In [161]: idx = np.array([[0, 1], [0, 2], [1, 3]])
In [162]: idx.shape
Out[162]: (3, 2)

Your iterative approach:

In [164]: out = np.zeros((3, 2, 2), int)
In [165]: for i in range(3):
     ...:     out[i] = arr[i, idx[i]]
In [166]: out
Out[166]: 
array([[[ 0,  1],
        [ 2,  3]],

       [[ 8,  9],
        [12, 13]],

       [[18, 19],
        [22, 23]]])

your first try doesn't work because it is applying the idx to the 1st dimension of the array. In my case I get an error because the 2 dimensions don't match

In [167]: arr[idx]
Traceback (most recent call last):
  Input In [167] in <module>
    arr[idx]
IndexError: index 3 is out of bounds for axis 0 with size 3

But if we uses a (3,1) array as the first index, it pairs nicely with the (3,2) idx array for the 2nd.

In [168]: arr[np.arange(3)[:, None], idx, :]
Out[168]: 
array([[[ 0,  1],
        [ 2,  3]],

       [[ 8,  9],
        [12, 13]],

       [[18, 19],
        [22, 23]]])

Upvotes: 0

Related Questions