Chao
Chao

Reputation: 99

Weird behaviour of loss function in pytorch

I'm computing a custom cost function that is simply taking the exponential of cross-entropy divided by a parameter \eta. During the first iterations (around 20), the training loss is decreasing, but after that, I get suddenly a nan, which I don't understand why is happening.

The code I'm using is the following:

e_loss = []
eta = 2 #just an example of value of eta I'm using 
criterion = nn.CrossEntropyLoss()
for e in range(epoch):
    train_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        client_model.train()
        optimizer.zero_grad()
        output = client_model(data)
        loss = torch.exp(criterion(output, target)/eta) # this is the line where I input my custom loss function
        loss.backward()
        optimizer.step()
        train_loss += loss.item()*data.size(0)
    train_loss = train_loss/len(train_loader) # average losses
    e_loss.append(train_loss)

Upvotes: 1

Views: 1075

Answers (1)

jodag
jodag

Reputation: 22184

Directly using exp is quite unstable when the input is unbounded. Cross-entropy loss can return very large values if the network predicts very confidently the wrong class (b/c -log(x) goes to inf as x goes to 0). A single inaccurate prediction by your model like this can result in numerical precision that would cause gradients to go to nan, which will immediately cause weights and outputs of your model to become nan.

For example

>>> import torch
>>> import torch.nn.functional as F
>>> torch.exp(F.cross_entropy(torch.tensor([[-50.0, 50.0]]), torch.tensor([0])))
tensor(inf)

Upvotes: 1

Related Questions