Reputation: 653
I was looking at the shapes used during decoder (both self-attention and enc-dec-attention blocks) and understand there is a difference in the way decoder runs during training versus during inference based on this link and the original Attention paper
In Inference, it uses all previous tokens generated until that time step (say k
th time-step), as shown in the diagram below and explained at this link.
Issue:
However when I look at actual shapes of the QKV projection in the decoder self-attention, and feeding of the decoder self-attention output to the "enc-dec-attention"'s Q matrix, I see only 1 token from the output being used.
I'm very confused how the shapes for all matrices in the Decoder's self-attention and enc-dec-attention can match up with variable length of input to the decoder during inference. I looked at several online material but couldn't find answer.
I see only the BGemms in the decoder's self-attention (not enc-dec-attention) using the variable shapes until all previous k
steps, but all other Gemms are fixed size.
Another diagram that shows self-attention and enc-dec-attention within decoder:
Upvotes: 1
Views: 1728
Reputation: 555
Hope this answers your question it was all from my previous try to understand the inference process.
Upvotes: 0