Reputation: 23
I am trying to create ASR model by myself and learn how to use CTC loss.
I test and I see this:
ctc_loss = nn.CTCLoss(blank=95)
output: tensor([[63, 8, 1, 38, 29, 14, 41, 71, 14, 29, 45, 41, 3]]): torch.Size([1, 13]); output_size: tensor([13])
input1: torch.Size([167, 1, 96]); input1_size: tensor([167])
After applying the argmax
on this input (= prediction of phonems)
torch.argmax(input1, dim=2)
I get a series of symbols:
tensor([[63, 63, 63, 63, 63, 63, 95, 95, 63, 63, 95, 95, 8, 8, 8, 95, 8, 95,
8, 8, 95, 95, 95, 1, 1, 95, 1, 95, 1, 1, 95, 95, 38, 95, 95, 38,
38, 38, 38, 38, 29, 29, 29, 29, 29, 29, 29, 95, 29, 29, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 14, 95, 14, 95, 95, 95, 95, 14, 95, 14, 41, 41,
41, 95, 41, 41, 41, 41, 41, 41, 71, 71, 71, 95, 71, 71, 71, 71, 71, 95,
95, 14, 14, 95, 14, 14, 95, 14, 14, 95, 29, 29, 95, 29, 29, 29, 29, 29,
29, 29, 45, 95, 95, 45, 45, 95, 45, 45, 45, 45, 41, 95, 41, 41, 95, 95,
95, 41, 41, 41, 3, 3, 3, 3, 3, 95, 3, 3, 3, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95]])
and the following loss values.
ctc_loss(input1, output, input_size, output_size)
# Returns 222.8446
With a different input:
input2: torch.Size([167, 1, 96]) input2_size: tensor([167])
torch.argmax(input2, dim=2)
the prediction is just a sequence of blank symbols.
tensor([[95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
95, 95, 95, 95, 95]])
However, the loss value with the same desired output is much lower.
ctc_loss(input2, output, input_size, output_size)
# Returns 3.7955
I don't know why input1
is better than input2
but the loss of input1
is higher than input2
? Can someone explain that?
Upvotes: 0
Views: 2402
Reputation: 11240
The CTC loss does not operate on the argmax predictions but on the entire output distribution. The CTC loss is the sum of the negative log-likelihood of all possible output sequences that produce the desired output. The output symbols might be interleaved with the blank symbols, which leaves exponentially many possibilities. It is, in theory, possible that the sum of the negative log-likelihoods of the correct outputs is low and still the most probable sequence is all blanks.
In practice, this is quite rare, so I guess there might be a problem somewhere else. The CTCLoss
as implemented in PyTorch requires log probabilities as the input that you get, e.g., by applying the log_softmax
function. Different sorts of input might lead to strange results such the one you observe.
Upvotes: 2