Ogiad
Ogiad

Reputation: 173

How to speed up the computation that is slow even with Numba

I'm having trouble with the slow computation of my Python code. Based on the pycallgraph below, the bottleneck seems to be the module named miepython.miepython.mie_S1_S2 (highlighted by pink), which takes 0.47 seconds per call. enter image description here

The source code for this module is as follows:

import numpy as np
from numba import njit, int32, float64, complex128

__all__ = ('ez_mie',
           'ez_intensities',
           'generate_mie_costheta',
           'i_par',
           'i_per',
           'i_unpolarized',
           'mie',
           'mie_S1_S2',
           'mie_cdf',
           'mie_mu_with_uniform_cdf',
           )


@njit((complex128, float64, float64[:]), cache=True)
def _mie_S1_S2(m, x, mu):
    """
    Calculate the scattering amplitude functions for spheres.

    The amplitude functions have been normalized so that when integrated
    over all 4*pi solid angles, the integral will be qext*pi*x**2.

    The units are weird, sr**(-0.5)

    Args:
        m: the complex index of refraction of the sphere
        x: the size parameter of the sphere
        mu: array of angles, cos(theta), to calculate scattering amplitudes

    Returns:
        S1, S2: the scattering amplitudes at each angle mu [sr**(-0.5)]
    """
    nstop = int(x + 4.05 * x**0.33333 + 2.0) + 1
    a = np.zeros(nstop - 1, dtype=np.complex128)
    b = np.zeros(nstop - 1, dtype=np.complex128)
    _mie_An_Bn(m, x, a, b)

    nangles = len(mu)
    S1 = np.zeros(nangles, dtype=np.complex128)
    S2 = np.zeros(nangles, dtype=np.complex128)

    nstop = len(a)
    for k in range(nangles):
        pi_nm2 = 0
        pi_nm1 = 1
        for n in range(1, nstop):
            tau_nm1 = n * mu[k] * pi_nm1 - (n + 1) * pi_nm2

            S1[k] += (2 * n + 1) * (pi_nm1 * a[n - 1]
                                    + tau_nm1 * b[n - 1]) / (n + 1) / n

            S2[k] += (2 * n + 1) * (tau_nm1 * a[n - 1]
                                    + pi_nm1 * b[n - 1]) / (n + 1) / n

            temp = pi_nm1
            pi_nm1 = ((2 * n + 1) * mu[k] * pi_nm1 - (n + 1) * pi_nm2) / n
            pi_nm2 = temp

    # calculate norm = sqrt(pi * Qext * x**2)
    n = np.arange(1, nstop + 1)
    norm = np.sqrt(2 * np.pi * np.sum((2 * n + 1) * (a.real + b.real)))

    S1 /= norm
    S2 /= norm

    return [S1, S2]

Apparently, the source code is jitted by Numba so it should be faster than it actually is. The number of iterations in for loop in this function is around 25,000 (len(mu)=50, len(a)-1=500).

Any ideas on how to speed up this computation? Is something hindering the fast computation of Numba? Or, do you think the computation is already fast enough?

[More details]

In the above, another function _mie_An_Bn is being used. This function is also jitted, and the source code is as follows:

@njit((complex128, float64, complex128[:], complex128[:]), cache=True)
def _mie_An_Bn(m, x, a, b):
    """
    Compute arrays of Mie coefficients A and B for a sphere.

    This estimates the size of the arrays based on Wiscombe's formula. The length
    of the arrays is chosen so that the error when the series are summed is
    around 1e-6.

    Args:
        m: the complex index of refraction of the sphere
        x: the size parameter of the sphere

    Returns:
        An, Bn: arrays of Mie coefficents
    """
    psi_nm1 = np.sin(x)                   # nm1 = n-1 = 0
    psi_n = psi_nm1 / x - np.cos(x)       # n = 1
    xi_nm1 = complex(psi_nm1, np.cos(x))
    xi_n = complex(psi_n, np.cos(x) / x + np.sin(x))

    nstop = len(a)
    if m.real > 0.0:
        D = _D_calc(m, x, nstop + 1)

        for n in range(1, nstop):
            temp = D[n] / m + n / x
            a[n - 1] = (temp * psi_n - psi_nm1) / (temp * xi_n - xi_nm1)
            temp = D[n] * m + n / x
            b[n - 1] = (temp * psi_n - psi_nm1) / (temp * xi_n - xi_nm1)
            xi = (2 * n + 1) * xi_n / x - xi_nm1
            xi_nm1 = xi_n
            xi_n = xi
            psi_nm1 = psi_n
            psi_n = xi_n.real

    else:
        for n in range(1, nstop):
            a[n - 1] = (n * psi_n / x - psi_nm1) / (n * xi_n / x - xi_nm1)
            b[n - 1] = psi_n / xi_n
            xi = (2 * n + 1) * xi_n / x - xi_nm1
            xi_nm1 = xi_n
            xi_n = xi
            psi_nm1 = psi_n
            psi_n = xi_n.real

