PKlumpp
PKlumpp

Reputation: 5233

Numpy where broadcastable condition

I have used numpy.where() so many times now, and I always wondered about the following statement in the docs:

x, y and condition need to be broadcastable to some shape.

I see why this is necessary for both x and y. We want to assemble the resulting array from the two, so they should be broadcastable to the same shape. However, I do not understand why this is so important for the condition as well. It is only the decision rule. Suppose I have the following three shapes:

condition = (100,)
x         = (100, 5)
y         = (100, 5)
result    = np.where(condition, x, y)

This results in a ValueError, because the "operands could not be broadcast together". To my understanding, this expression should work just fine, because I compose my result of both x and y which are broadcastable.

Can you help me understand why it is so important for the condition to be broadcastable along with x and y?

Upvotes: 6

Views: 536

Answers (1)

senderle
senderle

Reputation: 151077

The condition is fundamentally a boolean array, not a generic condition. You could think of it as a mask over the final broadcasted shape of x and y.

If you think of it that way, it should be clear that the mask must have the same shape, or be broadcastable to the same shape, as the final output.

To illustrate this, here's a simple example. To begin with, consider a scenario in which we have hand-defined a 3x3 mask array as our condition, and we pass in two 3-item arrays as x and y, shaped to broadcast appropriately:

condition = numpy.array([[0, 1, 1],
                         [1, 0, 1],
                         [0, 0, 1]])
ones = numpy.ones(3)
numpy.where(condition, ones[:, None], ones[None, :] + 1)

The result looks like this:

>>> numpy.where(condition, ones[:, None], ones[None, :] + 1)
array([[2., 1., 1.],
       [1., 2., 1.],
       [2., 2., 1.]])

Because of the broadcasting step, x and y behave as if they were defined like this:

>>> x
array([[1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.]])
>>> y
array([[2., 2., 2.],
       [2., 2., 2.],
       [2., 2., 2.]])
>>> numpy.where(condition, ones[:, None], ones[None, :] + 1)
array([[2., 1., 1.],
       [1., 2., 1.],
       [2., 2., 1.]])

This is the fundamental behavior of where. The fact that you can pass in a condition like (x > 5) doesn't change anything about the above; (x > 5) becomes a boolean array, and it must have the same shape as the output, or else it must be broadcastable to that shape. Otherwise, the behavior of where would be ill-defined.

(By the way, I am assuming your question is not about why the shapes (100,), (100, 5), and (100, 5) aren't broadcastable; that seems to be a different question.)

Upvotes: 2

Related Questions