Reputation: 161
See this tensorflow-probability issue
tensorflow==2.7.0
tensorflow-probability==0.14.1
To perform VI on discrete RVs, should I use:
and how to implement it ?
Sorry in advance for the long issue, but I believe the problem requires some explaining.
I want to implement a Hierarchical Bayesian Model involving both continuous and discrete Random Variables. A minimal example is a Gaussian Mixture model:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
G = 2
p = tfd.JointDistributionNamed(
model=dict(
mu=tfd.Sample(
tfd.Normal(0., 1.),
sample_shape=(G,)
),
z=tfd.Categorical(
probs=tf.ones((G,)) / G
),
x=lambda mu, z: tfd.Normal(
loc=mu[z],
scale=1.
)
)
)
In this example I don't use the tfd.Mixture
API on purpose to expose the Categorical label. I want to perform Variational Inference in this context, and for instance given an observed x
fit over the posterior of z
a Categorical
distribution with parametric probabilities:
q_probs = tfp.util.TransformedVariable(
tf.ones((G,)) / G,
tfb.SoftmaxCentered(),
name="q_probs"
)
q_loc = tf.Variable(0., name="q_loc")
q_scale = tfp.util.TransformedVariable(
1.,
tfb.Exp(),
name="q_scale"
)
q = tfd.JointDistributionNamed(
model=dict(
mu=tfd.Normal(q_loc, q_scale),
z=tfd.Categorical(probs=q_probs)
)
)
The issue is: when computing the ELBO and trying to optimize for the optimal q_probs
I cannot use the reparameterization gradient estimators: this is AFAIK because z
is a discrete RV:
def log_prob_fn(**kwargs):
return p.log_prob(
**kwargs,
x=tf.constant([2.])
)
optimizer = tf.optimizers.SGD()
@tf.function
def fit_vi():
return tfp.vi.fit_surrogate_posterior(
target_log_prob_fn=log_prob_fn,
surrogate_posterior=q,
optimizer=optimizer,
num_steps=10,
sample_size=8
)
_ = fit_vi()
# This last line raises:
# ValueError: Distribution `surrogate_posterior` must be reparameterized, i.e.,a diffeomorphic transformation
# of a parameterless distribution. (Otherwise this function has a biased gradient.)
I'm looking into a way to make this work. I've identified at least 2 ways to circumvent the issue: using REINFORCE gradient estimator or the Gumbel-Softmax reparameterization.
cf this TFP API link a classical result in VI is that the REINFORCE gradient can deal with a non-differentiable objective function, for instance due to discrete RVs.
I can use a tfp.vi.GradientEstimators.SCORE_FUNCTION
estimator instead of the tfp.vi.GradientEstimators.REPARAMETERIZATION
one using the lower-level tfp.vi.monte_carlo_variational_loss
function?
Using the REINFORCE gradient, In only need the log_prob
method of q
to be differentiable, but the sample
method needn't be differentiated.
As far as I understood it, the sample
method for a Categorical
distribution implies a gradient break, but the log_prob
method does not. Am I correct to assume that this could help with my issue? Am I missing something here?
Also I wonder: why is this possibility not exposed in the tfp.vi.fit_surrogate_posterior
API ? Is the performance bad, meaning is the variance of the estimator too large for practical purposes ?
cf this TFP API link I could also reparameterize z
as a variable y = tfd.RelaxedOneHotCategorical(...)
. The issue is: I need to have a proper categorical label to use for the definition of x
, so AFAIK I need to do the following:
p_GS = tfd.JointDistributionNamed(
model=dict(
mu=tfd.Sample(
tfd.Normal(0., 1.),
sample_shape=(G,)
),
y=tfd.RelaxedOneHotCategorical(
temperature=1.,
probs=tf.ones((G,)) / G
),
x=lambda mu, y: tfd.Normal(
loc=mu[tf.argmax(y)],
scale=1.
)
)
)
...but his would just move the gradient breaking problem to tf.argmax
. This is where I maybe miss something. Following the Gumbel-Softmax (Jang et al., 2016) paper, I could then use the "STRAIGHT-THROUGH" (ST) strategy and "plug" the gradients of the variable tf.one_hot(tf.argmax(y))
-the "discrete y"- onto y
-the "continuous y".
But again I wonder: how to do this properly ? I don't want to mix and match the gradients by hand, and I guess an autodiff backend is precisely meant to avoid me this issue. How could I create a distribution that differentiates the forward direction (sampling a "discrete y") from the backward direction (gradient computed using the "continuous y") ? I guess this is the meant usage of the tfd.RelaxedOneHotCategorical
distribution, but I don't see this implemented anywhere in the API.
Should I implement this myself ? How ? Could I use something in the lines of tf.custom_gradient
?
Which solution -A or B or another- is meant to be used in the TFP API, if any? How should I implement said solution efficiently?
Upvotes: 0
Views: 256
Reputation: 161
So the ides was not to make a Q&A but I looked into this issue for a couple days and here are my conclusions:
First off, we need to reparameterize the joint distribution p
as the KL between a discrete and a continuous distribution is ill-defined (as explained in the Maddison et al. (2017)
paper). To not break the gradients, I implemented a simple one_hot_straight_through
operation that converts the continuous RV y
into a discrete RV z
:
G = 2
@tf.custom_gradient
def one_hot_straight_through(y):
depth = y.shape[-1]
z = tf.one_hot(
tf.argmax(
y,
axis=-1
),
depth=depth
)
def grad(upstream):
return upstream
return z, grad
p = tfd.JointDistributionNamed(
model=dict(
mu=tfd.Sample(
tfd.Normal(0., 1.),
sample_shape=(G,)
),
y=tfd.RelaxedOneHotCategorical(
temperature=1.,
probs=tf.ones((G,)) / G
),
x=lambda mu, y: tfd.Normal(
loc=tf.reduce_sum(
one_hot_straight_through(y)
* mu
),
scale=1.
)
)
)
The variational distribution q
follows the same reparameterization and the following code bit does work:
q_probs = tfp.util.TransformedVariable(
tf.ones((G,)) / G,
tfb.SoftmaxCentered(),
name="q_probs"
)
q_loc = tf.Variable(tf.zeros((2,)), name="q_loc")
q_scale = tfp.util.TransformedVariable(
1.,
tfb.Exp(),
name="q_scale"
)
q = tfd.JointDistributionNamed(
model=dict(
mu=tfd.Independent(
tfd.Normal(q_loc, q_scale),
reinterpreted_batch_ndims=1
),
y=tfd.RelaxedOneHotCategorical(
temperature=1.,
probs=q_probs
)
)
)
def log_prob_fn(**kwargs):
return p.log_prob(
**kwargs,
x=tf.constant([2.])
)
optimizer = tf.optimizers.SGD()
@tf.function
def fit_vi():
return tfp.vi.fit_surrogate_posterior(
target_log_prob_fn=log_prob_fn,
surrogate_posterior=q,
optimizer=optimizer,
num_steps=10,
sample_size=8
)
_ = fit_vi()
Now there are several issues with that design:
q
but also p
so we "modify our target model". This results in our models p
and q
not outputing discrete RVs like originally intended but continuous RVs. I think that the introduction of a hard
option like in the torch implem could be a nice addition to overcome this issue;temperature
parameter. The latter make the continuous RV y
smoothly converge to its discrete counterpart z
. An annealing strategy, reducing the temperature
to reduce the bias introduced by the relaxation at the cost of a higher variance can be implemented. Or the temperature
can be learned online, akin to an entropy regularization (see Maddison et al. (2017)
and Jang et al. (2017)
);Recent methods like REBAR (Tucker et al. (2017)
) or RELAX (Grathwohl et al. (2018)
) can instead obtain unbiased estimators with a lower variance than the original REINFORCE. But they do so at the cost of introducing -learnable- control variates with separate losses. Modifications of the one_hot_straight_through
functions could probably implement this.
In conclusion my opinion is that the tensorflow probability support for discrete RVs optimization is too scarce at the moment and that the API lacks native functions and tutorials to make it easier for the user.
Upvotes: 0