Reputation: 407
I am trying to calculate cross entropy loss using pytorch's BCELoss Function for a binary classification problem. While tinkering I found this weird behaviour.
from torch import nn
sigmoid = nn.Sigmoid()
loss = nn.BCELoss(reduction="sum")
target = torch.tensor([0., 1.])
input1 = torch.tensor([1., 1.], requires_grad=True)
input2 = sigmoid(torch.tensor([10., 10.], requires_grad=True))
print(input2) #tensor([1.0000, 1.0000], grad_fn=<SigmoidBackward>)
print(loss(input1, target)) #tensor(100., grad_fn=<BinaryCrossEntropyBackward>)
print(loss(input2, target)) #tensor(9.9996, grad_fn=<BinaryCrossEntropyBackward>)
Since both input1 and input2 have same value, shouldn't it return the same loss value instead of 100 and 9.9996. The correct loss value should be 100 since I am multiplying log(0) ~-infinity which is capped at -100 in pytorch. https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
What is going on here and where am I going wrong?
Upvotes: 2
Views: 1306
Reputation: 8981
sigmoid(10)
is not exactly equal to 1:
>>> 1 / (1 + torch.exp(-torch.tensor(10.))).item()
0.9999545833234493
In your case:
>>> sigmoid(torch.tensor([10., 10.], requires_grad=True)).tolist()
[0.9999545812606812, 0.9999545812606812]
Thus input1
is not the same as input2
: [1.0, 1.0]
vs [0.9999545812606812, 0.9999545812606812]
,
Let's compute BCE manually:
def bce(x, y):
return - (y * torch.log(x) + (1 - y) * torch.log(1 - x)).item()
# input1
x1 = torch.tensor(1.)
x2 = torch.tensor(1.)
y1 = torch.tensor(0.)
y2 = torch.tensor(1.)
print("input1:", sum([bce(x1, y1), bce(x2, y2)]))
# input2
x1 = torch.tensor(0.9999545812606812)
x2 = torch.tensor(0.9999545812606812)
y1 = torch.tensor(0.)
y2 = torch.tensor(1.)
print("input2:", sum([bce(x1, y1), bce(x2, y2)]))
input1: nan
input2: 9.999631525119185
For input1
we get nan
, but according to docs:
Our solution is that BCELoss clamps its log function outputs to be greater than or equal to -100. This way, we can always have a finite loss value and a linear backward method.
That's why we have 100
in a final pytorch
's BCE
output.
Upvotes: 3