Philipp
Philipp

Reputation: 662

Finding the index of a sublist in a Numpy Array

I am struggling with finding the index of a sublist in a Numpy Array.

a = [[False,  True,  True,  True],
     [ True,  True,  True,  True],
     [ True,  True,  True,  True]]
sub = [True, True, True, True]
index = np.where(a.tolist() == sub)[0]
print(index)

This code gives me

array([0 0 0 1 1 1 1 2 2 2 2])

which I cannot explain to me. Shouldn't the output be array([1, 2]) and why is it not? Also how can I achieve this output?

Upvotes: 3

Views: 6864

Answers (2)

timgeb
timgeb

Reputation: 78690

If I understand correctly, here's my idea:

>>> a
array([[False,  True,  True,  True],
       [ True,  True,  True,  True],
       [ True,  True,  True,  True]])
>>> sub
>>> array([ True,  True,  True,  True])
>>> 
>>> result, = np.where(np.all(a == sub, axis=1))
>>> result
array([1, 2])

Details regarding this solution:

a == sub gives you

>>> a == sub
array([[False,  True,  True,  True],
       [ True,  True,  True,  True],
       [ True,  True,  True,  True]])

a boolean array where for each row the True/False value indicates if the value in a is equal to the corresponding value in sub. (sub is being broadcasted along the rows here.)

np.all(a == sub, axis=1) gives you

>>> np.all(a == sub, axis=1)
array([False,  True,  True])

a boolean array corresponding to the rows of a that are equal to sub.

Using np.where on this sub-result gives you the indices where this boolean array is True.

Details regarding your attempt:

np.where(a == sub) (the tolist is unnecessary) gives you two arrays which together indicate the indices where the array a == sub is True.

>>> np.where(a == sub)
(array([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]),
 array([1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]))

If you would zip these two arrays together you would get the row/column indices where a == sub is True, i.e.

>>> for row, col in zip(*np.where(a==sub)):
...:    print('a == sub is True at ({}, {})'.format(row, col))
a == sub is True at (0, 1)
a == sub is True at (0, 2)
a == sub is True at (0, 3)
a == sub is True at (1, 0)
a == sub is True at (1, 1)
a == sub is True at (1, 2)
a == sub is True at (1, 3)
a == sub is True at (2, 0)
a == sub is True at (2, 1)
a == sub is True at (2, 2)
a == sub is True at (2, 3)

Upvotes: 8

Sofien
Sofien

Reputation: 1680

You can also do that without using numpy only with native python res = [i for i, v in enumerate(a) if all(e==f for e, f in zip(v, sub))]

Upvotes: 0

Related Questions