Paul
Paul

Reputation: 1176

Where are the different projection matrices for huggingface transformer model?

My question is very simple. I have a pre-trained transformer model I'm loading using pytorch and huggingface. Using collab, I run the following code and print out the keys of the state dict:

model = DistilBertModel.from_pretrained("distilbert-base-uncased")
model.state_dict().keys()

the output of this is:

odict_keys(['embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'transformer.layer.0.attention.q_lin.weight', 'transformer.layer.0.attention.q_lin.bias', 'transformer.layer.0.attention.k_lin.weight', 'transformer.layer.0.attention.k_lin.bias', 'transformer.layer.0.attention.v_lin.weight', 'transformer.layer.0.attention.v_lin.bias', 'transformer.layer.0.attention.out_lin.weight', 'transformer.layer.0.attention.out_lin.bias', 'transformer.layer.0.sa_layer_norm.weight', 'transformer.layer.0.sa_layer_norm.bias', 'transformer.layer.0.ffn.lin1.weight', 'transformer.layer.0.ffn.lin1.bias', 'transformer.layer.0.ffn.lin2.weight', 'transformer.layer.0.ffn.lin2.bias', 'transformer.layer.0.output_layer_norm.weight', 'transformer.layer.0.output_layer_norm.bias', 'transformer.layer.1.attention.q_lin.weight', 'transformer.layer.1.attention.q_lin.bias', 'transformer.layer.1.attention.k_lin.weight', 'transformer.layer.1.attention.k_lin.bias', 'transformer.layer.1.attention.v_lin.weight', 'transformer.layer.1.attention.v_lin.bias', 'transformer.layer.1.attention.out_lin.weight', 'transformer.layer.1.attention.out_lin.bias', 'transformer.layer.1.sa_layer_norm.weight', 'transformer.layer.1.sa_layer_norm.bias', 'transformer.layer.1.ffn.lin1.weight', 'transformer.layer.1.ffn.lin1.bias', 'transformer.layer.1.ffn.lin2.weight', 'transformer.layer.1.ffn.lin2.bias', 'transformer.layer.1.output_layer_norm.weight', 'transformer.layer.1.output_layer_norm.bias', 'transformer.layer.2.attention.q_lin.weight', 'transformer.layer.2.attention.q_lin.bias', 'transformer.layer.2.attention.k_lin.weight', 'transformer.layer.2.attention.k_lin.bias', 'transformer.layer.2.attention.v_lin.weight', 'transformer.layer.2.attention.v_lin.bias', 'transformer.layer.2.attention.out_lin.weight', 'transformer.layer.2.attention.out_lin.bias', 'transformer.layer.2.sa_layer_norm.weight', 'transformer.layer.2.sa_layer_norm.bias', 'transformer.layer.2.ffn.lin1.weight', 'transformer.layer.2.ffn.lin1.bias', 'transformer.layer.2.ffn.lin2.weight', 'transformer.layer.2.ffn.lin2.bias', 'transformer.layer.2.output_layer_norm.weight', 'transformer.layer.2.output_layer_norm.bias', 'transformer.layer.3.attention.q_lin.weight', 'transformer.layer.3.attention.q_lin.bias', 'transformer.layer.3.attention.k_lin.weight', 'transformer.layer.3.attention.k_lin.bias', 'transformer.layer.3.attention.v_lin.weight', 'transformer.layer.3.attention.v_lin.bias', 'transformer.layer.3.attention.out_lin.weight', 'transformer.layer.3.attention.out_lin.bias', 'transformer.layer.3.sa_layer_norm.weight', 'transformer.layer.3.sa_layer_norm.bias', 'transformer.layer.3.ffn.lin1.weight', 'transformer.layer.3.ffn.lin1.bias', 'transformer.layer.3.ffn.lin2.weight', 'transformer.layer.3.ffn.lin2.bias', 'transformer.layer.3.output_layer_norm.weight', 'transformer.layer.3.output_layer_norm.bias', 'transformer.layer.4.attention.q_lin.weight', 'transformer.layer.4.attention.q_lin.bias', 'transformer.layer.4.attention.k_lin.weight', 'transformer.layer.4.attention.k_lin.bias', 'transformer.layer.4.attention.v_lin.weight', 'transformer.layer.4.attention.v_lin.bias', 'transformer.layer.4.attention.out_lin.weight', 'transformer.layer.4.attention.out_lin.bias', 'transformer.layer.4.sa_layer_norm.weight', 'transformer.layer.4.sa_layer_norm.bias', 'transformer.layer.4.ffn.lin1.weight', 'transformer.layer.4.ffn.lin1.bias', 'transformer.layer.4.ffn.lin2.weight', 'transformer.layer.4.ffn.lin2.bias', 'transformer.layer.4.output_layer_norm.weight', 'transformer.layer.4.output_layer_norm.bias', 'transformer.layer.5.attention.q_lin.weight', 'transformer.layer.5.attention.q_lin.bias', 'transformer.layer.5.attention.k_lin.weight', 'transformer.layer.5.attention.k_lin.bias', 'transformer.layer.5.attention.v_lin.weight', 'transformer.layer.5.attention.v_lin.bias', 'transformer.layer.5.attention.out_lin.weight', 'transformer.layer.5.attention.out_lin.bias', 'transformer.layer.5.sa_layer_norm.weight', 'transformer.layer.5.sa_layer_norm.bias', 'transformer.layer.5.ffn.lin1.weight', 'transformer.layer.5.ffn.lin1.bias', 'transformer.layer.5.ffn.lin2.weight', 'transformer.layer.5.ffn.lin2.bias', 'transformer.layer.5.output_layer_norm.weight', 'transformer.layer.5.output_layer_norm.bias'])

It seems at first glance that I'm missing the weights for different heads. Where are the weights for different heads?

My second question is a yes or no: it seems the answer to my first question may be that the weights for different heads have been concatenated. On inspection, the projection matrices for example are 768x768. Is this really 12 768x64 projection matrices concatenated?

Where is the documentation for this? I can't find any explanation of these state_dict keys anywhere on huggingface.

EDIT:

I tried loading a pre-trained BERT model using tensorflow instead, but its the same issue. The Wq and Wk matrices are both 768x768. My hunch is that since each of the Wq matrices for the 12 different were intended to be 64xdim, this matrix stacks the projection matrices for each of the heads row-by-row. But how do I know I'm not getting it backwards or transposed without any documentation for either Tensorflow or Pytorch on how this state is defined?

Upvotes: 0

Views: 424

Answers (1)

Gabriel A.
Gabriel A.

Reputation: 641

The weights for the heads are concatenated; not only that, but it is also customary to concatenate the weights for Q, K, and V into a single matrix (for self-attention).

The easiest (most common) way to make the split is along the last dimension, as the resulting matrix would still be contiguous in memory. But ultimately, that depends on the implementation.

Take for instance this code from minGPT

# on init
c_attn = nn.Linear(n_embed, 3*n_embed)
...
# on forward
B, T, C = x.size() # batch size, time sequence length, n_embed
q, k, v = c_attn(x).split(n_embed, dim=2) 
q = q.view(B, T, n_head, n_embed/n_head) #.transpose(1,2)

As you can understand, both the QKV and heads are concatenated along the last dimension, and were split/reshaped accordingly.

For your DistilBert, you can look at the source code of HugginFace Transformers MultiHeadAttention

def shape(x: torch.Tensor) -> torch.Tensor:
            """separate heads"""
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
# query: torch.tensor(bs, seq_length, dim)
# self.q_lin(query) still has shape (bs, seq_length, dim)
q = shape(self.q_lin(query))  # (bs, n_heads, q_length, dim_per_head)
# the result is already transpose for the next matrix product

Is this really 12 768x64 projection matrices concatenated? Yes, the split/concat is along the last dimension.

The best Documentation is the source.

Upvotes: 0

Related Questions