Marcello
Marcello

Reputation: 377

How to efficiently filter maximum elements of a matrix per row

Given a 2D array, I'm looking for a pythonic way to get an array of same shape, with only the maximum element per each row. See max_row_filter function below

def max_row_filter(mat2d):
    m = np.zeros(mat2d.shape)
    for r in range(mat2d.shape[0]):
        c = np.argmax(mat2d[r])
        m[r,c]=mat2d[r,c]
    return m

p = np.array([[1,2,3],[5,4,3,],[9,10,3]])
max_row_filter(p)

Out: array([[ 0.,  0.,  3.],
            [ 5.,  0.,  0.],
            [ 0., 10.,  0.]])

I'm looking for an efficient way to do this, suitable to be done on big arrays.

Upvotes: 2

Views: 484

Answers (2)

Tarifazo
Tarifazo

Reputation: 4343

Alternative answer (this will keep duplicates):

p * (p==p.max(axis=1, keepdims=True))

Upvotes: 3

Dani Mesejo
Dani Mesejo

Reputation: 61910

If there are no duplicates, you could use numpy.argmax:

import numpy as np

p = np.array([[1, 2, 3],
              [5, 4, 3, ],
              [9, 10, 3]])

result = np.zeros_like(p)

rows, cols = zip(*enumerate(np.argmax(p, axis=1)))
result[rows, cols] = p[rows, cols]

print(result)

Output

[[ 0  0  3]
 [ 5  0  0]
 [ 0 10  0]]

Note that, for multiple occurrences argmax return the first occurence.

Upvotes: 0

Related Questions