Reputation: 8439
Is there a way in a numba jitted function to evaluate every function in a tuple (or list) of functions, at compile time?
Please note that this question is about how to use a Python loop to construct jit code at compile time, rather than iterating over a tuple at runtime, which I know is not supported.
A full non-working example is below, but the core of it is that the following works:
@jit(nopython=True)
def do_stuff(func_tuple):
results = []
results.append(func_tuple[0]())
results.append(func_tuple[1]())
results.append(func_tuple[2]())
results.append(func_tuple[3]())
return results
but the following does not:
@jit(nopython=True)
def do_stuff_2(func_tuple):
results = []
for i in range(4):
results.append(func_tuple[i]())
return results
The error message is as follows, and its meaning is quite clear: indexing into such a tuple is not supported at runtime.
Invalid usage of getitem with parameters ((type(CPUDispatcher(<function f1 at 0x116968268>)), type(CPUDispatcher(<function f2 at 0x1169688c8>)), type(CPUDispatcher(<function f3 at 0x1169a1b70>)), type(CPUDispatcher(<function f4 at 0x1169a1f28>))), int64)
* parameterized
[1] During: typing of intrinsic-call at numba_minimal_not_working_example_2.py (36)
File "numba_minimal_not_working_example_2.py", line 36:
def do_stuff_2(func_tuple):
<source elided>
for i in range(4):
results.append(func_tuple[i]())
^
However, I only need the indexing to occur at compile time - I basically just want to generate functions similar to do_stuff
, but to do so automatically depending on the number of elements in the tuple.
In principle this can happen at compile time, because numba considers the length of a tuple to be part of its type. But I have not been able to work out how to do this. I have tried various tricks involving recursion and/or the @generated_jit
decorator, but I haven't managed to hit on something that works. Is there a way to achieve this?
Here is the full example:
from numba import jit
@jit(nopython=True)
def f1():
return 1
@jit(nopython=True)
def f2():
return 2
@jit(nopython=True)
def f3():
return 3
@jit(nopython=True)
def f4():
return 4
func_tuple = (f1, f2, f3, f4)
# this works:
@jit(nopython=True)
def do_stuff(func_tuple):
results = []
results.append(func_tuple[0]())
results.append(func_tuple[1]())
results.append(func_tuple[2]())
results.append(func_tuple[3]())
return results
# but this does not:
@jit(nopython=True)
def do_stuff_2(func_tuple):
results = []
for i in range(4):
results.append(func_tuple[i]())
return results
# this doesn't either (similar error to do_stuff_2).
@jit(nopython=True)
def do_stuff_3(func_tuple):
results = [f() for f in func_tuple]
return results
print(do_stuff(func_tuple)) # prints '[1, 2, 3, 4]'
print(do_stuff_2(func_tuple)) # gives the error above
#print(do_stuff_3(func_tuple)) # gives a similar error
Upvotes: 0
Views: 653
Reputation: 26886
This is actually a known limitation of Numba. This is also mentioned somehow in the traceback you get.
Basically, when you ask to @jit
your function, Numba is not able to properly infer types for the compiled code.
One workaround could be to use @jit(nopython=False)
on do_stuff_2()
, which would then be able to handle such code by making use of the Python objects system.
Instead, you will not be able to @jit
the do_stuff_3()
function, not even with nopython=False
, since comprehensions are not supported by numba
(at least up to version 0.39.0).
Upvotes: 1