Reputation: 4162
I am working on Multiclass Classification (4 classes) for Language Task and I am using the BERT model for classification task. I am following this blog post Transfer Learning for NLP: Fine-Tuning BERT for Text Classification. My BERT Fine Tuned model returns nn.LogSoftmax(dim=1)
.
My data is pretty imbalanced so I used sklearn.utils.class_weight.compute_class_weight
to compute weights of the classes and used the weights inside the Loss.
class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels)
weights= torch.tensor(class_weights,dtype=torch.float)
cross_entropy = nn.NLLLoss(weight=weights)
My results were not so good so I thought of Experementing with Focal Loss
and have a code for Focal Loss.
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.logits = logits
self.reduce = reduce
def forward(self, inputs, targets):
BCE_loss = nn.CrossEntropyLoss()(inputs, targets)
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
if self.reduce:
return torch.mean(F_loss)
else:
return F_loss
I have 3 questions now. First and the Most important is
Focal Loss
, can I use weights
parameters inside nn.CrossEntropyLoss()
Upvotes: 14
Views: 19418
Reputation: 11
I try to implement it based on a weight computed by compute_class_weight by sklearn. And I think my code could extend to multiclass by changing F.nll_loss to entropy loss.
class WeightedFocalLoss(nn.Module):
def __init__(self, alpha, gamma=2):
super(WeightedFocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs, targets):
BCE_loss = F.nll_loss(inputs, targets, reduction='none')
targets = targets.type(torch.long)
# at = self.alpha.gather(0, targets.data.view(-1))
pt = torch.exp(-BCE_loss)
F_loss = self.alpha[targets]*(1-pt)**self.gamma * BCE_loss
loss_weighted_manual = F_loss.sum() / self.alpha[targets].sum()
return loss_weighted_manual
Upvotes: 0
Reputation: 388
I think OP would've gotten his answer by now. I am writing this for other people who might ponder upon this.
There in one problem in OPs implementation of Focal Loss:
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
In this line, the same alpha
value is multiplied with every class output probability i.e. (pt
). Additionally, code doesn't show how we get pt
. A very good implementation of Focal Loss could be find in What is Focal Loss and when should you use it. But this implementation is only for binary classification as it has alpha
and 1-alpha
for two classes in self.alpha
tensor.
In case of multi-class classification or multi-label classification, self.alpha
tensor should contain number of elements equal to the total number of labels. The values could be inverse label frequency of labels or inverse label normalized frequency (just be cautious with labels which has 0 as frequency).
Upvotes: 5
Reputation: 99
I was searching for this myself and found most implementations way to cumbersome. One can use pytorch's CrossEntropyLoss instead (and use ignore_index) and add the focal term. Keep in mind that class weights need to be applied after getting pt from CE so they must be applied separately rather than in CE as weights=alpha
class FocalLoss(nn.Module):
def __init__(self, alpha=None, gamma=2, ignore_index=-100, reduction='mean'):
super().__init__()
# use standard CE loss without reducion as basis
self.CE = nn.CrossEntropyLoss(reduction='none', ignore_index=ignore_index)
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, input, target):
'''
input (B, N)
target (B)
'''
minus_logpt = self.CE(input, target)
pt = torch.exp(-minus_logpt) # don't forget the minus here
focal_loss = (1-pt)**self.gamma * minus_logpt
# apply class weights
if self.alpha != None:
focal_loss *= self.alpha.gather(0, target)
if self.reduction == 'mean':
focal_loss = focal_loss.mean()
elif self.reduction == 'sum':
focal_loss = focal_loss.sum()
return focal_loss
Upvotes: 3
Reputation: 7140
I think the implementation in your question is wrong. The alpha is the class weight.
In cross entropy the class weight is the alpha_t as shown in the following expression:
you see that it is alpha_t rather than alpha.
and we can see from this popular Pytorch implementation the alpha acts the same way as class weight.
References:
Upvotes: 3
Reputation: 51
You may find answers to your questions as follows:
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
You can also experiment with another focal loss version available
Upvotes: 4