Reputation: 402
I'm a newbie in Jax and not an Python expert either.
I'm running the jax version '0.2.14' on my mac laptop. Please find below a simple code, which at least for me give some results.
But, as stated in the comment jax_metropolis_sampler
method, I would like to save intermediate results 'positions' but I do not figure out to do it propertly using jax_fori_loop
and I guess doing like I have done is certainly horrible.
I'm pretty sure that someone can give me a better solution which exploit the jax parallelism. FOr the time beeing I have not look at forward/backward differentiation of my MixtureModel_jax.
Thanks in advance
import jax
import jax.numpy as jnp
from functools import partial
class MixtureModel_jax():
def __init__(self, locs, scales, weights, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loc = jnp.array([locs]).T
self.scale = jnp.array([scales]).T
self.weights = jnp.array([weights]).T
norm = jnp.sum(self.weights)
self.weights = self.weights/norm
self.num_distr = len(locs)
def pdf(self, x):
probs = jax.scipy.stats.norm.pdf(x,loc=self.loc, scale=self.scale)
return jnp.dot(self.weights.T,probs).squeeze()
def logpdf(self, x):
log_probs = jax.scipy.stats.norm.logpdf(x,loc=self.loc, scale=self.scale)
return jax.scipy.special.logsumexp(np.log(self.weights) + log_probs, axis=0)
@partial(jax.jit, static_argnums=(1,))
def jax_metropolis_kernel(rng_key, logpdf, position, log_prob):
key, subkey = jax.random.split(rng_key)
"""Moves the chain by one step using the Random Walk Metropolis algorithm."""
move_proposals = jax.random.normal(key, shape=position.shape) * 0.1
proposal = position + move_proposals
proposal_log_prob = logpdf(proposal)
log_uniform = jnp.log(jax.random.uniform(subkey))
do_accept = log_uniform < proposal_log_prob - log_prob
position = jnp.where(do_accept, proposal, position)
log_prob = jnp.where(do_accept, proposal_log_prob, log_prob)
return position, log_prob
@partial(jax.jit, static_argnums=(1, 2))
def jax_metropolis_sampler(rng_key, n_samples, logpdf, initial_position):
"""Generate samples using the Random Walk Metropolis algorithm."""
def mh_update(i, state):
key, position, log_prob = state
_, key = jax.random.split(key)
new_position, new_log_prob = jax_metropolis_kernel(key, logpdf, position, log_prob)
return (key, new_position, new_log_prob)
logp = logpdf(initial_position)
# Just return the last position
# rng_key, position, log_prob = jax.lax.fori_loop(0, n_samples,
# mh_update,
# (rng_key, initial_position, logp))
# return position
# Porposal to save intermediate positions: slow and horrible I guess !
spls = []
state = (rng_key, initial_position, logp)
for i in range(n_samples):
state = mh_update(i, state)
spls.append(state[1])
return spls
mixture_gaussian_model = MixtureModel_jax([0,1.5],[0.5,0.1],[8,2])
n_dim = 1
n_samples = 50
n_chains = 7
rng_key = jax.random.PRNGKey(42)
rng_keys = jax.random.split(rng_key, n_chains)
initial_position = jnp.zeros((n_dim, n_chains))
run_mcmc = jax.vmap(jax_metropolis_sampler,
in_axes=(0, None, None, 1),
out_axes=0)
positions = run_mcmc(rng_keys, n_samples,
mixture_gaussian_modelbda x: mixture_gaussian_model.logpdf(x),
initial_position)
print(len(positions))
print(positions[0].shape)
Upvotes: 1
Views: 1592
Reputation: 402
Here the solution I manage to get after @jakevdp hint
@partial(jax.jit, static_argnums=(1, 2))
def jax_metropolis_sampler(rng_key, n_samples, logpdf, initial_position):
def mh_update_sol2(i, state):
key, positions, log_prob = state
_, key = jax.random.split(key)
new_position, new_log_prob = jax_metropolis_kernel(key, logpdf, positions[i-1], log_prob)
positions=positions.at[i].set(new_position)
return (key, positions, new_log_prob)
logp = logpdf(initial_position)
all_positions = jnp.zeros((n_samples,)+initial_position.shape)
initial_state = (rng_key,all_positions, logp)
rng_key, all_positions, log_prob = jax.lax.fori_loop(1, n_samples,
mh_update_sol2,
initial_state)
return all_positions
n_dim = 1
n_samples = 100_000
n_chains = 100
rng_key = jax.random.PRNGKey(42)
rng_keys = jax.random.split(rng_key, n_chains)
initial_position = jnp.zeros((n_dim, n_chains))
run_mcmc = jax.vmap(jax_metropolis_sampler,
in_axes=(0, None, None, 1),
out_axes=0)
all_positions = run_mcmc(rng_keys, n_samples,
lambda x: mixture_gaussian_model.logpdf(x),
initial_position)
all_positions=all_positions.squeeze()
Then, after you can plot the 100 chains...
x_axis = jnp.arange(-3, 3, 0.001)
for i in range(all_positions.shape[0]):
plt.hist(all_positions[i],bins=50, density=True, histtype='step',label=f"chain [{i}]");
plt.plot(x_axis, mixture_gaussian_model.pdf(x_axis),'r-', lw=5, alpha=0.6, label='true pdf')
plt.legend()
plt.show()
Thanks fro your help.
Upvotes: 0
Reputation: 86443
The best way to do this would be to carry the list of previous positions in the fori_loop
function. Something like this:
def mh_update(i, state):
key, positions, log_prob = state
_, key = jax.random.split(key)
new_position, new_log_prob = jax_metropolis_kernel(key, logpdf, positions[-1], log_prob)
positions = jnp.vstack([positions, new_position])
return (key, positions, new_log_prob)
logp = logpdf(initial_position)
initial_state = (rng_key, initial_position[jnp.newaxis], logp)
rng_key, positions, log_prob = jax.lax.fori_loop(0, n_samples,
mh_update,
initial_state)
return positions
Upvotes: 1