tstseby
tstseby

Reputation: 1319

How does the Transformer Model Compute Self Attention?

In the transformer model, https://arxiv.org/pdf/1706.03762.pdf there is self-attention which is computed using softmax on Query (Q) and Key (K) vectors:

I am trying to understand the matrix multiplications:

Q = batch_size x seq_length x embed_size

K = batch_size x seq_length x embed_size

QK^T = batch_size x seq_length x seq_length

Softmax QK^T = Softmax (batch_size x seq_length x seq_length)

How is the softmax computed since there are seq_length x seq_length values per batch element?

A reference to Pytorch computation will be very helpful.

Cheers!

Upvotes: 1

Views: 1204

Answers (2)

Wasi Ahmad
Wasi Ahmad

Reputation: 37681

How is the softmax computed since there are seq_length x seq_length values per batch element?

The softmax is performed on w.r.t the last axis (torch.nn.Softmax(dim=-1)(tensor) where tensor is of shape batch_size x seq_length x seq_length) to get the probability of attending to every element for each element in the input sequence.


Let's assume, we have a text sequence "Thinking Machines", so we have a matrix of shape "2 x 2" (where seq_length = 2) after performing QK^T.

I am using the following illustration (reference) to explain self-attention computation. As you know, first scaled-dot-product is performed QK^T/square_root(d_k) and then softmax is computed for each sequence element.

Here, Softmax is performed for the first sequence element "Thinking". The raw score of 14 and 12 is turned into a probability of 0.88 and 0.12 by doing softmax. These probability indicates that the token "Thinking" would attend itself with 88% probability, and the token "Machines" with 12% probability. Similarly, the attention probability is computed for the token "Machines" too.

enter image description here


Note. I strongly suggest reading this excellent article on Transformer. For implementation, you can take a look at OpenNMT.

Upvotes: 3

ziedaniel1
ziedaniel1

Reputation: 360

The QKᵀ multiplication is a batched matrix multiplication -- it's doing a separate seq_length x embed_size by embed_size x seq_length multiplication batch_size times. Each one gives a result of size seq_length x seq_length, which is how we end up with QKᵀ having the shape batch_size x seq_length x seq_length.

Gabriela Melo's suggested resource uses the following PyTorch code for this operation:

torch.matmul(query, key.transpose(-2, -1))

This works because torch.matmul does a batched matrix multiplication when an input has at least 3 dimensions (see https://pytorch.org/docs/stable/torch.html#torch.matmul).

Upvotes: 2

Related Questions