Reputation: 222
I have a Matrix of indices I e.g.
I = np.array([[1, 0, 2], [2, 1, 0]])
The index at i-th row selects an element from another Matrix M in the i-th row.
So having M e.g.
M = np.array([[6, 7, 8], [9, 10, 11])
M[I] should select:
[[7, 6, 8], [11, 10, 9]]
I could have:
I1 = np.repeat(np.arange(0, I.shape[0]), I.shape[1])
I2 = np.ravel(I)
Result = M[I1, I2].reshape(I.shape)
but this looks very complicated and I am looking for a more elegant solution. Preferably without flattening and reshaping.
In the example I used numpy, but I am actually using jax. So if there is a more efficient solution in jax, feel free to share.
Upvotes: 2
Views: 678
Reputation: 18315
np.take_along_axis
can also be used here to take values of M
using indices I
over axis=1
:
>>> np.take_along_axis(M, I, axis=1)
array([[ 7, 6, 8],
[11, 10, 9]])
Upvotes: 0
Reputation: 231665
In [108]: I = np.array([[1, 0, 2], [2, 1, 0]])
...: M = np.array([[6, 7, 8], [9, 10, 11]])
...:
...: I,M
I had to add a ']' to M.
Out[108]:
(array([[1, 0, 2],
[2, 1, 0]]),
array([[ 6, 7, 8],
[ 9, 10, 11]]))
Advanced indexing with broadcasting
:
In [110]: M[np.arange(2)[:,None],I]
Out[110]:
array([[ 7, 6, 8],
[11, 10, 9]])
THe first index has shape (2,1) which pairs with the (2,3) shape of I
to select a (2,3) block of values.
Upvotes: 1
Reputation: 2777
How about this one line code? The idea is to enumerate both the rows and the row indices of the matrix, so you can access the corresponding rows in the indexing matrix.
import numpy as np
I = np.array([[1, 0, 2], [2, 1, 0]])
M = np.array([[6, 7, 8], [9, 10, 11]])
Result = np.array([row[I[i]] for i, row in enumerate(M)])
print(Result)
Output:
[[ 7 6 8]
[11 10 9]]
Upvotes: 1