JimSD
JimSD

Reputation: 144

The definition of "heads" in MultiheadAttention in Pytorch Transformer module

I am a bit confused about the definition of Multihead.
Are [1] and [2] below the same?

[1] My understanding about multiplhead is the multiple attention patterns as below.
"multiple sets of Query/Key/Value weight matrices (the Transformer uses eight attention heads, so we end up with eight sets for each encoder/decoder)."
http://jalammar.github.io/illustrated-transformer/

But

[2] in class MultiheadAttention(Module): in Pytorch Transformer module, it seems like embed_dim is DIVIDED by the number of heads.. WHy?

Or... the embed_dim is meant to be the feature dimension times the number of heads in the first place?

self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/activation.py

Upvotes: 2

Views: 1786

Answers (1)

Ashwin Geet D'Sa
Ashwin Geet D'Sa

Reputation: 7369

As per your understanding, multi-head attention is multiple times attention over some data.

But on contrast, it isn't implemented by multiplying the set of weights into number of required attention. Instead, you rearrange the weight matrices corresponding to the number of attentions, that is reshape to the weight-matrix. So, in essence, it still remains multiple time attention, but you are attending different parts of the weights.

Upvotes: 1

Related Questions