hopeso
hopeso

Reputation: 165

Cython optimization of the code

I'm struggling to boost the performance of my python particle tracking code with Cython.

Here's my pure Python code:

from scipy.integrate import odeint
import numpy as np
from numpy import sqrt, pi, sin, cos
from time import time as Time
import multiprocessing as mp
from functools import partial

cLight = 299792458.
Dim = 6

class Integrator:
    def __init__(self, ring):
        self.ring = ring

    def equations(self, X, s):
        dXds = np.zeros(Dim)

        E, B = self.ring.getEMField( [X[0], X[2], s], X[4] )

        h = 1 + X[0]/self.ring.ringRadius
        p_s = np.sqrt(X[5]**2 - self.ring.particle.mass**2 - X[1]**2 - X[3]**2)
        dtds = h*X[5]/p_s
        gamma = X[5]/self.ring.particle.mass
        beta = np.array( [X[1], X[3], p_s] ) / X[5]

        dXds[0] = dtds*beta[0]
        dXds[2] = dtds*beta[1]
        dXds[1] = p_s/self.ring.ringRadius + self.ring.particle.charge*(dtds*E[0] + dXds[2]*B[2] - h*B[1])
        dXds[3] = self.ring.particle.charge*(dtds*E[1] + h*B[0] - dXds[0]*B[2])
        dXds[4] = dtds
        dXds[5] = self.ring.particle.charge*(dXds[0]*E[0] + dXds[2]*E[1] + h*E[2])
        return dXds

    def odeSolve(self, X0, sRange):
        sol = odeint(self.equations, X0, sRange)
        return sol

class Ring:
    def __init__(self, particle):
        self.particle = particle
        self.ringRadius = 7.112
        self.magicB0 = self.particle.magicMomentum/self.ringRadius

    def getEMField(self, pos, time):
        x, y, s = pos
        theta = (s/self.ringRadius*180/pi) % 360
        r = sqrt(x**2 + y**2)
        arg = 0 if r == 0 else np.angle( complex(x/r, y/r) )
        rn = r/0.045

        k2 = 37*24e3
        k10 = -4*24e3

        E = np.zeros(3)
        B = np.array( [ 0, self.magicB0, 0 ] )

        for i in range(4):
            if ((21.9+90*i < theta < 34.9+90*i or 38.9+90*i < theta < 64.9+90*i) and (-0.05 < x < 0.05 and -0.05 < y < 0.05)):
                E = np.array( [ k2*x/0.045 + k10*rn**9*cos(9*arg), -k2*y/0.045 -k10*rn**9*sin(9*arg), 0] )
                break
        return E, B

class Particle:
    def __init__(self):
        self.mass = 105.65837e6
        self.charge = 1.
        self.gm2 = 0.001165921 

        self.magicMomentum = self.mass/sqrt(self.gm2)
        self.magicEnergy = sqrt(self.magicMomentum**2 + self.mass**2)
        self.magicGamma = self.magicEnergy/self.mass
        self.magicBeta = self.magicMomentum/(self.magicGamma*self.mass)


def runSimulation(nParticles, tEnd):
    particle = Particle()
    ring = Ring(particle)
    integrator = Integrator(ring)

    Xs = np.array( [ np.array( [45e-3*(np.random.rand()-0.5)*2, 0, 0, 0, 0, particle.magicEnergy] ) for i in range(nParticles) ] )
    sRange = np.arange(0, tEnd, 1e-9)*particle.magicBeta*cLight 

    ode = partial(integrator.odeSolve, sRange=sRange)

    t1 = Time()

    pool = mp.Pool()
    sol = np.array(pool.map(ode, Xs))

    t2 = Time()
    print ("%.3f sec" %(t2-t1))

    return t2-t1

Obviously, the most time-consuming process is integrating the ODE, defined as odeSolve() and equations() in class Integrator. Also, getEMField() method in class Ring is called as much as equations() method during the solving process. I tried to get significant amount of speed up (at least 10x~20x) using Cython, but I only got ~1.5x level of speed up by the following Cython script:

import cython
import numpy as np
cimport numpy as np
from libc.math cimport sqrt, pi, sin, cos

from scipy.integrate import odeint
from time import time as Time
import multiprocessing as mp
from functools import partial

cdef double cLight = 299792458.
cdef int Dim = 6

@cython.boundscheck(False)
cdef class Integrator:
    cdef Ring ring

    def __init__(self, ring):
        self.ring = ring

    cpdef np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] equations(self,
                  np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] X,
                  double s):
        cdef np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] dXds = np.zeros(Dim)
        cdef double h, p_s, dtds, gamma
        cdef np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] beta, E, B

        E, B = self.ring.getEMField( [X[0], X[2], s], X[4] )

        h = 1 + X[0]/self.ring.ringRadius
        p_s = np.sqrt(X[5]*X[5] - self.ring.particle.mass*self.ring.particle.mass - X[1]*X[1] - X[3]*X[3])
        dtds = h*X[5]/p_s
        gamma = X[5]/self.ring.particle.mass
        beta = np.array( [X[1], X[3], p_s] ) / X[5]

        dXds[0] = dtds*beta[0]
        dXds[2] = dtds*beta[1]
        dXds[1] = p_s/self.ring.ringRadius + self.ring.particle.charge*(dtds*E[0] + dXds[2]*B[2] - h*B[1])
        dXds[3] = self.ring.particle.charge*(dtds*E[1] + h*B[0] - dXds[0]*B[2])
        dXds[4] = dtds
        dXds[5] = self.ring.particle.charge*(dXds[0]*E[0] + dXds[2]*E[1] + h*E[2])
        return dXds

    cpdef np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] odeSolve(self,
                 np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] X0,
                 np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] sRange):
        sol = odeint(self.equations, X0, sRange)
        return sol

