Reputation: 321
I'm currently studying code of transformer, but I can not understand the masked multi-head of decoder. The paper said that it is to prevent you from seeing the generating word, but I can not unserstand if the words after generating word have not been generated, how can them be seen?
I try to read the code of transformer (link:https://github.com/Kyubyong/transformer). The code achieved mask is shown below. It uses the lower triangular matrix to mask, I can not understand why.
padding_num = -2 ** 32 + 1
diag_vals = tf.ones_like(inputs[0, :, :]) # (T_q, T_k)
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k)
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1]) # (N, T_q, T_k)
paddings = tf.ones_like(masks) * padding_num
outputs = tf.where(tf.equal(masks, 0), paddings, inputs)
Upvotes: 32
Views: 35737
Reputation: 1940
I had the very same question after reading the Transformer paper. I found no complete and detailed answer to the question in the Internet so I'll try to explain my understanding of Masked Multi-Head Attention.
The short answer is - we need masking to make the training parallel. And the parallelization is good as it allows the model to train faster. I've also made a video with explanation of this mechanism.
Here's an example explaining the idea. Let's say we train to translate "I love you" to German. The encoder works in parallel mode - it can produce vector representation of the input sequence ("I love you") within a constant number of steps (i.e. the number of steps doesn't depend on the length of the input sequence).
Let's say the encoder produces the numbers 11, 12, 13
as the vector representations of the input sequence. In reality these vectors will be much longer but for simplicity we use the short ones. Also for simplicity we ignore the service tokens, like - beginning of the sequence, - end of the sequence and others.
During the training we know that the translation should be "Ich liebe dich" (we always know the expected output during the training). Let's say the expected vector representations of the "Ich liebe dich" words are 21, 22, 23
.
If we make the decoder training in sequential mode, it'll look like the training of the Recurrent Neural Network. The following sequential steps will be performed:
11, 12, 13
.
21
.21
, let's say it'll be 21.1
.11, 12, 13
, and also 21.1
as the previous output.
22
.22
, let's say it'll be 22.3
.11, 12, 13
, and also 22.3
as the previous output.
23
.23
, let's say it'll be 23.5
.This means we'll need to make 3 sequential operations (in general case - a sequential operation per each input). Also we'll have an accumulating error on each next iteration. Also we don't use attention as we only look to a single previous output.
As we actually know the expected outputs we can adjust the process and make it parallel. There's no need to wait for the previous step output.
11, 12, 13
.
21
.11, 12, 13
, and also 21
.
22
.11, 12, 13
, and also 21, 22
.
23
.This algorithm can be executed in parallel and also it doesn't accumulate the error. And this algorithm uses attention (i.e. looks to all previous inputs) thus has more information about the context to consider while making the prediction.
And here is where we need the masking. The training algorithm knows the entire expected output (21, 22, 23
). It hides (masks) a part of this known output sequence for each of the parallel operations.
Masking itself is implemented as the following (from the original paper):
We implement this inside of scaled dot-product attention by masking out (setting to −∞) all values in the input of the softmax which correspond to illegal connections
Note: during the inference (not training) the decoder works in the sequential (not parallel) mode as it doesn't know the output sequence initially. But it's different from RNN approach as Transformer inference still uses self-attention and looks at all previous outputs (but not only the very previous one).
Note 2: I've seen in some materials that masking can be used differently for non-translation applications. For example, for language modeling the masking can be used to hide some words from the input sentence and the model will try to predict them during the training using other, non-masked words (i.e. learn to understand the context).
Upvotes: 68
Reputation: 11
I believe you were somehow confused by some folks saying that the masked attention is essential for causality. I just wanted to add that causality is important during testing, that's what we all agree. The problem is during training where we input the "target sequence" to the decoder "all at once". Yes, here we need the masked multi-head attention as well and that's where the model learns to generate tokens in a causal way in each time step. Here is the steps:
1- Input Embeddings: The decoder receives the entire target sequence as input embeddings.
2- Linear Projections: For each attention head, the embeddings are projected into queries (Q), keys (K), and values (V).
Q = X * W_Q; K = X * W_k; V = x * W_V
3- Compute Attention Scores: The attention scores are computed using the scaled dot-product of queries and keys.
attn = (Q * K^T)/sqrt(d)
4- Generate mask-- Mask = zeros matrix with all elements above the main diagonal equal to minus-infinity
5- masked_score = attn + mask
Note: This ensures causality. Just draw a 3 by 3 matrix and assume the target sentence is "Tu es belle". Write the same sentence two times, once above the matrix and once at the left of the matrix.
[0.1, -infty, -infty;
0.1, 0.2, -infty;
0.1, 0.2, 0.3]
You see that every word can only have relationship with previous words
6- apply softmax and continue
So in brief, we do have the self attention in training where the target sentence is fed to the decoder at once and yes, it is for causality
Upvotes: 1
Reputation: 4323
Let's say the text we're training on is "one two three four five"
This is self-supervised training and we're just going to train the model to predict the next word in the sequence. Rather than an encoder-decoder model we'll just use a GPT-style transformer (sometimes called "decoder only" because it's "causal," sometimes called "encoder only" because no cross-attention).
If we're doing generative pre-training we're going to train this model with:
input_tokens = [one, two, three, four]
and
output_tokens [ two, three, four, five ]
We shift the tokens such that the model is always predicting the next token.
Now, for the output "two," we only want to consider the input "one". When learning to generate the output "three" we only want to consider the inputs "one" and "two." And so on.
Now, after pre-training, we'll feed into the model "Mary had a little" and we expect to get the output "had a little lamb." The output is shifted by one token.
It may seem wasteful to train on the entire sentence in the output. You may be asking yourself, why not just train the model to predict only the next token? why predict words that are in the input, necessitating this causal mask? well, parameters are shared. When the model learns "two" by attending to "one," it's changing model parameters that help generate "five" in attending to "four." Longer sequence lengths end up being equivalent to larger batch sizes, and so this redundant-seeming way of training is actually very data efficient.
Upvotes: 0
Reputation: 188
decoder is a self-regressor and can't see the future words
x
can't see the future words;Upvotes: 0