Reputation: 311
So lets say I've got a variable a which is a numpy array. When a is less than a certain value I want to apply a certain function, and when it is greater than this value I would apply a different function.
I've tried doing this with a boolean if statement but return the following error:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
I know from this answer that I need to use numpy a.any() and a.all() but I'm unclear how/where I would use them in the loop. I've provided a really simple example below:
import numpy as np
a = np.linspace(1, 10, num=9)
def sortfoo(a):
if a < 5:
b = a*3
else:
b = a/2
return b
result = sortfoo(a)
print(result)
So I guess I'm asking for an example of where and how I need to use any() and all() to the above.
Really basic question but for some reason my brain isn't working clearly. Any help much appreciated.
Upvotes: 0
Views: 219
Reputation: 1165
Using simple statement in numpy, you can do that:
import numpy as np
a = np.linspace(1, 10, num=9)
s = a < 5 # Test a < 5
a[s] = a[s] * 3
a[s == False] = a[s == False] / 2
Upvotes: 2
Reputation: 5955
Given the description, this looks like a use case for np.where()
a = np.linspace(1, 10, num=9)
b = np.where(a<5,a*3,a/2)
b
array([ 3. , 6.375 , 9.75 , 13.125 , 2.75 , 3.3125, 3.875 ,
4.4375, 5. ])
Since you also mention wanting to apply different functions, you can use the same syntax
def f1(n):
return n*3
def f2(n):
return n/2
np.where(a<5,f1(a),f2(a))
array([ 3. , 6.375 , 9.75 , 13.125 , 2.75 , 3.3125, 3.875 ,
4.4375, 5. ])
Upvotes: 3