Reputation: 607
We are attempting to implement multi-label classification using CNN in pytorch. We have 8 labels and around 260 images using a 90/10 split for train/validation sets.
The classes are highly imbalanced with the most frequent class occurring in over 140 images. On the other hand, the least frequent class occurs in less than 5 images.
We attempted BCEWithLogitsLoss function initially that led to the model predicting the same label for all images.
We then implemented a focal loss approach to handle class imbalance as follows:
import torch.nn as nn
import torch
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, outputs, targets):
bce_criterion = nn.BCEWithLogitsLoss()
bce_loss = bce_criterion(outputs, targets)
pt = torch.exp(-bce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
return focal_loss
This resulted in the model predicting empty sets (no labels) for every image since it could not get a greater than 0.5 confidence for any classes.
Is there a approach in pytorch to help address this situation?
Upvotes: 2
Views: 3640
Reputation: 5373
There's basically three ways of dealing with this.
Discard data from the more common class
Weight minority class loss values more heavily
Oversample the minority class
Option 1 is implemented by selecting the files you include in your Dataset.
Option 2 is implemented with the pos_weight
parameter for BCEWithLogitsLoss
Option 3 is implemented with a custom Sampler
passed to your Dataloader
For deep learning, oversampling typically works best.
Upvotes: 7