jason
jason

Reputation: 2106

Why embed dimemsion must be divisible by num of heads in MultiheadAttention?

I am learning the Transformer. Here is the pytorch document for MultiheadAttention. In their implementation, I saw there is a constraint:

 assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

Why require the constraint: embed_dim must be divisible by num_heads? If we go back to the equation

MultiHead(Q,K,V)=Concat(head1​,…,headh​)WOwhereheadi​=Attention(QWiQ​,KWiK​,VWiV​)

Assume: Q, K,V are n x emded_dim matrices; all the weight matrices W is emded_dim x head_dim,

Then, the concat [head_i, ..., head_h] will be a n x (num_heads*head_dim) matrix;

W^O with size (num_heads*head_dim) x embed_dim

[head_i, ..., head_h] * W^O will become a n x embed_dim output

I don't know why we require embed_dim must be divisible by num_heads.

Let say we have num_heads=10000, the resuts are the same, since the matrix-matrix product will absort this information.

Upvotes: 16

Views: 6475

Answers (2)

Wenuka
Wenuka

Reputation: 1082

From what I understood, it is a simplification they have added to keep things simple. Theoretically, we can implement the model like you proposed (similar to the original paper). In pytorch documention, they have briefly mentioned it.

Note that `embed_dim` will be split across `num_heads` (i.e. each head will have dimension `embed_dim` // `num_heads`)

Also, if you see the Pytorch implementation, you can see it is a bit different (optimised in my point of view) when comparing to the originally proposed model. For example, they use MatMul instead of Linear and Concat layer is ignored. Refer the below which shows the first encoder (with Btach size 32, 10 words, 512 features).

enter image description here

P.s: If you need to see the model params (like the above image), this is the code I used.

import torch
transformer_model = torch.nn.Transformer(d_model=512, nhead=8, num_encoder_layers=1,num_decoder_layers=1,dim_feedforward=11)  # change params as necessary
tgt = torch.rand((20, 32, 512))
src = torch.rand((11, 32, 512))
torch.onnx.export(transformer_model, (src, tgt), "transformer_model.onnx")

Upvotes: 5

Theodor Peifer
Theodor Peifer

Reputation: 3496

When you have a sequence of seq_len x emb_dim (ie. 20 x 8) and you want to use num_heads=2, the sequence will be split along the emb_dim dimension. Therefore you get two 20 x 4 sequences. You want every head to have the same shape and if emb_dim isn't divisible by num_heads this wont work. Take for example a sequence 20 x 9 and again num_heads=2. Then you would get 20 x 4 and 20 x 5 which are not the same dimension.

Upvotes: 1

Related Questions