ndrwnaguib
ndrwnaguib

Reputation: 6115

Applying Kullback-Leibler (aka kl divergence) element-wise in Pytorch

I have two tensors named x_t, x_k with follwing shapes NxHxW and KxNxHxW respectively, where K, is the number of autoencoders used to reconstruct x_t (if you have no idea what is this, assume they're K different nets aiming to predict x_t, this probably has nothing to do with the question anyways) N is batch size, H matrix height, W matrix width.

I'm trying to apply Kullback-Leibler divergence algorithm to both tensors (after broadcasting x_t as x_k along the Kth dimension) using Pytorch's nn.functional.kl_div method.

However, it does not seem to be working as I expected. I'm looking to calcualte the kl_div between each observation in x_t and x_k resulting in a tensor of size KxN (i.e., kl_div of each observation for each K autoencoder).

The actual output is a single value if I use the reduction argument, and the same tensor size (i.e., KxNxHxW) if I do not use it.

Has anyone tried something similar?


Reproducible example:

import torch
import torch.nn.functional as F
#                  K   N   H  W
x_t = torch.randn(    10, 5, 5)
x_k = torch.randn( 3, 10, 5, 5)

x_broadcasted = x_t.expand_as(x_k)

loss = F.kl_div(x_t, x_k, reduction="none") # or "batchmean", or there are many options

Upvotes: 2

Views: 3679

Answers (1)

Jatentaki
Jatentaki

Reputation: 13103

It's unclear to me what exactly constitutes a probability distribution in your model. With reduction='none', kl_div, given log(x_n) and y_n, computes kl_div = y_n * (log(y_n) - log(x_n)), which is the "summed" part of the actual Kullback-Leibler divergence. Summation (or, in other words, taking the expectation) is up to you. If your point is that H, W are the two dimensions over which you want to take expectation, it's as simple as

loss = F.kl_div(x_t, x_k, reduction="none").sum(dim=(-1, -2))

Which is of shape [K, N]. If your network output is to be interpreted differently, you need to better specify which are the event dimensions and which are sample dimensions of your distribution.

Upvotes: 3

Related Questions