Attention Mechanism
The Foundations of Self-Attention and the Transformer
The Attention mechanism is a scheme that dynamically learns "where to focus" within an input sequence. By doing so, it solved the long-range dependency problem of RNNs and became the foundation of the Transformer architecture.
The Birth of Attention
Challenges of the Seq2Seq Model
In Sequence-to-Sequence (Seq2Seq) tasks such as machine translation, the encoder compressed the input sequence into a fixed-length context vector. This caused a problem.
The Solution Brought by Attention
The Attention mechanism (Bahdanau et al., 2014) allowed the decoder to refer to all of the encoder's hidden states at each step.
Computing Attention
Given the decoder state \(s_t\) and the encoder hidden states \(h_1, ..., h_n\):
- Score computation: \(e_{ti} = \text{score}(s_t, h_i)\)
- Normalization (Softmax): \(\alpha_{ti} = \dfrac{\exp(e_{ti})}{\displaystyle\sum_j \exp(e_{tj})}\)
- Context vector: \(c_t = \displaystyle\sum_i \alpha_{ti} h_i\)
Types of Score Functions
1. Dot Product
\[ \text{score}(s, h) = s^\top h \]2. General (generalized dot product)
\[ \text{score}(s, h) = s^\top W h \]3. Additive (Bahdanau)
\[ \text{score}(s, h) = v^\top \tanh(W_s s + W_h h) \]Here \(s\) is the decoder-side state (the "what am I looking for right now" = the query), and \(h\) is each encoder hidden state (the "what is at each position" = the key). \(W,\ W_s,\ W_h\) are weight matrices acquired through training, whose role is to linearly transform \(s\) and \(h\) into a space where they are easy to compare. \(v\) is a weight vector acquired through training that converts the vector produced by \(\tanh\) into a single score (a scalar value). All of them are parameters optimized by gradient descent from the training data.
- Dot Product: has no learnable parameters and is fast. However, it assumes that \(s\) and \(h\) have the same dimension and lie in the same space.
- General: inserts \(W\) in between, so it can compare \(s\) and \(h\) even when they lie in different spaces, making it more flexible.
- Additive: computes the score with a small neural network (\(W_s, W_h, v\)), giving it high expressive power.
Self-Attention
Self-Attention is a scheme in which the input sequence computes Attention with respect to itself. It can directly model the relationship between any pair of positions within the sequence.
Three representations are created from the input \(X\):
- Query (Q): "what am I looking for" → \(Q = XW^Q\)
- Key (K): "what is there" → \(K = XW^K\)
- Value (V): "the actual information" → \(V = XW^V\)
Dividing by \(\sqrt{d_k}\) for scaling prevents the dot products from becoming too large and saturating the Softmax (which would make the gradients vanish).
Why Scaling Is Necessary
When \(d_k\) is large, the dot product \(q \cdot k\) tends to take large values, pushing the Softmax output toward extreme values (close to 0 or 1). Since this causes the gradient to vanish, we scale by dividing by \(\sqrt{d_k}\).
Multi-Head Attention
Because a single Attention has limited expressive power, multiple "heads" compute Attention from different perspectives, and the results are combined.
How Are the Heads Split?
- Linearly transform the input (each token is \(d_{model}\)-dimensional) with the weight matrices \(W^Q, W^K, W^V\) to produce the full Q, K, V (\(d_{model}\)-dimensional).
- Split that \(d_{model}\)-dimensional vector into \(h\) equal parts, with each head handling a small Q, K, V of dimension \(d_k = d_{model}/h\) (in implementation this is a reshape: \((\text{seq length},\, d_{model}) \to (\text{seq length},\, h,\, d_k)\)). Rather than preparing \(h\) sets of separate weights, this simply splits one large projection into \(h\) parts, and the \(W_i^Q\) in the formula corresponds to the columns of \(W^Q\) that each head is responsible for.
- Each head independently computes Scaled Dot-Product Attention, and the resulting outputs (each \(d_k\)-dimensional) are concatenated back to \(d_{model}\) dimensions, then finally mixed together by \(W^O\).
How to choose the number of heads \(h\): \(h\) is a hyperparameter chosen by a human, selected from the divisors of \(d_{model}\) so that \(d_k = d_{model}/h\) is an integer. Increasing the number of heads keeps the compute roughly constant since each head becomes thinner, but increasing it too much reduces the dimension per head and lowers expressive power. Empirically, 8 to 16 is the standard choice.
Why naive "sequential equal splitting" is sufficient
The split itself is sufficient with the naive approach of simply dividing the post-projection vector by consecutive element indices (\(0\sim d_k-1\) for the first head, \(d_k\sim 2d_k-1\) for the second head, and so on). The reason is that there is a learnable projection \(W^Q\) (a fully connected \(d_{model}\times d_{model}\) matrix) right before the split. Since each output dimension of \(W^Q\) is a linear combination of the entire input, "which features are sent to which head's range" can be freely decided by \(W^Q\) through learning. A fixed slice position or permutation can be absorbed into \(W^Q\), and the expressive power does not change.
Therefore, the individuality of each head (one syntactic, another semantic, ...) arises from training, not from how it is split. Random initialization plus gradient descent breaks the symmetry, and each head's responsible columns converge toward different subspaces. "Naive equal splitting + a learned projection" is mathematically equivalent to "independently learning \(h\) arbitrary projections," and this is sufficient.
Typical Settings
- Number of heads h: 8 to 16
- Dimension of each head: \(d_k = d_v = d_{model} / h\)
- e.g., \(d_{model} = 512, h = 8 \Rightarrow d_k = 64\)
Positional Encoding
Self-Attention does not consider order (it is permutation-invariant). Therefore, positional information must be added explicitly.
\(pos\): position, \(i\): dimension index
Why use sin and cos
- The values are bounded (\([-1, 1]\)): adding the position index \(pos\) directly would make the values grow without bound and corrupt the word vectors, but sin/cos stay at a constant magnitude anywhere in the sequence.
- Each position is unique: since sin/cos with different wavelengths are arranged per dimension, each position has its own unique "combination of waveforms" and can be distinguished from the others (the same idea as how the second, minute, and hour hands of a clock uniquely determine the time).
- It easily represents relative positions: by the angle-addition formulas of trigonometry, \(PE_{pos+k}\) for a position \(k\) away can be written as a linear transformation (rotation) of \(PE_{pos}\) (see the derivation below). This makes it easy for the model to learn the relative relationship of "being \(k\) apart."
- It can extrapolate to unknown lengths: the same formula computes positions even for sequences longer than those seen in training (a scheme that memorizes a value per position cannot handle unknown positions).
For a single frequency \(\omega = 1/10000^{2i/d_{model}}\), the encoding of position \(pos\) is the pair \((\sin\omega pos,\ \cos\omega pos)\). Shifting the position by \(k\), the angle-addition formulas give
\[ \begin{aligned} \sin\omega(pos+k) &= \sin\omega pos\,\cos\omega k + \cos\omega pos\,\sin\omega k,\\ \cos\omega(pos+k) &= \cos\omega pos\,\cos\omega k - \sin\omega pos\,\sin\omega k. \end{aligned} \]Collecting them into matrix form, this can be written using the rotation matrix \(R(\omega k)\) as follows.
\[ \begin{pmatrix}\sin\omega(pos+k)\\ \cos\omega(pos+k)\end{pmatrix} = \underbrace{\begin{pmatrix}\cos\omega k & \sin\omega k\\ -\sin\omega k & \cos\omega k\end{pmatrix}}_{R(\omega k)} \begin{pmatrix}\sin\omega pos\\ \cos\omega pos\end{pmatrix} \]This \(R(\omega k)\) is determined solely by the shift \(k\) and does not depend on \(pos\). In other words, the operation of "moving \(k\) apart" is always the same rotation by angle \(\omega k\), regardless of which position you are currently at. So the model can capture the relative relationship "being \(k\) apart" with a single fixed linear transformation.
Intuition: a clock hand. No matter what time it currently is, "advancing by \(k\) hours" can be expressed as rotating the hand by a fixed angle — the same thing happens for the pair at each frequency.
Won't there be confusion when \(\omega k\) exceeds \(2\pi\)?
Looking at only a single wavelength, there indeed would be confusion. \(\sin,\cos\) have period \(2\pi\) in their argument, but here the argument is \(\omega\cdot pos\), so the wavelength (period) with respect to position \(pos\) is \(2\pi/\omega\) (\(\sin\omega(pos+2\pi/\omega)=\sin(\omega\,pos+2\pi)=\sin\omega\,pos\)). That is, two positions one wavelength (\(2\pi/\omega\)) apart take the same value and cannot be distinguished (aliasing).
What prevents this is the design of "using many waves with different wavelengths simultaneously." The wavelength is \(2\pi\cdot 10000^{2i/d_{model}}\), and it grows by a fixed multiplicative factor each time \(i\) increases by one (geometrically). Concretely, it spans from a wavelength \(\approx 2\pi\) at \(i=0\) (oscillating finely and quickly) to a wavelength \(\approx 2\pi\cdot 10000 \approx 6.3\times10^4\) at the maximum \(i\) (oscillating slowly). The longest wavelength covers about 60,000 positions, so for realistic sequence lengths (usually a few thousand or less) it does not complete a full cycle. To organize the division of labor — the long wavelengths do not complete a cycle, so they prevent confusion (aliasing) of far-apart positions, while the short wavelengths oscillate finely, so they clearly distinguish adjacent positions (\(pos\) and \(pos+1\)). Short wavelength = local resolution, long wavelength = global uniqueness, and combining all wavelengths uniquely determines every position.
Intuition: the second, minute, and hour hands (an odometer of gears at different scales). The fast hand (short wavelength) represents the fine differences between adjacent moments, and the slow hand (long wavelength) represents the coarse time. Reading both together uniquely identifies any time over a wide range — positional encoding likewise represents each position uniquely by the set of phases across all wavelengths.
Properties of Positional Encoding
- For any fixed offset \(k\), \(PE_{pos+k}\) can be expressed as a linear function of \(PE_{pos}\)
- This makes it easy to learn relative position information
- It can handle sequences longer than those in the trained model (extrapolation is possible)
Implementation in PyTorch
Scaled Dot-Product Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(query, key, value, mask=None):
"""
Scaled Dot-Product Attention
Args:
query: (batch, heads, seq_len, d_k)
key: (batch, heads, seq_len, d_k)
value: (batch, heads, seq_len, d_v)
mask: (batch, 1, 1, seq_len) or (batch, 1, seq_len, seq_len)
Returns:
output: (batch, heads, seq_len, d_v)
attention_weights: (batch, heads, seq_len, seq_len)
"""
d_k = query.size(-1)
# QK^T / sqrt(d_k)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# Apply mask (optional)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax
attention_weights = F.softmax(scores, dim=-1)
# Attention x V
output = torch.matmul(attention_weights, value)
return output, attention_weights
Multi-Head Attention
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Linear transformation layers
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
"""(batch, seq_len, d_model) -> (batch, heads, seq_len, d_k)"""
x = x.view(batch_size, -1, self.num_heads, self.d_k)
return x.transpose(1, 2)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# Linear transformation
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)
# Split into heads
Q = self.split_heads(Q, batch_size)
K = self.split_heads(K, batch_size)
V = self.split_heads(V, batch_size)
# Scaled Dot-Product Attention
attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
# Concatenate the heads
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.d_model)
# Final linear transformation
output = self.W_o(attn_output)
return output, attn_weights
# Usage example
d_model = 512
num_heads = 8
seq_len = 100
batch_size = 32
mha = MultiHeadAttention(d_model, num_heads)
x = torch.randn(batch_size, seq_len, d_model)
output, attention = mha(x, x, x) # Self-Attention
print(f"Output shape: {output.shape}") # (32, 100, 512)
print(f"Attention shape: {attention.shape}") # (32, 8, 100, 100)
Positional Encoding
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000, dropout=0.1):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# Precompute the positional encoding
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term) # even dimensions
pe[:, 1::2] = torch.cos(position * div_term) # odd dimensions
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: (batch, seq_len, d_model)
"""
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
# Usage example
pe = PositionalEncoding(d_model=512)
x = torch.randn(32, 100, 512)
output = pe(x)
print(f"Output shape: {output.shape}") # (32, 100, 512)
Visualizing Attention
By visualizing the Attention weights, we can understand "where the model is focusing."
Summary
- Attention is a mechanism that dynamically learns "where to focus"
- Self-Attention directly models the relationship between any pair of positions within a sequence
- It uses three representations: Query, Key, Value
- Scores are computed with Scaled Dot-Product and normalized with Softmax
- Multi-Head computes Attention from multiple perspectives
- Positional encoding adds order information
- Together these form the foundation of the Transformer architecture
Frequently Asked Questions (FAQ)
Q1. What is the attention mechanism?
It is a mechanism that aggregates information by weighting how much attention to pay to each position of the input. It consists of three components — Query, Key, and Value — and computes a weighted sum of Values using the similarity between Q and K (\(\text{Attention}(Q,K,V)=\text{softmax}(QK^\top/\sqrt{d_k})\,V\)). It solved the long-range dependency problem of Seq2Seq models.
Q2. What is Self-Attention?
It is attention in which Q, K, and V are all generated from the same sequence. Each position can compute its relevance to every position in the sequence, capturing long-range dependencies in constant time. It is the core component of the Transformer and dynamically learns the semantic relationships between words in a sentence.
Q3. Why is Multi-Head Attention effective?
It computes attention in parallel across different representation subspaces, then concatenates and projects the results, capturing diverse dependencies simultaneously. For example, a division of labor arises in which one head learns syntactic relations while another learns semantic relations.