Night Walker
Night Walker

Reputation: 21280

Multi label classification with unbalanced labels

I am building multi label classification network. My GTs are vectors of length 512 [0,0,0,1,0,1,0,...,0,0,0,1] Most of the time they are zeroes, each vector has about 5 ones, and rest are zeros .

I am thinking to do:

Use sigmoid for activation for output layer.

Use binary_crossentropy for loss function.

But how I can solve the unbalance issue ? Network can learn to predict always zeros and still have really low learning loss score.

How I can make it actually learn to predict ones...

Upvotes: 1

Views: 1360

Answers (1)

Szymon Maszke
Szymon Maszke

Reputation: 24814

You cannot easily upsample as this is a multilabel case (what I've missed from the post originally).

What you can do is give 1 way higher weights, something like this:

import torch


class BCEWithLogitsLossWeighted(torch.nn.Module):
    def __init__(self, weight, *args, **kwargs):
        super().__init__()
        # Notice none reduction
        self.bce = torch.nn.BCEWithLogitsLoss(*args, **kwargs, reduction="none")
        self.weight = weight

    def forward(self, logits, labels):
        loss = self.bce(logits, labels)
        binary_labels = labels.bool()
        loss[binary_labels] *= labels[binary_labels] * self.weight
        # Or any other reduction
        return torch.mean(loss)


loss = BCEWithLogitsLossWeighted(50)
logits = torch.randn(64, 512)
labels = torch.randint(0, 2, size=(64, 512)).float()

print(loss(logits, labels))

Also you can use FocalLoss to focus on positive examples (there should be some implementations available in some libraries).

EDIT:

Focal Loss can be coded something along those lines also (functional form cause that's what I have in repo, but you should be able to work from that):

def binary_focal_loss(
    outputs: torch.Tensor,
    targets: torch.Tensor,
    gamma: float,
    weight=None,
    pos_weight=None,
    reduction: typing.Callable[[torch.Tensor], torch.Tensor] = None,
) -> torch.Tensor:

    probabilities = (1 - torch.sigmoid(outputs)) ** gamma
    loss = probabilities * torch.nn.functional.binary_cross_entropy_with_logits(
        outputs,
        targets.float(),
        weight,
        reduction="none",
        pos_weight=pos_weight,
    )

    return reduction(loss)

Upvotes: 3

Related Questions