Reputation: 465
While trying out TFP, I tried to sample from the posterior distribution of the conjugate normal model (known variance), that is
x|mu ~ Normal(mu, 1.)
mu ~ Normal(4., 2.)
The tf.mcmc.RandomWalkMetropolis sampler gives different posterior compared to pymc3 and the analytical solution. Note: pymc3 retrieves the correct posterior.
I also tried the HMC sampler in TFP with the same (incorrect) result
!pip install tensorflow==2.0.0-beta0
!pip install tfp-nightly
### IMPORTS
import numpy as np
import pymc3 as pm
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
import matplotlib.pyplot as plt
import seaborn as sns
tf.random.set_seed(1905)
%matplotlib inline
sns.set(rc={'figure.figsize':(9.3,6.1)})
sns.set_context('paper')
sns.set_style('whitegrid')
### CREATE DATA
observed = tfd.Normal(loc=0., scale=1.).sample(20)
sns.distplot(observed, kde=False)
sns.despine();
### MODEL
# prior
mu_0, sigma_0 = 4., 2.
prior = tfd.Normal(mu_0, sigma_0)
# likelihood
mu, sigma = prior.sample(1), 1. # use a sample from the prior as guess for mu
likelihood = tfd.Normal(mu, sigma)
# function to get posterior analytically
def get_param_updates(data, sigma, prior_mu, prior_sigma): #sigma is known
n = len(data)
sigma2 = sigma**2
prior_sigma2 = prior_sigma**2
x_bar = tf.reduce_mean(data)
post_mu = ((sigma2 * prior_mu) + (n * prior_sigma2 * x_bar)) / ((n * prior_sigma2) + (sigma2))
post_sigma2 = (sigma2 * prior_sigma2) / ((n * prior_sigma2) + sigma2)
post_sigma = tf.math.sqrt(post_sigma2)
return post_mu, post_sigma
# posterior
mu_n, sigma_n = get_param_updates(observed,
sigma=1,
prior_mu=mu_0,
prior_sigma=sigma_0)
posterior = tfd.Normal(mu_n, sigma_n, name='posterior')
### PyMC3
# define model
with pm.Model() as model:
mu = pm.Normal('mu', mu=4., sigma=2.)
x = pm.Normal('observed', mu=mu, sigma=1., observed=observed)
trace_pm = pm.sample(10000, tune=500, chains=1)
# plots
sns.distplot(posterior.sample(10**5))
sns.distplot(trace_pm['mu'])
plt.legend(labels=['Analytic Posterior', 'PyMC Posterior']);
### TFP
# definition of the joint_log_prob to evaluate samples
def joint_log_prob(data, proposal):
prior = tfd.Normal(mu_0, sigma_0, name='prior')
likelihood = tfd.Normal(proposal, sigma, name='likelihood')
return (prior.log_prob(proposal) + tf.reduce_mean(likelihood.log_prob(data)))
# define a closure on joint_log_prob
def unnormalized_log_posterior(proposal):
return joint_log_prob(data=observed, proposal=proposal)
# define how to propose state
rwm = tfp.mcmc.RandomWalkMetropolis(
target_log_prob_fn=unnormalized_log_posterior
)
# define initial state
initial_state = tf.constant(0., name='initial_state')
# sample trace
trace, kernel_results = tfp.mcmc.sample_chain(
num_results=10**5,
num_burnin_steps=5000,
current_state=initial_state,
num_steps_between_results=1,
kernel=rwm,
parallel_iterations=1
)
# plots
sns.distplot(posterior.sample(10**5))
sns.distplot(trace_pm['mu'])
sns.distplot(trace)
sns.despine()
plt.legend(labels=['Analytic','PyMC3', 'TFP'])
plt.xlim(-5, 7);
I expected the same results from tfp, pymc3, and the analytical solution (pymc3 finds the correct posterior).
Upvotes: 0
Views: 328
Reputation: 1383
Random walk is not a great sampler for this kind of problem. You might need a crazy number of samples before it gets close to the true posterior.
PyMC uses NUTS -- a kind of adaptive Hamiltonian Monte Carlo method. TFP supports HMC (tfp.mcmc.HamiltonianMonteCarlo); you should be able to drop it in place of RWM (but you might have to tune the step size and leapfrog steps parameters (this is the thing that NUTS does adaptively for you). That alone should get you closer to consistent results.
Upvotes: 1