MerelyLearning
MerelyLearning

Reputation: 63

Different results in computing KL Divergence using Pytorch Distributions vs manually

I noticed the KL-Divergence term KL(Q(x)||P(x)) is computed differently when using

mean(Q(x)*(log Q(x) - log P(x)))

vs

torch.distributions.kl_divergence(Q, P)

where

Q = torch.distributions.Normal(some mean, some sigma)
P = torch.distributions.Normal(0, 1)

and when I plot the KL-divergence losses, I get this two similar but different plots: here

Can anyone point out what is causing this difference?

The full code is below:

import numpy as np
import torch
import torch.distributions as dist
import matplotlib.pyplot as plt

def kl_1(log_qx, log_px):
    """
    inputs: [B, z_dim] torch
    """
    return (log_qx.exp() * (log_qx-log_px)).mean()

# ground-truth (target) P(x)
P = dist.Normal(0, 1)


mus = np.arange(-5, 5, 0.1)
sigma = 1
N = 100
kls = {"1": [], "2": []}
for mu in mus:
    # prediction (current) Q(x)
    Q = dist.Normal(mu, sigma)
    
    # sample from Q
    qx = Q.sample((N,))
        
    # log prob
    log_qx = Q.log_prob(qx)
    log_px = P.log_prob(qx)
    
    # kl 1
    kl1 = kl_1(log_qx, log_px)
    kls['1'].append(kl1.numpy())
    
    # kl 2
    kl2 = dist.kl_divergence(Q, P)
    kls['2'].append(kl2.numpy())
    
plt.figure()
plt.scatter(mus, kls['1'], label="Q*(logQ-logP)")
plt.scatter(mus, kls['2'], label="kl_divergence")
plt.xlabel("mean of Q(x)")
plt.ylabel("computed KL Divergence")
plt.legend()
plt.show()

Upvotes: 0

Views: 315

Answers (1)

Bob
Bob

Reputation: 14654

You have the sample weighted by the probability density if you are computing the expected value from an integral on dx. If you are using a sample from the given distribution then you approximate the expected value as the mean directly, that corresponds to integration on d cq(x) thus d cq(x) = q(x) dx, where cq(x) is the cumulative probability function, and q(x) id the probability density funciton of the variable Q.

import numpy as np
import torch
import torch.distributions as dist
import matplotlib.pyplot as plt

def kl_1(log_qx, log_px):
    """
    inputs: [B, z_dim] torch
    """
    return (log_qx-log_px).mean()

# ground-truth (target) P(x)
P = dist.Normal(0, 1)


mus = np.arange(-5, 5, 0.1)
sigma = 1
N = 100
kls = {"1": [], "2": []}
for mu in mus:
    # prediction (current) Q(x)
    Q = dist.Normal(mu, sigma)
    
    # sample from Q
    qx = Q.sample((N,))
        
    # log prob
    log_qx = Q.log_prob(qx)
    log_px = P.log_prob(qx)
    
    # kl 1
    kl1 = kl_1(log_qx, log_px)
    kls['1'].append(kl1.numpy())
    
    # kl 2
    kl2 = dist.kl_divergence(Q, P)
    kls['2'].append(kl2.numpy())
    
plt.figure()
plt.scatter(mus, kls['1'], label="Q*(logQ-logP)")
plt.scatter(mus, kls['2'], label="kl_divergence")
plt.xlabel("mean of Q(x)")
plt.ylabel("computed KL Divergence")
plt.legend()

Upvotes: 1

Related Questions