ikamen
ikamen

Reputation: 3485

KL Divergence of two torch.distribution.Distribution objects

I'm trying to determine how to compute KL Divergence of two torch.distribution.Distribution objects. I couldn't find a function to do that so far. Here is what I've tried:

import torch as t
from torch import distributions as tdist
import torch.nn.functional as F

def kl_divergence(x: t.distributions.Distribution, y: t.distributions.Distribution):
    """Compute the KL divergence between two distributions."""
    return F.kl_div(x, y)  

a = tdist.Normal(0, 1)
b = tdist.Normal(1, 1)

print(kl_divergence(a, b))  # TypeError: kl_div(): argument 'input' (position 1) must be Tensor, not Normal

Upvotes: 3

Views: 3134

Answers (2)

jodag
jodag

Reputation: 22214

torch.nn.functional.kl_div is computing the KL-divergence loss. The KL-divergence between two distributions can be computed using torch.distributions.kl.kl_divergence.

Upvotes: 3

Bhupen
Bhupen

Reputation: 1320

tdist.Normal(...) will return a normal distribution object, you have to get a sample out of the distribution...

x = a.sample()
y = b.sample()

Upvotes: 0

Related Questions