leftaroundabout
leftaroundabout

Reputation: 120731

How to write conditional code that's compatible with both plain Python values and NumPy arrays?

For writing “piecewise functions” in Python, I'd normally use if (in either the control-flow or ternary-operator form).

def spam(x):
    return x+1 if x>=0 else 1/(1-x)

Now, with NumPy, the mantra is to avoid working on single values in favour of vectorisation, for performance. So I reckon something like this would be preferred:As Leon remarks, the following is wrong

def eggs(x):
    y = np.zeros_like(x)
    positive = x>=0
    y[positive] = x+1
    y[np.logical_not(positive)] = 1/(1-x)
    return y

(Correct me if I've missed something here, because frankly I find this very ugly.)

Now, of course eggs will only work if x is actually a NumPy array, because otherwise x>=0 simply yields a single boolean, which can't be used for indexing (at least doesn't do the right thing).

Is there a good way to write code that looks more like spam but works idiomatic on Numpy arrays, or should I just use vectorize(spam)?

Upvotes: 6

Views: 453

Answers (2)

Praveen
Praveen

Reputation: 7222

Use np.where. You'll get an array as the output even for plain number input, though.

def eggs(x):
    y = np.asarray(x)
    return np.where(y>=0, y+1, 1/(1-y))

This works for both arrays and plain numbers:

>>> eggs(5)
array(6.0)
>>> eggs(-3)
array(0.25)
>>> eggs(np.arange(-3, 3))
/home/praveen/.virtualenvs/numpy3-mkl/bin/ipython3:2: RuntimeWarning: divide by zero encountered in true_divide
array([ 0.25      ,  0.33333333,  0.5       ,  1.        ,  2.        ,  3.        ])
>>> eggs(1)
/home/praveen/.virtualenvs/numpy3-mkl/bin/ipython3:3: RuntimeWarning: divide by zero encountered in long_scalars
  # -*- coding: utf-8 -*-
array(2.0)

As ayhan remarks, this raises a warning, since 1/(1-x) gets evaluated for the whole range. But a warning is just that: a warning. If you know what you're doing, you can ignore the warning. In this case, you're only choosing 1/(1-x) from indices where it can never be inf, so you're safe.

Upvotes: 3

Maxim
Maxim

Reputation: 7695

I would use numpy.asarray (which is a no-op if the argument is already an numpy array) if I want to handle both numbers and numpy arrays

def eggs(x):
    x = np.asfarray(x)
    m = x>=0
    x[m] = x[m] + 1
    x[~m] = 1 / (1 - x[~m])
    return x

(here I used asfarray to enforce a floating-point type, since your function requires floating-point computations).

This is less efficient than your spam function for single inputs, and arguably uglier. However it seems to be the easiest choice.

EDIT: If you want to ensure that x is not modified (as pointed out by Leon) you can replace np.asfarray(x) by np.array(x, dtype=np.float64), the array constructor copies by default.

Upvotes: 2

Related Questions