Ryan
Ryan

Reputation: 10139

Implementation of Focal loss for multi label classification

trying to write focal loss for multi-label classification

class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=0.25):
        self._gamma = gamma
        self._alpha = alpha

    def forward(self, y_true, y_pred):
        cross_entropy_loss = torch.nn.BCELoss(y_true, y_pred)
        p_t = ((y_true * y_pred) +
               ((1 - y_true) * (1 - y_pred)))
        modulating_factor = 1.0
        if self._gamma:
            modulating_factor = torch.pow(1.0 - p_t, self._gamma)
        alpha_weight_factor = 1.0
        if self._alpha is not None:
            alpha_weight_factor = (y_true * self._alpha +
                                   (1 - y_true) * (1 - self._alpha))
        focal_cross_entropy_loss = (modulating_factor * alpha_weight_factor *
                                    cross_entropy_loss)
        return focal_cross_entropy_loss.mean()

But when i run this i get

  File "train.py", line 82, in <module>
    loss = loss_fn(output, target)
  File "/home/bubbles/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 538, in __call__
    for hook in self._forward_pre_hooks.values():
  File "/home/bubbles/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 591, in __getattr__
    type(self).__name__, name))
AttributeError: 'FocalLoss' object has no attribute '_forward_pre_hooks'

Any suggestions would be really helpful, Thanks in advance.

Upvotes: 7

Views: 4070

Answers (1)

Szymon Maszke
Szymon Maszke

Reputation: 24874

You shouldn't inherit from torch.nn.Module as it's designed for modules with learnable parameters (e.g. neural networks).

Just create normal functor or function and you should be fine.

BTW. If you inherit from it, you should call super().__init__() somewhere in your __init__().

EDIT

Actually inheriting from nn.Module might be a good idea, it allows you to use the loss as part of neural network and is common in PyTorch implementations/PyTorch Lightning.

Upvotes: 6

Related Questions