PinkBanter
PinkBanter

Reputation: 1966

What should be the Query Q, Key K and Value V vectors/matrics in torch.nn.MultiheadAttention?

Following an amazing blog, I implemented my own self-attention module. However, I found PyTorch has already implemented a multi-head attention module. The input to the forward pass of the MultiheadAttention module includes Q (which is query vector) , K (key vector), and V (value vector). It is strange that PyTorch wouldn't just take the input embedding and compute the Q, K, V vectors on the inside. In the self-attention module that I implemented, I compute this Q, K, V vectors from the input embeddings multiplied by the Q, K, V weights. At this point, I am not sure what the Q, K, and V vector inputs that MultiheadAttention module requires. Should they be Q, K, and V weights or vectors and should these be normal vectors, or should these be Parameters?

Upvotes: 5

Views: 3286

Answers (2)

chorus12
chorus12

Reputation: 49

Q, K and V as input to pytorch MultiheadAttention are the same - they all are embeddings. Almost. But why there are 3 parameters in forward pass of MultiheadAttention? That is for Decoder of transformer - in MultiHeadAttention of decoder Q comes from the subsequent decoder and K and V comes from encoder. This is why there are 3 different parameters in forward pass. If you take a look at Figure 1 one in original paper you will see it more clearly.

Upvotes: 1

Ashwiniku918
Ashwiniku918

Reputation: 281

If you look at the implementation of Multihead attention in pytorch. Q,K and V are learned during the training process. In most cases should be smaller then the embedding vectors. So you just need to define their dimension, everything else is taken by the module. You have two choices :

    kdim: total number of features in key. Default: None.
    vdim: total number of features in value. Default: None. 

Query vector has size of your embedding. Note: if kdim and vdim are None, they will be set to embed_dim such that query, key, and value have the same number of features.

For more details, look at the source code : https://pytorch.org/docs/master/_modules/torch/nn/modules/activation.html#MultiheadAttention

Specially this class : class MultiheadAttention(Module):

Upvotes: 1

Related Questions