lucky yang
lucky yang

Reputation: 1669

How to compute cross entropy loss for binary classification in Pytorch ?

For binary classification, my output and label is like this

output = [0.7, 0.3, 0.1, 0.9 ... ]
label = [1, 0, 0, 1 ... ]

where the output is the probability for precited label = 1

And I want a cross entropy like this:

def cross_entropy(output, label):
    return sum(-label * log(output) - (1 - label) * log(1 - output))

However, this gives me a NaN error because that in log(output) the output might be zero.

I know there is torch.nn.CrossEntropyLoss however it does not apply for my data format here.

Upvotes: 2

Views: 3742

Answers (2)

jneuendorf
jneuendorf

Reputation: 442

As Leonard2 mentioned in a comment to the question, torch.nn.BCELoss (meaning "Binary Cross Entropy Loss" seems to be exactly what was asked for.

Upvotes: 0

viven
viven

Reputation: 11

import torch
import torch.nn.functional as F
def my_binary_cross_entrophy(output,label):
    label = label.float()
    #print(label)
    loss = 0
    for i in range(len(label)):
        loss += -(label[i]*math.log(output[i])+(1-label[i])*math.log(1-output[i]))
        #print(loss)
    return loss/len(label)

label1 = torch.randint(0,2,(3,)).float()
output = torch.rand(3)
my_binary_cross_entrophy(output,label1)

The value it returned is the same as F.binary_cross_entropy value.

F.binary_cross_entropy(output,label1)

Upvotes: 1

Related Questions