Reputation: 51
I have a numpy array and a list as follows
y=np.array([[1],[2],[1],[3],[1],[3],[2],[2]])
x=[1,2,3]
I would like to return a tuple of arrays each of which contains the indices of each element of x in y. i.e.
(array([[0,2,4]]),array([[1,6,7]]),array([[3,5]]))
Is this possible to be done in a vectorized fashion(without any loops)?
Upvotes: 1
Views: 1275
Reputation: 231355
For this small example, a dictionary approach is actually faster (then the `wheres):
dd = {i:[] for i in [1,2,3]}
for i,v in enumerate(y):
v=v[0]
if v in dd:
dd[v].append(i)
list(dd.values())
This problem has come up in other SO questions. Alternatives using unique
and sort
have been proposed, but they are more complex and harder to recreate - and not necessarily faster.
It's not a ideal problem for numpy
. The result is a list of arrays or lists of differing size, which is a pretty good clue that a simple 'vectorized' whole-array solution is not possible. If speed is an important enough issue you may need to look at numba
or cython
implementations.
Different methods could have different relative times depending on the mix of values. Few unique values, but long sublists might favor methods that use repeated where
. Many unique values with short sublists might favor an approach that iterates on y
.
Upvotes: 0
Reputation: 164623
You can use collections.defaultdict
followed by a comprehension:
y = np.array([[1],[2],[1],[3],[1],[3],[2],[2]])
x = [1,2,3]
from collections import defaultdict
d = defaultdict(list)
for idx, item in enumerate(y.flat):
d[item].append(idx)
res = tuple(np.array(d[k]) for k in x)
(array([0, 2, 4]), array([1, 6, 7]), array([3, 5]))
Upvotes: 0
Reputation: 1559
Try the following:
y = y.flatten()
[np.where(y == searchval)[0] for searchval in x]
Upvotes: 1
Reputation: 59274
One solution is to map
y = y.reshape(1,len(y))
map(lambda k: np.where(y==k)[-1], x)
[array([0, 2, 4]),
array([1, 6, 7]),
array([3, 5])]
Reasonable performance. For 100000 rows,
%timeit list(map(lambda k: np.where(y==k), x))
3.1 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Upvotes: 1