Jean-Eric
Jean-Eric

Reputation: 402

JIT unable to improve my JAX code: where am I wrong?

Here is a simple JAX code to show the Metropolis Algorithm in action to solve a 3 parameters bayesian regression pb. Running wo the JIT compilation is ok even on a CPU. Now I would like to know why when the 2 line concerning JIT are decommented then the timing is not really different both of CPU (Jit or not JIT) and comparing running on a CPU or on a K80/Nvidia GPU?

May I have coded in a wrong/inefficient way?

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import jacfwd, jacrev, hessian
from jax.ops import index, index_update
from functools import partial

import scipy.stats as scs
import numpy as np

#@partial(jax.jit, static_argnums=(1,))
def jax_metropolis_kernel(rng_key, logpdf, position, log_prob):
    key, subkey = jax.random.split(rng_key)
    """Moves the chain by one step using the Random Walk Metropolis algorithm.
    """
    move_proposals = jax.random.normal(key, shape=position.shape) * 0.1
        
    proposal = position + move_proposals
    proposal_log_prob = logpdf(proposal)

    log_uniform = jnp.log(jax.random.uniform(subkey))
    do_accept = log_uniform < proposal_log_prob - log_prob

    position = jnp.where(do_accept, proposal, position)
    log_prob = jnp.where(do_accept, proposal_log_prob, log_prob)
    return position, log_prob

#@partial(jax.jit, static_argnums=(1, 2))
def jax_metropolis_sampler(rng_key, n_samples, logpdf, initial_position):
    """Generate samples using the Random Walk Metropolis algorithm.
    """
    def mh_update(i, state):
        key, positions, log_prob = state
        _, key = jax.random.split(key)
                
        new_position, new_log_prob = jax_metropolis_kernel(key, 
                                                           logpdf, 
                                                           positions[i-1], 
                                                           log_prob)
                
        
        positions=positions.at[i].set(new_position)
        return (key, positions, new_log_prob)

    #Initialisation
    keys = jax.random.split(rng_key,num=4)
    all_positions = jnp.zeros((n_samples,initial_position.shape[0]))  # 1 chain for each vmap call    ?
#    all_positions=all_positions.at[0,0].set(scs.norm.rvs(loc=1,scale=1))
#    all_positions=all_positions.at[0,1].set(scs.norm.rvs(loc=2,scale=1))
#    all_positions=all_positions.at[0,2].set(scs.uniform.rvs(loc=1,scale=2))
    
    all_positions=all_positions.at[0,0].set(jax.random.normal(keys[0])+1.)
    all_positions=all_positions.at[0,1].set(jax.random.normal(keys[1])+2.)
    all_positions=all_positions.at[0,2].set(jax.random.uniform(keys[2],minval=1.0, maxval=3.0))

    logp = logpdf(all_positions[0])
    
    initial_state = (rng_key,all_positions, logp)
    rng_key, all_positions, log_prob = jax.lax.fori_loop(1, n_samples, 
                                                 mh_update, 
                                                 initial_state)
    
    return all_positions

def jax_my_logpdf(par,xi,yi):
    # priors: a=par[0], b=par[1], sigma=par[2]
    logpdf_a = jax.scipy.stats.norm.logpdf(x=par[0],loc=1.,scale=1.)
    logpdf_b = jax.scipy.stats.norm.logpdf(x=par[1],loc=2.,scale=1.)
    logpdf_s = jax.scipy.stats.gamma.logpdf(x=par[2],a=3,scale=1.)

    val = xi*par[1]+par[0]
    tmp = jax.scipy.stats.norm.logpdf(x=val,loc=yi,scale=par[2])    
    log_likeh= jnp.sum(tmp)
    
    rc = log_likeh + logpdf_a + logpdf_b + logpdf_s

    return log_likeh + logpdf_a + logpdf_b + logpdf_s

######## Main ########
n_dim = 3
n_forget = 1_000
n_samples = 100_000 + n_forget
n_chains = 100
rng_key = jax.random.PRNGKey(42)

# generation of (xi,yi) set
sample_size = 5_000
sigma_e = 1.5             # true value of parameter error sigma
random_num_generator = np.random.RandomState(0)
xi = 10.0 * random_num_generator.rand(sample_size)
e = random_num_generator.normal(0, sigma_e, sample_size)
yi = 1.0 + 2.0 * xi +  e          # a = 1.0; b = 2.0; y = a + b*x


rng_keys = jax.random.split(rng_key, n_chains)    # generate an array of size (n_chains, 2)
initial_position = jnp.ones((n_dim, n_chains))    # generate an array of size (n_dim, n_chains)
                                                  # so for vmap one should connect axis 0 of rng_keys  
                                                  # and axis 1 of initial_position

#print("main initial_position shape",initial_position.shape)

run_mcmc = jax.vmap(jax_metropolis_sampler, 
                    in_axes=(0, None, None, 1),   # see comment above 
                    out_axes=0)                   # output axis 0 hold the vectorization over n_chains
                                                  # => (n_chains, n_samples, n_dims)


all_positions = run_mcmc(rng_keys, n_samples, 
                     lambda par: jax_my_logpdf(par,xi,yi), 
                     initial_position)

Then once the code is called once one can do

%timeit all_positions = run_mcmc(rng_keys, n_samples, 
                     lambda par: jax_my_logpdf(par,xi,yi), 
                     initial_position)

The timing on CPU without JIT (ie. @partial lines commented) I get 1min 27sec, while with JIT I get 1min 20sec (both results are averaged over 7 runs) Thanks for your advises.

Upvotes: 1

Views: 774

Answers (1)

jakevdp
jakevdp

Reputation: 86513

The reason JIT compilation does not give you any speedup here is because the bulk of your computation is happening within the function you pass to fori_loop, and this is JIT compiled by default, so in a relative sense there's simply not much to gain from JIT-compiling the remaining steps.

As for why your computation takes multiple minutes to execute: you are using a fori_loop with 101,000 steps, and doing a fairly significant amount of work within each step. What you're seeing is simply how long it takes to run your code for the inputs you specified.

Upvotes: 1

Related Questions