Reputation: 759
There is a famous trick in u-net architecture to use custom weight maps to increase accuracy. Below are the details of it:
Now, by asking here and at multiple other place, I get to know about 2 approaches. I want to know which one is correct or is there any other right approach which is more correct?
First is to use torch.nn.Functional
method in the training loop:
loss = torch.nn.functional.cross_entropy(output, target, w)
where w will be the calculated custom weight.
Second is to use reduction='none'
in the calling of loss function outside the training loop
criterion = torch.nn.CrossEntropy(reduction='none')
and then in the training loop multiplying with the custom weight:
gt # Ground truth, format torch.long
pd # Network output
W # per-element weighting based on the distance map from UNet
loss = criterion(pd, gt)
loss = W*loss # Ensure that weights are scaled appropriately
loss = torch.sum(loss.flatten(start_dim=1), axis=0) # Sums the loss per image
loss = torch.mean(loss) # Average across a batch
Now, I am kinda confused which one is right or is there any other way, or both are right?
Upvotes: 10
Views: 2989
Reputation: 2896
The weighting portion looks like just simply weighted cross entropy which is performed like this for the number of classes (2 in the example below).
weights = torch.FloatTensor([.3, .7])
loss_func = nn.CrossEntropyLoss(weight=weights)
EDIT:
Have you seen this implementation from Patrick Black?
# Set properties
batch_size = 10
out_channels = 2
W = 10
H = 10
# Initialize logits etc. with random
logits = torch.FloatTensor(batch_size, out_channels, H, W).normal_()
target = torch.LongTensor(batch_size, H, W).random_(0, out_channels)
weights = torch.FloatTensor(batch_size, 1, H, W).random_(1, 3)
# Calculate log probabilities
logp = F.log_softmax(logits)
# Gather log probabilities with respect to target
logp = logp.gather(1, target.view(batch_size, 1, H, W))
# Multiply with weights
weighted_logp = (logp * weights).view(batch_size, -1)
# Rescale so that loss is in approx. same interval
weighted_loss = weighted_logp.sum(1) / weights.view(batch_size, -1).sum(1)
# Average over mini-batch
weighted_loss = -1. * weighted_loss.mean()
Upvotes: 5
Reputation: 54
Note that torch.nn.CrossEntropyLoss() is a class that calls torch.nn.functional. See https://pytorch.org/docs/stable/_modules/torch/nn/modules/loss.html#CrossEntropyLoss
You can use the weights when you define the criteria. Comparing them functionally, both methods are the same.
Now, I do not understand your idea of computing loss inside the training loop in method 1 and outside the training loop in method 2. if you compute loss outside the loop then how will you backpropagate?
Upvotes: 0