Franke Hsu
Franke Hsu

Reputation: 300

how to speed up correctly using Numba?

I am currently working on the speed up for my python functions.

def d_lat(dlat,R=6.371*1e6):
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)


def d_lon(lat1,lat2,dlon,R=6.371*1e6):
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) * 
                           np.cos(np.deg2rad(lat2)) *
                           np.sin(np.deg2rad(dlon)/2)**2)

def distance(u,v,lon1,lat1):
    lat2, lon2 = lat1.copy(), lon1.copy()
    lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
    lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    return dlon, dlat

As you can see, this is the simple code that is based on the numpy. I read most of the articles on the internet, what they said are just put @numba.jit as the decorator in front of the function, and then I can use Numba to speed up my code.

Here is the test I have done.

u = np.random.randn(10000)
v = np.random.randn(10000)
lon1 = np.random.uniform(-99,-96,10000)
lat1 = np.random.uniform( 23, 25,10000)
print(u)
%%timeit
for i in range(10000):
    distance(u,v,lon1,lat1)

5.61 s ± 58.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Add Numba decorator

@numba.njit()
def d_lat(dlat,R=6.371*1e6):
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)

@numba.njit()
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) * 
                           np.cos(np.deg2rad(lat2)) *
                           np.sin(np.deg2rad(dlon)/2)**2)

@numba.njit()
def distance(u, v, lon1, lat1, R=6.371*1e6):
    lat2, lon2 = lat1.copy(), lon1.copy()
    lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
    lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    return d_lon(lat1,lat2,dlon), d_lat(dlat)
%%timeit
for i in range(10000):
    a,b = distance(u,v,lon1,lat1)

7.76 s ± 64.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

As you can see above, the computational speed of my Numba case is slower than my pure python case. Could anyone please help me to solve this?

ps: version of numba
llvmlite 0.32.0rc1
numba 0.49.0rc2

------ Computational Test Regarding the macroeconomist's answer. ------

According to his answer, even Numba now is smart enough, if we want the code is going to be Numba-decorated, it is better to use plain "Fortran"/"C" type of styles. Below is presenting the computational time comparison between different methods that I was thinking.

def d_lat(dlat,R=6.371*1e6):
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)

def d_lon(lat1,lat2,dlon,R=6.371*1e6):
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) * 
                           np.cos(np.deg2rad(lat2)) *
                           np.sin(np.deg2rad(dlon)/2)**2)

def distance(u,v,lon1,lat1):
    lat2, lon2 = lat1.copy(), lon1.copy()
    lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
    lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    return dlon, dlat
%%timeit
for i in range(10000):
    distance(u,v,lon1,lat1)

54 s ± 485 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

@numba.jit(nogil=True)
def d_lat(dlat,R=6.371*1e6):
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)

@numba.jit(nogil=True)
def d_lon(lat1,lat2,dlon,R=6.371*1e6):
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) * 
                           np.cos(np.deg2rad(lat2)) *
                           np.sin(np.deg2rad(dlon)/2)**2)

def distance(u, v, lon1, lat1, R=6.371*1e6):
    lat2, lon2 = lat1.copy(), lon1.copy()
    lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
    lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    return d_lon(lat1,lat2,dlon), d_lat(dlat)
%%timeit
for i in range(10000):
    a,b = distance(u,v,lon1,lat1)

1min 21s ± 815 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

def d_lat(dlat,R=6.371*1e6):
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)

def d_lon(lat1,lat2,dlon,R=6.371*1e6):
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) * 
                           np.cos(np.deg2rad(lat2)) *
                           np.sin(np.deg2rad(dlon)/2)**2)

@numba.njit(nogil=True)
def distance(u, v, lon1, lat1, R=6.371*1e6):
    def d_lat(dlat,R=6.371*1e6):
        return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2)             
    def d_lon(lat1,lat2,dlon,R=6.371*1e6):
        return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) * 
                               np.cos(np.deg2rad(lat2)) *
                               np.sin(np.deg2rad(dlon)/2)**2)
    lat2, lon2 = lat1.copy(), lon1.copy()
    lat2[v>0], lat2[v<0], = lat1[v>0]+1, lat1[v<0]-1,
    lon2[u>0], lon2[u<0], = lon1[u>0]+1, lon1[u<0]-1,
    dlat = d_lat(lat2 - lat1)
    dlon = d_lon(lat1,lat2,lon2 - lon1)
    return dlon, dlat
