In-Depth Breakdown of Self-Attention, Multi-Head and Masked Attention (Part 1)
Table of Content
- Overview
- 1、Self-Attention
- 1.1. Why Use Self-Attention
- 1.2. A Closer Look at Self-Attention
- 1.3. How Self-Attention Incorporates Context?
- 1.4. How to Calculate the Relevance Score
- 1.5. Normalizing
- 1.6. Combining All the Components
- 1.7. Vectorization
- 1.8. What is
and Why Divide by - 1.9. Code Practice: Defining a Self-Attention Model in PyTorch
- 2. Multi-head Attention and Masked Attention
- References
Overview
This article is based on Li Hongyi's lecture on Self-Attention, expanding upon his explanations and incorporating additional understanding and PyTorch code examples. The goal is to help both myself and readers develop a clearer understanding of Self-Attention.
Link to Li Hongyi's Self-Attention Lecture: https://www.youtube.com/watch?v=hYdO9CscNes
Slides can be found below the video.
After reading this article, you should gain the following insights:
- What Self-Attention is and why we use it
- How Self-Attention works
- How Self-Attention is designed
- Detailed explanation of the Self-Attention formula
- MultiHead Attention
- Masked Attention
1、Self-Attention
1.1. Why Use Self-Attention
Let’s consider a part-of-speech tagging (POS tagging) task as an example. Suppose the input is the sentence I saw a saw
, and our goal is to identify the part of speech for each word, resulting in N, V, DET, N
(Noun, Verb, Determiner, Noun).
In this sentence, the first saw
is a verb, while the second saw
refers to a noun (a saw). To make this distinction, the model needs to consider the context surrounding each word and determine how much attention each context element should receive. For example, when processing the first saw
, the model should focus more on I
, while for the second saw
, it should give more attention to a
.
This is where the Attention mechanism comes in: if a task involves an input sequence (a series of vectors) with interdependent relationships, then the Attention mechanism helps capture these relationships.
1.2. A Closer Look at Self-Attention
This illustration shows how Self-Attention works. Self-Attention takes in a sequence (a series of vectors, which could be inputs or outputs from a previous hidden layer) and outputs a sequence of the same length, where each vector incorporates context from the entire sequence.
For instance, if the input sequence is I
, saw
, a
, saw
, represented by vectors as follows:
After passing through the Self-Attention layer, the sequence might be transformed to something like this:
In this transformed sequence, the first instance of saw
now reflects I
, while the second saw
incorporates a
to capture contextual meaning.
1.3. How Self-Attention Incorporates Context?
As shown, each input vector is compared with all others in the sequence to calculate a relevance score, which is used to generate a new, contextually enriched vector.
For example, to calculate the new representation for
Once we have these
Similarly, to compute
There are two common issues when computing this way:
- The sum of the
values might not be 1, which could scale the input vectors up or down unpredictably. - Directly multiplying by the input vector
can limit the model's expressive power.
To address these issues:
- For issue 1, we often apply a Softmax function to the
values (though other methods are possible). - For issue 2, we typically multiply
by a learnable matrix to create , which is then weighted by the values.
1.4. How to Calculate the Relevance Score
First, let’s review vector multiplication. When two vectors are multiplied (taking their inner product), the formula is:
- The smaller the angle between two vectors (the more aligned they are), the larger the inner product, and the higher the relevance. Conversely, the larger the angle, the less relevant they are. If the angle is 90°, the vectors are perpendicular, giving an inner product of 0, indicating no relevance.
From this, it seems straightforward to measure the relevance of
To address this, Self-Attention introduces two additional matrices,
applies a linear transformation to the “main word” or “query” to produce , referred to as the query vector. applies a linear transformation to the “context word” or “key” to produce , referred to as the key vector.
With
This process is summarized visually in the diagram below:
To calculate the relevance between
- Use
to compute - Use
to compute - Calculate
using and values
The diagram doesn’t explicitly show
, but in actual calculations, we include , meaning we also calculate the relevance score between and itself.
1.5. Normalizing
As we mentioned earlier, the sum of
In the end, the normalized
It's not mandatory to use Softmax—feel free to try any other method if you think it might work better. Normalization isn’t strictly required either, although it’s commonly applied in practice.
1.6. Combining All the Components
Once we have computed the attention scores
After obtaining
The process of computing
To formally describe the process:
Given an input sequence
- Compute the query vector
using the formula - Calculate
, where - Compute attention scores
using the formula - Normalize
to obtain with the formula - Obtain vectors
using the formula - Calculate
with the formula
Here,
are trainable parameters.
While the Self-Attention mechanism is now clearer, there's an additional step. The above calculations, if implemented directly in code, would require numerous for-loops, which would be computationally inefficient. To improve efficiency, vectorization is necessary, merging components into vectors and matrices wherever possible.
1.7. Vectorization
To begin, let's consider vectorizing a vector
Next, we define the matrices
Once we’ve defined the dimensions of
The matrix form of
Similarly, the matrix form of
Likewise, the matrix form of
With matrices
The matrix form of the relevance score
Note: I defined
as a column vector, so we will transpose it.
Then, we can calculate an adjusted matrix
Once we have
The matrix form for vectorizing
Bringing it all together, we can derive the final combined equation:
If you’ve seen other resources on this topic, you may notice that the actual final formula is as follows:
Our formula differs only by a transpose and the factor
In the original formula, the matrices
, and the output correspond to , and in our formula, respectively.
1.8. What is and Why Divide by
Firstly,
For example, assume
The mean of a matrix is simply the sum of all its elements divided by the total number of elements, and variance follows similarly.
This effect can be confirmed through the following code (since my math is weak, I’m validating with an experiment—sigh):
```python Q = np.random.normal(size=(123, 456)) # Generate Q and K with mean 0, std 1 K = np.random.normal(size=(123, 456)) print("Q.std=%s, K.std=%s, \nQ·K^T.std=%s, Q·K^T/√d.std=%s" % (Q.std(), K.std(), Q.dot(K.T).std(), Q.dot(K.T).std() / np.sqrt(456))) ```
``` Q.std=0.9977961671085275, K.std=1.0000574599289282, Q·K^T.std=21.240017020263437, Q·K^T/√d.std=0.9946549289466212 ```
From the output, we can see that the standard deviations of Q and K are both around 1, but after multiplying the matrices, the standard deviation rises to 21.24. Dividing by
Here's another example where Q and K have random standard deviations, which is closer to a real-world scenario:
```python Q = np.random.normal(loc=1.56, scale=0.36, size=(123, 456)) # Generate Q and K with random mean and std K = np.random.normal(loc=-0.34, scale=1.2, size=(123, 456)) print("Q.std=%s, K.std=%s, \nQ·K^T.std=%s, Q·K^T/√d.std=%s" % (Q.std(), K.std(), Q.dot(K.T).std(), Q.dot(K.T).std() / np.sqrt(456))) ```
``` Q.std=0.357460640868945, K.std=1.204536717914841, Q·K^T.std=37.78368871510589, Q·K^T/√d.std=1.769383337989377 ```
In this case, the initial standard deviations of Q and K are
1.9. Code Practice: Defining a Self-Attention Model in PyTorch
Let's define a Self-Attention model in PyTorch using the original formula from the attention mechanism paper:
To make the code logic clearer, I've labeled the dimensions of each component:
Here, each variable is defined as follows:
: input_num, the number of input vectors. For example, if a sentence has 20 words, then . : dimension of K, the row dimension for both the and matrices (a hyperparameter, typically equal to the input vector dimension ), determining the width of the linear layers. : dimension of V, the row dimension for the matrix, which corresponds to the output vector dimension (also a hyperparameter, often set equal to the input dimension ).
In this formula,
The
: input_vector_dim, the dimension of the input vectors. For example, if each word is encoded as a 10-dimensional vector, then .
With formulas (1) and (2), we can now define the Self-Attention model. Here's the code:
```python class SelfAttention(nn.Module): def __init__(self, input_vector_dim: int, dim_k=None, dim_v=None): """ Initializes the SelfAttention module with the following parameters: input_vector_dim: The dimension of the input vector, denoted as d in the formula. For example, if each word is encoded as a 10-dimensional vector, this value should be 10. dim_k: Dimension of the matrices W^k and W^q. Defaults to input_vector_dim if not specified. dim_v: Dimension of the output vector (dimension of b). For example, if you want the output vector b to have a dimension of 15, set this value to 15. If not provided, it defaults to input_vector_dim. """ super(SelfAttention, self).__init__() self.input_vector_dim = input_vector_dim # If dim_k and dim_v are not provided, default to input vector dimension. if dim_k is None: dim_k = input_vector_dim if dim_v is None: dim_v = input_vector_dim """ In practice, linear layers are often used to represent trainable matrices since they simplify backpropagation and parameter updates. """ self.W_q = nn.Linear(input_vector_dim, dim_k, bias=False) self.W_k = nn.Linear(input_vector_dim, dim_k, bias=False) self.W_v = nn.Linear(input_vector_dim, dim_v, bias=False) # This is sqrt(d_k), used for normalization in scaled dot-product attention self._norm_fact = 1 / np.sqrt(dim_k) def forward(self, x): """ Forward pass: x: Input vector with size (batch_size, input_num, input_vector_dim) """ # Calculate Q, K, V using the weight matrices W_q, W_k, and W_v. # The size of Q, K, and V is (batch_size, input_num, output_vector_dim) Q = self.W_q(x) K = self.W_k(x) V = self.W_v(x) # permute rearranges the dimensions of a tensor. # Here, we change K's size from (batch_size, input_num, output_vector_dim) # to (batch_size, output_vector_dim, input_num). # The 0, 1, 2 represent the indices of each dimension; before the transformation, # batch_size is at position 0, and input_num is at position 1. K_T = K.permute(0, 2, 1) # bmm is batch matrix-matrix product, which performs matrix multiplication on a batch of matrices. # For more details on bmm, see https://pytorch.org/docs/stable/generated/torch.bmm.html atten = nn.Softmax(dim=-1)(torch.bmm(Q, K_T) * self._norm_fact) # Finally, multiply by V output = torch.bmm(atten, V) return output ```
Now, let's use this model. We'll define a batch of size 50, where each input vector has 3 dimensions, and we'll input 5 vectors at a time. After passing through the Attention layer, we aim to encode the output into 5 vectors, each with 4 dimensions:
```python model = SelfAttention(3, 5, 4) model(torch.Tensor(50,5,3)).size() ```
``` torch.Size([50, 5, 4]) ```
The Attention model is often part of a larger model, commonly integrated within other layers, with the Transformer model being the most classic example.
2. Multi-head Attention and Masked Attention
Please see the next article: In-Depth Breakdown of Self-Attention, Multi-Head and Masked Attention (Part 2)
References
-
Self-Attention: https://www.youtube.com/watch?v=hYdO9CscNes
-
annotated-transformer:https://github.com/harvardnlp/annotated-transformer/