user107511
user107511

Reputation: 822

Reducing numba @jitclass compilation time (with caching?)

I'm using numba's @jitclass and I want to reduce the compilation time.

For example (my actual class is much bigger):

@jitclass
class State:
    a: int
    b: int

    def __init__(self):
        a = 0
        b = 0

    def update(self, x: int):
        self.a += x
        self.b -= x

I tried adding the cache parameter to @jitclass, but it seems that it isn't supported.

@jitclass(cache=True)
class State:
    ...

I also tried to change my class just to hold data, and compile all methods with @njit with cache:

@jitclass
class State:
    a: int
    b: int

    def __init__(self):
        a = 0
        b = 0

@njit(cache=True)
def update(state: State, x: int):
    state.a += x
    state.b -= x

But this seems to make compilation time even worse. My guess is that because State isn't cached it compiles each time, and then the dependent functions requires compilation.

Is there any solution that reduces this compilation time?

Upvotes: 2

Views: 810

Answers (1)

Jérôme Richard
Jérôme Richard

Reputation: 50528

This is a known opened issue and it looks like this will not be solved soon (since the problem has been reported since at least 6 years ago). You can get information about this in the following issues: #1, #2, #3. jitclass cannot be currently safely cached across interpreter restarts and this is a pretty deep issue. To quote one developer:

Note that if this pickling did work, then the cached bytecode would be invalid across interpreter restarts, as it contains a hard-coded memory address of where the type used to live.

Using structrefs seems to be a workaround (though not great). Restructuring the code so not to use jitclass if possible is a better option. Note that jitclass and structref are experimental features so far.

Upvotes: 2

Related Questions