Reputation: 9645
I have a 3D array a
of data and a 2D array b
of indices. I need to take a sub-array of a
along the 3rd axis, using the indices from b
. I can do it with take
like this:
a = np.arange(24).reshape((2,3,4))
b = np.array([0,2,1,3]).reshape((2,2))
np.array([np.take(a_,b_,axis=1) for (a_,b_) in zip(a,b)])
Can I do it without list comprehension, using some fancy indexing? I am worried about efficiency, so if fancy indexing is not more efficient in this case, I would like to know it.
EDIT The 1st thing I've tried is a[[0,1],:,b]
but it doesn't give the sub-array I need
Upvotes: 3
Views: 3203
Reputation: 19947
This is my first try. I will see if I can do better.
#using numpy broadcasting.
np.r_[a[0][:,b[0]],a[1][:,b[1]]].reshape(2,3,2)
Out[300]: In [301]:
array([[[ 0, 2],
[ 4, 6],
[ 8, 10]],
[[13, 15],
[17, 19],
[21, 23]]])
Second try:
#convert both a and b to a 2d array and then slice all rows and only columns determined by b.
a.reshape(6,4)[np.arange(6)[:,None],b.repeat(3,0)].reshape(2,3,2)
Out[429]:
array([[[ 0, 2],
[ 4, 6],
[ 8, 10]],
[[13, 15],
[17, 19],
[21, 23]]])
Upvotes: 1
Reputation: 231335
In [317]: a
Out[317]:
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
In [318]: a = np.arange(24).reshape((2,3,4))
...: b = np.array([0,2,1,3]).reshape((2,2))
...: np.array([np.take(a_,b_,axis=1) for (a_,b_) in zip(a,b)])
...:
Out[318]:
array([[[ 0, 2],
[ 4, 6],
[ 8, 10]],
[[13, 15],
[17, 19],
[21, 23]]])
So you want the 0 & 2 columns from the 1st block, and 1 & 3 from the second.
Make a c
that matches b
in shape, and embodies this observation
In [319]: c=np.array([[0,0],[1,1]])
In [320]: c
Out[320]:
array([[0, 0],
[1, 1]])
In [321]: b
Out[321]:
array([[0, 2],
[1, 3]])
In [322]: a[c,:,b]
Out[322]:
array([[[ 0, 4, 8],
[ 2, 6, 10]],
[[13, 17, 21],
[15, 19, 23]]])
That's the right numbers, but not the right shape.
A column vector can be used instead of c
.
In [323]: a[np.arange(2)[:,None],:,b] # or a[[[0],[1]],:,b]
Out[323]:
array([[[ 0, 4, 8],
[ 2, 6, 10]],
[[13, 17, 21],
[15, 19, 23]]])
As for the shape, we can transpose the last two axes
In [324]: a[np.arange(2)[:,None],:,b].transpose(0,2,1)
Out[324]:
array([[[ 0, 2],
[ 4, 6],
[ 8, 10]],
[[13, 15],
[17, 19],
[21, 23]]])
This transpose is required because we have a slice between two index arrays, a mix of basic and advanced indexing. It's documented, but never the less often puzzling. It put the slice dimension (3) last, and we have to transpose it back.
Nice little indexing puzzle!
The latest question and explanation of this advanced/basic transpose:
Indexing numpy multidimensional arrays depends on a slicing method
Upvotes: 3