Reputation: 21280
I am building multi label classification network.
My GTs are vectors of length 512
[0,0,0,1,0,1,0,...,0,0,0,1]
Most of the time they are zeroes
, each vector has about 5 ones
, and rest are zeros .
I am thinking to do:
Use sigmoid
for activation for output layer.
Use binary_crossentropy
for loss function.
But how I can solve the unbalance issue ?
Network can learn to predict always zeros
and still have really low learning loss score.
How I can make it actually learn to predict ones...
Upvotes: 1
Views: 1360
Reputation: 24814
You cannot easily upsample as this is a multilabel case (what I've missed from the post originally).
What you can do is give 1
way higher weights, something like this:
import torch
class BCEWithLogitsLossWeighted(torch.nn.Module):
def __init__(self, weight, *args, **kwargs):
super().__init__()
# Notice none reduction
self.bce = torch.nn.BCEWithLogitsLoss(*args, **kwargs, reduction="none")
self.weight = weight
def forward(self, logits, labels):
loss = self.bce(logits, labels)
binary_labels = labels.bool()
loss[binary_labels] *= labels[binary_labels] * self.weight
# Or any other reduction
return torch.mean(loss)
loss = BCEWithLogitsLossWeighted(50)
logits = torch.randn(64, 512)
labels = torch.randint(0, 2, size=(64, 512)).float()
print(loss(logits, labels))
Also you can use FocalLoss to focus on positive examples (there should be some implementations available in some libraries).
EDIT:
Focal Loss can be coded something along those lines also (functional form cause that's what I have in repo, but you should be able to work from that):
def binary_focal_loss(
outputs: torch.Tensor,
targets: torch.Tensor,
gamma: float,
weight=None,
pos_weight=None,
reduction: typing.Callable[[torch.Tensor], torch.Tensor] = None,
) -> torch.Tensor:
probabilities = (1 - torch.sigmoid(outputs)) ** gamma
loss = probabilities * torch.nn.functional.binary_cross_entropy_with_logits(
outputs,
targets.float(),
weight,
reduction="none",
pos_weight=pos_weight,
)
return reduction(loss)
Upvotes: 3