Reputation: 144
I am a bit confused about the definition of Multihead.
Are [1] and [2] below the same?
[1]
My understanding about multiplhead is the multiple attention patterns as below.
"multiple sets of Query/Key/Value weight matrices (the Transformer uses eight attention heads, so we end up with eight sets for each encoder/decoder)."
http://jalammar.github.io/illustrated-transformer/
But
[2] in class MultiheadAttention(Module): in Pytorch Transformer module, it seems like embed_dim is DIVIDED by the number of heads.. WHy?
Or... the embed_dim is meant to be the feature dimension times the number of heads in the first place?
self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/activation.py
Upvotes: 2
Views: 1786
Reputation: 7369
As per your understanding, multi-head attention is multiple times attention over some data.
But on contrast, it isn't implemented by multiplying the set of weights into number of required attention. Instead, you rearrange the weight matrices corresponding to the number of attentions, that is reshape to the weight-matrix. So, in essence, it still remains multiple time attention, but you are attending different parts of the weights.
Upvotes: 1