Nate
Nate

Reputation: 1948

Numba not speeding up function

I have some code I'm trying to speed up with numba. I've done some reading on the topic, but I haven't been able to figure it out 100%.

Here is the code:

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
import seaborn as sns
from numba import jit, vectorize, float64, autojit
sns.set(context='talk', style='ticks', font_scale=1.2, rc={'figure.figsize': (6.5, 5.5), 'xtick.direction': 'in', 'ytick.direction': 'in'})

#%% constraints
x_min = 0                               # death below this
x_max = 20                              # maximum weight
t_max = 100                             # maximum time
foraging_efficiencies = np.linspace(0, 1, 10)               # potential foraging efficiencies
R = 10.0                                    # Resource level

#%% make the body size and time categories
body_sizes = np.arange(x_min, x_max+1)
time_steps = np.arange(t_max)

#%% parameter functions
@jit
def metabolic_fmr(x, u,temp):                           # metabolic cost function
    fmr = 0.125*(2**(0.2*temp))*(1 + 0.5*u) + x*0.1
    return fmr

def intake_dist(u):                         # intake stochastic function (returns a vector)
    g = st.binom.pmf(np.arange(R+1), R, u)
    return g

@jit
def mass_gain(x, u, temp):                      # mass gain function (returns a vector)
    x_prime = x - metabolic_fmr(x, u,temp) + np.arange(R+1)
    x_prime = np.minimum(x_prime, x_max)
    x_prime = np.maximum(x_prime, 0)
    return x_prime

@jit
def prob_attack(P):                         # probability of an attack
    p_a = 0.02*P
    return p_a

@jit
def prob_see(u):                            # probability of not seeing an attack
    p_s = 1-(1-u)**0.3
    return p_s

@jit
def prob_lethal(x):                         # probability of lethality given a successful attack
    p_l = 0.5*np.exp(-0.05*x) 
    return p_l

@jit
def prob_mort(P, u, x):
    p_m = prob_attack(P)*prob_see(u)*prob_lethal(x)
    return np.minimum(p_m, 1)

#%% terminal fitness function
@jit
def terminal_fitness(x):
    t_f = 15.0*x/(x+5.0)
    return t_f

#%% linear interpolation function
@jit
def linear_interpolation(x, F, t):
    floor = x.astype(int)
    delta_c = x-floor
    ceiling = floor + 1
    ceiling[ceiling>x_max] = x_max
    floor[floor<x_min] = x_min
    interpolated_F = (1-delta_c)*F[floor,t] + (delta_c)*F[ceiling,t]
    return interpolated_F

#%% solver
@jit
def solver_jit(P, temp):
    F = np.zeros((len(body_sizes), len(time_steps)))            # Expected fitness
    F[:,-1] = terminal_fitness(body_sizes)              # expected terminal fitness for every body size
    V = np.zeros((len(foraging_efficiencies), len(body_sizes), len(time_steps)))        # Fitness for each foraging effort
    D = np.zeros((len(body_sizes), len(time_steps)))            # Decision
    for t in range(t_max-1)[::-1]:
        for x in range(x_min+1, x_max+1):               # iterate over every body size except dead
            for i in range(len(foraging_efficiencies)):     # iterate over every possible foraging efficiency
                u = foraging_efficiencies[i]
                g_u = intake_dist(u)                # calculate the distribution of intakes
                xp = mass_gain(x, u, temp)          # calculate the mass gain
                p_m = prob_mort(P, u, x)            # probability of mortality
                V[i,x,t] = (1 - p_m)*(linear_interpolation(xp, F, t+1)*g_u).sum()       # Fitness calculation
            vmax = V[:,x,t].max()
            idx = np.argwhere(V[:,x,t]==vmax).min()
            D[x,t] = foraging_efficiencies[idx]
            F[x,t] = vmax
    return D, F

def solver_norm(P, temp):
    F = np.zeros((len(body_sizes), len(time_steps)))            # Expected fitness
    F[:,-1] = terminal_fitness(body_sizes)              # expected terminal fitness for every body size
    V = np.zeros((len(foraging_efficiencies), len(body_sizes), len(time_steps)))        # Fitness for each foraging effort
    D = np.zeros((len(body_sizes), len(time_steps)))            # Decision
    for t in range(t_max-1)[::-1]:
        for x in range(x_min+1, x_max+1):               # iterate over every body size except dead
            for i in range(len(foraging_efficiencies)):     # iterate over every possible foraging efficiency
                u = foraging_efficiencies[i]
                g_u = intake_dist(u)                # calculate the distribution of intakes
                xp = mass_gain(x, u, temp)          # calculate the mass gain
                p_m = prob_mort(P, u, x)            # probability of mortality
                V[i,x,t] = (1 - p_m)*(linear_interpolation(xp, F, t+1)*g_u).sum()       # Fitness calculation
            vmax = V[:,x,t].max()
            idx = np.argwhere(V[:,x,t]==vmax).min()
            D[x,t] = foraging_efficiencies[idx]
            F[x,t] = vmax
    return D, F

