Pritam
Pritam

Reputation: 333

Numpy selecting specific columns from each row

I have matrix A of shape (p, q, r, r) and another matrix I of shape (r, s). I want to select only s of the r elements from the last dimension of A, so that the shape of the new matrix becomes (p, q, r, s).

To simplify (ignoring the first two dimensions), let

>>> A
array([[5, 2, 5, 7],
       [2, 6, 4, 3],
       [4, 2, 3, 9],
       [6, 2, 4, 3]])
>>> I
array([[1, 2],
       [2, 2],
       [3, 1],
       [2, 1]])

I want the matrix

array([[2, 5],
       [4, 4],
       [9, 2],
       [4, 2]])

How can I do it? A[..., I] gives a (4, 4, 2) matrix, selecting elements located at I from each row. I can solve the problem by

>>> c = np.arange(4)
>>> A[..., I][c, c, :]
array([[2, 5],
       [4, 4],
       [9, 2],
       [4, 2]])

But I think it requires a lot of computation. Is there any more efficient way to solve this issue?

Edit: For higher dimensional example, consider I to be same as before, and

A
array([[[[12, 15,  6, 12],
         [16, 16,  4, 17],
         [ 6, 19, 10,  9],
         [ 5, 11, 18, 17]],

        [[13, 12,  5,  6],
         [12,  7,  5,  4],
         [ 9, 19, 12,  4],
         [15,  4, 16,  7]],

        [[13,  6,  5, 17],
         [ 8,  4, 10,  9],
         [ 3, 13, 16,  4],
         [ 3,  3,  4,  4]]],


       [[[ 8,  3,  8, 18],
         [ 7, 11,  8,  7],
         [10,  8, 14,  9],
         [ 8, 12, 16,  5]],

        [[ 9, 10, 10,  7],
         [11,  6, 10,  6],
         [16, 19, 10, 14],
         [ 9, 13, 13, 19]],

        [[10,  8, 19, 12],
         [ 9, 10, 17, 19],
         [ 4, 11, 12, 14],
         [ 8,  5, 16, 10]]]])

Expected output:

array([[[[15,  6],
         [ 4,  4],
         [ 9, 19],
         [18, 11]],

        [[12,  5],
         [ 5,  5],
         [ 4, 19],
         [16,  4]],

        [[ 6,  5],
         [10, 10],
         [ 4, 13],
         [ 4,  3]]],


       [[[ 3,  8],
         [ 8,  8],
         [ 9,  8],
         [16, 12]],

        [[10, 10],
         [10, 10],
         [14, 19],
         [13, 13]],

        [[ 8, 19],
         [17, 17],
         [14, 11],
         [16,  5]]]]

A[...,I][..., c, c, :] yield this result

Upvotes: 1

Views: 277

Answers (1)

yatu
yatu

Reputation: 88226

Since you're using integer array indexing, you'll need to specify which rows you want to select those columns from:

A[np.arange(A.shape[0])[:,None], I]
array([[2, 5],
       [4, 4],
       [9, 2],
       [4, 2]])

Or you also have np.take_along_axis:

np.take_along_axis(A, I, 1)

For a larger array of shape (p, q, r, r), take full slices along the front axes, and use broadcasting in a similar way:

A[...,np.arange(A.shape[2])[:,None],I]
array([[[[15,  6],
         [ 4,  4],
         [ 9, 19],
         [18, 11]],

        [[12,  5],
         [ 5,  5],
         [ 4, 19],
         [16,  4]],
        ...

Upvotes: 2

Related Questions