Louis
Louis

Reputation: 161

Tensorflow Probability VI: Discrete + Continuous RVs inference: gradient estimation?

See this tensorflow-probability issue

tensorflow==2.7.0
tensorflow-probability==0.14.1

TLDR

To perform VI on discrete RVs, should I use:

and how to implement it ?

Problem statement

Sorry in advance for the long issue, but I believe the problem requires some explaining.

I want to implement a Hierarchical Bayesian Model involving both continuous and discrete Random Variables. A minimal example is a Gaussian Mixture model:

import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions
tfb = tfp.bijectors

G = 2

p = tfd.JointDistributionNamed(
    model=dict(
        mu=tfd.Sample(
            tfd.Normal(0., 1.),
            sample_shape=(G,)
        ),
        z=tfd.Categorical(
            probs=tf.ones((G,)) / G
        ),
        x=lambda mu, z: tfd.Normal(
            loc=mu[z],
            scale=1.
        )
    )
)

In this example I don't use the tfd.Mixture API on purpose to expose the Categorical label. I want to perform Variational Inference in this context, and for instance given an observed x fit over the posterior of z a Categorical distribution with parametric probabilities:

q_probs = tfp.util.TransformedVariable(
    tf.ones((G,)) / G,
    tfb.SoftmaxCentered(),
    name="q_probs"
)
q_loc = tf.Variable(0., name="q_loc")
q_scale = tfp.util.TransformedVariable(
    1.,
    tfb.Exp(),
    name="q_scale"
)

q = tfd.JointDistributionNamed(
    model=dict(
        mu=tfd.Normal(q_loc, q_scale),
        z=tfd.Categorical(probs=q_probs)
    )
)

The issue is: when computing the ELBO and trying to optimize for the optimal q_probs I cannot use the reparameterization gradient estimators: this is AFAIK because z is a discrete RV:


def log_prob_fn(**kwargs):
    return p.log_prob(
        **kwargs,
        x=tf.constant([2.])
    )


optimizer = tf.optimizers.SGD()

@tf.function
def fit_vi():
    return tfp.vi.fit_surrogate_posterior(
        target_log_prob_fn=log_prob_fn,
        surrogate_posterior=q,
        optimizer=optimizer,
        num_steps=10,
        sample_size=8
    )

_ = fit_vi() 
# This last line raises:
# ValueError: Distribution `surrogate_posterior` must be reparameterized, i.e.,a diffeomorphic transformation
# of a parameterless distribution. (Otherwise this function has a biased gradient.)

I'm looking into a way to make this work. I've identified at least 2 ways to circumvent the issue: using REINFORCE gradient estimator or the Gumbel-Softmax reparameterization.

A- REINFORCE gradient

cf this TFP API link a classical result in VI is that the REINFORCE gradient can deal with a non-differentiable objective function, for instance due to discrete RVs.

I can use a tfp.vi.GradientEstimators.SCORE_FUNCTION estimator instead of the tfp.vi.GradientEstimators.REPARAMETERIZATION one using the lower-level tfp.vi.monte_carlo_variational_loss function? Using the REINFORCE gradient, In only need the log_prob method of q to be differentiable, but the sample method needn't be differentiated.

As far as I understood it, the sample method for a Categorical distribution implies a gradient break, but the log_prob method does not. Am I correct to assume that this could help with my issue? Am I missing something here?

Also I wonder: why is this possibility not exposed in the tfp.vi.fit_surrogate_posterior API ? Is the performance bad, meaning is the variance of the estimator too large for practical purposes ?

B- Gumbel-Softmax reparameterization

cf this TFP API link I could also reparameterize z as a variable y = tfd.RelaxedOneHotCategorical(...) . The issue is: I need to have a proper categorical label to use for the definition of x, so AFAIK I need to do the following:

p_GS = tfd.JointDistributionNamed(
    model=dict(
        mu=tfd.Sample(
            tfd.Normal(0., 1.),
            sample_shape=(G,)
        ),
        y=tfd.RelaxedOneHotCategorical(
            temperature=1.,
            probs=tf.ones((G,)) / G
        ),
        x=lambda mu, y: tfd.Normal(
            loc=mu[tf.argmax(y)],
            scale=1.
        )
    )
)

...but his would just move the gradient breaking problem to tf.argmax. This is where I maybe miss something. Following the Gumbel-Softmax (Jang et al., 2016) paper, I could then use the "STRAIGHT-THROUGH" (ST) strategy and "plug" the gradients of the variable tf.one_hot(tf.argmax(y)) -the "discrete y"- onto y -the "continuous y".