@cython.boundscheck(False)
cdef class Ring:
    cdef Particle particle
    cdef double ringRadius
    cdef double magicB0

    def __init__(self, particle):
        self.particle = particle
        self.ringRadius = 7.112
        self.magicB0 = self.particle.magicMomentum/self.ringRadius

    cpdef tuple getEMField(self,
                   list pos,
                   double time):
        cdef double x, y, s
        cdef double theta, r, rn, arg, k2, k10
        cdef np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] E, B

        x, y, s = pos
        theta = (s/self.ringRadius*180/pi) % 360
        r = sqrt(x*x + y*y)
        arg = 0 if r == 0 else np.angle( complex(x/r, y/r) )
        rn = r/0.045

        k2 = 37*24e3
        k10 = -4*24e3

        E = np.zeros(3)
        B = np.array( [ 0, self.magicB0, 0 ] )

        for i in range(4):
            if ((21.9+90*i < theta < 34.9+90*i or 38.9+90*i < theta < 64.9+90*i) and (-0.05 < x < 0.05 and -0.05 < y < 0.05)):
                E = np.array( [ k2*x/0.045 + k10*rn**9*cos(9*arg), -k2*y/0.045 -k10*rn**9*sin(9*arg), 0] )
                #E = np.array( [ k2*x/0.045, -k2*y/0.045, 0] )
                break
        return E, B

cdef class Particle:
    cdef double mass
    cdef double charge
    cdef double gm2

    cdef double magicMomentum
    cdef double magicEnergy
    cdef double magicGamma
    cdef double magicBeta

    def __init__(self):
        self.mass = 105.65837e6
        self.charge = 1.
        self.gm2 = 0.001165921 

        self.magicMomentum = self.mass/sqrt(self.gm2)
        self.magicEnergy = sqrt(self.magicMomentum**2 + self.mass**2)
        self.magicGamma = self.magicEnergy/self.mass
        self.magicBeta = self.magicMomentum/(self.magicGamma*self.mass)

def runSimulation(nParticles, tEnd):
    particle = Particle()
    ring = Ring(particle)
    integrator = Integrator(ring)

    #nParticles = 5
    Xs = np.array( [ np.array( [45e-3*(np.random.rand()-0.5)*2, 0, 0, 0, 0, particle.magicEnergy] ) for i in range(nParticles) ] )
    sRange = np.arange(0, tEnd, 1e-9)*particle.magicBeta*cLight 

    ode = partial(integrator.odeSolve, sRange=sRange)

    t1 = Time()

    pool = mp.Pool()
    sol = np.array(pool.map(ode, Xs))

    t2 = Time()
    print ("%.3f sec" %(t2-t1))

    return t2-t1

What should I do to get the maximum effect from Cython? (I tried Numba instead of Cython, and actually the performance gain from Numba was enormous (around ~20x speedup). But I had extremely hard time to utilize Numba with python class instances, and I decided to use Cython instead of Numba).

For reference, the following is cython annotation on its compilation: enter image description here

enter image description here

enter image description here

enter image description here

Upvotes: 2

Views: 484

Answers (1)

DavidW
DavidW

Reputation: 30929

This is a very incomplete answer since I haven't profiled or timed anything or even checked that it gives the same answer. However here are some suggestions that reduce the amount of Python code that Cython generates:

  • Add the @cython.cdivision(True) compilation directive. This means that a ZeroDivisionError won't be raised on float division and you'll get a NaN value instead. (Only do this if you don't want the error to be raised).

  • Change p_s = np.sqrt(...) to p_s = sqrt(...). This removes a numpy call that only operates on a single value. You seem to have done this elsewhere so I don't know why you missed this line.

  • Where possible use fixed size C arrays instead of numpy arrays:

    cdef double beta[3]
    # ...
    beta[0] = X[1]/X[5]
    beta[1] = X[3]/X[5]
    beta[2] = p_s/X[5]
    

    You can do this when the size is known at compile time (and fairly small) and when you don't want to return it. This avoids a call to np.zeros and some subsequent type-checking to assign it the the typed numpy array. I think beta is the only place you can do this.

  • np.angle( complex(x/r, y/r) ) can be replaced by atan2(y/r, x/r) (using atan2 from libc.math. You can also lose the division by r

  • cdef int i helps make your for loop faster in getEMField (Cython is often good at automatically picking up the types of loop variables but seems to have failed here)

  • I suspect it's quicker to assign E element-by-element than as a whole array:

            E[0] = k2*x/0.045 + k10*rn**9*cos(9*arg)
            E[1] = -k2*y/0.045 -k10*rn**9*sin(9*arg)
    
  • There isn't much value in specifying types like list and tuple and it may actually make the code slightly slower (because it will waste time checking the types).

  • A bigger change would be to pass E and B into GetEMField as pointers rather than using allocating them np.zeros. This would let you allocate them as static C arrays in equations (cdef double E[3]). The downside is that GetEMField would have to be cdef so no longer callable from Python (but you could make a Python callable wrapper function too if you like).

Upvotes: 1

Related Questions