Mojtaba Komeili
Mojtaba Komeili

Reputation: 537

KL Divergence for two probability distributions in PyTorch

I have two probability distributions. How should I find the KL-divergence between them in PyTorch? The regular cross entropy only accepts integer labels.

Upvotes: 24

Views: 51887

Answers (5)

PumpkinQ
PumpkinQ

Reputation: 91

If you are using the normal distribution, then the following code will directly compare the two distributions themselves:

p = torch.distributions.normal.Normal(p_mu, p_std)
q = torch.distributions.normal.Normal(q_mu, q_std)
    
loss = torch.distributions.kl_divergence(p, q)

p and q are two tensor objects.

This code will work and won't give any NotImplementedError.

Upvotes: 2

Gaurav Shrivastava
Gaurav Shrivastava

Reputation: 943

If you have two probability distribution in form of pytorch distribution object. Then you are better off using the function torch.distributions.kl.kl_divergence(p, q). For documentation follow the link

Upvotes: 13

Union find
Union find

Reputation: 8150

If working with Torch distributions

mu = torch.Tensor([0] * 100)
sd = torch.Tensor([1] * 100)

p = torch.distributions.Normal(mu,sd)
q = torch.distributions.Normal(mu,sd)

out = torch.distributions.kl_divergence(p, q).mean()
out.tolist() == 0
True

Upvotes: 1

hantian_pang
hantian_pang

Reputation: 1039

function kl_div is not the same as wiki's explanation.

I use the following:

# this is the same example in wiki
P = torch.Tensor([0.36, 0.48, 0.16])
Q = torch.Tensor([0.333, 0.333, 0.333])

(P * (P / Q).log()).sum()
# tensor(0.0863), 10.2 µs ± 508

F.kl_div(Q.log(), P, None, None, 'sum')
# tensor(0.0863), 14.1 µs ± 408 ns

compare to kl_div, even faster

Upvotes: 19

jdhao
jdhao

Reputation: 28389

Yes, PyTorch has a method named kl_div under torch.nn.functional to directly compute KL-devergence between tensors. Suppose you have tensor a and b of same shape. You can use the following code:

import torch.nn.functional as F
out = F.kl_div(a, b)

For more details, see the above method documentation.

Upvotes: 24

Related Questions