Joel
Joel

Reputation: 23887

How to get a python function to work on an np.array or a float, with conditional logic

I have a function that I'd like to take numpy arrays or floats as input. I want to keep doing an operation until some measure of error is less than a threshold.

A simple example would be the following to divide a number or array by 2 until it's below a threshold (if a float), or until it's maximum is below a threshold (if an array).

def f(x):   #float version
    while x>1e-5:
       x = x/2
    return x

def f(x):    #np array version
    while max(x)>1e-5:
       x = x/2
    return x

Unfortunately, max won't work if I've got something that is not iterable, and x>1e-5 won't work if x is an array. I can't find anything to do this, except perhaps vectorize, but that seems to not be as efficient as I would want. How can I get a single function to handle both cases?

Upvotes: 2

Views: 56

Answers (2)

Tinyfold
Tinyfold

Reputation: 146

The reason that max(x) does not work for np arrays is because np arrays are not iterables. You can look in the np docs and find that np arrays have a method called argmax(), which returns the index of the maximum element in the np array.

def f(x): # note that all the values of the np array will become <= 1e-5
    while x[x.argmax()] > 1e-5:
       x[x.argmax()] = x[x.argmax()] / 2
    return x

Upvotes: 0

rehaqds
rehaqds

Reputation: 2055

What about checking the type of input inside the function and adapt it ?

def f(x):    # for float or np.array
    if type(x) is float:
        x = x * np.ones(1)
    while np.max(x)>1e-5:
       x = x/2
    return x

Upvotes: 2

Related Questions