singularity
singularity

Reputation: 335

PyMC3 sample() function does not accept the "start" value to generate a trace

I am new to PyMC3 and Bayesian inference methods. I have a simple code that tries to infer the value of some decay constant (=1) from the artificial data generated using a truncated exponential distribution:

import numpy as np
from scipy import stats
import matplotlib.pyplot as plt 
import pymc3 as pm
import arviz as az


T = stats.truncexpon(b = 10.)
t = T.rvs(1000)

#Bayesian Inference

with pm.Model() as model: 
    #Define Priors
    lam = pm.Gamma('$\lambda$', alpha=1, beta=1)

    #Define Likelihood
    time = pm.Exponential('time', lam = lam, observed = t)

    #Inference
    trace = pm.sample(20, start = {'lam': 10.}, \
            step=pm.Metropolis(), chains=1, cores=1, \
            progressbar = True)


az.plot_trace(trace)
plt.show()

This code produces a trace like below

enter image description here

I am really confused as to why the starting value of 10. is not accepted by the sampler. The trace above should start at 10. I am using python 3.7 to run the code.

Thank you.

Upvotes: 2

Views: 1044

Answers (1)

merv
merv

Reputation: 76720

Few things going on:

  • when the sampler first starts it has a tuning phase; samples during this phase are discarded by default, but this can be controlled with the discard_tuned_samples argument
  • the keys in the start argument dictionary need to correspond to the name given to the RandomVariable ('$\lambda$') not the Python variable

Incorporating those two, one can try

trace = pm.sample(20, start = {'$\lambda$': 10.},
            step=pm.Metropolis(), chains=1, cores=1,
            discard_tuned_samples=False)

However, the other possible issue is that

  • the starting value isn't guaranteed to be emitted in the first draw; only if the first proposal sample is rejected, which is down to chance.

Fixing the game (setting a random seed), though, we can get glimpse:

trace = pm.sample(20, start = {'$\lambda$': 10.},
            step=pm.Metropolis(), chains=1, cores=1,
            discard_tuned_samples=False, random_seed=1)

...

trace.get_values(varname='$\lambda$')[:10]

# array([10.        ,  5.42397358,  3.19841997,  1.09383329,  1.09383329,
#         1.09383329,  1.09383329,  1.09383329,  1.09383329,  1.09383329])

Upvotes: 3

Related Questions