The individual jit functions tend to be much faster than the un-jitted ones. For example, prob_mort is about 600% faster once it has been run through jit. However, the solver itself isn't much faster:

In [3]: %timeit -n 10 solver_jit(200, 25)
10 loops, best of 3: 3.94 s per loop

In [4]: %timeit -n 10 solver_norm(200, 25)
10 loops, best of 3: 4.09 s per loop

I know some functions can't be jitted, so I replaced the st.binom.pmf function with a custom jit function and that actually slowed down the time to about 17s per loop, over 5x slower. Presumably because the scipy functions are, at this point, heavily optimized.

So I suspect the slowness is either in the linear_interpolate function or somewhere in the solver code outside of the jitted functions (because at one point I un-jitted all the functions and ran solver_norm and got the same time). Any thoughts on where the slow part would be and how to speed it up?

UPDATE

Here's the binomial code I used in an attempt to speed up jit

@jit
def factorial(n):
    if n==0:
        return 1
    else:
        return n*factorial(n-1)

@vectorize([float64(float64,float64,float64)])
def binom(k, n, p):
    binom_coef = factorial(n)/(factorial(k)*factorial(n-k))
    pmf = binom_coef*p**k*(1-p)**(n-k)
    return pmf

@jit
def intake_dist(u):                         # intake stochastic function (returns a vector)
    g = binom(np.arange(R+1), R, u)
    return g

UPDATE 2 I tried running my binomial code in nopython mode and found out I was doing it wrong because it was recursive. Upon fixing that by changing code to:

@jit(int64(int64), nopython=True)
def factorial(nn):
    res = 1
    for ii in range(2, nn + 1):
        res *= ii
    return res

@vectorize([float64(float64,float64,float64)], nopython=True)
def binom(k, n, p):
    binom_coef = factorial(n)/(factorial(k)*factorial(n-k))
    pmf = binom_coef*p**k*(1-p)**(n-k)
    return pmf

the solver now runs at

In [34]: %timeit solver_jit(200, 25)
1 loop, best of 3: 921 ms per loop

which is about 3.5x faster. However, solver_jit() and solver_norm() still run at the same pace, which means there is some code outside the jit functions slowing it down.

Upvotes: 3

Views: 1522

Answers (2)

JoshAdel
JoshAdel

Reputation: 68732

I was able to make a few changes to your code to make it so the jit version could compile completely in nopython mode. On my laptop, this results in:

%timeit solver_jit(200, 25)
1 loop, best of 3: 50.9 ms per loop

%timeit solver_norm(200, 25)
1 loop, best of 3: 192 ms per loop

For reference, I'm using Numba 0.27.0. I'll admit that Numba's compilation errors still make it difficult to identify what is going on, but since I've been playing with it for a while, I've built up an intuition for what needs to be fixed. The complete code is below, but here is the list of changes I made:

  • In linear_interpolation change x.astype(int) to x.astype(np.int64) so it could compile in nopython mode.
  • In the solver, use np.sum as a function and not a method of an array.
  • np.argwhere isn't supported. Write a custom loop.

There are probably some further optimizations that could be made, but this gives an initial speed-up.

The full code:

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
import seaborn as sns
from numba import jit, vectorize, float64, autojit, njit
sns.set(context='talk', style='ticks', font_scale=1.2, rc={'figure.figsize': (6.5, 5.5), 'xtick.direction': 'in', 'ytick.direction': 'in'})

#%% constraints
x_min = 0                               # death below this
x_max = 20                              # maximum weight
t_max = 100                             # maximum time
foraging_efficiencies = np.linspace(0, 1, 10)               # potential foraging efficiencies
R = 10.0                                    # Resource level

#%% make the body size and time categories
body_sizes = np.arange(x_min, x_max+1)
time_steps = np.arange(t_max)

#%% parameter functions
@njit
def metabolic_fmr(x, u,temp):                           # metabolic cost function
    fmr = 0.125*(2**(0.2*temp))*(1 + 0.5*u) + x*0.1
    return fmr

@njit()
def factorial(nn):
    res = 1
    for ii in range(2, nn + 1):
        res *= ii
    return res

