Nigel Ng
Nigel Ng

Reputation: 583

PyMC3: How to use pymc3.traceplot() to overlap means on traceplot

There is a function pymc3.traceplot() that plots the traceplots of the sampling process. I see that the function takes an argument lines that takes a dictionary, in which you can pass the means as lines to be plotted. How would you go about doing this?

Upvotes: 0

Views: 2243

Answers (1)

aloctavodia
aloctavodia

Reputation: 2070

You can pass any value you want not only the mean.

theta_val = 0.35
pm.traceplot(trace, lines={'theta':theta_val})

enter image description here

theta is the name of the variable in the model and theta_val is the value you want to plot (overlap).

You can compute the mean from the trace by doing:

trace['theta'].mean()

or you can also do something like:

lines = {var:trace[var].mean() for var in trace.varnames}

Upvotes: 1

Related Questions