Reputation: 1506
I have been trying to understand the numpy where function but not getting anywhere. I can understand simple comparisons such as where value > otherValue, but this example, from the documentation is not becomming clearer.
I would appreciate an easy to understand breakdown of this. thanks for any help provided:
>>> np.where([[True, False], [True, True]],
... [[1, 2], [3, 4]],
... [[9, 8], [7, 6]])
array([[1, 8],
[3, 4]])
Upvotes: 1
Views: 171
Reputation: 7130
I think we can make things simpler and focus on the np.where
other than the nested list.
np.where([True, False, True, True],
[1, 2, 3, 4],
[9, 8, 7, 6])
Out[4]: array([1, 8, 3, 4])
I thought you can get the point from this simple equivalent. Simply put, it just selects the corresponding element from the first list([1, 2, 3, 4]
) where the condition is True and the second list([9, 8, 7, 6]
) where the condition is False.
The first condition is True then we choose 1(from the first list in the corresponding position), the second is False we choose 8(from the second list in the corresponding position) and so on and so forth.
Upvotes: 0
Reputation: 107287
The where()
function accepts 3 arguments. Condition, x
and y
. And as it's stated in documentation, if both x
and y
are specified, the output array contains elements of x
where condition is True
, and elements from y
elsewhere.
In your case for first row it selects 1
from x
and 8
from y
(because of False) and for second row since both are True it selects them from x
.
np.where([[True, False], [True, True]],
[[1, 2], [3, 4]],
[[9, 8], [7, 6]])
array([[1, 8],
[3, 4]])
Upvotes: 1