Reputation: 2663
Cross posting my question from the PyTorch forum:
I started receiving negative KL divergences between a target Dirichlet distribution and my model’s output Dirichlet distribution. Someone online suggested that this might be indicative that the parameters of the Dirichlet distribution don’t sum to 1. I thought this was ridiculous since the output of the model is passed through
output = F.softmax(self.weights(x), dim=1)
But after looking into it more closely, I found that torch.all(torch.sum(output, dim=1) == 1.)
returns False! Looking at the problematic row, I see that it is tensor([0.0085, 0.9052, 0.0863], grad_fn=<SelectBackward>)
. But torch.sum(output[5]) == 1.
produces tensor(False)
.
What am I misusing about softmax such that output probabilities do not sum to 1?
This is PyTorch version 1.2.0+cpu. Full model is copied below:
import torch
import torch.nn as nn
import torch.nn.functional as F
def assert_no_nan_no_inf(x):
assert not torch.isnan(x).any()
assert not torch.isinf(x).any()
class Network(nn.Module):
def __init__(self):
super().__init__()
self.weights = nn.Linear(
in_features=2,
out_features=3)
def forward(self, x):
output = F.softmax(self.weights(x), dim=1)
assert torch.all(torch.sum(output, dim=1) == 1.)
assert_no_nan_no_inf(x)
return output
Upvotes: 2
Views: 4897
Reputation: 16440
This is most probably due to the floating point numerical errors due to finite precision.
Instead of checking strict inequality you should check the mean square error or something to be within an acceptable limit.
For ex: I get torch.norm(output.sum(dim=1)-1)/N
to be less than 1e-8
. N is the batch size.
Upvotes: 2