carlina
carlina

Reputation: 1

Why does the query in multihead attention not affect the output?

I am experimenting with multihead attention and am trying to understand why the value of the query embeddings in the following code has no impact on the attention output, as well as why the output is repeated across the 1st dimension (indexed from 0) :

import torch
import torch.nn as nn

multihead_attn = nn.MultiheadAttention(embed_dim=2, num_heads=1, batch_first=True)

kv = torch.randn(1, 1, 2)
print("Key / Val")
print(kv)
for _ in range(2):
    q = torch.randn(1, 2, 2)
    print("Query")
    print(q)
    attn_output, _ = multihead_attn(q, kv, kv)
    print("Output")
    print(attn_output)

The output I get is:

Key / Val
tensor([[[ 0.1782, -1.3460]]])
Query
tensor([[[-0.7521,  0.6856],
         [-0.8761, -1.6864]]])
Output
tensor([[[ 0.0517, -0.3687],
         [ 0.0517, -0.3687]]], grad_fn=<TransposeBackward0>)
Query
tensor([[[-0.9609, -1.0166],
         [-1.1555, -1.3593]]])
Output
tensor([[[ 0.0517, -0.3687],
         [ 0.0517, -0.3687]]], grad_fn=<TransposeBackward0>)

Upvotes: 0

Views: 538

Answers (1)

kmkurn
kmkurn

Reputation: 685

why the value of the query embeddings in the following code has no impact on the attention output

This is because your kv tensor represents a source sequence of length 1. Recall that a multihead attention basically performs a weighted average over the value vectors in the source sequence. The weights are computed by applying the softmax operation over the source sequence length dimension. For a batch index b and a source and a target sequence indices src_i and trg_i respectively, the process in code looks roughly like below:

import torch.nn.functional as F

attn_weight = torch.zeros(kv.size(1))
for src_i in range(kv.size(1)):
    attn_weight[src_i] = q[b, trg_i] @ kv[b, src_i]
attn_weight = F.softmax(attn_weight, dim=0)

The value of F.softmax(x, dim) is always a tensor of ones if x.size(dim) == 1 for any tensor x. This is true in your example where kv.size(1) == attn_weight.size(0) == 1. Thus, the value of q doesn't matter here because the attention weight will always be one. Setting the source sequence length to greater than 1 will give a different output. For example,

multihead_attn = nn.MultiheadAttention(embed_dim=2, num_heads=1, batch_first=True)

kv = torch.randn(1, 2, 2)  # dim 1 is now 2
print("Key / Val")
print(kv)
for _ in range(2):
    q = torch.randn(1, 2, 2)
    print("Query")
    print(q)
    attn_output, _ = multihead_attn(q, kv, kv)
    print("Output")
    print(attn_output)

gives

Key / Val
tensor([[[-0.6962,  1.0835],
         [-2.5286,  1.4495]]])
Query
tensor([[[-1.6931,  0.3306],
         [ 0.3503,  0.7533]]])
Output
tensor([[[ 0.1860, -0.2510],
         [ 0.2282, -0.3111]]], grad_fn=<TransposeBackward0>)
Query
tensor([[[ 0.0181, -0.6645],
         [-1.6577,  0.9446]]])
Output
tensor([[[ 0.2368, -0.3233],
         [ 0.1802, -0.2428]]], grad_fn=<TransposeBackward0>)

Notice how the outputs are different.

why the output is repeated across the 1st dimension

Because what I explained above happens twice for each index of the target sequence length, i.e. q.size(1).

Upvotes: 0

Related Questions