Niloufar Modir
Niloufar Modir

Reputation: 31

'RuntimeError: Expected object of scalar type Long but got scalar' for torch.nn.CrossEntropyLoss()

I'm using this loss function for xlm-roberta-large-longformer and it gives me this error:

    import torch.nn.functional as f
    from scipy.special import softmax
    
    loss_func = torch.nn.CrossEntropyLoss()
    output = torch.softmax(logits.view(-1,num_labels), dim=0).float()
    target = b_labels.type_as(logits).view(-1,num_labels)
    loss = loss_func(output, target)
    train_loss_set.append(loss.item()) 

when I try

b_labels.type_as(logits).view(-1,num_labels).long()

it tells me

RuntimeError: multi-target not supported at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:15

What should I do?

Upvotes: 0

Views: 58

Answers (1)

Ivan
Ivan

Reputation: 40678

Your target tensor should contain integers corresponding to the correct class labels and should not be a one/multi-hot encoding of the class.

You can extract the class labels from a one-hot encoding format using argmax:

>>> b_labels.argmax(1)

Upvotes: 1

Related Questions