Reputation: 13
I'm trying to graph a function I define piecewise. As an example, take
def f(x,y):
if x in I.open(0, 1):
if y in I.open(0, 1):
return (x+y)
else:
return(0)
I then define Z as follows:
X = np.arange(-5, 5, 0.25)
Y = np.arange(-5, 5, 0.25)
X, Y = np.meshgrid(X, Y)
Z = f(X,Y)
If I run this, I get:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
I've seen some answers to similar questions, but I haven't been able to fix this problem with anything I've read. I understand (I think) that the problem is in the statement if x in I.open(0,1)...
since running f(X,Y)
asks whether X
is in I.open(0,1)
and, since X is an array, this may be true for some elements of X
while not true for others. But when you graph a function like, Z=X+Y
, the computer has no problem determining which values of X
and Y
to use at each step, so why can't it do that here?
Upvotes: 1
Views: 316
Reputation: 80329
First, note that your function doesn't return anything when its first if
-test is False
. You'd need an else
part for if x in I.open(0, 1):
.
Now, you want to use your function in a vectorized (numpy) way. Python if
-tests can't be used that way. Often, np.where
can be used instead.
It is unclear where I.open()
comes from. Your code is missing essential imports. Supposing it stands for the open interval ]0,1[
, the function could be written in a vectorized form as:
import numpy as np
def f(x, y):
return np.where((x > 0) & (x < 1) & (y > 0) & (y < 1), # condition
x + y, # result of the if clause
0) # result of the else clause
Note that &
is used for the vectorized version of the logical AND
(|
would be the logical OR
). Due to overloading of Python operators, often more brackets are needed than when you'd write standard expressions.
With arrays as input, the vectorization takes care of executing the function for each element separately.
Here is some example code that plots the result, using a few more points:
import matplotlib.pyplot as plt
import numpy as np
def f(x, y):
return np.where((x > 0) & (x < 1) & (y > 0) & (y < 1), # condition
x + y, # result of the if clause
0) # result of the else clause
X1d = np.linspace(-5, 5, 100)
Y1d = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(X1d, Y1d)
Z = f(X, Y)
plt.imshow(Z, extent=[X1d[0], X1d[-1], Y1d[0], Y1d[-1]], interpolation='nearest')
plt.xticks(range(-5, 6))
plt.yticks(range(-5, 6))
plt.grid(True, ls=':', color='white')
plt.colorbar()
plt.show()
Upvotes: 2