Jeight An
Jeight An

Reputation: 43

Numpy, change max value in each row to 1 without changing others

I'm trying to change max value of each rows to 1 and leave others.

Each values is between 0 to 1.

I want to change this

>>> a = np.array([[0.5, 0.2, 0.1], 
...               [0.6, 0.3, 0.8], 
...               [0.3, 0.4, 0.2]])

into this

>>> new_a = np.array([[1, 0.2, 0.1],
...                   [0.6, 0.3, 1],
...                   [0.3, 1, 0.2]])

Is there any good solution for this problem using np.where maybe? (without using for loop)

Upvotes: 4

Views: 1127

Answers (5)

Dariusz Krynicki
Dariusz Krynicki

Reputation: 2718

the question differes from desired output. the author says he wants to replace max value and leave others but actualy he replaces max value and some others.

this is the solution for replacing max value only.

np.where(arr == np.amax(arr), 1, arr)

Upvotes: 3

U13-Forward
U13-Forward

Reputation: 71580

Use np.argmax and slice assignment:

>>> a[np.arange(len(a)), np.argmax(a, axis=1)] = 1
>>> a
array([[1. , 0.2, 0.1],
       [1. , 0.3, 0.6],
       [1. , 0.3, 0.2]])
>>> 

Upvotes: 4

Jithendra Yenugula
Jithendra Yenugula

Reputation: 61

U12-Forward's and AcaNg's answers are perfect. Here's another way to do it usng numpy.where

new_a = np.where(a==[[i] for i in np.amax(a,axis=1)],1,a)

Upvotes: 1

Kazi
Kazi

Reputation: 391

Here's a more detailed step by step process that gives us the desired output:

# input array
a = np.array([[0.5, 0.8, 0.1],
              [0.8, 0.9, 0.6],
               [0.4, 0.3, 12]])

# finding the max element for each row
# axis=1 is given because we want to find the max for each row
max_elements = np.amax(a, axis=1)

# this changes the shape of max_elements array so that it matches with input array(a)
# this shape change is done so that we can compare directly
max_elements = max_elements[:, None]

# this code is checking the main condition
# if the value in a row matches with the max element of that row, change it to 1
# else keep it the same
new_arr = np.where(a == max_elements, 1, a)

print(new_arr)

Upvotes: 1

AcaNg
AcaNg

Reputation: 706

U12-Forward's answer does it perfectly. Here is another answer using numpy.where

np.where(a[0]==a.max(1), 1, a)
# `a[0]==a.max(1)` -> ​for each row, find element that is equal to max element in that row
# `1` -> set it to `1`
# `a` -> others remain the same

Upvotes: 1

Related Questions