Reputation: 1
I'm using numpyro to try to sample several variables in a model, one of which is the number of iterations of a for loop. I have showed an analogous toy model here.
def model():
mu = 0.
sigma = numpyro.sample("sigma", dist.HalfNormal())
T = numpyro.sample("T", dist.DiscreteUniform(1, 5))
for i in range(1, T):
mu += 1.
logl = numpyro.deterministic("logl", normal_logl(data, mu, sigma))
numpyro.factor("log_likelihood", logl)
The above model raises an error. Replacing the for loop with jax.lax.fori_loop
doesn't help either. Is there a workaround?
Upvotes: 0
Views: 40