Reputation: 1966
Following an amazing blog, I implemented my own self-attention module. However, I found PyTorch has already implemented a multi-head attention module. The input to the forward pass of the MultiheadAttention
module includes Q
(which is query vector) , K
(key vector), and V
(value vector). It is strange that PyTorch wouldn't just take the input embedding and compute the Q
, K
, V
vectors on the inside. In the self-attention module that I implemented, I compute this Q
, K
, V
vectors from the input embeddings multiplied by the Q
, K
, V
weights. At this point, I am not sure what the Q
, K
, and V
vector inputs that MultiheadAttention
module requires. Should they be Q
, K
, and V
weights or vectors and should these be normal vectors, or should these be Parameters?
Upvotes: 5
Views: 3286
Reputation: 49
Q, K and V as input to pytorch MultiheadAttention are the same - they all are embeddings. Almost. But why there are 3 parameters in forward pass of MultiheadAttention? That is for Decoder of transformer - in MultiHeadAttention of decoder Q comes from the subsequent decoder and K and V comes from encoder. This is why there are 3 different parameters in forward pass. If you take a look at Figure 1 one in original paper you will see it more clearly.
Upvotes: 1
Reputation: 281
If you look at the implementation of Multihead attention in pytorch. Q,K and V are learned during the training process. In most cases should be smaller then the embedding vectors. So you just need to define their dimension, everything else is taken by the module. You have two choices :
kdim: total number of features in key. Default: None.
vdim: total number of features in value. Default: None.
Query vector has size of your embedding. Note: if kdim and vdim are None, they will be set to embed_dim such that query, key, and value have the same number of features.
For more details, look at the source code : https://pytorch.org/docs/master/_modules/torch/nn/modules/activation.html#MultiheadAttention
Specially this class : class MultiheadAttention(Module):
Upvotes: 1