Roy
Roy

Reputation: 43

Numpy Mutidimensional Subsetting

I have searched long and hard for an answer to this question, but haven't found anything that quite fits the bill. I have a multidimensional numpy array containing data (in my case 3 dimensional) and another array (2 dimensional) that contains information on which value I want along the last dimension of the original array. For instance, here is a simple example illustrating the problem. I have an array a of data, and another array b containing indices along dimension 2 of a. I want a new two dimensional array c where c[i, j] = a[i, j, b[i, j]].The only way that I can think to do it is with a loop, as outlined below. However, this seems clunky and slow.

In [3]: a = np.arange(8).reshape((2, 2, 2))
In [4]: a
Out[4]: 
array([[[0, 1],
        [2, 3]],

       [[4, 5],
        [6, 7]]])

In [6]: b = np.array([[0, 1], [1, 1]])

In [8]: c = np.zeros_like(b)

In [9]: for i in xrange(2):
   ...:     for j in xrange(2):
   ...:         c[i, j] = a[i, j, b[i, j]]

In [10]: c
Out[10]: 
array([[0, 3],
       [5, 7]])

Is there a more pythonic way of doing this, perhaps some numpy indexing feature of which I am unaware?

Upvotes: 3

Views: 189

Answers (2)

hpaulj
hpaulj

Reputation: 231355

In [40]: a = np.arange(8).reshape((2, 2, 2))

In [41]: b = np.array([[0, 1], [1, 1]])

In [42]: i = np.array([[0,0],[1,1]])

In [43]: a[i,i.T,b]
Out[43]: 
array([[0, 3],
       [5, 7]])

or using ix_ to generate the indexes:

In [47]: j = np.ix_([0,1],[0,1])

In [48]: a[j[0],j[1],b]
Out[48]: 
array([[0, 3],
       [5, 7]])
In [49]: j
Out[49]: 
(array([[0],
       [1]]), array([[0, 1]]))

or with ogrid

In [101]: i = np.ogrid[0:2,0:2]

In [102]: i.append(b)

In [103]: a[i]
Out[103]: 
array([[0, 3],
       [5, 7]])

Upvotes: 0

Jaime
Jaime

Reputation: 67427

When you fancy-index a multidimensional array with multidimensional arrays, the indices for each dimension are broadcasted together. With that in mind, you can do:

>>> rows = np.arange(a.shape[0])
>>> cols = np.arange(a.shape[1])
>>> a[rows[:, None], cols, b]
array([[0, 3],
       [5, 7]])

Upvotes: 2

Related Questions