mcExchange
mcExchange

Reputation: 6498

Example CrossEntropyLoss for 3D semantic segmentation in pytorch

I have a network performing 3D convolutions on a 5D input tensor. The output of my network if of size (1, 12, 60, 36, 60) corresponding to ( BatchSize, NumClasses, x-dim, y-dim, z-dim). I need to compute a voxel-wise cross entropy loss. However I keep on getting errors.

When trying to compute cross entropy loss using torch.nn.CrossEntropyLoss(), I keep on getting the following error message:

RuntimeError: multi-target not supported at .../src/THCUNN/generic/ClassNLLCriterion.cu:16

here is the extract of my code:

import torch
import torch.nn as nn
from torch.autograd import Variable
criterion = torch.nn.CrossEntropyLoss()
images = Variable(torch.randn(1, 12, 60, 36, 60)).cuda()
labels = Variable(torch.zeros(1, 12, 60, 36, 60).random_(2)).long().cuda()
loss = criterion(images.view(1,-1), labels.view(1,-1))

Same happens when I create a one-hot tensor for the labels:

nclasses = 12
labels = (np.random.randint(0,12,(1,60,36,60))) # Random labels with values between [0..11]
labels = (np.arange(nclasses) == labels[..., None] - 1).astype(int) # Converts labels to one_hot_tensor
a = np.transpose(labels,(0,4,3,2,1)) #  Reorder dimensions to match shape of "images" ([1, 12, 60, 36, 60])
b = Variable(torch.from_numpy(a)).cuda()
loss = criterion(images.view(1,-1), b.view(1,-1))

Any idea what I'm doing wrong? Can someone provide an example of computing cross entropy on a 5D output tensor?

Upvotes: 2

Views: 2988

Answers (2)

mcExchange
mcExchange

Reputation: 6498

Just checked some implementation (fcn) for 2D semantic segmentation, and tried to adapt it to 3D semantic segmentation. No guarantee that this is correct, I'll have to double check...

 import torch
 import torch.nn.functional as F
 def cross_entropy3d(input, target, weight=None, size_average=True):
    # input: (n, c, h, w, z), target: (n, h, w, z)
    n, c, h, w , z = input.size()
    # log_p: (n, c, h, w, z)
    log_p = F.log_softmax(input, dim=1)
    # log_p: (n*h*w*z, c)
    log_p = log_p.permute(0, 4, 3, 2, 1).contiguous().view(-1, c) # make class dimension last dimension
    log_p = log_p[target.view(n, h, w, z, 1).repeat(1, 1, 1, 1, c) >= 0] # this looks wrong -> Should rather be a one-hot vector
    log_p = log_p.view(-1, c)
    # target: (n*h*w*z,)
    mask = target >= 0
    target = target[mask]
    loss = F.nll_loss(log_p, target.view(-1), weight=weight, size_average=False)
    if size_average:
        loss /= mask.data.sum()
    return loss
images = Variable(torch.randn(5, 3, 16, 16, 16))
labels = Variable(torch.LongTensor(5, 16, 16, 16).random_(3))
cross_entropy3d(images, labels, weight=None, size_average=True)

Upvotes: 1

cleros
cleros

Reputation: 4343

The docs explain this behavior (bottom line, it looks like it's actually computing the sparse Cross Entropy Loss, thereby not requiring targets for all dimensions of the output, but only the index of the required one) ... they specifically state:

Input: (N,C), where C = number of classes
Target: (N), where each value is 0 <= targets[i] <= C-1
Output: scalar. If reduce is False, then (N) instead.

I'm not sure about your use-case, but you might want to use the KL Divergence or the Binary Cross Entropy Loss instead. Both are defined over inputs and targets of equal size.

Upvotes: 0

Related Questions