Gil
Gil

Reputation: 177

Accelerating the following code using Numba

I am trying to use Numba to accelerate a piece of code. The code is simple, basically a loop with simple calculations on a numpy array.

import numpy as np
import time
from numba import jit, double

def MinimizeSquareDiffBudget(x, budget):
    if (budget > np.sum(x)):
        return x
    n = np.size(x,0)
    j = 1
    i = 0
    y = np.zeros((n, 1))
    while (budget > 0):
        while (x[i] == x[j]) and (j < n-1):
            j += 1
        i = j - 1
        if (np.std(x)<1e-10):
            to_give = budget/n
            y += to_give
            x= x- to_give
            break
        to_give = min(budget, (x[0] - x[j])*j)
        y[0:j] += to_give/j
        x[0:j]=x[0:j]-to_give/j
        budget = budget - to_give
        j = 1
    return y

Now, I tried optimizing it using @jit and by defining:

fastMinimizeSquareDiffBudget = jit(double[:,:](double[:,:], double[:,:]))(MinimizeSquareDiffBudget)

However, the time is roughly the same, while I expected Numba to be much faster.

Testing the code:

budget = 335.0

x = np.random.uniform(0,1,(1000,1))
x.sort(axis=0)
x = x[::-1]
t = time.process_time()
y = MinimizeSquareDiffBudget(x, budget)
print(time.process_time()-t)

x = np.random.uniform(0,1,(1000,1))
x.sort(axis=0)
x = x[::-1]

t = time.process_time()
y = fastMinimizeSquareDiffBudget(x, budget)
print(time.process_time()-t)

takes 0.28 secs for the direct implementation and 0.45 secs for the optimized code with Numba. The same code written in C takes less than 0.001 secs.

Any ideas?

Upvotes: 0

Views: 1027

Answers (1)

JoshAdel
JoshAdel

Reputation: 68732

When you time only one execution of the jitted function, you are seeing both the run time and the time it takes Numba to jit the code. If you run the code a second time you'll see the actual speed-up since Numba uses an in-memory cache of the compiled function so you only pay the compilation time once per argument type.

On my machine using python 3.6 and numba 0.31.0, the pure python function takes 0.32 seconds. The first time I call fastMinimizeSquareDiffBudget it takes 0.57 seconds, but the second time it takes 0.31 seconds.

Now the reason you're not seeing a huge speed-up is because you have a function that Numba can't compile in nopython mode, so it falls back to the much slower object mode. If you pass nopython=True to the jit method, you'll be able to see where it can't compile. The two issues I saw were that you should use x.shape[0] instead of np.size(x,0), and you can't use min in the way you are.

Upvotes: 1

Related Questions