Reputation: 3892
I am writing a function in python for scientific computation. One of the arguments of this function represents a real-valued input parameter. If a complex value is passed as this argument, the result of the function will be incorrect because I am not implementing a special care needed for the case of complex-valued input, but the function will return an incorrect value without error or exception because each line in the function is valid in terms of syntax.
For example, please consider a function like this:
import numpy as np
def foo(vara):
"""
This function evaluates the Foo formula for the real variable vara.
This function does not work for the complex variable vara because I am
too lazy to take care of the branch cut of the complex square-root function.
"""
if vara<0:
vv = -0.57386286*vara
else:
vv = 3.49604327*vara
return np.sqrt(vv)
The function foo
will return a complex value even if the argument vara
is complex because numpy.sqrt
function is also defined for complex argument, but the returned value will be incorrect supposing that the function foo
was implemented with only real argument in mind.
How can I check in the function that an argument is real-valued so that I can make the function throw an exception or exit by error otherwise?
Please not that I want to keep the function working for both the native float
type of python as well as an numpy array of float
type elements. I just want to prohibit the use of the function with a complex
variable or a numpy array of complex
elements.
(I thought of multiplying 1.0j
to the argument and checking that the real part of the result is zero, but this does not look neat.)
Upvotes: 1
Views: 721
Reputation: 3892
(I am answering to my own question. I am not sure this is the best way, but I wanted to leave a code that I tried for record.)
Based on polpak's answer, I wrote the following code. I guess this would satisfy the conditions I raised. The function is pedantic in that it rejects any other types of input argument than float scaler or float ndarray. (Perhaps it does not even accept all kinds of float ndarray.) Particularly, it rejects integer scaler and integer ndarray as well as complex scaler and complex ndarray.
#!/usr/bin/python
import numpy as np
import types
def foo(vara):
"""vara must be a real-valued scaler or ndarray."""
real_types = [types.FloatType, np.float16, np.float32, np.float64, np.float128]
print '----------'
print 'vara:', vara
if isinstance(vara, np.ndarray):
if not any(vara.dtype==t for t in real_types):
print 'NG.'
print ' type(vara)=', type(vara)
print ' vara.dtype=', vara.dtype
# raise an error here
else:
print 'OK.'
print ' type(vara)=', type(vara)
print ' vara.dtype=', vara.dtype
else:
if not any(isinstance(vara, t) for t in real_types):
print 'NG.'
print ' type(vara)=', type(vara)
# raise an error here
else:
print 'OK.'
print ' type(vara)=', type(vara)
varalist=[3.0,
np.array([0.5, 0.2]),
np.array([3, 4, 1]),
np.array([3.4+1.2j, 0.8+0.7j]),
np.array([3.4+0.0j, 0.8+0.0j]),
np.array([1.3, 4.2, 5.9], dtype=complex),
np.array([1.3, 4.2, 5.9], dtype=complex).real ]
for vara in varalist:
foo(vara)
The output of this code was as following.
$ ./main003.py
----------
vara: 3.0
OK.
type(vara)= <type 'float'>
----------
vara: [ 0.5 0.2]
OK.
type(vara)= <type 'numpy.ndarray'>
vara.dtype= float64
----------
vara: [3 4 1]
NG.
type(vara)= <type 'numpy.ndarray'>
vara.dtype= int64
----------
vara: [ 3.4+1.2j 0.8+0.7j]
NG.
type(vara)= <type 'numpy.ndarray'>
vara.dtype= complex128
----------
vara: [ 3.4+0.j 0.8+0.j]
NG.
type(vara)= <type 'numpy.ndarray'>
vara.dtype= complex128
----------
vara: [ 1.3+0.j 4.2+0.j 5.9+0.j]
NG.
type(vara)= <type 'numpy.ndarray'>
vara.dtype= complex128
----------
vara: [ 1.3 4.2 5.9]
OK.
type(vara)= <type 'numpy.ndarray'>
vara.dtype= float64
Upvotes: 0
Reputation: 6633
If you want to only forbid complex data types this will do the trick:
import types
scalar_complex_types = [types.ComplexType, np.complex64, np.complex128]
def is_complex_sequence(vara):
return (hasattr(vara, '__iter__')
and any(isinstance(v, t) for v in vara for t in complex_types)
def is_complex_scalar(vara):
return any(isinstance(vara, t) for t in complex_types)
Then in your function you can just..
if is_complex_scalar(vara) or is_complex_sequence(vara):
raise ValueError('Argument must not be a complex number')
Upvotes: 1