Huu Tuong Tu
Huu Tuong Tu

Reputation: 23

How to use CTC Loss Seq2Seq correctly?

I am trying to create ASR model by myself and learn how to use CTC loss.

I test and I see this:

ctc_loss = nn.CTCLoss(blank=95)
output: tensor([[63,  8,  1, 38, 29, 14, 41, 71, 14, 29, 45, 41, 3]]): torch.Size([1, 13]); output_size: tensor([13])

input1: torch.Size([167, 1, 96]); input1_size: tensor([167])

After applying the argmax on this input (= prediction of phonems)

torch.argmax(input1, dim=2)

I get a series of symbols:

tensor([[63, 63, 63, 63, 63, 63, 95, 95, 63, 63, 95, 95,  8,  8,  8, 95,  8, 95,
           8,  8, 95, 95, 95,  1,  1, 95,  1, 95,  1,  1, 95, 95, 38, 95, 95, 38,
          38, 38, 38, 38, 29, 29, 29, 29, 29, 29, 29, 95, 29, 29, 95, 95, 95, 95,
          95, 95, 95, 95, 95, 95, 14, 95, 14, 95, 95, 95, 95, 14, 95, 14, 41, 41,
          41, 95, 41, 41, 41, 41, 41, 41, 71, 71, 71, 95, 71, 71, 71, 71, 71, 95,
          95, 14, 14, 95, 14, 14, 95, 14, 14, 95, 29, 29, 95, 29, 29, 29, 29, 29,
          29, 29, 45, 95, 95, 45, 45, 95, 45, 45, 45, 45, 41, 95, 41, 41, 95, 95,
          95, 41, 41, 41,  3,  3,  3,  3,  3, 95,  3,  3,  3, 95, 95, 95, 95, 95,
          95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
          95, 95, 95, 95, 95]])

and the following loss values.

ctc_loss(input1, output, input_size, output_size)
# Returns 222.8446

With a different input:

input2: torch.Size([167, 1, 96]) input2_size: tensor([167])
torch.argmax(input2, dim=2)

the prediction is just a sequence of blank symbols.

tensor([[95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
          95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
          95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
          95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
          95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
          95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
          95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
          95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
          95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95,
          95, 95, 95, 95, 95]]) 

However, the loss value with the same desired output is much lower.

ctc_loss(input2, output, input_size, output_size)
# Returns 3.7955

I don't know why input1 is better than input2 but the loss of input1 is higher than input2? Can someone explain that?

Upvotes: 0

Views: 2402

Answers (1)

Jindřich
Jindřich

Reputation: 11240

The CTC loss does not operate on the argmax predictions but on the entire output distribution. The CTC loss is the sum of the negative log-likelihood of all possible output sequences that produce the desired output. The output symbols might be interleaved with the blank symbols, which leaves exponentially many possibilities. It is, in theory, possible that the sum of the negative log-likelihoods of the correct outputs is low and still the most probable sequence is all blanks.

In practice, this is quite rare, so I guess there might be a problem somewhere else. The CTCLoss as implemented in PyTorch requires log probabilities as the input that you get, e.g., by applying the log_softmax function. Different sorts of input might lead to strange results such the one you observe.

Upvotes: 2

Related Questions