Surendra
Surendra

Reputation: 31

Label alignment in RNN Transducer training

I am trying to understand how RNN Transducer is trained with ground truth labels. In case of CTC, I know that model is trained with loss function that sums up all scores of all possible alignments of the ground truth labels.

But in RNN-T, the prediction network has to receive input from the last step to produce output similar to the "teacher-forcing" method. But my doubt here is should the ground truth labels be converted into all possible alignments with blank label and feed each alignment to the network by teacher-forcing" method?

Upvotes: 3

Views: 1516

Answers (1)

Phil
Phil

Reputation: 61

RNN-T has a transcription network (analogous to an acoustic model), a prediction network (language model) and a joint network (/function, depending on implementation) that combines the outputs of the prediction network and the transcription network.

During training, you process each utterance by:

  • Propagating all T acoustic frames through the transcription network and storing the outputs (transcription network hidden states)
  • Propagating the ground truth label sequence, of length U, through the prediction network, passing in an all-zero vector at the beginning of the sequence. Note that you do not need to worry about blank states at this point
  • Propagating all T*U combinations of transcription and prediction network hidden states through the joint network, whether that be a simple sum and exponential as per Graves (2012) or a feed-forward network as per the more recent Google ASR publications (i.e.: He et al. 2019).

The T*U outputs from the joint network can be viewed as a grid, as per Figure 1 of Graves 2012. The loss function can then be efficiently realised using the forward-backward algorithm (Section 2.4, Graves 2012). Only horizontal (consuming acoustic frames) and vertical (consuming labels) transitions are permitted. Stepping from t to t+1 is analogous to the blank state in CTC, whilst non-blank symbols are output when making vertical transitions, i.e. from output label u to u+1. Note that you can consume multiple time frames without outputting a non-blank symbol (as per CTC), but you can also output multiple labels without advancing through t.

To more directly answer your question, note that only non-blank outputs are passed back to the input of the prediction network, and that the transcription and prediction networks are operating asynchronously.

References:

  • Sequence Transduction with Recurrent Neural Networks, Graves 2012
  • Streaming End-to-end Speech Recognition For Mobile Devices, He et al. 2019

Upvotes: 6

Related Questions