xirururu
xirururu

Reputation: 5508

Understanding about the numpy.where

I am reading the numpy.where(condition[, x, y]) documentation, but I can not understand the small example:

>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
Out: (array([2, 2, 2]), array([0, 1, 2]))

Can some one explain how the result comes?

Upvotes: 4

Views: 1246

Answers (2)

reticentroot
reticentroot

Reputation: 3682

It's printing out the coordinates to your condition

import numpy as np

x = np.arange(9.).reshape(3, 3)
print x
print np.where( x > 5 )

where print x prints:

[[ 0.  1.  2.]
 [ 3.  4.  5.]
 [ 6.  7.  8.]]

and np.where( x > 5 ) prints the index location of all elements greater than 5

(array([2, 2, 2]), array([0, 1, 2]))

where 2,0 == 6 and 2,1 == 7 and 2,2 == 8

Upvotes: 2

Kasravnd
Kasravnd

Reputation: 107357

The first array (array([2, 2, 2])) is the index of rows and the second (array([0, 1, 2])) is the columns of those values that are more than 5.

You can use zip to get the exact index of values :

>>> zip(*np.where( x > 5 ))
[(2, 0), (2, 1), (2, 2)]

Or use np.dstack :

>>> np.dstack(np.where( x > 5 ))
array([[[2, 0],
        [2, 1],
        [2, 2]]])

Upvotes: 6

Related Questions