Reputation: 9104
I am trying to understand what is the problem with the following code:
import pymc3 as pm
import theano as t
X = t.shared(train_new)
features = list(map(str, range(train_new.shape[1])))
with pm.Model() as logistic_model:
glm = pm.glm.GLM(X, targets, labels=features,
intercept=False, family='binomial')
trace = pm.sample(3000, tune=3000, jobs=-1)
The dataset is by no means big: its shape is (891, 13)
. Here is what I concluded on my own:
theano.shared
because if I remove it the performance is again the same;pymc3.glm.GLM
because when I manually build the model (which is probably simpler than the one in GLM
) the performance is just as terrible:
with pm.Model() as logistic_model:
invlogit = lambda x: 1 / (1 + pm.math.exp(-x))
σ = pm.HalfCauchy('σ', beta=2)
β = pm.Normal('β', 0, sd=σ, shape=X.get_value().shape[1])
π = invlogit(tt.dot(X, β))
likelihood = pm.Bernoulli('likelihood', π, observed=targets)
It starts at around 200 it/s
and the quickly falls to 5 it/s
. After half sampling, it decreases further to around 2 it/s
. This is a serious problem, as the model barely converges with a couple of thousands of samplings. I need to perform many more samples than what this situation currently allows.
This is the log:
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
99%|█████████▊| 5923/6000 [50:00<00:39, 1.97it/s]
I tried with pm.Metropolis()
as step, and it was a bit faster but it didn't converge.
MWE: a file with a minimal working example showing the issue and the data is here: https://gist.github.com/rubik/74ddad91317b4d366d3879e031e03396
Upvotes: 0
Views: 904
Reputation: 1090
A non-centered version of the model should work much better:
β_raw = pm.Normal('β_raw', 0, sd=1, shape=X.get_value().shape[1])
β = pm.Deterministic('β', β_raw * σ)
Usually your first impulse if the effective sample size is small shouldn't be to just increase the number of samples, but to try and play with the parametrization a bit.
Also, you can use tt.nnet.sigmoid
instead of your custom invlogit, that might be faster/more stable.
Upvotes: 4