sangeetpaul
sangeetpaul

Reputation: 1

How to sample variable number of iterations of a for loop in numpyro?

I'm using numpyro to try to sample several variables in a model, one of which is the number of iterations of a for loop. I have showed an analogous toy model here.

def model():
    mu = 0.
    sigma = numpyro.sample("sigma", dist.HalfNormal())
    T = numpyro.sample("T", dist.DiscreteUniform(1, 5))
    for i in range(1, T):
        mu += 1.
    logl = numpyro.deterministic("logl", normal_logl(data, mu, sigma))
    numpyro.factor("log_likelihood", logl)

The above model raises an error. Replacing the for loop with jax.lax.fori_loop doesn't help either. Is there a workaround?

Upvotes: 0

Views: 40

Answers (0)

Related Questions