ganto
ganto

Reputation: 222

How to index row elements of a Matrix with a Matrix of indices for each row?

I have a Matrix of indices I e.g.

I = np.array([[1, 0, 2], [2, 1, 0]])

The index at i-th row selects an element from another Matrix M in the i-th row.

So having M e.g.

M = np.array([[6, 7, 8], [9, 10, 11])

M[I] should select:

[[7, 6, 8], [11, 10, 9]]

I could have:

I1 = np.repeat(np.arange(0, I.shape[0]), I.shape[1])
I2 = np.ravel(I)
Result = M[I1, I2].reshape(I.shape)

but this looks very complicated and I am looking for a more elegant solution. Preferably without flattening and reshaping.

In the example I used numpy, but I am actually using jax. So if there is a more efficient solution in jax, feel free to share.

Upvotes: 2

Views: 678

Answers (3)

Mustafa Aydın
Mustafa Aydın

Reputation: 18315

np.take_along_axis can also be used here to take values of M using indices I over axis=1:

>>> np.take_along_axis(M, I, axis=1)

array([[ 7,  6,  8],
       [11, 10,  9]])

Upvotes: 0

hpaulj
hpaulj

Reputation: 231665

In [108]: I = np.array([[1, 0, 2], [2, 1, 0]])
     ...: M = np.array([[6, 7, 8], [9, 10, 11]])
     ...: 
     ...: I,M

I had to add a ']' to M.

Out[108]: 
(array([[1, 0, 2],
        [2, 1, 0]]),
 array([[ 6,  7,  8],
        [ 9, 10, 11]]))

Advanced indexing with broadcasting:

In [110]: M[np.arange(2)[:,None],I]
Out[110]: 
array([[ 7,  6,  8],
       [11, 10,  9]])

THe first index has shape (2,1) which pairs with the (2,3) shape of I to select a (2,3) block of values.

Upvotes: 1

Shaun Han
Shaun Han

Reputation: 2777

How about this one line code? The idea is to enumerate both the rows and the row indices of the matrix, so you can access the corresponding rows in the indexing matrix.

import numpy as np

I = np.array([[1, 0, 2], [2, 1, 0]])
M = np.array([[6, 7, 8], [9, 10, 11]])

Result = np.array([row[I[i]] for i, row in enumerate(M)])
print(Result)

Output:

[[ 7  6  8]
 [11 10  9]]

Upvotes: 1

Related Questions