samuelli97
samuelli97

Reputation: 71

Finding indices of search items in a NumPy array

I want to identify the indices in my numpy array whose value is one of the values contained in a set; for instance, this set my be (5,6,7,8).

Right now I am doing

np.where(np.isin(arr, [5,6,7,8]))

which works fine. I was wondering if there is a better way to achieve this functionality.

Upvotes: 2

Views: 199

Answers (3)

cs95
cs95

Reputation: 402483

You can't know whether your current solution is good or not if you don't know what the alternatives are.

First,

np.where(np.isin(arr, val))

Works for any general case. np.isin does a linear search over arr for elements in val.

You can also substitute np.where with np.nonzero, which is a bit faster for larger N.

Next, there is

(arr[:, None] == val).argmax(0)

Which is very fast for small sizes of arr and val (N < 100).

Finally, if arr is sorted, I recommend np.searchsorted.

np.searchsorted(arr, val)

arr = np.arange(100000)
val = np.random.choice(arr, 1000)

%timeit np.where(np.isin(arr, val))
%timeit np.nonzero(np.isin(arr, val))
%timeit (arr[:, None] == val).argmax(0)
%timeit np.searchsorted(arr, val)

8.3 ms ± 320 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
7.88 ms ± 791 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
861 ms ± 6.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
235 µs ± 31.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

The problem with (arr[:, None] == val).argmax(0) is the memory blowout - the comparison is broadcasted, introducing a very, very sparse matrix that is wasteful when N is large (so don't use it for large N).

Upvotes: 3

John Zwinck
John Zwinck

Reputation: 249153

The code you have is correct and reasonable. You should keep it.

Upvotes: 4

iankit
iankit

Reputation: 9352

Your approach is valid and It also works for multi dimensional arrays.

x = np.arange(9.).reshape(3, 3)
>>> goodvalues = [3, 4, 7]
>>> ix = np.isin(x, goodvalues)
>>> ix
array([[False, False, False],
      [ True,  True, False],
      [False,  True, False]], dtype=bool)
>>> np.where(ix)
(array([1, 1, 2]), array([0, 1, 1]))

This is straight out of the documentation here: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.where.html#numpy.where

Upvotes: 2

Related Questions