Michael
Michael

Reputation: 7839

Numpy: Find column index for element on each row

Suppose I have a vector with elements to find:

a = np.array([1, 5, 9, 7])

Now I have a matrix where those elements should be searched:

M = np.array([
[0, 1, 9],
[5, 3, 8],
[3, 9, 0],
[0, 1, 7]
])

Now I'd like to get an index array telling in which column of row j of M the element j of a occurs.

The result would be:

[1, 0, 1, 2]

Does Numpy offer such a function?

(Thanks for the answers with list comprehensions, but that's not an option performance-wise. I also apologize for mentioning Numpy just in the final question.)

Upvotes: 6

Views: 4909

Answers (5)

Divakar
Divakar

Reputation: 221684

For the first match in each row, it might be an efficient way to use argmax after extending a to 2D as done in @Benjamin's post -

(M == a[:,None]).argmax(1)

Sample run -

In [16]: M
Out[16]: 
array([[0, 1, 9],
       [5, 3, 8],
       [3, 9, 0],
       [0, 1, 7]])

In [17]: a
Out[17]: array([1, 5, 9, 7])

In [18]: a[:,None]
Out[18]: 
array([[1],
       [5],
       [9],
       [7]])

In [19]: (M == a[:,None]).argmax(1)
Out[19]: array([1, 0, 1, 2])

Upvotes: 3

Benjamin
Benjamin

Reputation: 11860

Note the result of:

M == a[:, None]
>>> array([[False,  True, False],
           [ True, False, False],
           [False,  True, False],
           [False, False,  True]], dtype=bool)

The indices can be retrieved with:

yind, xind = numpy.where(M == a[:, None])
>>> (array([0, 1, 2, 3], dtype=int64), array([1, 0, 1, 2], dtype=int64))

Upvotes: 4

Julien Spronck
Julien Spronck

Reputation: 15433

[sub.index(val) if val in sub else -1 for sub, val in zip(M, a)]
# [1, 0, 1, 2]

Upvotes: 0

BPL
BPL

Reputation: 9863

Maybe something like this?

>>> [list(M[i,:]).index(a[i]) for i in range(len(a))]
[1, 0, 1, 2]

Upvotes: 0

Gábor Erdős
Gábor Erdős

Reputation: 3699

Lazy solution without any import:

a = [1, 5, 9, 7]

M = [
[0, 1, 9],
[5, 3, 8],
[3, 9, 0],
[0, 1, 7],
]

for n, i in enumerate(M):
    for j in a:
        if j in i:
            print("{} found at row {} column: {}".format(j, n, i.index(j)))

Returns:

1 found at row 0 column: 1
9 found at row 0 column: 2
5 found at row 1 column: 0
9 found at row 2 column: 1
1 found at row 3 column: 1
7 found at row 3 column: 2

Upvotes: 0

Related Questions