apostofes
apostofes

Reputation: 3723

Multi head Attention calculation

I create a model with a multi head attention layer,

import torch
import torch.nn as nn
query = torch.randn(2, 4)
key = torch.randn(2, 4)
value = torch.randn(2, 4)
model = nn.MultiheadAttention(4, 1, bias=False)
model(query, key, value)

I attempt at matching the attention output obtained,

softmax_output = torch.softmax((([email protected]_proj_weight[:4])@(([email protected]_proj_weight[4:8]).t()))/2, dim=1)
intermediate_output = softmax_output@([email protected]_proj_weight[8:12])
final_output = [email protected]_proj.weight

but the final_output does not match the attention output

Upvotes: 1

Views: 782

Answers (1)

apostofes
apostofes

Reputation: 3723

was able to match the output,

q_w = [email protected]_proj_weight[:4].t()
k_w = [email protected]_proj_weight[4:8].t()
v_w = [email protected]_proj_weight[8:12].t()

softmax_output = torch.softmax((q_w@k_w.t())/2, dim=1)

attention = softmax_output@v_w

final_output = [email protected]_proj.weight.t()

was missing the transpose earlier

Upvotes: 1

Related Questions