Sam
Sam

Reputation: 207

BCEWithLogitsLoss Multi-label Classification

I'm a bit confused about how to accumulate the batch losses to obtain the epoch loss.

Two questions:

  1. Is #1 (see comments below) correct way to calculate loss with masks)
  2. Is #2 correct way to report epoch loss)
optimizer = torch.optim.Adam(model.parameters, lr=1e-3, weight_decay=1e-5)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)

for epoch in range(10):
    EPOCH_LOSS = 0.
    
    for inputs, gt_labels, masks in training_dataloader:
        optimizer.zero_grad()
        outputs = model(inputs)
        
        #1: Is this the correct way to calculate batch loss? Do I multiply batch_loss with outputs.shape[0[ before adding it to epoch_loss?
        batch_loss = (masks * criterion(outputs, gt_labels.float())).mean()
        EPOCH_LOSS += batch_loss
        loss.backward()
        optimizer.step()

    #2: then what do I do here? Do I divide the EPOCH_LOSS with len(training_dataloader)?
    print(f'EPOCH LOSS: {EPOCH_LOSS/len(training_dataloader)}:.3f')

Upvotes: 1

Views: 1086

Answers (1)

jhso
jhso

Reputation: 3283

In your criterion, you have got the default reduction field set (see the docs), so your masking approach won't work. You should use your masking one step earlier (prior to the loss calculation) like so:

batch_loss = (criterion(outputs*masks, gt_labels.float()*masks)).mean()

OR

batch_loss = (criterion(outputs[masks], gt_labels.float()[masks])).mean()

But, without seeing your data it might be a different format. You might want to check that this is working as expected.

In regards to your actual question, it depends on how you want to represent your data. What I would do is just to sum all of the batches' losses and represent that, but you can choose to divide by the number of batches if you want to represent the AVERAGE loss of each batch in the epoch.

Because this is purely an illustrative property of your model, it actually doesn't matter which one you pick, as long as it's consistent between epochs to represent the fact that your model is learning.

Upvotes: 2

Related Questions