Prook
Prook

Reputation: 107

Tensorflow, negative KL Divergence

I am working with a Variational Autoencoder Type model and part of my loss function is the KL divergence between a Normal Distribution with mean 0 and variance 1 and another Normal Distribution whose mean and variance are predicted by my model.

I defined the loss in the following way:

def kl_loss(mean, log_sigma):
    normal=tf.contrib.distributions.MultivariateNormalDiag(tf.zeros(mean.get_shape()),
                                                           tf.ones(log_sigma.get_shape()))
    enc_normal = tf.contrib.distributions.MultivariateNormalDiag(mean,
                                                                     tf.exp(log_sigma),
                                                                     validate_args=True,
                                                                     allow_nan_stats=False,
                                                                     name="encoder_normal")
    kl_div = tf.contrib.distributions.kl_divergence(normal,
                                                    enc_normal,
                                                    allow_nan_stats=False,
                                                    name="kl_divergence")
return kl_div

The input are unconstrained vectors of length N with

log_sigma.get_shape() == mean.get_shape()

Now during training I observe a negative KL divergence after a few thousand iterations up to values of -10. Below you can see the Tensorboard training curves:

KL divergence curve

Zoom in of KL divergence curve

Now this seems odd to me as the KL divergence should be positive under certain conditions. I understand that we require "The K-L divergence is only defined if P and Q both sum to 1 and if Q(i) > 0 for any i such that P(i) > 0." (see https://mathoverflow.net/questions/43849/how-to-ensure-the-non-negativity-of-kullback-leibler-divergence-kld-metric-rela) but I don't see how this could be violated in my case. Any help is highly appreciated!

Upvotes: 7

Views: 3352

Answers (1)

aki
aki

Reputation: 11

Faced the same problem. It happened because of float precision used. If you notice the negative values occur close to 0 and is bounded to a small negative value. Adding a small positive value to the loss is a work around.

Upvotes: 1

Related Questions