N. Virgo
N. Virgo

Reputation: 8439

Iterating over a tuple at jit-compile time in numba

Is there a way in a numba jitted function to evaluate every function in a tuple (or list) of functions, at compile time?

Please note that this question is about how to use a Python loop to construct jit code at compile time, rather than iterating over a tuple at runtime, which I know is not supported.

A full non-working example is below, but the core of it is that the following works:

@jit(nopython=True)
def do_stuff(func_tuple):
    results = []
    results.append(func_tuple[0]())
    results.append(func_tuple[1]())
    results.append(func_tuple[2]())
    results.append(func_tuple[3]())
    return results

but the following does not:

@jit(nopython=True)
def do_stuff_2(func_tuple):
    results = []
    for i in range(4):
        results.append(func_tuple[i]())
    return results

The error message is as follows, and its meaning is quite clear: indexing into such a tuple is not supported at runtime.

Invalid usage of getitem with parameters ((type(CPUDispatcher(<function f1 at 0x116968268>)), type(CPUDispatcher(<function f2 at 0x1169688c8>)), type(CPUDispatcher(<function f3 at 0x1169a1b70>)), type(CPUDispatcher(<function f4 at 0x1169a1f28>))), int64)
 * parameterized
[1] During: typing of intrinsic-call at numba_minimal_not_working_example_2.py (36)

File "numba_minimal_not_working_example_2.py", line 36:
def do_stuff_2(func_tuple):
    <source elided>
    for i in range(4):
        results.append(func_tuple[i]())
  ^

However, I only need the indexing to occur at compile time - I basically just want to generate functions similar to do_stuff, but to do so automatically depending on the number of elements in the tuple.

In principle this can happen at compile time, because numba considers the length of a tuple to be part of its type. But I have not been able to work out how to do this. I have tried various tricks involving recursion and/or the @generated_jit decorator, but I haven't managed to hit on something that works. Is there a way to achieve this?

Here is the full example:

from numba import jit

@jit(nopython=True)
def f1():
    return 1

@jit(nopython=True)
def f2():
    return 2

@jit(nopython=True)
def f3():
    return 3

@jit(nopython=True)
def f4():
    return 4

func_tuple = (f1, f2, f3, f4)

# this works:
@jit(nopython=True)
def do_stuff(func_tuple):
    results = []
    results.append(func_tuple[0]())
    results.append(func_tuple[1]())
    results.append(func_tuple[2]())
    results.append(func_tuple[3]())
    return results

# but this does not:
@jit(nopython=True)
def do_stuff_2(func_tuple):
    results = []
    for i in range(4):
        results.append(func_tuple[i]())
    return results

# this doesn't either (similar error to do_stuff_2).
@jit(nopython=True)
def do_stuff_3(func_tuple):
    results = [f() for f in func_tuple]
    return results


print(do_stuff(func_tuple)) # prints '[1, 2, 3, 4]'

print(do_stuff_2(func_tuple)) # gives the error above

#print(do_stuff_3(func_tuple)) # gives a similar error

Upvotes: 0

Views: 653

Answers (1)

norok2
norok2

Reputation: 26886

This is actually a known limitation of Numba. This is also mentioned somehow in the traceback you get.

Basically, when you ask to @jit your function, Numba is not able to properly infer types for the compiled code.

One workaround could be to use @jit(nopython=False) on do_stuff_2(), which would then be able to handle such code by making use of the Python objects system. Instead, you will not be able to @jit the do_stuff_3() function, not even with nopython=False, since comprehensions are not supported by numba (at least up to version 0.39.0).

Upvotes: 1

Related Questions