Reputation: 2106
I am learning the Transformer. Here is the pytorch document for MultiheadAttention. In their implementation, I saw there is a constraint:
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
Why require the constraint: embed_dim must be divisible by num_heads?
If we go back to the equation
Assume:
Q
, K
,V
are n x emded_dim
matrices; all the weight matrices W
is emded_dim x head_dim
,
Then, the concat [head_i, ..., head_h]
will be a n x (num_heads*head_dim)
matrix;
W^O
with size (num_heads*head_dim) x embed_dim
[head_i, ..., head_h] * W^O
will become a n x embed_dim
output
I don't know why we require embed_dim must be divisible by num_heads
.
Let say we have num_heads=10000
, the resuts are the same, since the matrix-matrix product will absort this information.
Upvotes: 16
Views: 6475
Reputation: 1082
From what I understood, it is a simplification they have added to keep things simple. Theoretically, we can implement the model like you proposed (similar to the original paper). In pytorch documention, they have briefly mentioned it.
Note that `embed_dim` will be split across `num_heads` (i.e. each head will have dimension `embed_dim` // `num_heads`)
Also, if you see the Pytorch implementation, you can see it is a bit different (optimised in my point of view) when comparing to the originally proposed model. For example, they use MatMul
instead of Linear
and Concat
layer is ignored. Refer the below which shows the first encoder (with Btach size 32, 10 words, 512 features).
P.s: If you need to see the model params (like the above image), this is the code I used.
import torch
transformer_model = torch.nn.Transformer(d_model=512, nhead=8, num_encoder_layers=1,num_decoder_layers=1,dim_feedforward=11) # change params as necessary
tgt = torch.rand((20, 32, 512))
src = torch.rand((11, 32, 512))
torch.onnx.export(transformer_model, (src, tgt), "transformer_model.onnx")
Upvotes: 5
Reputation: 3496
When you have a sequence of seq_len x emb_dim
(ie. 20 x 8
) and you want to use num_heads=2
, the sequence will be split along the emb_dim
dimension. Therefore you get two 20 x 4
sequences. You want every head to have the same shape and if emb_dim
isn't divisible by num_heads
this wont work. Take for example a sequence 20 x 9
and again num_heads=2
. Then you would get 20 x 4
and 20 x 5
which are not the same dimension.
Upvotes: 1