Getting nan as loss value

I have implemented focal loss in Pytorch with using of this paper. And ran into a problem with loss - got nan as loss function value.

This is implementation of focal loss:

def focal_loss(y_real, y_pred, gamma = 2):
    y_pred = torch.sigmoid(y_pred)
    return -torch.sum((1 - y_pred)**gamma * y_real * torch.log(y_pred) +
                       y_pred**gamma * (1 - y_real) * torch.log(1 - y_pred))

Train loop and my SegNet are working, I think so, because I have tested them with dice and bce losses.

I think errors occurs in backprop. Why can it be? Maybe my implementation is wrong?

Upvotes: 0

Views: 2688

Answers (2)

This version is working:

def focal_loss(y_real, y_pred, eps = 1e-8, gamma = 0):
    probabilities = torch.clamp(torch.sigmoid(y_pred), min=eps, max=1-eps)
    return torch.mean((1 - probabilities)**gamma * 
           (y_pred - y_real * y_pred + torch.log(1 + torch.exp(-y_pred))))

Upvotes: 1

Sandro H
Sandro H

Reputation: 136

This is most likely due to trying to calculate log(0).

I would recommend changing the code like this:

EPS = 1e-9
def focal_loss(y_real, y_pred, gamma = 2):
    y_pred = torch.sigmoid(y_pred)
    y_pred = torch.clamp(y_pred, EPS, 1. - EPS)
    return -torch.sum((1 - y_pred)**gamma * y_real * torch.log(y_pred) +
                       y_pred**gamma * (1 - y_real) * torch.log(1 - y_pred))

Upvotes: 0

Related Questions