Reputation: 3485
I'm trying to determine how to compute KL Divergence of two torch.distribution.Distribution
objects. I couldn't find a function to do that so far. Here is what I've tried:
import torch as t
from torch import distributions as tdist
import torch.nn.functional as F
def kl_divergence(x: t.distributions.Distribution, y: t.distributions.Distribution):
"""Compute the KL divergence between two distributions."""
return F.kl_div(x, y)
a = tdist.Normal(0, 1)
b = tdist.Normal(1, 1)
print(kl_divergence(a, b)) # TypeError: kl_div(): argument 'input' (position 1) must be Tensor, not Normal
Upvotes: 3
Views: 3134
Reputation: 22214
torch.nn.functional.kl_div
is computing the KL-divergence loss. The KL-divergence between two distributions can be computed using torch.distributions.kl.kl_divergence
.
Upvotes: 3
Reputation: 1320
tdist.Normal(...)
will return a normal distribution object, you have to get a sample out of the distribution...
x = a.sample()
y = b.sample()
Upvotes: 0