@vectorize([float64(float64,float64,float64)], nopython=True)
def binom(k, n, p):
    binom_coef = factorial(n)/(factorial(k)*factorial(n-k))
    pmf = binom_coef*p**k*(1-p)**(n-k)
    return pmf

@njit
def intake_dist(u):                         # intake stochastic function (returns a vector)
    g = binom(np.arange(R+1), R, u)
    return g

@njit
def mass_gain(x, u, temp):                      # mass gain function (returns a vector)
    x_prime = x - metabolic_fmr(x, u,temp) + np.arange(R+1)
    x_prime = np.minimum(x_prime, x_max)
    x_prime = np.maximum(x_prime, 0)
    return x_prime

@njit
def prob_attack(P):                         # probability of an attack
    p_a = 0.02*P
    return p_a

@njit
def prob_see(u):                            # probability of not seeing an attack
    p_s = 1-(1-u)**0.3
    return p_s

@njit
def prob_lethal(x):                         # probability of lethality given a successful attack
    p_l = 0.5*np.exp(-0.05*x) 
    return p_l

@njit
def prob_mort(P, u, x):
    p_m = prob_attack(P)*prob_see(u)*prob_lethal(x)
    return np.minimum(p_m, 1)

#%% terminal fitness function
@njit
def terminal_fitness(x):
    t_f = 15.0*x/(x+5.0)
    return t_f

#%% linear interpolation function
@njit
def linear_interpolation(x, F, t):
    floor = x.astype(np.int64)
    delta_c = x-floor
    ceiling = floor + 1
    ceiling[ceiling>x_max] = x_max
    floor[floor<x_min] = x_min
    interpolated_F = (1-delta_c)*F[floor,t] + (delta_c)*F[ceiling,t]
    return interpolated_F

#%% solver
@njit
def solver_jit(P, temp):
    F = np.zeros((len(body_sizes), len(time_steps)))            # Expected fitness
    F[:,-1] = terminal_fitness(body_sizes)              # expected terminal fitness for every body size
    V = np.zeros((len(foraging_efficiencies), len(body_sizes), len(time_steps)))        # Fitness for each foraging effort
    D = np.zeros((len(body_sizes), len(time_steps)))            # Decision
    for t in range(t_max-2,-1,-1):
        for x in range(x_min+1, x_max+1):               # iterate over every body size except dead
            for i in range(len(foraging_efficiencies)):     # iterate over every possible foraging efficiency
                u = foraging_efficiencies[i]
                g_u = intake_dist(u)                # calculate the distribution of intakes
                xp = mass_gain(x, u, temp)          # calculate the mass gain
                p_m = prob_mort(P, u, x)            # probability of mortality
                V[i,x,t] = (1 - p_m)*np.sum((linear_interpolation(xp, F, t+1)*g_u))       # Fitness calculation
            vmax = V[:,x,t].max()

            for k in xrange(V.shape[0]):
                if V[k,x,t] == vmax:
                    idx = k
                    break
            #idx = np.argwhere(V[:,x,t]==vmax).min()
            D[x,t] = foraging_efficiencies[idx]
            F[x,t] = vmax
    return D, F

def solver_norm(P, temp):
    F = np.zeros((len(body_sizes), len(time_steps)))            # Expected fitness
    F[:,-1] = terminal_fitness(body_sizes)              # expected terminal fitness for every body size
    V = np.zeros((len(foraging_efficiencies), len(body_sizes), len(time_steps)))        # Fitness for each foraging effort
    D = np.zeros((len(body_sizes), len(time_steps)))            # Decision
    for t in range(t_max-1)[::-1]:
        for x in range(x_min+1, x_max+1):               # iterate over every body size except dead
            for i in range(len(foraging_efficiencies)):     # iterate over every possible foraging efficiency
                u = foraging_efficiencies[i]
                g_u = intake_dist(u)                # calculate the distribution of intakes
                xp = mass_gain(x, u, temp)          # calculate the mass gain
                p_m = prob_mort(P, u, x)            # probability of mortality
                V[i,x,t] = (1 - p_m)*(linear_interpolation(xp, F, t+1)*g_u).sum()       # Fitness calculation
            vmax = V[:,x,t].max()
            idx = np.argwhere(V[:,x,t]==vmax).min()
            D[x,t] = foraging_efficiencies[idx]
            F[x,t] = vmax
    return D, F

Upvotes: 2

osvil
osvil

Reputation: 56

As said, there is likely some code that is falling back to object mode. I just wanted to add that you can use njit instead of jit to disable object mode. That will help diagnose what code is the culprit.

Upvotes: 0

Related Questions