Reputation: 477
I want to use PyTorch's nn.MultiheadAttention but it doesn't work.
I just want to use the functionality of pytorch for the manual calculated example of attention
I always got an error when trying to run this example.
import torch.nn as nn
embed_dim = 4
num_heads = 1
x = [
[1, 0, 1, 0], # Input 1
[0, 2, 0, 2], # Input 2
[1, 1, 1, 1] # Input 3
]
x = torch.tensor(x, dtype=torch.float32)
w_key = [
[0, 0, 1],
[1, 1, 0],
[0, 1, 0],
[1, 1, 0]
]
w_query = [
[1, 0, 1],
[1, 0, 0],
[0, 0, 1],
[0, 1, 1]
]
w_value = [
[0, 2, 0],
[0, 3, 0],
[1, 0, 3],
[1, 1, 0]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)
keys = x @ w_key
querys = x @ w_query
values = x @ w_value
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
attn_output, attn_output_weights = multihead_attn(querys, keys, values)
Upvotes: 6
Views: 14507
Reputation: 136
Try this.
First, your x is a (3x4) matrix. So you need a weight matrix of (4x4) instead.
Seems nn.MultiheadAttention only supports batch mode although the doc said it supports unbatch input. So let's just make your one data point in batch mode via .unsqueeze(0)
.
embed_dim = 4
num_heads = 1
x = [
[1, 0, 1, 0], # Seq 1
[0, 2, 0, 2], # Seq 2
[1, 1, 1, 1] # Seq 3
]
x = torch.tensor(x, dtype=torch.float32)
w_key = [
[0, 0, 1, 1],
[1, 1, 0, 1],
[0, 1, 0, 1],
[1, 1, 0, 1]
]
w_query = [
[1, 0, 1, 1],
[1, 0, 0, 1],
[0, 0, 1, 1],
[0, 1, 1, 1]
]
w_value = [
[0, 2, 0, 1],
[0, 3, 0, 1],
[1, 0, 3, 1],
[1, 1, 0, 1]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)
keys = (x @ w_key).unsqueeze(0) # to batch mode
querys = (x @ w_query).unsqueeze(0)
values = (x @ w_value).unsqueeze(0)
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
attn_output, attn_output_weights = multihead_attn(querys, keys, values)
Upvotes: 12