%%timeit
for i in range(10000):
    a,b = distance(u,v,lon1,lat1)

1min 2s ± 239 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

@numba.njit() 
def d_lat(dlat,R=6.371*1e6): 
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2) 

@numba.njit() 
def d_lon(lat1,lat2,dlon,R=6.371*1e6): 
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *  
                           np.cos(np.deg2rad(lat2)) * 
                           np.sin(np.deg2rad(dlon)/2)**2) 

@numba.njit() 
def distance(u, v, lon1, lat1): 
    lon2 = np.empty_like(lon1) 
    lat2 = np.empty_like(lat1) 
    dlon = np.empty_like(lon1) 
    dlat = np.empty_like(lat1) 

    for i in range(len(v)): 
        vi = v[i] 
        if vi > 0: 
            lat2[i] = lat1[i]+1 
            dlat[i] = 1 
        elif vi < 0: 
            lat2[i] = lat1[i]-1 
            dlat[i] = -1 
        else: 
            lat2[i] = lat1[i] 
            dlat[i] = 0 

    for i in range(len(u)): 
        ui = u[i] 
        if ui > 0:  
            lon2[i] = lon1[i]+1 
            dlon[i] = 1 
        elif ui < 0: 
            lon2[i] = lon1[i]-1 
            dlon[i] = -1 
        else: 
            lon2[i] = lon1[i] 
            dlon[i] = 0 

    return d_lon(lat1,lat2,dlon), d_lat(dlat) 
%%timeit
for i in range(10000):
    distance(u,v,lon1,lat1)

35.9 s ± 537 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Upvotes: 0

Views: 792

Answers (1)

macroeconomist
macroeconomist

Reputation: 701

There are a couple issues that jump out.

First, your calculations in the distance function are unnecessarily complicated, and written in a style (with lots of fancy indexing e.g. lat2[v>0]) that may not be ideal for the Numba compiler. Although Numba is getting smarter, I find that there is still a high return to writing code in a simple, loop-oriented way.

Second, Numba can be slowed down a little by optional arguments. I found that this was true primarily for the optional R in your distance function.

Fixing these two issues - in particular, replacing your vectorized code with simpler loops that minimize operations - we get code of the form

@numba.njit() 
def d_lat(dlat,R=6.371*1e6): 
    return 2 * R * np.sqrt(np.sin(np.deg2rad(dlat)/2)**2) 

@numba.njit() 
def d_lon(lat1,lat2,dlon,R=6.371*1e6): 
    return 2 * R * np.sqrt(np.cos(np.deg2rad(lat1)) *  
                           np.cos(np.deg2rad(lat2)) * 
                           np.sin(np.deg2rad(dlon)/2)**2) 

@numba.njit() 
def distance(u, v, lon1, lat1): 
    lon2 = np.empty_like(lon1) 
    lat2 = np.empty_like(lat1) 
    dlon = np.empty_like(lon1) 
    dlat = np.empty_like(lat1) 

    for i in range(len(v)): 
        vi = v[i] 
        if vi > 0: 
            lat2[i] = lat1[i]+1 
            dlat[i] = 1 
        elif vi < 0: 
            lat2[i] = lat1[i]-1 
            dlat[i] = -1 
        else: 
            lat2[i] = lat1[i] 
            dlat[i] = 0 

    for i in range(len(u)): 
        ui = u[i] 
        if ui > 0:  
            lon2[i] = lon1[i]+1 
            dlon[i] = 1 
        elif ui < 0: 
            lon2[i] = lon1[i]-1 
            dlon[i] = -1 
        else: 
            lon2[i] = lon1[i] 
            dlon[i] = 0 

    return d_lon(lat1,lat2,dlon), d_lat(dlat) 

On my (slower) system, this decreases the time after the initial cost of compilation from around 7 seconds to around 4 seconds. At that point, I believe the cost is dominated by the raw cost of all the functions np.sin, np.cos, np.exp, etc.

Upvotes: 1

Related Questions