Alex Long
Alex Long

Reputation: 57

Optionally passing parameters onto another function with jit

I am attempting to jit compile a python function, and use a optional argument to change the arguments of another function call.

I think where jit might be tripping up is that the default value of the optional argument is None, and jit doesn't know how to handle that, or at least doesn't know how to handle it when it changes to a numpy array. See below for a rough overview:

@jit(nopython=True)
def foo(otherFunc,arg1, optionalArg=None):

    if optionalArg is not None:
        out=otherFunc(arg1,optionalArg)

    else:
        out=otherFunc(arg1)
    return out

Where optionalArg is either None, or a numpy array

One solution would be to turn this into three functions as shown below, but this feels kinda janky and I don't like it, especially because speed is very important for this task.

def foo(otherFunc,arg1,optionalArg=None):

    if optionalArg is not None:
        out=func1(otherFunc,arg1,optionalArg)
    else:
        out=func2(otherFunc,arg1)
    return out

@jit(nopython=True)
def func1(otherFunc,arg1,optionalArg):
    out=otherFunc(arg1,optionalArg)
    return out

@jit(nopython=True)
def func2(otherFunc,arg1):
    out=otherFunc(arg1)
    return out

Note that other stuff is happening besides just calling otherFunc that makes using jit worth it, but I'm almost certain that is not where the problem is since this was working before without the optionalArg portion, so I have decided not to include it.

For those of you that are curious its runge-kutta order 4 implementation with optional extra parameters to pass to the differential equation. If you want to see the whole thing just ask.

The traceback is rather long but here is some of it:

inte.rk4(de2,y0,0.001,200,vals=np.ones(4))
Traceback (most recent call last):

  File "<ipython-input-38-478197aa6a1a>", line 1, in <module>
    inte.rk4(de2,y0,0.001,200,vals=np.ones(4))

  File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 350, in _compile_for_args
    error_rewrite(e, 'typing')

  File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 317, in error_rewrite
    reraise(type(e), e, None)

  File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\six.py", line 658, in reraise
    raise value.with_traceback(tb)
TypingError: Internal error at <numba.typeinfer.CallConstraint object at 0x00000258E168C358>:

This continues...

inte.rk4 is the equiavlent of foo, de2 is otherFunc, y0, 0.001 and 200 are just values, that I swaped out for arg1 in my problem description above, and vals is optionalArg.

A similar thing happens when I try to run this with the vals parameter omitted:

ysExp=inte.rk4(deExp,y0,0.001,200)
Traceback (most recent call last):

  File "<ipython-input-39-7dde4bcbdc2f>", line 1, in <module>
    ysExp=inte.rk4(deExp,y0,0.001,200)

  File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 350, in _compile_for_args
    error_rewrite(e, 'typing')

  File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 317, in error_rewrite
    reraise(type(e), e, None)

  File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\six.py", line 658, in reraise
    raise value.with_traceback(tb)
TypingError: Internal error at <numba.typeinfer.CallConstraint object at 0x00000258E048EA90>:

This continues...

Upvotes: 3

Views: 886

Answers (1)

Gambit1614
Gambit1614

Reputation: 8801

If you see the documentation here, you can specify the optional type arguments explicitly in Numba. For example (this is the same example from documentation):

>>> @jit((optional(intp),))
... def f(x):
...     return x is not None
...
>>> f(0)
True
>>> f(None)
False

Additionally, based on the conversation going on this Github issue you can use the following workaround to implement optional keyword. I have modified the code from the solution provided in the github issue to suit your example:

from numba import jitclass, int32, njit
from collections import OrderedDict
import numpy as np

np_arr = np.asarray([1,2])

spec = OrderedDict()
spec['x'] = int32

@jitclass(spec)
class Foo(object):
    def __init__(self, x):
        self.x = x

    def otherFunc(self, optionalArg):
        if optionalArg is None:
            return self.x + 10
        else:
            return len(optionalArg)
@njit
def useOtherFunc(arg1, optArg):
    foo = Foo(arg1)

    print(foo.otherFunc(optArg))

arg1 = 5

useOtherFunc(arg1, np_arr)   # Output: 2
useOtherFunc(arg1, None)     # Output : 15

See this colab notebook for the example shown above.

Upvotes: 4

Related Questions