Elidor00
Elidor00

Reputation: 1624

Number of learnable parameters of MultiheadAttention

While testing (using PyTorch's MultiheadAttention), I noticed that increasing or decreasing the number of heads of the multi-head attention does not change the total number of learnable parameters of my model.

Is this behavior correct? And if so, why?

Shouldn't the number of heads affect the number of parameters the model can learn?

Upvotes: 5

Views: 2641

Answers (1)

KonstantinosKokos
KonstantinosKokos

Reputation: 3473

The standard implementation of multi-headed attention divides the model's dimensionality by the number of attention heads.

A model of dimensionality d with a single attention head would project embeddings to a single triplet of d-dimensional query, key and value tensors (each projection counting d2 parameters, excluding biases, for a total of 3d2).

A model of the same dimensionality with k attention heads would project embeddings to k triplets of d/k-dimensional query, key and value tensors (each projection counting d×d/k=d2/k parameters, excluding biases, for a total of 3kd2/k=3d2).


References:

From the original paper: enter image description here

The Pytorch implementation you cited: enter image description here

Upvotes: 6

Related Questions