user3243944
user3243944

Reputation: 121

Fastest Double Integration Method

I am using scipy's double integration dblquad and I am trying to increase the speed. I have checked the solutions proposed online, but couldn't get them work. To ease the question, I have prepared the comparison below. What I am doing wrong or what can I do to improve the speed?

from scipy import integrate
import timeit
from numba import njit, jit

def bb_pure(q, z, x_loc, y_loc, B, L):

    def f(y, x):
        return (
            3
            * q
            * z ** 3
            / (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
            )
        )

    return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]

def bb_numbanjit(q, z, x_loc, y_loc, B, L):
    
    @njit
    def f(y, x):
        return (
            3
            * q
            * z ** 3
            / (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
            )
        )

    return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]

def bb_numbajit(q, z, x_loc, y_loc, B, L):
    
    @jit
    def f(y, x):
        return (
            3
            * q
            * z ** 3
            / (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
            )
        )

    return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]

####

starttime = timeit.default_timer()
for i in range(100):
    bb_pure(200, 5, 0, 0, i, i*2)

print("Pure Function:", round(timeit.default_timer() - starttime,2))

####

starttime = timeit.default_timer()
for i in range(100):
    bb_numbanjit(200, 5, 0, 0, i, i*2)

print("Numba njit:", round(timeit.default_timer() - starttime,2))

####

starttime = timeit.default_timer()
for i in range(100):
    bb_numbajit(200, 5, 0, 0, i, i*2)

print("Numba jit:", round(timeit.default_timer() - starttime,2))

Results

Pure Function: 3.22
Numba njit: 8.14
Numba jit: 8.15

Upvotes: 1

Views: 482

Answers (1)

Jérôme Richard
Jérôme Richard

Reputation: 50826

The main issue is that you are timing the compilation time of the Numba function. Indeed, when bb_numbanjit is called, the @njit decorator tell to Numba to declare a lazily-compiled function which is compiled when the first call is performed, so in integrate.dblquad. The exact same behaviour applies for bb_numbajit. The Numba implementation is slower because the compilation time is pretty big compared to the execution time. The thing is the Numba functions are closures that reads local parameters requiring a new compilation. The typical way to solve this is to add new parameters to the Numba function and compile it once. Since you need a closure here, you can use a proxy closure. Here is an example:

@njit
def f_numba(y, x, q, z, x_loc, y_loc, B, L):
    return (
        3
        * q
        * z ** 3
        / (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
        )
    )

def bb_numbanjit(q, z, x_loc, y_loc, B, L):
    def f_proxy(y, x):
        return f_numba(y, x, q, z, x_loc, y_loc, B, L)

    return integrate.dblquad(f_proxy, 0, B, lambda x: 0, lambda x: L)[0]

This is twice faster than the bb_pure solution.

One reason why this Numba solution is not much faster is that Python function calls are expensive, especially when there are a lot of parameter. Another problem is that some parameters appear to be a constant and Numba is not aware of that because they are passed as runtime argument instead of compile-time constant. You can move the constants in global variables to to let Numba optimize further the code (by pre-computing constant sub-expressions).

Note also that Numba functions are already wrapped by proxy functions internally. The proxy functions are a bit expensive for such basic numerical operations (they do some type checking and pure-Python object to native value conversions). That being said, there is not much to do here due to the closure issue.

Upvotes: 1

Related Questions