Newton
Newton

Reputation: 498

Computational Complexity of Self-Attention in the Transformer Model

I recently went through the Transformer paper from Google Research describing how self-attention layers could completely replace traditional RNN-based sequence encoding layers for machine translation. In Table 1 of the paper, the authors compare the computational complexities of different sequence encoding layers, and state (later on) that self-attention layers are faster than RNN layers when the sequence length n is smaller than the dimension of the vector representations d.

However, the self-attention layer seems to have an inferior complexity than claimed if my understanding of the computations is correct. Let X be the input to a self-attention layer. Then, X will have shape (n, d) since there are n word-vectors (corresponding to rows) each of dimension d. Computing the output of self-attention requires the following steps (consider single-headed self-attention for simplicity):

  1. Linearly transforming the rows of X to compute the query Q, key K, and value V matrices, each of which has shape (n, d). This is accomplished by post-multiplying X with 3 learned matrices of shape (d, d), amounting to a computational complexity of O(n d^2).
  2. Computing the layer output, specified in Equation 1 of the paper as SoftMax(Q Kt / sqrt(d)) V, where the softmax is computed over each row. Computing Q Kt has complexity O(n^2 d), and post-multiplying the resultant with V has complexity O(n^2 d) as well.

Therefore, the total complexity of the layer is O(n^2 d + n d^2), which is worse than that of a traditional RNN layer. I obtained the same result for multi-headed attention too, on considering the appropriate intermediate representation dimensions (dk, dv) and finally multiplying by the number of heads h.

Why have the authors ignored the cost of computing the Query, Key, and Value matrices while reporting total computational complexity?

I understand that the proposed layer is fully parallelizable across the n positions, but I believe that Table 1 does not take this into account anyway.

Upvotes: 49

Views: 51581

Answers (5)

Julian Chen
Julian Chen

Reputation: 1604

I have the same question, and I think their claim is a little bit misleading. I will explain my understanding.

First, the Q,K,V here is only symbols for distinguish in the formula of attention calculation, not means they have to be different. For vanilla self-attention, they are actually all equal to input X, without any linear projection/transforming.

This can be seen in Figure 2 of the transformer paper, the linear layer is not included in the dot-product attention layer(left), and in multi-head attention, there is addition linear layer before the dot-product attention(right). So their complexity result is for vanilla self-attention, without any linear projection, i.e. Q=K=V=X.

And, I found this slides from one of the author of the transformer paper, you can see clearly, O(n^2 d) is only for the dot-product attention, without the linear projection. While the complexity of multi-head attention is actually O(n^2 d+n d^2).

Also I don't think the argument of @igrinis is correct. Although it didn't require to calculate QKV in original attention paper, the complexity of alignment model(MLP) is actually O(d^2) for each pair of value, so total complexity of attention layer is O(n^2·d^2), even larger than the QKV attention.

Upvotes: 0

pandahop
pandahop

Reputation: 1

one small caveat: the derivation seems correct to me but it'd be good to point out that there are 8 heads (denoted by h=8 in the original paper) for each mult-head attention module, and the attention matrices Q,K,V are shrunk by the same factor h. so the expressions are divided and multiplied by h for the second part (n*n*d) but the first part needs to be divided by h i believe.

Upvotes: 0

igrinis
igrinis

Reputation: 13666

First, you are correct in your complexity calculations. So, what is the source of confusion?

When the original Attention paper was first introduced, it didn't require to calculate Q, V and K matrices, as the values were taken directly from the hidden states of the RNNs, and thus the complexity of Attention layer is O(n^2·d).

Now, to understand what Table 1 contains please keep in mind how most people scan papers: they read title, abstract, then look at figures and tables. Only then if the results were interesting, they read the paper more thoroughly. So, the main idea of the Attention is all you need paper was to replace the RNN layers completely with attention mechanism in seq2seq setting because RNNs were really slow to train. If you look at the Table 1 in this context, you see that it compares RNN, CNN and Attention and highlights the motivation for the paper: using Attention should have been beneficial over RNNs and CNNs. It should have been advantageous in 3 aspects: constant amount of calculation steps, constant amount of operations and lower computational complexity for usual Google setting, where n ~= 100 and d ~= 1000. But as any idea, it hit the hard wall of reality. And in reality in order for that great idea to work, they had to add positional encoding, reformulate the Attention and add multiple heads to it. The result is the Transformer architecture which while has the computational complexity of O(n^2·d + n·d^2) still is much faster then RNN (in a sense of wall clock time), and produces better results.

So the answer for your question is that attention layer the authors refer to in Table 1 is strictly the attention mechanism. It is not the complexity of the Transformer. They are very well aware about the complexity of their model (I quote):

Separable convolutions [6], however, decrease the complexity considerably, to O(k·n·d + n·d^2). Even with k = n, however, the complexity of a separable convolution is equal to the combination of a self-attention layer and a point-wise feed-forward layer, the approach we take in our model.

Upvotes: 23

Yoan B. M.Sc
Yoan B. M.Sc

Reputation: 1505

You cannot compare this to a traditional RNN encoder-decoder, the architecture described in the paper is meant to improve upon the classical Attention Mechanism first established in this paper.

In its initial form, the attention mechanism was relying on a Neural network trained to retrieved the relevant hidden states of the encoder. Instead of relying on a fixed retrieval strategy (for instance: using the last hidden state) you allow the system some control over the process.

enter image description here

There is already a very good post on StackExchange explaining the differences in computational complexity here.

The paper you are describing is "replacing" this Neural Network with a dot product between two array, which less demanding computationally than having to train a Neural Network and relatively more efficient. But its not meant to be more efficient than a regular RRN-based auto-encoder without attention.

How is this any less demanding computationally?

In a traditionnal RNN / LSTM based auto encoder each time step is encoded into a vector h. The decoder usually (again, there's a lot of different architectures but that's the basic one) takes the last vector h as input produce the output sequence.

enter image description here

In this scenario there are no attention mechanism, this is straight forward reading the last encoded state. The problem with this architecture is that as your sequence gets longer, all the relevant information gets squeezed into the last encoded state h(t) and you loose relevant information.

Introducing attention mechanism

As described in the paper above, the original attention mechanism aims at circumventing this limitation by allowing the decoder to access not only the last encoded state but any of the previous encoded state and to combine them in order to improve on the prediction.

For each time steps, a probability vector alpha is computed by a neural network to choose encoded state to retreive :

If we restrict α to be an one-hot vector, this operation becomes the same as retrieving from a set of elements h with index α. With the restriction removed, the attention operation can be thought of as doing "proportional retrieval" according to the probability vector α

I won't copy paste the SE post , there you'll have the explanation on why the method of the dot product is computationally more efficient than the neural network.

Conclusion

The key take away is that you cannot compare this to a traditional RNN encoder-decoder because there are no attention mechanism in such network. It would be like comparing CNN with LSTM layer, these are just different architecture.

Upvotes: -2

Shai
Shai

Reputation: 114926

Strictly speaking, when considering the complexity of only the self-attention block (Fig 2 left, equation 1) the projection of x to q, k and v is not included in the self-attention. The complexities shown in table 1 are only for the very core of self-attention layer and thus are O(n^2 d).

Upvotes: 2

Related Questions