Kosi
Kosi

Reputation: 1

Numba bad performance for a simple for loop (Python 3.10)

my code:


from numba import njit
from functools import wraps
import time

def timeit(my_func):
    @wraps(my_func)
    def timed(*args, **kw):
    
        tstart = time.time()
        output = my_func(*args, **kw)
        tend = time.time()
        
        print('"{}" took {:.3f} ms to execute\n'.format(my_func.__name__, (tend - tstart) * 1000))
        return output
    return timed

@timeit
@njit
def calculate_smth(a,b):
    result = 0
    for i_a in range(a):
        for i_b in range(b):
            result = result + i_a + i_b
    return result

if __name__ == "__main__":
    value = calculate_smth(1000,1000)

without the numba decorator my function completes in ~62ms, with njit decorates (after compiling beforehand) it needs ~370ms. Can someone explain what I am missing?

Upvotes: 0

Views: 175

Answers (1)

matszwecja
matszwecja

Reputation: 7971

JIT stands for Just-In-Time - meaning the code is compiled at execution time - as opposed to AOT - Ahead Of Time. As you can read in Numba docs, by default compilation is lazy, ie. it happens on the first function execution in a program.

It also supports AOT compilation, as described here

Another option would be cache=True parameter passed to numba.njit decorator.

As a concrete example, editing your code to include a dummy call to compile the function, we can see that it actually takes little to no time at all to perform the function:

...
if __name__ == "__main__":
    calculate_smth(1,1)
    value = calculate_smth(1000, 1000)

Output:

"calculate_smth" took 590.578 ms to execute

"calculate_smth" took 0.000 ms to execute

Upvotes: 1

Related Questions