Reputation: 35
I am using numba's @njit
decorator to compile a function that gets used in parallel processes, but it is slower than I expected. Processes that should differ an order of magnitude in execution time take around the same time, which makes it look like there's a lot of compilation overhead.
I have a function
@njit
def foo(ar):
(do something)
return ar
and a normal python function
def bar(x):
(do something)
return foo(x)
which gets called in parallel processes like
if __name__=="__main__":
with concurrent.futures.ProcessPoolExecutor(max_workers=maxWorkers) as executor:
results = executor.map(bar, args)
Where args is a long list of arguments.
Does this mean that foo()
gets compiled separately within each process? That would explain the extra overhead. Is there a good solution for this? I could just call foo()
once on one of the arguments before spawning the processes, forcing it to compile ahead of time. Is there a better way?
Upvotes: 3
Views: 1026
Reputation: 50846
Multiprocessing cause spawned processes to execute the code that is not in the main section (ie. if __name__ == "__main__"
). This indeed includes the compilation of the Numba function. Caching can be used to compile the function once and cache it so subsequent compilation are much faster (the code can be loaded from the cache) assuming the function context is the same (eg. parameter type, dependence on global variables, compilation flags, etc.). This feature is available with @nb.njit(cache=True)
. For more information about this, please read this section of the documentation. in your case, the main process will compile the function and other ones will load it from the cache.
Note that it is often better to use the multithreading feature of Numba instead of multiprocessing since spawning process is more expensive (both in time and memory usage). That being said, only few functions can be called from a multithreaded Numba context (mainly Numpy and Numba functions).
Upvotes: 2