Soul
Soul

Reputation: 11

Pytorch - nn.CrossEntropyLoss

I want to apply nn.CrossEntropyLoss

output = model(data)
output.shape -> 1,30,7 (batch, frame, class)
label.shape -> 1,30 (batch, frame)

In this case,

label = label.squeeze(0)
output = output.squeeze(0)
criterion = nn.CrossEntropyLoss()
loss = criterion(outputs, targets)

can I solve this?

But if the batch size is 2

output.shape is 2,30,7 and label.shape is 2, 30

How to apply loss = criterion(outputs, targets)

Upvotes: 1

Views: 123

Answers (1)

Shai
Shai

Reputation: 114926

The loss function nn.CrossEntropyLoss can be applied to multi-dim predictions:
enter image description here

All you need is to make sure your C dimension (7 in your case) is the second:

output = output.transpose(1, 2)  # B,30,7 -> B,7,30
loss = criterion(outputs, targets)

You do not need to change targets at all.

Upvotes: 2

Related Questions