Reputation: 95
Case 1 (solved): Array A has shape (say) (300,50). Array B is an indices array with the shape (300,5), such that B[i,j]
indicate for the row i
the index of another row to "concate" next to the row i
. The end result is an array C with the shape (300,5,50), such that C[i,j,:] = A[B[i,j],:]
. This can be done by calling A[B,:]
.
Here is small script example for case 1:
import numpy as np
## A is the data array
A = np.arange(20).reshape((5,4))
## B indicate for each row which rows to pull together
B = np.array([[0,2],[1,2],[2,0],[3,4],[4,1]])
A[B,:] #The desired result
Case 2 (unsolved): Same problem, only now A is shaped (100,300,50). If B is the indicies matrix shaped (100,300,5), the end result would be an array C with the shape (100,300,5,50) such that C[i,j,k,:] = A[i,B[i,j,k],:]
. A[B,:]
doesn't work anymore, because it result with a shape (100,300,5,300,50), due to broadcasting.
How should I approach this with indexing?
Upvotes: 1
Views: 80
Reputation: 221584
One approach would be reshaping to 2D
keeping the number of columns intact and then indexing into the first axis with the flattened B
indices and finally reshaping back to the desired one.
Thus, the implementation would be -
A.reshape(-1,A.shape[-1])[B.ravel()].reshape(100,300,5,50)
Those reshaping being merely views into the arrays, should be quite efficient.
This solves both cases. Here's a sample run for the case #1 -
1) Inputs :
In [667]: A = np.random.rand(3,4)
...: B = np.random.randint(0,3,(3,5))
...:
2) Original method :
In [668]: A[B,:]
Out[668]:
array([[[ 0.1 , 0.91, 0.1 , 0.98],
[ 0.1 , 0.91, 0.1 , 0.98],
[ 0.1 , 0.91, 0.1 , 0.98],
[ 0.45, 0.16, 0.02, 0.02],
[ 0.1 , 0.91, 0.1 , 0.98]],
[[ 0.45, 0.16, 0.02, 0.02],
[ 0.48, 0.6 , 0.96, 0.21],
[ 0.48, 0.6 , 0.96, 0.21],
[ 0.1 , 0.91, 0.1 , 0.98],
[ 0.45, 0.16, 0.02, 0.02]],
[[ 0.48, 0.6 , 0.96, 0.21],
[ 0.45, 0.16, 0.02, 0.02],
[ 0.48, 0.6 , 0.96, 0.21],
[ 0.45, 0.16, 0.02, 0.02],
[ 0.45, 0.16, 0.02, 0.02]]])
3) Proposed method :
In [669]: A.reshape(-1,A.shape[-1])[B.ravel()].reshape(3,5,4)
Out[669]:
array([[[ 0.1 , 0.91, 0.1 , 0.98],
[ 0.1 , 0.91, 0.1 , 0.98],
[ 0.1 , 0.91, 0.1 , 0.98],
[ 0.45, 0.16, 0.02, 0.02],
[ 0.1 , 0.91, 0.1 , 0.98]],
[[ 0.45, 0.16, 0.02, 0.02],
[ 0.48, 0.6 , 0.96, 0.21],
[ 0.48, 0.6 , 0.96, 0.21],
[ 0.1 , 0.91, 0.1 , 0.98],
[ 0.45, 0.16, 0.02, 0.02]],
[[ 0.48, 0.6 , 0.96, 0.21],
[ 0.45, 0.16, 0.02, 0.02],
[ 0.48, 0.6 , 0.96, 0.21],
[ 0.45, 0.16, 0.02, 0.02],
[ 0.45, 0.16, 0.02, 0.02]]])
Upvotes: 2