Anastasia Manokhina
Anastasia Manokhina

Reputation: 98

Replacing all values of numpy array which are smaller than the "n" largest item in each row

I have 2d numpy array of size ~70k * 10k. I want to replace all values with zero which are smaller than the "N" largest element in every row. For example:

arr = np.array([[1, 0, 6, 5, 2, 5], 
                [7, 5, 2, 6, 7, 3], 
                [3, 5, 1, 5, 6, 4]])

For N = 3 the result should be:

result = np.array([[0, 0, 6, 5, 0, 5], # 3 largest in row: 6, 5, 5
                   [7, 0, 0, 6, 7, 0], 
                   [0, 5, 0, 5, 6, 0]])

The positions of numbers that were not replaced and the shape of the array should stay the same.

Upvotes: 3

Views: 496

Answers (1)

MSeifert
MSeifert

Reputation: 152657

You could find the N-th largest value using np.partition and then just use boolean indexing to replace everything that's "below" that value in it's row:

import numpy as np
arr = np.array([[1, 0, 6, 5, 2, 5], 
                [7, 5, 2, 6, 7, 3], 
                [3, 5, 1, 5, 6, 4]])

N = 3
nlargest = np.partition(arr, -N, axis=1)[:, -N]
arr[arr < nlargest[:, None]] = 0
arr
# array([[0, 0, 6, 5, 0, 5],
#        [7, 0, 0, 6, 7, 0],
#        [0, 5, 0, 5, 6, 0]])

Upvotes: 4

Related Questions