Wu Shiauthie
Wu Shiauthie

Reputation: 109

pos_weight in binary cross entropy calculation

When we deal with imbalanced training data (there are more negative samples and less positive samples), usually pos_weight parameter will be used. The expectation of pos_weight is that the model will get higher loss when the positive sample gets the wrong label than the negative sample. When I use the binary_cross_entropy_with_logits function, I found:

bce = torch.nn.functional.binary_cross_entropy_with_logits

pos_weight = torch.FloatTensor([5])
preds_pos_wrong =  torch.FloatTensor([0.5, 1.5])
label_pos = torch.FloatTensor([1, 0])
loss_pos_wrong = bce(preds_pos_wrong, label_pos, pos_weight=pos_weight)

preds_neg_wrong =  torch.FloatTensor([1.5, 0.5])
label_neg = torch.FloatTensor([0, 1])
loss_neg_wrong = bce(preds_neg_wrong, label_neg, pos_weight=pos_weight)

However:

>>> loss_pos_wrong
tensor(2.0359)

>>> loss_neg_wrong
tensor(2.0359)

The losses derived from wrong positive samples and negative samples are the same, so how does pos_weight work in the imbalanced data loss calculation?

Upvotes: 5

Views: 6903

Answers (1)

Ivan
Ivan

Reputation: 40768

TLDR; both losses are identical because you are computing the same quantity: both inputs are identical, the two batch elements and labels are just switched.


Why are you getting the same loss?

I think you got confused in the usage of F.binary_cross_entropy_with_logits (you can find a more detailed documentation page with nn.BCEWithLogitsLoss). In your case your input shape (aka the output of your model) is one-dimensional, which means you only have a single logit x, not two).

In your example you have

preds_pos_wrong = torch.FloatTensor([0.5, 1.5])
label_pos = torch.FloatTensor([1, 0])

This means your batch size is 2, and since by default the function is averaging the losses of the batch elements, you end up with the same result for BCE(preds_pos_wrong, label_pos) and BCE(preds_neg_wrong, label_neg). The two elements of your batch are just switched.

You can verify this very easily by not averaging the loss over the batch-elements with the reduction='none' option:

>>> F.binary_cross_entropy_with_logits(preds_pos_wrong, label_pos, 
       pos_weight=pos_weight, reduction='none')
tensor([2.3704, 1.7014])

>>> F.binary_cross_entropy_with_logits(preds_pos_wrong, label_pos, 
       pos_weight=pos_weight, reduction='none')
tensor([1.7014, 2.3704])

Looking into F.binary_cross_entropy_with_logits:

That being said the formula for the binary cross-entropy is:

bce = -[y*log(sigmoid(x)) + (1-y)*log(1- sigmoid(x))]

Where y (respectively sigmoid(x) is for the positive class associated with that logit, and 1 - y (resp. 1 - sigmoid(x)) is the negative class.

The documentation could be more precise on the weighting scheme for pos_weight (not to be confused with weight, which is the weighting of the different logits output). The idea with pos_weight as you said, is to weigh the positive term, not the whole term.

bce = -[w_p*y*log(sigmoid(x)) + (1-y)*log(1- sigmoid(x))]

Where w_p is the weight for the positive term, to compensate for the positive to negative sample imbalance. In practice, this should be w_p = #negative/#positive.

Therefore:

>>> w_p = torch.FloatTensor([5])
>>> preds = torch.FloatTensor([0.5, 1.5])
>>> label = torch.FloatTensor([1, 0])

With the builtin loss function,

>>> F.binary_cross_entropy_with_logits(preds, label, pos_weight=w_p, reduction='none')
tensor([2.3704, 1.7014])

Compared with the manual computation:

>>> z = torch.sigmoid(preds)
>>> -(w_p*label*torch.log(z) + (1-label)*torch.log(1-z))
tensor([2.3704, 1.7014])

Upvotes: 11

Related Questions