NNNNNNN
NNNNNNN

Reputation: 39

Why does all my emission mu of HMM in pyro converge to the same number?

I'm trying to create a Gaussian HMM model in pyro to infer the parameters of a very simple Markov sequence. However, my model fails to infer the parameters and something wired happened during the training process. Using the same sequence, hmmlearn has successfully infer the true parameters.

Full code can be accessed in here:

https://colab.research.google.com/drive/1u_4J-dg9Y1CDLwByJ6FL4oMWMFUVnVNd#scrollTo=ZJ4PzdTUBgJi

My model is modified from the example in here:

https://github.com/pyro-ppl/pyro/blob/dev/examples/hmm.py

I manually created a first order Markov sequence where there are 3 states, the true means are [-10, 0, 10], sigmas are [1,2,1].

Here is my model

def model(observations, num_state):
    
    assert not torch._C._get_tracing_state()

    with poutine.mask(mask = True):
  
        p_transition = pyro.sample("p_transition",
                                   dist.Dirichlet((1 / num_state) * torch.ones(num_state, num_state)).to_event(1))
        
        p_init = pyro.sample("p_init",
                             dist.Dirichlet((1 / num_state) * torch.ones(num_state)))
        
    p_mu = pyro.param(name = "p_mu",
                      init_tensor = torch.randn(num_state),
                      constraint = constraints.real)

    p_tau = pyro.param(name = "p_tau",
                       init_tensor = torch.ones(num_state),
                       constraint = constraints.positive)


    current_state = pyro.sample("x_0",
                                dist.Categorical(p_init),
                                infer = {"enumerate" : "parallel"})
    

    for t in pyro.markov(range(1, len(observations))):

        current_state = pyro.sample("x_{}".format(t),
                                    dist.Categorical(Vindex(p_transition)[current_state, :]),
                                    infer = {"enumerate" : "parallel"})
        

        pyro.sample("y_{}".format(t),
                    dist.Normal(Vindex(p_mu)[current_state], Vindex(p_tau)[current_state]),
                    obs = observations[t])

My model is compiled as

device = torch.device("cuda:0")
obs = torch.tensor(obs)
obs = obs.to(device)

torch.set_default_tensor_type("torch.cuda.FloatTensor")

guide = AutoDelta(poutine.block(model, expose_fn = lambda msg : msg["name"].startswith("p_")))

Elbo = Trace_ELBO
elbo = Elbo(max_plate_nesting = 1)


optim = Adam({"lr": 0.001})
svi = SVI(model, guide, optim, elbo)

As the training goes, the ELBO has decreased steadily as shown. However, the three means of the states converges. ELBO and means

I have tried to put the for loop of my model into a pyro.plate and switch pyro.param to pyro.sample and vice versa, but nothing worked for my model.

Upvotes: -2

Views: 268

Answers (1)

James
James

Reputation: 180

I have not tried this model, but I think it should be possible to solve the problem by modifying the model in the following way: def model(observations, num_state):

assert not torch._C._get_tracing_state()

with poutine.mask(mask = True):

    p_transition = pyro.sample("p_transition",
                            dist.Dirichlet((1 / num_state) * torch.ones(num_state, num_state)).to_event(1))

    p_init = pyro.sample("p_init",
                        dist.Dirichlet((1 / num_state) * torch.ones(num_state)))

p_mu = pyro.sample("p_mu",
                dist.Normal(torch.zeros(num_state), torch.ones(num_state)).to_event(1))

p_tau = pyro.sample("p_tau",
                dist.HalfCauchy(torch.zeros(num_state)).to_event(1))

current_state = pyro.sample("x_0",
                            dist.Categorical(p_init),
                            infer = {"enumerate" : "parallel"})


for t in pyro.markov(range(1, len(observations))):

    current_state = pyro.sample("x_{}".format(t),
                                dist.Categorical(Vindex(p_transition)[current_state, :]),
                                infer = {"enumerate" : "parallel"})


    pyro.sample("y_{}".format(t),
                dist.Normal(Vindex(p_mu)[current_state], Vindex(p_tau)[current_state]),
                obs = observations[t])

The model would then be trained using MCMC:

# MCMC
hmc_kernel = NUTS(model, target_accept_prob = 0.9, max_tree_depth = 7)
mcmc = MCMC(hmc_kernel, num_samples = 1000, warmup_steps = 100, num_chains = 1)
mcmc.run(obs)

The results could then be analysed using: mcmc.get_samples()

Upvotes: 0

Related Questions