Christian
Christian

Reputation: 477

How to use PyTorch's nn.MultiheadAttention

I want to use PyTorch's nn.MultiheadAttention but it doesn't work.
I just want to use the functionality of pytorch for the manual calculated example of attention

I always got an error when trying to run this example.

import torch.nn as nn

embed_dim = 4
num_heads = 1

x = [
  [1, 0, 1, 0], # Input 1
  [0, 2, 0, 2], # Input 2
  [1, 1, 1, 1]  # Input 3
 ]
x = torch.tensor(x, dtype=torch.float32)

w_key = [
  [0, 0, 1],
  [1, 1, 0],
  [0, 1, 0],
  [1, 1, 0]
]
w_query = [
  [1, 0, 1],
  [1, 0, 0],
  [0, 0, 1],
  [0, 1, 1]
]
w_value = [
  [0, 2, 0],
  [0, 3, 0],
  [1, 0, 3],
  [1, 1, 0]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)


keys = x @ w_key
querys = x @ w_query
values = x @ w_value

multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
attn_output, attn_output_weights = multihead_attn(querys, keys, values)

Upvotes: 6

Views: 14507

Answers (2)

guifeng deng
guifeng deng

Reputation: 11

just modify embed_dim = 4 to embed_dim = 3 is OK

Upvotes: 1

Muscle Guy
Muscle Guy

Reputation: 136

Try this.

First, your x is a (3x4) matrix. So you need a weight matrix of (4x4) instead.

Seems nn.MultiheadAttention only supports batch mode although the doc said it supports unbatch input. So let's just make your one data point in batch mode via .unsqueeze(0).

embed_dim = 4
num_heads = 1

x = [
  [1, 0, 1, 0], # Seq 1
  [0, 2, 0, 2], # Seq 2
  [1, 1, 1, 1]  # Seq 3
 ]
x = torch.tensor(x, dtype=torch.float32)

w_key = [
  [0, 0, 1, 1],
  [1, 1, 0, 1],
  [0, 1, 0, 1],
  [1, 1, 0, 1]
]
w_query = [
  [1, 0, 1, 1],
  [1, 0, 0, 1],
  [0, 0, 1, 1],
  [0, 1, 1, 1]
]
w_value = [
  [0, 2, 0, 1],
  [0, 3, 0, 1],
  [1, 0, 3, 1],
  [1, 1, 0, 1]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)


keys = (x @ w_key).unsqueeze(0)     # to batch mode
querys = (x @ w_query).unsqueeze(0)
values = (x @ w_value).unsqueeze(0)

multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
attn_output, attn_output_weights = multihead_attn(querys, keys, values)

Upvotes: 12

Related Questions