Reputation: 1
I am experimenting with multihead attention and am trying to understand why the value of the query embeddings in the following code has no impact on the attention output, as well as why the output is repeated across the 1st dimension (indexed from 0) :
import torch
import torch.nn as nn
multihead_attn = nn.MultiheadAttention(embed_dim=2, num_heads=1, batch_first=True)
kv = torch.randn(1, 1, 2)
print("Key / Val")
print(kv)
for _ in range(2):
q = torch.randn(1, 2, 2)
print("Query")
print(q)
attn_output, _ = multihead_attn(q, kv, kv)
print("Output")
print(attn_output)
The output I get is:
Key / Val
tensor([[[ 0.1782, -1.3460]]])
Query
tensor([[[-0.7521, 0.6856],
[-0.8761, -1.6864]]])
Output
tensor([[[ 0.0517, -0.3687],
[ 0.0517, -0.3687]]], grad_fn=<TransposeBackward0>)
Query
tensor([[[-0.9609, -1.0166],
[-1.1555, -1.3593]]])
Output
tensor([[[ 0.0517, -0.3687],
[ 0.0517, -0.3687]]], grad_fn=<TransposeBackward0>)
Upvotes: 0
Views: 538
Reputation: 685
why the value of the query embeddings in the following code has no impact on the attention output
This is because your kv
tensor represents a source sequence of length 1. Recall that a multihead attention basically performs a weighted average over the value vectors in the source sequence. The weights are computed by applying the softmax operation over the source sequence length dimension. For a batch index b
and a source and a target sequence indices src_i
and trg_i
respectively, the process in code looks roughly like below:
import torch.nn.functional as F
attn_weight = torch.zeros(kv.size(1))
for src_i in range(kv.size(1)):
attn_weight[src_i] = q[b, trg_i] @ kv[b, src_i]
attn_weight = F.softmax(attn_weight, dim=0)
The value of F.softmax(x, dim)
is always a tensor of ones if x.size(dim) == 1
for any tensor x
. This is true in your example where kv.size(1) == attn_weight.size(0) == 1
. Thus, the value of q
doesn't matter here because the attention weight will always be one. Setting the source sequence length to greater than 1 will give a different output. For example,
multihead_attn = nn.MultiheadAttention(embed_dim=2, num_heads=1, batch_first=True)
kv = torch.randn(1, 2, 2) # dim 1 is now 2
print("Key / Val")
print(kv)
for _ in range(2):
q = torch.randn(1, 2, 2)
print("Query")
print(q)
attn_output, _ = multihead_attn(q, kv, kv)
print("Output")
print(attn_output)
gives
Key / Val
tensor([[[-0.6962, 1.0835],
[-2.5286, 1.4495]]])
Query
tensor([[[-1.6931, 0.3306],
[ 0.3503, 0.7533]]])
Output
tensor([[[ 0.1860, -0.2510],
[ 0.2282, -0.3111]]], grad_fn=<TransposeBackward0>)
Query
tensor([[[ 0.0181, -0.6645],
[-1.6577, 0.9446]]])
Output
tensor([[[ 0.2368, -0.3233],
[ 0.1802, -0.2428]]], grad_fn=<TransposeBackward0>)
Notice how the outputs are different.
why the output is repeated across the 1st dimension
Because what I explained above happens twice for each index of the target sequence length, i.e. q.size(1)
.
Upvotes: 0