Reputation: 189
I'm reading the Pytorch tutorial of a multi-class classification problem. And I find the behavior of Loss calculation in Pytorch confuses me a lot. Can you help me with this?
The model used for classification goes like this:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
The training process goes as follows:
optimizer.zero_grad()
outputs = net(inputs)
loss = nn.CrossEntropyLoss(outputs, labels)
loss.backward()
optimizer.step()
My question is: What's the exact behavior of Loss calculation in Pytorch here? During each iteration, the input of nn.CrossEntropyLoss() has two parts:
As far as I know, the calculation of cross-entropy usually used between two tensors like:
So based on this assumption, nn.CrossEntropyLoss() here needs to achieve:
May I ask is this what nn.CrossEntropyLoss() does? Or do we need to one-hot encoding the true label before we input into the model?
Thank you a lot for your time in advance!
Upvotes: 3
Views: 3525
Reputation: 33010
nn.CrossEntropyLoss
first applies log-softmax (log(Softmax(x)
) to get log probabilities and then calculates the negative-log likelihood as mentioned in the documentation:
This criterion combines
nn.LogSoftmax()
andnn.NLLLoss()
in one single class.
When using one-hot encoded targets, the cross-entropy can be calculated as follows:
where y is the one-hot encoded target vector and ŷ is the vector of probabilities for each class. To get the probabilities you would apply softmax to the output of the model. The logarithm of the probabilities is used, and PyTorch just combines the logarithm and the softmax into one operation nn.LogSoftmax()
, for numerical stability.
Since all of the values except one in the one-hot vector are zero, only a single term of the sum will be non-zero. Therefore given the actual class, it can be simplified to:
As long as you know the class index, the loss can be calculated directly, making it more efficient than using a one-hot encoded target, hence nn.CrossEntropyLoss
expects the class indices.
The full calculation is given in the documentation of nn.CrossEntropyLoss
:
Upvotes: 7