divya balakrishnan
divya balakrishnan

Reputation: 51

Indices of multiple elements in a numpy array

I have a numpy array and a list as follows

y=np.array([[1],[2],[1],[3],[1],[3],[2],[2]])
x=[1,2,3]

I would like to return a tuple of arrays each of which contains the indices of each element of x in y. i.e.

(array([[0,2,4]]),array([[1,6,7]]),array([[3,5]]))

Is this possible to be done in a vectorized fashion(without any loops)?

Upvotes: 1

Views: 1275

Answers (4)

hpaulj
hpaulj

Reputation: 231355

For this small example, a dictionary approach is actually faster (then the `wheres):

dd = {i:[] for i in [1,2,3]}
for i,v in enumerate(y):
   v=v[0]
   if v in dd:
       dd[v].append(i)
list(dd.values())

This problem has come up in other SO questions. Alternatives using unique and sort have been proposed, but they are more complex and harder to recreate - and not necessarily faster.

It's not a ideal problem for numpy. The result is a list of arrays or lists of differing size, which is a pretty good clue that a simple 'vectorized' whole-array solution is not possible. If speed is an important enough issue you may need to look at numba or cython implementations.

Different methods could have different relative times depending on the mix of values. Few unique values, but long sublists might favor methods that use repeated where. Many unique values with short sublists might favor an approach that iterates on y.

Upvotes: 0

jpp
jpp

Reputation: 164623

You can use collections.defaultdict followed by a comprehension:

y = np.array([[1],[2],[1],[3],[1],[3],[2],[2]])
x = [1,2,3]

from collections import defaultdict

d = defaultdict(list)
for idx, item in enumerate(y.flat):
    d[item].append(idx)

res = tuple(np.array(d[k]) for k in x)

(array([0, 2, 4]), array([1, 6, 7]), array([3, 5]))

Upvotes: 0

Rushabh Mehta
Rushabh Mehta

Reputation: 1559

Try the following:

y = y.flatten()
[np.where(y == searchval)[0] for searchval in x]

Upvotes: 1

rafaelc
rafaelc

Reputation: 59274

One solution is to map

y = y.reshape(1,len(y))
map(lambda k: np.where(y==k)[-1], x)

[array([0, 2, 4]), 
 array([1, 6, 7]), 
 array([3, 5])]

Reasonable performance. For 100000 rows,

%timeit list(map(lambda k: np.where(y==k), x))
3.1 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Upvotes: 1

Related Questions