user859385
user859385

Reputation: 627

How does numpy argmax work?

So I know that the numpy argmax retrieves the maximum value along an axis. Thus,

x = np.array([[12,11,10,9],[16,15,14,13],[20,19,18,17]])
print(x)
print(x.sum(axis=1))
print(x.sum(axis=0))

would output,

[[12 11 10  9]
 [16 15 14 13]
 [20 19 18 17]]


[42 58 74]

[48 45 42 39]

This makes sense as the sum along axis 1 (row) is [42 58 74] and axis 0 (column) is [48 45 42 39]. However, i am confused of how argmax work. From my understanding, argmax is supposed to return the max number along the axis. Below is my code and output.

Code: print(np.argmax(x,axis=1)). Output: [0 0 0]

Code: print(np.argmax(x,axis=0)). Output: [2 2 2 2]

Where does 0 and 2 come from? I've deliberately used a set of more complex integer values (9..20) so as to distinguish between the 0 and 2 and the integer values inside the array.

Upvotes: 3

Views: 3250

Answers (2)

user3190961
user3190961

Reputation: 1

Correction: axis=0 refers to rows, not to columns. axis=1 refers to columns, not to rows.

x = np.array([[12,11,10,9],[16,15,14,13],[20,19,18,17]])
  print(x)

[[12 11 10  9]
[16 15 14 13]
[20 19 18 17]]

np.argmax(x, axis=0)
array([2, 2, 2, 2] # third row, index 2 of each of the 4 columns

np.argmax(x, axis=1)
array([0, 0, 0]  # first column, index 0 of each of the three rows.

Upvotes: 0

llllllllll
llllllllll

Reputation: 16404

np.argmax(x,axis=1) returns the index of maximum of in every row.

axis=1 means "along axis 1", i.e, row.

[[12 11 10  9]    <-- max at index 0
 [16 15 14 13]    <-- max at index 0
 [20 19 18 17]]   <-- max at index 0

Thus its output is [0 0 0].

It's similar for np.argmax(x,axis=0), but now it returns the index of maximum of in every column.

Upvotes: 5

Related Questions