norio
norio

Reputation: 3892

How to make sure in python that an input argument is not complex-valued but real-valued

I am writing a function in python for scientific computation. One of the arguments of this function represents a real-valued input parameter. If a complex value is passed as this argument, the result of the function will be incorrect because I am not implementing a special care needed for the case of complex-valued input, but the function will return an incorrect value without error or exception because each line in the function is valid in terms of syntax.

For example, please consider a function like this:

import numpy as np
def foo(vara):
    """
    This function evaluates the Foo formula for the real variable vara.
    This function does not work for the complex variable vara because I am 
    too lazy to take care of the branch cut of the complex square-root function.
    """
    if vara<0:
        vv = -0.57386286*vara
    else:
        vv =  3.49604327*vara
    return np.sqrt(vv)

The function foo will return a complex value even if the argument vara is complex because numpy.sqrt function is also defined for complex argument, but the returned value will be incorrect supposing that the function foo was implemented with only real argument in mind.

How can I check in the function that an argument is real-valued so that I can make the function throw an exception or exit by error otherwise?

Please not that I want to keep the function working for both the native float type of python as well as an numpy array of float type elements. I just want to prohibit the use of the function with a complex variable or a numpy array of complex elements.

(I thought of multiplying 1.0j to the argument and checking that the real part of the result is zero, but this does not look neat.)

Upvotes: 1

Views: 721

Answers (2)

norio
norio

Reputation: 3892

(I am answering to my own question. I am not sure this is the best way, but I wanted to leave a code that I tried for record.)

Based on polpak's answer, I wrote the following code. I guess this would satisfy the conditions I raised. The function is pedantic in that it rejects any other types of input argument than float scaler or float ndarray. (Perhaps it does not even accept all kinds of float ndarray.) Particularly, it rejects integer scaler and integer ndarray as well as complex scaler and complex ndarray.

#!/usr/bin/python

import numpy as np
import types

def foo(vara):
    """vara must be a real-valued scaler or ndarray."""

    real_types = [types.FloatType, np.float16, np.float32, np.float64, np.float128]
    print '----------'
    print 'vara:', vara
    if isinstance(vara, np.ndarray):
        if not any(vara.dtype==t for t in real_types):
            print 'NG.'
            print '   type(vara)=', type(vara)
            print '   vara.dtype=', vara.dtype
            # raise an error here
        else:
            print 'OK.'
            print '   type(vara)=', type(vara)
            print '   vara.dtype=', vara.dtype
    else:
        if not any(isinstance(vara, t) for t in real_types):
            print 'NG.'
            print '   type(vara)=', type(vara)
            # raise an error here
        else:
            print 'OK.'
            print '   type(vara)=', type(vara)


varalist=[3.0, 
          np.array([0.5, 0.2]), 
          np.array([3, 4, 1]), 
          np.array([3.4+1.2j, 0.8+0.7j]),
          np.array([3.4+0.0j, 0.8+0.0j]),
          np.array([1.3, 4.2, 5.9], dtype=complex),
          np.array([1.3, 4.2, 5.9], dtype=complex).real ]

for vara  in varalist:
    foo(vara)

The output of this code was as following.

$ ./main003.py 
----------
vara: 3.0
OK.
   type(vara)= <type 'float'>
----------
vara: [ 0.5  0.2]
OK.
   type(vara)= <type 'numpy.ndarray'>
   vara.dtype= float64
----------
vara: [3 4 1]
NG.
   type(vara)= <type 'numpy.ndarray'>
   vara.dtype= int64
----------
vara: [ 3.4+1.2j  0.8+0.7j]
NG.
   type(vara)= <type 'numpy.ndarray'>
   vara.dtype= complex128
----------
vara: [ 3.4+0.j  0.8+0.j]
NG.
   type(vara)= <type 'numpy.ndarray'>
   vara.dtype= complex128
----------
vara: [ 1.3+0.j  4.2+0.j  5.9+0.j]
NG.
   type(vara)= <type 'numpy.ndarray'>
   vara.dtype= complex128
----------
vara: [ 1.3  4.2  5.9]
OK.
   type(vara)= <type 'numpy.ndarray'>
   vara.dtype= float64

Upvotes: 0

Chad S.
Chad S.

Reputation: 6633

If you want to only forbid complex data types this will do the trick:

import types

scalar_complex_types = [types.ComplexType, np.complex64, np.complex128]

def is_complex_sequence(vara):
    return (hasattr(vara, '__iter__') 
             and any(isinstance(v, t) for v in vara for t in complex_types)

def is_complex_scalar(vara):
    return any(isinstance(vara, t) for t in complex_types)

Then in your function you can just..

if is_complex_scalar(vara) or is_complex_sequence(vara):
    raise ValueError('Argument must not be a complex number')

Upvotes: 1

Related Questions