Core Concepts
Reasoning
Memory & Retrieval
Agent Types
Design Patterns
Training & Alignment
Frameworks
Tools
Safety & Security
Evaluation
Meta
Core Concepts
Reasoning
Memory & Retrieval
Agent Types
Design Patterns
Training & Alignment
Frameworks
Tools
Safety & Security
Evaluation
Meta
Attention mechanisms allow neural networks to dynamically focus on relevant parts of input sequences, forming the core computational primitive of Transformer-based models. From self-attention within a single sequence to cross-attention between encoder and decoder, attention enables models to capture long-range dependencies with $O(1)$ sequential operations.
The fundamental attention operation computes a weighted sum of values based on query-key similarity:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$
Where $Q \in \mathbb{R}^{n \times d_k}$, $K \in \mathbb{R}^{m \times d_k}$, $V \in \mathbb{R}^{m \times d_v}$. The scaling factor $\sqrt{d_k}$ prevents the dot products from growing large in magnitude, which would push the softmax into regions with vanishingly small gradients.
Computational complexity: $O(n^2 d)$ time and $O(n^2)$ space for the attention matrix, making it quadratic in sequence length.
Self-Attention: Query, key, and value matrices all derive from the same sequence. Each token attends to every other token in the sequence, capturing intra-sequence dependencies such as coreference, syntax, and semantic relationships.
Cross-Attention: Queries come from one sequence (e.g., decoder), while keys and values come from another (e.g., encoder output). This aligns representations across sequences, as in machine translation where decoder tokens attend to relevant source tokens.
Causal (Masked) Attention: Self-attention with a mask preventing tokens from attending to future positions. Essential for autoregressive language models where each token can only depend on preceding tokens.
Multi-head attention projects inputs into $h$ separate subspaces and performs attention in parallel:
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W^O$$
$$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$
Each head has its own learned projections $W_i^Q \in \mathbb{R}^{d_{model} \times d_k}$, $W_i^K \in \mathbb{R}^{d_{model} \times d_k}$, $W_i^V \in \mathbb{R}^{d_{model} \times d_v}$, where $d_k = d_{model} / h$. This allows different heads to capture diverse patterns (syntactic, semantic, positional).
During autoregressive generation, the KV cache stores previously computed key and value tensors to avoid redundant computation:
This reduces per-step complexity from $O(t^2 d)$ to $O(td)$ but requires $O(n \cdot h \cdot d_k)$ memory per layer, which becomes the primary memory bottleneck for long sequences.
| Variant | KV Heads | Quality | Cache Size | Use Case |
|---|---|---|---|---|
| Multi-Head (MHA) | $h$ | Highest | $O(h \cdot n \cdot d_k)$ | Training, quality-critical |
| Multi-Query (MQA) | 1 | Lower | $O(n \cdot d_k)$ | Fast inference |
| Grouped-Query (GQA) | $g < h$ | High | $O(g \cdot n \cdot d_k)$ | Production LLMs (Llama 2/3) |
Multi-Query Attention (MQA): All query heads share a single set of keys and values, reducing KV cache size by a factor of $h$. Trades some model quality for dramatically faster inference.
Grouped-Query Attention (GQA): Groups of query heads share KV heads. With $g$ groups, the KV cache is $g/h$ the size of full MHA. Balances quality and efficiency; adopted by Llama 2 70B and Llama 3.
Flash Attention (Dao et al., 2022) is an IO-aware exact attention algorithm that minimizes memory reads/writes between GPU HBM (slow) and SRAM (fast):
Flash Attention achieves 2-4x wall-clock speedup over standard attention and enables sequences up to 64K tokens with linear IO complexity $O(n)$ instead of $O(n^2)$.
import torch import torch.nn.functional as F def scaled_dot_product_attention(Q, K, V, mask=None): # Scaled dot-product attention with optional causal mask d_k = Q.size(-1) scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5) if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) weights = F.softmax(scores, dim=-1) return torch.matmul(weights, V) def grouped_query_attention(Q, K, V, n_heads, n_kv_heads): # GQA: multiple query heads share fewer KV heads batch, seq_len, d_model = Q.shape d_k = d_model // n_heads heads_per_group = n_heads // n_kv_heads Q = Q.view(batch, seq_len, n_heads, d_k).transpose(1, 2) K = K.view(batch, seq_len, n_kv_heads, d_k).transpose(1, 2) V = V.view(batch, seq_len, n_kv_heads, d_k).transpose(1, 2) # Expand KV heads to match query head count K = K.repeat_interleave(heads_per_group, dim=1) V = V.repeat_interleave(heads_per_group, dim=1) out = scaled_dot_product_attention(Q, K, V) return out.transpose(1, 2).contiguous().view(batch, seq_len, d_model)