Reputation: 662
I am struggling with finding the index of a sublist in a Numpy Array.
a = [[False, True, True, True],
[ True, True, True, True],
[ True, True, True, True]]
sub = [True, True, True, True]
index = np.where(a.tolist() == sub)[0]
print(index)
This code gives me
array([0 0 0 1 1 1 1 2 2 2 2])
which I cannot explain to me. Shouldn't the output be array([1, 2])
and why is it not? Also how can I achieve this output?
Upvotes: 3
Views: 6864
Reputation: 78690
If I understand correctly, here's my idea:
>>> a
array([[False, True, True, True],
[ True, True, True, True],
[ True, True, True, True]])
>>> sub
>>> array([ True, True, True, True])
>>>
>>> result, = np.where(np.all(a == sub, axis=1))
>>> result
array([1, 2])
Details regarding this solution:
a == sub
gives you
>>> a == sub
array([[False, True, True, True],
[ True, True, True, True],
[ True, True, True, True]])
a boolean array where for each row the True
/False
value indicates if the value in a
is equal to the corresponding value in sub
. (sub
is being broadcasted along the rows here.)
np.all(a == sub, axis=1)
gives you
>>> np.all(a == sub, axis=1)
array([False, True, True])
a boolean array corresponding to the rows of a
that are equal to sub
.
Using np.where
on this sub-result gives you the indices where this boolean array is True
.
Details regarding your attempt:
np.where(a == sub)
(the tolist
is unnecessary) gives you two arrays which together indicate the indices where the array a == sub
is True
.
>>> np.where(a == sub)
(array([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]),
array([1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]))
If you would zip these two arrays together you would get the row/column indices where a == sub
is True
, i.e.
>>> for row, col in zip(*np.where(a==sub)):
...: print('a == sub is True at ({}, {})'.format(row, col))
a == sub is True at (0, 1)
a == sub is True at (0, 2)
a == sub is True at (0, 3)
a == sub is True at (1, 0)
a == sub is True at (1, 1)
a == sub is True at (1, 2)
a == sub is True at (1, 3)
a == sub is True at (2, 0)
a == sub is True at (2, 1)
a == sub is True at (2, 2)
a == sub is True at (2, 3)
Upvotes: 8
Reputation: 1680
You can also do that without using numpy only with native python
res = [i for i, v in enumerate(a) if all(e==f for e, f in zip(v, sub))]
Upvotes: 0