Neptuner
Neptuner

Reputation: 321

How to understand masked multi-head attention in transformer

I'm currently studying code of transformer, but I can not understand the masked multi-head of decoder. The paper said that it is to prevent you from seeing the generating word, but I can not unserstand if the words after generating word have not been generated, how can them be seen?

I try to read the code of transformer (link:https://github.com/Kyubyong/transformer). The code achieved mask is shown below. It uses the lower triangular matrix to mask, I can not understand why.

padding_num = -2 ** 32 + 1
diag_vals = tf.ones_like(inputs[0, :, :])  # (T_q, T_k)
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense()  # (T_q, T_k)
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1])  # (N, T_q, T_k)
paddings = tf.ones_like(masks) * padding_num
outputs = tf.where(tf.equal(masks, 0), paddings, inputs)

Upvotes: 32

Views: 35737

Answers (4)

artoby
artoby

Reputation: 1940

I had the very same question after reading the Transformer paper. I found no complete and detailed answer to the question in the Internet so I'll try to explain my understanding of Masked Multi-Head Attention.

The short answer is - we need masking to make the training parallel. And the parallelization is good as it allows the model to train faster. I've also made a video with explanation of this mechanism.

Here's an example explaining the idea. Let's say we train to translate "I love you" to German. The encoder works in parallel mode - it can produce vector representation of the input sequence ("I love you") within a constant number of steps (i.e. the number of steps doesn't depend on the length of the input sequence).

Let's say the encoder produces the numbers 11, 12, 13 as the vector representations of the input sequence. In reality these vectors will be much longer but for simplicity we use the short ones. Also for simplicity we ignore the service tokens, like - beginning of the sequence, - end of the sequence and others.

During the training we know that the translation should be "Ich liebe dich" (we always know the expected output during the training). Let's say the expected vector representations of the "Ich liebe dich" words are 21, 22, 23.

If we make the decoder training in sequential mode, it'll look like the training of the Recurrent Neural Network. The following sequential steps will be performed:

  • Sequential operation #1. Input: 11, 12, 13.
    • Trying to predict 21.
    • The predicted output won't be exactly 21, let's say it'll be 21.1.
  • Sequential operation #2. Input: 11, 12, 13, and also 21.1 as the previous output.
    • Trying to predict 22.
    • The predicted output won't be exactly 22, let's say it'll be 22.3.
  • Sequential operation #3. Input 11, 12, 13, and also 22.3 as the previous output.
    • Trying to predict 23.
    • The predicted output won't be exactly 23, let's say it'll be 23.5.

This means we'll need to make 3 sequential operations (in general case - a sequential operation per each input). Also we'll have an accumulating error on each next iteration. Also we don't use attention as we only look to a single previous output.

As we actually know the expected outputs we can adjust the process and make it parallel. There's no need to wait for the previous step output.

  • Parallel operation #A. Inputs: 11, 12, 13.
    • Trying to predict 21.
  • Parallel operation #B. Inputs: 11, 12, 13, and also 21.
    • Trying to predict 22.
  • Parallel operation #C. Inputs: 11, 12, 13, and also 21, 22.
    • Trying to predict 23.

This algorithm can be executed in parallel and also it doesn't accumulate the error. And this algorithm uses attention (i.e. looks to all previous inputs) thus has more information about the context to consider while making the prediction.

And here is where we need the masking. The training algorithm knows the entire expected output (21, 22, 23). It hides (masks) a part of this known output sequence for each of the parallel operations.

  • When it executes #A - it hides (masks) the entire output.
  • When it executes #B - it hides 2nd and 3rd outputs.
  • When it executes #C - it hides 3rd output.

Masking itself is implemented as the following (from the original paper):

We implement this inside of scaled dot-product attention by masking out (setting to −∞) all values in the input of the softmax which correspond to illegal connections

Note: during the inference (not training) the decoder works in the sequential (not parallel) mode as it doesn't know the output sequence initially. But it's different from RNN approach as Transformer inference still uses self-attention and looks at all previous outputs (but not only the very previous one).

Note 2: I've seen in some materials that masking can be used differently for non-translation applications. For example, for language modeling the masking can be used to hide some words from the input sentence and the model will try to predict them during the training using other, non-masked words (i.e. learn to understand the context).

Upvotes: 68

Mojtaba Nourani
Mojtaba Nourani

Reputation: 11

I believe you were somehow confused by some folks saying that the masked attention is essential for causality. I just wanted to add that causality is important during testing, that's what we all agree. The problem is during training where we input the "target sequence" to the decoder "all at once". Yes, here we need the masked multi-head attention as well and that's where the model learns to generate tokens in a causal way in each time step. Here is the steps:

1- Input Embeddings: The decoder receives the entire target sequence as input embeddings.

2- Linear Projections: For each attention head, the embeddings are projected into queries (Q), keys (K), and values (V).

Q = X * W_Q; K = X * W_k; V = x * W_V

3- Compute Attention Scores: The attention scores are computed using the scaled dot-product of queries and keys.

attn = (Q * K^T)/sqrt(d)

4- Generate mask-- Mask = zeros matrix with all elements above the main diagonal equal to minus-infinity

5- masked_score = attn + mask

Note: This ensures causality. Just draw a 3 by 3 matrix and assume the target sentence is "Tu es belle". Write the same sentence two times, once above the matrix and once at the left of the matrix.

[0.1, -infty, -infty;
 0.1, 0.2, -infty;
 0.1, 0.2, 0.3] 

You see that every word can only have relationship with previous words

6- apply softmax and continue

So in brief, we do have the self attention in training where the target sentence is fed to the decoder at once and yes, it is for causality

Upvotes: 1

nairbv
nairbv

Reputation: 4323

Let's say the text we're training on is "one two three four five"

This is self-supervised training and we're just going to train the model to predict the next word in the sequence. Rather than an encoder-decoder model we'll just use a GPT-style transformer (sometimes called "decoder only" because it's "causal," sometimes called "encoder only" because no cross-attention).

If we're doing generative pre-training we're going to train this model with: input_tokens = [one, two, three, four] and output_tokens [ two, three, four, five ] We shift the tokens such that the model is always predicting the next token.

Now, for the output "two," we only want to consider the input "one". When learning to generate the output "three" we only want to consider the inputs "one" and "two." And so on.

Now, after pre-training, we'll feed into the model "Mary had a little" and we expect to get the output "had a little lamb." The output is shifted by one token.

It may seem wasteful to train on the entire sentence in the output. You may be asking yourself, why not just train the model to predict only the next token? why predict words that are in the input, necessitating this causal mask? well, parameters are shared. When the model learns "two" by attending to "one," it's changing model parameters that help generate "five" in attending to "four." Longer sequence lengths end up being equivalent to larger batch sizes, and so this redundant-seeming way of training is actually very data efficient.

Upvotes: 0

zhangjq
zhangjq

Reputation: 188

decoder is a self-regressor and can't see the future words

  1. encoder in transformer is a self-regressor;
  2. which means it will predict the next token according to the previous;
  3. so input x can't see the future words;
  4. we use masked multi-head attention to do this.

Upvotes: 0

Related Questions