Reputation: 121
I am using scipy's double integration dblquad
and I am trying to increase the speed. I have checked the solutions proposed online, but couldn't get them work.
To ease the question, I have prepared the comparison below. What I am doing wrong or what can I do to improve the speed?
from scipy import integrate
import timeit
from numba import njit, jit
def bb_pure(q, z, x_loc, y_loc, B, L):
def f(y, x):
return (
3
* q
* z ** 3
/ (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
)
)
return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]
def bb_numbanjit(q, z, x_loc, y_loc, B, L):
@njit
def f(y, x):
return (
3
* q
* z ** 3
/ (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
)
)
return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]
def bb_numbajit(q, z, x_loc, y_loc, B, L):
@jit
def f(y, x):
return (
3
* q
* z ** 3
/ (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
)
)
return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]
####
starttime = timeit.default_timer()
for i in range(100):
bb_pure(200, 5, 0, 0, i, i*2)
print("Pure Function:", round(timeit.default_timer() - starttime,2))
####
starttime = timeit.default_timer()
for i in range(100):
bb_numbanjit(200, 5, 0, 0, i, i*2)
print("Numba njit:", round(timeit.default_timer() - starttime,2))
####
starttime = timeit.default_timer()
for i in range(100):
bb_numbajit(200, 5, 0, 0, i, i*2)
print("Numba jit:", round(timeit.default_timer() - starttime,2))
Results
Pure Function: 3.22
Numba njit: 8.14
Numba jit: 8.15
Upvotes: 1
Views: 482
Reputation: 50826
The main issue is that you are timing the compilation time of the Numba function. Indeed, when bb_numbanjit
is called, the @njit
decorator tell to Numba to declare a lazily-compiled function which is compiled when the first call is performed, so in integrate.dblquad
. The exact same behaviour applies for bb_numbajit
. The Numba implementation is slower because the compilation time is pretty big compared to the execution time. The thing is the Numba functions are closures that reads local parameters requiring a new compilation. The typical way to solve this is to add new parameters to the Numba function and compile it once. Since you need a closure here, you can use a proxy closure. Here is an example:
@njit
def f_numba(y, x, q, z, x_loc, y_loc, B, L):
return (
3
* q
* z ** 3
/ (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
)
)
def bb_numbanjit(q, z, x_loc, y_loc, B, L):
def f_proxy(y, x):
return f_numba(y, x, q, z, x_loc, y_loc, B, L)
return integrate.dblquad(f_proxy, 0, B, lambda x: 0, lambda x: L)[0]
This is twice faster than the bb_pure
solution.
One reason why this Numba solution is not much faster is that Python function calls are expensive, especially when there are a lot of parameter. Another problem is that some parameters appear to be a constant and Numba is not aware of that because they are passed as runtime argument instead of compile-time constant. You can move the constants in global variables to to let Numba optimize further the code (by pre-computing constant sub-expressions).
Note also that Numba functions are already wrapped by proxy functions internally. The proxy functions are a bit expensive for such basic numerical operations (they do some type checking and pure-Python object to native value conversions). That being said, there is not much to do here due to the closure issue.
Upvotes: 1