Mathew Carroll
Mathew Carroll

Reputation: 377

NumPy: Quickest way to get argmax of row sums of 3D array

Suppose I have a 3-dimensional array, with dimensions 10x10000x5. Interpreting this array as 10 'sub-arrays', each with 10000 rows and 5 columns, what I want to do is, for each row:

(1) Compute the sum of the row in each of the 10 sub-arrays.

(2) Determine which sub-array yields the highest sum.

An example is shown below. I do the above, but only for the first two rows, where 'firstrow' is the sum of the first row of each sub-array, and 'secondrow' is the sum of the second row of each sub-array. I then use np.argmax() to find the sub-array that yields the highest sum. But I want to do this for all 10000 rows, not just the first two.

import numpy as np
np.random.seed(777)
A = np.random.randn(10,10000,5)

first = [None]*10
second = [None]*10
for i in range(10):
    firstrow[i] = A[i].sum(axis=1)[0]
    secondrow[i] = A[i].sum(axis=1)[1]

np.argmax(np.array(firstrow)) # Sub-array 9 yields the highest sum
np.argmax(np.array(secondrow)) # Sub-array 8 yields the highest sum
#...

What is the quickest way of doing this, for all 10000 rows?

Upvotes: 1

Views: 285

Answers (1)

javidcf
javidcf

Reputation: 59701

You can do that just like this:

result = A.sum(2).argmax(0)

Tested in your example:

import numpy as np

np.random.seed(777)
A = np.random.randn(10, 10000, 5)

result = A.sum(2).argmax(0)

# Check against loop
first = [None] * 10
second = [None] * 10
for i in range(10):
    first[i] = A[i].sum(axis=1)[0]
    second[i] = A[i].sum(axis=1)[1]

print(result[0], np.argmax(np.array(first)))
# 9 9
print(result[1], np.argmax(np.array(second)))
# 8 8

Upvotes: 1

Related Questions