chromos
chromos

Reputation: 153

Faster alternative to np.where for a sorted array

Given a large array a which is sorted along each row, is there faster alternative to numpy's np.where to find the indices where min_v <= a <= max_v? I would imagine that leveraging the sorted nature of the array should be able to speed things up.

Here's an example of a setup using np.where to find the given indices in a large array.

import numpy as np

# Initialise an example of an array in which to search
r, c = int(1e2), int(1e6)
a = np.arange(r*c).reshape(r, c)

# Set up search limits
min_v = (r*c/2)-10
max_v = (r*c/2)+10

# Find indices of occurrences
idx = np.where(((a >= min_v) & (a <= max_v)))

Upvotes: 5

Views: 998

Answers (2)

Armali
Armali

Reputation: 19375

When I use np.searchsorted with the 100 million numbers in the original example with the not up-to-date NumPy version 1.12.1 (can't tell about newer versions), it is not much faster than np.where:

>>> import timeit
>>> timeit.timeit('np.where(((a >= min_v) & (a <= max_v)))', number=10, globals=globals())
6.685825735330582
>>> timeit.timeit('np.searchsorted(a.ravel(), [min_v, max_v])', number=10, globals=globals())
5.304438766092062

But, despite the NumPy docs for searchsorted say This function uses the same algorithm as the builtin python bisect.bisect_left and bisect.bisect_right functions, the latter are a lot faster:

>>> import bisect
>>> timeit.timeit('bisect.bisect_left(a.base, min_v), bisect.bisect_right(a.base, max_v)', number=10, globals=globals())
0.002058468759059906

Therefore, I'd use this:

idx = np.unravel_index(range(bisect.bisect_left(a.base, min_v),
                             bisect.bisect_right(a.base, max_v)), a.shape)

Upvotes: 2

javidcf
javidcf

Reputation: 59691

You can use np.searchsorted:

import numpy as np

r, c = 10, 100
a = np.arange(r*c).reshape(r, c)

min_v = ((r * c) // 2) - 10
max_v = ((r * c) // 2) + 10

# Old method
idx = np.where(((a >= min_v) & (a <= max_v)))

# With searchsorted
i1 = np.searchsorted(a.ravel(), min_v, 'left')
i2 = np.searchsorted(a.ravel(), max_v, 'right')
idx2 = np.unravel_index(np.arange(i1, i2), a.shape)
print((idx[0] == idx2[0]).all() and (idx[1] == idx2[1]).all())
# True

Upvotes: 2

Related Questions