Itamar Katz
Itamar Katz

Reputation: 9645

numpy `take` along 2 axes

I have a 3D array a of data and a 2D array b of indices. I need to take a sub-array of a along the 3rd axis, using the indices from b. I can do it with take like this:

a = np.arange(24).reshape((2,3,4))
b = np.array([0,2,1,3]).reshape((2,2))
np.array([np.take(a_,b_,axis=1) for (a_,b_) in zip(a,b)])

Can I do it without list comprehension, using some fancy indexing? I am worried about efficiency, so if fancy indexing is not more efficient in this case, I would like to know it.

EDIT The 1st thing I've tried is a[[0,1],:,b] but it doesn't give the sub-array I need

Upvotes: 3

Views: 3203

Answers (2)

Allen Qin
Allen Qin

Reputation: 19947

This is my first try. I will see if I can do better.

#using numpy broadcasting.
np.r_[a[0][:,b[0]],a[1][:,b[1]]].reshape(2,3,2)
Out[300]: In [301]: 
array([[[ 0,  2],
        [ 4,  6],
        [ 8, 10]],

       [[13, 15],
        [17, 19],
        [21, 23]]])

Second try:

#convert both a and b to a 2d array and then slice all rows and only columns determined by b.
a.reshape(6,4)[np.arange(6)[:,None],b.repeat(3,0)].reshape(2,3,2)
Out[429]: 
array([[[ 0,  2],
        [ 4,  6],
        [ 8, 10]],

       [[13, 15],
        [17, 19],
        [21, 23]]])

Upvotes: 1

hpaulj
hpaulj

Reputation: 231335

In [317]: a
Out[317]: 
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])
In [318]: a = np.arange(24).reshape((2,3,4))
     ...: b = np.array([0,2,1,3]).reshape((2,2))
     ...: np.array([np.take(a_,b_,axis=1) for (a_,b_) in zip(a,b)])
     ...: 
Out[318]: 
array([[[ 0,  2],
        [ 4,  6],
        [ 8, 10]],

       [[13, 15],
        [17, 19],
        [21, 23]]])

So you want the 0 & 2 columns from the 1st block, and 1 & 3 from the second.

Make a c that matches b in shape, and embodies this observation

In [319]: c=np.array([[0,0],[1,1]])
In [320]: c
Out[320]: 
array([[0, 0],
       [1, 1]])
In [321]: b
Out[321]: 
array([[0, 2],
       [1, 3]])

In [322]: a[c,:,b]
Out[322]: 
array([[[ 0,  4,  8],
        [ 2,  6, 10]],

       [[13, 17, 21],
        [15, 19, 23]]])

That's the right numbers, but not the right shape.

A column vector can be used instead of c.

In [323]: a[np.arange(2)[:,None],:,b]  # or a[[[0],[1]],:,b]
Out[323]: 
array([[[ 0,  4,  8],
        [ 2,  6, 10]],

       [[13, 17, 21],
        [15, 19, 23]]])

As for the shape, we can transpose the last two axes

In [324]: a[np.arange(2)[:,None],:,b].transpose(0,2,1)
Out[324]: 
array([[[ 0,  2],
        [ 4,  6],
        [ 8, 10]],

       [[13, 15],
        [17, 19],
        [21, 23]]])

This transpose is required because we have a slice between two index arrays, a mix of basic and advanced indexing. It's documented, but never the less often puzzling. It put the slice dimension (3) last, and we have to transpose it back.

Nice little indexing puzzle!

The latest question and explanation of this advanced/basic transpose:

Indexing numpy multidimensional arrays depends on a slicing method

Upvotes: 3

Related Questions