MorganM
MorganM

Reputation: 311

filter numpy array with row-specific criteria

Suppose I have a 2d numpy array and I want to filter for the elements that pass a certain criterion on a per-row basis. For example, I want only the elements that are above the 90th percentile for their specific row. I've come up with this solution:

import numpy as np
a = np.random.random((6,5))
thresholds = np.percentile(a, 90, axis=1)
threshold_2d = np.vstack([thresholds]*a.shape[1]).T
mask = a > threshold_2d
final = np.where(mask, a, np.nan)

It works and it's vectorized but it feels a little awkward, especially the part where I create threshold_2d. Is there a more elegant way? Can I somehow automatically broadcast a condition with np.where without having to create a matching 2d mask?

Upvotes: 0

Views: 386

Answers (1)

gboffi
gboffi

Reputation: 25023

Broadcast

In [36]: np.random.seed(1023)

In [37]: a = np.random.random((6,5))

In [38]: thresholds = np.percentile(a, 90, axis=1)

In [39]: threshold_2d = np.vstack([thresholds]*a.shape[1]).T

In [40]: a>threshold_2d
Out[40]: 
array([[ True, False, False, False, False],
       [False, False,  True, False, False],
       [False,  True, False, False, False],
       [False, False, False,  True, False],
       [False, False, False, False,  True],
       [False,  True, False, False, False]], dtype=bool)

In [41]: a>thresholds[:,np.newaxis]
Out[41]: 
array([[ True, False, False, False, False],
       [False, False,  True, False, False],
       [False,  True, False, False, False],
       [False, False, False,  True, False],
       [False, False, False, False,  True],
       [False,  True, False, False, False]], dtype=bool)

In [42]: 

numpy.newaxis creates an axis of length one, the resulting array view has dimensions (6,1) and can be broadcast with the a arrray.

Upvotes: 2

Related Questions