BlindDriver
BlindDriver

Reputation: 681

indexing in numpy (related to max/argmax)

Suppose I have an N-dimensional numpy array x and an (N-1)-dimensional index array m (for example, m = x.argmax(axis=-1)). I'd like to construct (N-1) dimensional array y such that y[i_1, ..., i_N-1] = x[i_1, ..., i_N-1, m[i_1, ..., i_N-1]] (for the argmax example above it would be equivalent to y = x.max(axis=-1)). For N=3 I could achieve what I want by

y = x[np.arange(x.shape[0])[:, np.newaxis], np.arange(x.shape[1]), m]

The question is, how do I do this for an arbitrary N?

Upvotes: 3

Views: 1036

Answers (2)

B. M.
B. M.

Reputation: 18628

you can use indices :

firstdims=np.indices(x.shape[:-1])

And add yours :

ind=tuple(firstdims)+(m,) 

Then x[ind] is what you want.

In [228]: allclose(x.max(-1),x[ind]) 
Out[228]: True

Upvotes: 2

Divakar
Divakar

Reputation: 221564

Here's one approach using reshaping and linear indexing to handle multi-dimensional arrays of arbitrary dimensions -

shp = x.shape[:-1]
n_ele = np.prod(shp)
y_out = x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)

Let's take a sample case with a ndarray of 6 dimensions and let's say we are using m = x.argmax(axis=-1) to index into the last dimension. So, the output would be x.max(-1). Let's verify this for the proposed solution -

In [121]: x = np.random.randint(0,9,(4,5,3,3,2,4))

In [122]: m = x.argmax(axis=-1)

In [123]: shp = x.shape[:-1]
     ...: n_ele = np.prod(shp)
     ...: y_out = x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)
     ...: 

In [124]: np.allclose(x.max(-1),y_out)
Out[124]: True

I liked @B. M.'s solution for its elegance. So, here's a runtime test to benchmark these two -

def reshape_based(x,m):
    shp = x.shape[:-1]
    n_ele = np.prod(shp)
    return x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)

def indices_based(x,m):  ## @B. M.'s solution
    firstdims=np.indices(x.shape[:-1])
    ind=tuple(firstdims)+(m,) 
    return x[ind]

Timings -

In [152]: x = np.random.randint(0,9,(4,5,3,3,4,3,6,2,4,2,5))
     ...: m = x.argmax(axis=-1)
     ...: 

In [153]: %timeit indices_based(x,m)
10 loops, best of 3: 30.2 ms per loop

In [154]: %timeit reshape_based(x,m)
100 loops, best of 3: 5.14 ms per loop

Upvotes: 1

Related Questions