But again I wonder: how to do this properly ? I don't want to mix and match the gradients by hand, and I guess an autodiff backend is precisely meant to avoid me this issue. How could I create a distribution that differentiates the forward direction (sampling a "discrete y") from the backward direction (gradient computed using the "continuous y") ? I guess this is the meant usage of the tfd.RelaxedOneHotCategorical distribution, but I don't see this implemented anywhere in the API.

Should I implement this myself ? How ? Could I use something in the lines of tf.custom_gradient?

Actual question

Which solution -A or B or another- is meant to be used in the TFP API, if any? How should I implement said solution efficiently?

Upvotes: 0

Views: 256

Answers (1)

Louis
Louis

Reputation: 161

So the ides was not to make a Q&A but I looked into this issue for a couple days and here are my conclusions:

  • solution A -REINFORCE- is a possibility, it doesn't introduce any bias, but as far as I understood it it has high variance in its vanilla form -making it prohibitively slow for most real-world tasks. As detailed a bit below, control variates can help tackle the variance issue;
  • solution B, Gumbell-Softmax, exists as well in the API, but I did not find any native way to make it work for hierarchical tasks. Below is my implementation.

First off, we need to reparameterize the joint distribution p as the KL between a discrete and a continuous distribution is ill-defined (as explained in the Maddison et al. (2017) paper). To not break the gradients, I implemented a simple one_hot_straight_through operation that converts the continuous RV y into a discrete RV z:

G = 2

@tf.custom_gradient
def one_hot_straight_through(y):
    depth = y.shape[-1]
    z = tf.one_hot(
        tf.argmax(
            y,
            axis=-1
        ),
        depth=depth
    )

    def grad(upstream):
        return upstream

    return z, grad


p = tfd.JointDistributionNamed(
    model=dict(
        mu=tfd.Sample(
            tfd.Normal(0., 1.),
            sample_shape=(G,)
        ),
        y=tfd.RelaxedOneHotCategorical(
            temperature=1.,
            probs=tf.ones((G,)) / G
        ),
        x=lambda mu, y: tfd.Normal(
            loc=tf.reduce_sum(
                one_hot_straight_through(y)
                * mu
            ),
            scale=1.
        )
    )
)

The variational distribution q follows the same reparameterization and the following code bit does work:

q_probs = tfp.util.TransformedVariable(
    tf.ones((G,)) / G,
    tfb.SoftmaxCentered(),
    name="q_probs"
)
q_loc = tf.Variable(tf.zeros((2,)), name="q_loc")
q_scale = tfp.util.TransformedVariable(
    1.,
    tfb.Exp(),
    name="q_scale"
)

q = tfd.JointDistributionNamed(
    model=dict(
        mu=tfd.Independent(
            tfd.Normal(q_loc, q_scale),
            reinterpreted_batch_ndims=1
        ),
        y=tfd.RelaxedOneHotCategorical(
            temperature=1.,
            probs=q_probs
        )
    )
)


def log_prob_fn(**kwargs):
    return p.log_prob(
        **kwargs,
        x=tf.constant([2.])
    )


optimizer = tf.optimizers.SGD()

@tf.function
def fit_vi():
    return tfp.vi.fit_surrogate_posterior(
        target_log_prob_fn=log_prob_fn,
        surrogate_posterior=q,
        optimizer=optimizer,
        num_steps=10,
        sample_size=8
    )

_ = fit_vi()

Now there are several issues with that design:

  • first off we needed to reparameterize not only q but also p so we "modify our target model". This results in our models p and q not outputing discrete RVs like originally intended but continuous RVs. I think that the introduction of a hard option like in the torch implem could be a nice addition to overcome this issue;
  • second we introduce the burden of setting up the temperature parameter. The latter make the continuous RV y smoothly converge to its discrete counterpart z. An annealing strategy, reducing the temperature to reduce the bias introduced by the relaxation at the cost of a higher variance can be implemented. Or the temperature can be learned online, akin to an entropy regularization (see Maddison et al. (2017) and Jang et al. (2017));
  • the gradient obtained with this estimator are biased, which probably can be acceptable for most applications but is an issue in theory.

Recent methods like REBAR (Tucker et al. (2017)) or RELAX (Grathwohl et al. (2018)) can instead obtain unbiased estimators with a lower variance than the original REINFORCE. But they do so at the cost of introducing -learnable- control variates with separate losses. Modifications of the one_hot_straight_through functions could probably implement this.

In conclusion my opinion is that the tensorflow probability support for discrete RVs optimization is too scarce at the moment and that the API lacks native functions and tutorials to make it easier for the user.

Upvotes: 0

Related Questions