mac389
mac389

Reputation: 3133

Speed up NumPy's where function

I am trying to extract the indices of all values of a 1D array of numbers that exceed some threshold. The array is on the order of 1e9 long.

My approach is the following in NumPy:

idxs = where(data>threshold) 

This takes something upwards of 20 mins, which is unacceptable. How can I speed this function up? Or, are there faster alternatives?

(To be specific, it takes that long on a Mac OS X running 10.6.7, 1.86 GHz Intel, 4GB RAM doing nothing else.)

Upvotes: 5

Views: 3757

Answers (1)

user648852
user648852

Reputation:

Try a mask array. This creates a view of the same data.

So the syntax would be:

 b=a[a>threshold]

b is not a new array (unlike where) but a view of a where the elements meet the boolean in the index.

Example:

import numpy as np
import time

a=np.random.random_sample(int(1e9))

t1=time.time()
b=a[a>0.5]
print(time.time()-t1,'seconds')

On my machine, that prints 22.389815092086792 seconds


edit

I tried the same with np.where, and it is just as fast. I am suspicious: are you deleting these values from the array?

Upvotes: 7

Related Questions