ztl
ztl

Reputation: 2602

How to subset a `numpy.ndarray` where another one is max along some axis?

In python/numpy, how can I subset a multidimensional array where another one, of the same shape, is maximum along some axis (e.g. the first one)?

Suppose I have two 3*2*4 arrays, a and b. I want to obtain a 2*4 array containing the values of b at the locations where a has its maximal values along the first axis.

import numpy as np

np.random.seed(7)
a = np.random.rand(3*2*4).reshape((3,2,4))
b = np.random.rand(3*2*4).reshape((3,2,4))

print a
#[[[ 0.07630829  0.77991879  0.43840923  0.72346518]
#  [ 0.97798951  0.53849587  0.50112046  0.07205113]]
#
# [[ 0.26843898  0.4998825   0.67923     0.80373904]
#  [ 0.38094113  0.06593635  0.2881456   0.90959353]]
#
# [[ 0.21338535  0.45212396  0.93120602  0.02489923]
#  [ 0.60054892  0.9501295   0.23030288  0.54848992]]]

print a.argmax(axis=0) #(I would like b at these locations along axis0)
#[[1 0 2 1]
# [0 2 0 1]]

I can do this really ugly manual subsetting:

index = zip(a.argmax(axis=0).flatten(),
            [0]*a.shape[2]+[1]*a.shape[2], # a.shape[2] = 4 here
            range(a.shape[2])+range(a.shape[2]))
# [(1, 0, 0), (0, 0, 1), (2, 0, 2), (1, 0, 3), 
#  (0, 1, 0), (2, 1, 1), (0, 1, 2), (1, 1, 3)]

Which would allow me to obtain my desired result:

b_where_a_is_max_along0 = np.array([b[i] for i in index]).reshape(2,4)

# For verification:
print a.max(axis=0) == np.array([a[i] for i in index]).reshape(2,4)
#[[ True  True  True  True]
# [ True  True  True  True]]

What is the smart, numpy way to achieve this? Thanks :)

Upvotes: 1

Views: 271

Answers (2)

Divakar
Divakar

Reputation: 221664

Use advanced-indexing -

m,n = a.shape[1:]
b_out = b[a.argmax(0),np.arange(m)[:,None],np.arange(n)]

Sample run -

Setup input array a and get its argmax along first axis -

In [185]: a = np.random.randint(11,99,(3,2,4))

In [186]: idx = a.argmax(0)

In [187]: idx
Out[187]: 
array([[0, 2, 1, 2],
       [0, 1, 2, 0]])

In [188]: a
Out[188]: 
array([[[49*, 58, 13, 69],   # * are the max positions
        [94*, 28, 55, 86*]],

       [[34, 17, 57*, 50],
        [48, 73*, 22, 80]],

       [[19, 89*, 42, 71*],
        [24, 12, 66*, 82]]])

Verify results with b -

In [193]: b
Out[193]: 
array([[[18*, 72, 35, 51],   # Mark * at the same positions in b
        [74*, 57, 50, 84*]], # and verify

       [[58, 92, 53*, 65],
        [51, 95*, 43, 94]],

       [[85, 23*, 13, 17*],
        [17, 64, 35*, 91]]])

In [194]: b[a.argmax(0),np.arange(2)[:,None],np.arange(4)]
Out[194]: 
array([[18, 23, 53, 17],
       [74, 95, 35, 84]])

Upvotes: 1

Paul Panzer
Paul Panzer

Reputation: 53089

You could use ogrid

>>> x = np.random.random((2,3,4))
>>> x
array([[[ 0.87412737,  0.11069105,  0.86951092,  0.74895912],
        [ 0.48237622,  0.67502597,  0.11935148,  0.44133397],
        [ 0.65169681,  0.21843482,  0.52877862,  0.72662927]],

       [[ 0.48979028,  0.97103611,  0.36459645,  0.80723839],
        [ 0.90467511,  0.79118429,  0.31371856,  0.99443492],
        [ 0.96329039,  0.59534491,  0.15071331,  0.52409446]]])
>>> y = np.argmax(x, axis=1)
>>> y
array([[0, 1, 0, 0],
       [2, 0, 0, 1]])
>>> i, j = np.ogrid[:2,:4]
>>> x[i ,y, j]
array([[ 0.87412737,  0.67502597,  0.86951092,  0.74895912],
       [ 0.96329039,  0.97103611,  0.36459645,  0.99443492]])

Upvotes: 1

Related Questions