8556732
8556732

Reputation: 311

Numpy boolean statement - help on using a.any() and a.all() in statement

So lets say I've got a variable a which is a numpy array. When a is less than a certain value I want to apply a certain function, and when it is greater than this value I would apply a different function.

I've tried doing this with a boolean if statement but return the following error:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

I know from this answer that I need to use numpy a.any() and a.all() but I'm unclear how/where I would use them in the loop. I've provided a really simple example below:

import numpy as np

a = np.linspace(1, 10, num=9)

def sortfoo(a):
    if a < 5:
        b = a*3
    else:
        b = a/2
    return b

result = sortfoo(a)
print(result)

So I guess I'm asking for an example of where and how I need to use any() and all() to the above.

Really basic question but for some reason my brain isn't working clearly. Any help much appreciated.

Upvotes: 0

Views: 219

Answers (2)

B.Gees
B.Gees

Reputation: 1165

Using simple statement in numpy, you can do that:

import numpy as np
a = np.linspace(1, 10, num=9)
s = a < 5 # Test a < 5
a[s] = a[s] * 3
a[s == False] = a[s == False] / 2

Upvotes: 2

G. Anderson
G. Anderson

Reputation: 5955

Given the description, this looks like a use case for np.where()

a = np.linspace(1, 10, num=9)

b = np.where(a<5,a*3,a/2)

b
array([ 3.    ,  6.375 ,  9.75  , 13.125 ,  2.75  ,  3.3125,  3.875 ,
    4.4375,  5.    ])

Since you also mention wanting to apply different functions, you can use the same syntax

def f1(n):
    return n*3

def f2(n):
    return n/2

np.where(a<5,f1(a),f2(a))

array([ 3.    ,  6.375 ,  9.75  , 13.125 ,  2.75  ,  3.3125,  3.875 ,
        4.4375,  5.    ])

Upvotes: 3

Related Questions