The example inputs are like the followings:

m = 1.336-2.462e-09j
x = 8526.95
mu = np.array([-1., -0.7500396, 0.46037385, 0.5988121, 0.67384093, 0.72468684, 0.76421644, 0.79175856, 0.81723714, 0.83962897, 0.85924182, 0.87641596, 0.89383665, 0.90708978, 0.91931481, 0.93067567, 0.94073113, 0.94961222, 0.95689496, 0.96467123,  0.97138347, 0.97791831, 0.98339434, 0.98870543, 0.99414948, 0.9975728   0.9989995, 0.9989995, 0.9989995, 0.9989995, 0.9989995,0.99899951, 0.99899951, 0.99899951, 0.99899951, 0.99899951, 0.99899951, 0.99899951, 0.99899951, 0.99899951, 0.99899952,  0.99899952,
  0.99899952,  0.99899952,  0.99899952,  0.99899952,  0.99899952, 0.99899952, 0.99899952,  1.        ])

Upvotes: 3

Views: 170

Answers (1)

Jérôme Richard
Jérôme Richard

Reputation: 50308

I am focussing on _mie_S1_S2 since it appear to be the most expensive function on the provided example dataset.

First of all, you can use the parameter fastmath=True to the JIT to accelerate the computation if there is no values like +Inf, -Inf, -0 or NaN computed.

Then you can pre-compute some expensive expression containing divisions or implicit integer-to-float conversions. Note that (2 * n + 1) / n = 2 + 1/n and (n + 1) / n = 1 + 1/n. This can be useful to reduce the number of precomputed array but did not change the performance on my machine (this may change regarding the target architecture). Note also that such a precomputation have a slight impact on the result accuracy (most of the time negligible and sometime better than the reference implementation).

On my machine, this strategy make the code 4.5 times faster with fastmath=True and 2.8 times faster without.

The k-based loop can be parallelized using parallel=True and prange of Numba. However, this may not be always faster on all machines (especially the ones with a lot of cores) since the loop is pretty fast.

Here is the final code:

@njit((complex128, float64, float64[:]), cache=True, parallel=True)
def _mie_S1_S2_opt(m, x, mu):
    nstop = int(x + 4.05 * x**0.33333 + 2.0) + 1
    a = np.zeros(nstop - 1, dtype=np.complex128)
    b = np.zeros(nstop - 1, dtype=np.complex128)
    _mie_An_Bn(m, x, a, b)

    nangles = len(mu)
    S1 = np.zeros(nangles, dtype=np.complex128)
    S2 = np.zeros(nangles, dtype=np.complex128)

    factor1 = np.empty(nstop, dtype=np.float64)
    factor2 = np.empty(nstop, dtype=np.float64)
    factor3 = np.empty(nstop, dtype=np.float64)

    for n in range(1, nstop):
        factor1[n - 1] = (2 * n + 1) / (n + 1) / n
        factor2[n - 1] = (2 * n + 1) / n
        factor3[n - 1] = (n + 1) / n

    nstop = len(a)
    for k in nb.prange(nangles):
        pi_nm2 = 0
        pi_nm1 = 1
        for n in range(1, nstop):
            i = n - 1
            tau_nm1 = n * mu[k] * pi_nm1 - (n + 1.0) * pi_nm2

            S1[k] += factor1[i] * (pi_nm1 * a[i] + tau_nm1 * b[i])
            S2[k] += factor1[i] * (tau_nm1 * a[i] + pi_nm1 * b[i])

            temp = pi_nm1
            pi_nm1 = factor2[i] * mu[k] * pi_nm1 - factor3[i] * pi_nm2
            pi_nm2 = temp

    # calculate norm = sqrt(pi * Qext * x**2)
    n = np.arange(1, nstop + 1)
    norm = np.sqrt(2 * np.pi * np.sum((2 * n + 1) * (a.real + b.real)))

    S1 /= norm
    S2 /= norm

    return [S1, S2]

%timeit -n 1000 _mie_S1_S2_opt(m, x, mu)

On my machine with 6 cores, the final optimized implementation is 12 times faster with fastmath=True and 8.8 times faster without. Note that using similar strategies in other functions may also helps to speed up them.

Upvotes: 2

Related Questions