user3668129
user3668129

Reputation: 4810

Why VAE model in pytorch doesn’t use torch.nn.KLDivLoss?

One example can be found here: https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py

Why did they implement the part of the KLDivLoss and didn't use torch.nn.KLDivLoss?

Upvotes: 1

Views: 712

Answers (1)

Umang Gupta
Umang Gupta

Reputation: 16440

torch.nn.KLDivLoss is KL divergence between two multinomial distributions and takes the distributions p, q as input. It computes the following:

\sum_{i=0}^{C-1} q[i]\log p[i]/q[i]

However, for VAE, you need KL div between two gaussian distributions. KLDivLoss won't compute this. Instead, this is computed with a closed-form formula.

Upvotes: 2

Related Questions