youngdev
youngdev

Reputation: 607

How to handle class imbalance in multi-label classification using pytorch

We are attempting to implement multi-label classification using CNN in pytorch. We have 8 labels and around 260 images using a 90/10 split for train/validation sets.

The classes are highly imbalanced with the most frequent class occurring in over 140 images. On the other hand, the least frequent class occurs in less than 5 images.

We attempted BCEWithLogitsLoss function initially that led to the model predicting the same label for all images.

We then implemented a focal loss approach to handle class imbalance as follows:

    import torch.nn as nn
    import torch

    class FocalLoss(nn.Module):
        def __init__(self, alpha=1, gamma=2):
            super(FocalLoss, self).__init__()
            self.alpha = alpha
            self.gamma = gamma

        def forward(self, outputs, targets):
            bce_criterion = nn.BCEWithLogitsLoss()
            bce_loss = bce_criterion(outputs, targets)
            pt = torch.exp(-bce_loss)
            focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
            return focal_loss 

This resulted in the model predicting empty sets (no labels) for every image since it could not get a greater than 0.5 confidence for any classes.

Is there a approach in pytorch to help address this situation?

Upvotes: 2

Views: 3640

Answers (1)

Karl
Karl

Reputation: 5373

There's basically three ways of dealing with this.

  1. Discard data from the more common class

  2. Weight minority class loss values more heavily

  3. Oversample the minority class

Option 1 is implemented by selecting the files you include in your Dataset.

Option 2 is implemented with the pos_weight parameter for BCEWithLogitsLoss

Option 3 is implemented with a custom Sampler passed to your Dataloader

For deep learning, oversampling typically works best.

Upvotes: 7

Related Questions