Joe Black
Joe Black

Reputation: 653

Question about tokens used in Transformer decoder attention layers during Inference

I was looking at the shapes used during decoder (both self-attention and enc-dec-attention blocks) and understand there is a difference in the way decoder runs during training versus during inference based on this link and the original Attention paper

In Inference, it uses all previous tokens generated until that time step (say kth time-step), as shown in the diagram below and explained at this link.

Issue:

However when I look at actual shapes of the QKV projection in the decoder self-attention, and feeding of the decoder self-attention output to the "enc-dec-attention"'s Q matrix, I see only 1 token from the output being used.

I'm very confused how the shapes for all matrices in the Decoder's self-attention and enc-dec-attention can match up with variable length of input to the decoder during inference. I looked at several online material but couldn't find answer. I see only the BGemms in the decoder's self-attention (not enc-dec-attention) using the variable shapes until all previous k steps, but all other Gemms are fixed size.

Another diagram that shows self-attention and enc-dec-attention within decoder:

enter image description here

Upvotes: 1

Views: 1728

Answers (1)

Arij Aladel
Arij Aladel

Reputation: 555

  1. This is possible because usually in transformer you have previous keys-values pairs which are active just in the inference stage. These previous keys and values are added to this one token that was generated and then it is passed into the embedding layer and updated at every generation step to form the final keys-values that the new token is depending on to generate the next token, in this way you have attention to the current token and all previously generated tokens. Then the previous keys-values are updated again to be used in the next generation step, to understand it is better to track the process of inference token by token I did that before. oh wait what for the first generation steps then we feedforward none! what is the previous keys-values? they are also none! so the first token will completely will be generated depdending on the encoded input from the encoder, to make it imaginable I have drawn a small diagram hope it will help.enter image description here
  2. For shapes the input shape for the decoder is fixed as we see from the above diagram it is always one(the last generated token from the decoder output!) note please that the shape of attention output is always the same as query shape i.e. it will be to better understand I will give T5 from hugging face as an example ,this condition explain what I am talking about when first projectibg the keys and values to generate first token and for greedy search you can see here how for generation they call the whole transformer model to generate just the next token , here they concat the new generated token to the previous generated tokens to check the stop condition which is later either generation the token or getting to the max length of generated tokens.

Hope this answers your question it was all from my previous try to understand the inference process.

Upvotes: 0

Related Questions