====== Attention Mechanism ====== Attention mechanisms allow neural networks to dynamically focus on relevant parts of input sequences, forming the core computational primitive of Transformer-based models. From self-attention within a single sequence to cross-attention between encoder and decoder, attention enables models to capture long-range dependencies with $O(1)$ sequential operations. ===== Scaled Dot-Product Attention ===== The fundamental attention operation computes a weighted sum of values based on query-key similarity: $$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$ Where $Q \in \mathbb{R}^{n \times d_k}$, $K \in \mathbb{R}^{m \times d_k}$, $V \in \mathbb{R}^{m \times d_v}$. The scaling factor $\sqrt{d_k}$ prevents the dot products from growing large in magnitude, which would push the softmax into regions with vanishingly small gradients. **Computational complexity**: $O(n^2 d)$ time and $O(n^2)$ space for the attention matrix, making it quadratic in sequence length. ===== Types of Attention ===== **Self-Attention**: Query, key, and value matrices all derive from the same sequence. Each token attends to every other token in the sequence, capturing intra-sequence dependencies such as coreference, syntax, and semantic relationships. **Cross-Attention**: Queries come from one sequence (e.g., decoder), while keys and values come from another (e.g., encoder output). This aligns representations across sequences, as in machine translation where decoder tokens attend to relevant source tokens. **Causal (Masked) Attention**: Self-attention with a mask preventing tokens from attending to future positions. Essential for autoregressive language models where each token can only depend on preceding tokens. ===== Multi-Head Attention ===== Multi-head attention projects inputs into $h$ separate subspaces and performs attention in parallel: $$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W^O$$ $$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$ Each head has its own learned projections $W_i^Q \in \mathbb{R}^{d_{model} \times d_k}$, $W_i^K \in \mathbb{R}^{d_{model} \times d_k}$, $W_i^V \in \mathbb{R}^{d_{model} \times d_v}$, where $d_k = d_{model} / h$. This allows different heads to capture diverse patterns (syntactic, semantic, positional). graph LR Input["Input (Q, K, V)"] --> Split["Linear Projections"] Split --> H1["Head 1: Attention"] Split --> H2["Head 2: Attention"] Split --> H3["Head 3: Attention"] Split --> Hh["Head h: Attention"] H1 --> Cat["Concat"] H2 --> Cat H3 --> Cat Hh --> Cat Cat --> WO["Linear W_O"] WO --> Out["Multi-Head Output"] ===== KV Cache ===== During autoregressive generation, the KV cache stores previously computed key and value tensors to avoid redundant computation: - At step $t$, only the new token's query $q_t$ is computed fresh - Keys and values from all previous steps are retrieved from cache - New $k_t, v_t$ are appended to the cache - Attention is computed as $\text{softmax}(q_t K_{1:t}^T / \sqrt{d_k}) V_{1:t}$ This reduces per-step complexity from $O(t^2 d)$ to $O(td)$ but requires $O(n \cdot h \cdot d_k)$ memory per layer, which becomes the primary memory bottleneck for long sequences. ===== Attention Variants ===== ^ Variant ^ KV Heads ^ Quality ^ Cache Size ^ Use Case ^ | Multi-Head (MHA) | $h$ | Highest | $O(h \cdot n \cdot d_k)$ | Training, quality-critical | | Multi-Query (MQA) | 1 | Lower | $O(n \cdot d_k)$ | Fast inference | | Grouped-Query (GQA) | $g < h$ | High | $O(g \cdot n \cdot d_k)$ | Production LLMs (Llama 2/3) | **Multi-Query Attention (MQA)**: All query heads share a single set of keys and values, reducing KV cache size by a factor of $h$. Trades some model quality for dramatically faster inference. **Grouped-Query Attention (GQA)**: Groups of query heads share KV heads. With $g$ groups, the KV cache is $g/h$ the size of full MHA. Balances quality and efficiency; adopted by Llama 2 70B and Llama 3. ===== Flash Attention ===== Flash Attention (Dao et al., 2022) is an IO-aware exact attention algorithm that minimizes memory reads/writes between GPU HBM (slow) and SRAM (fast): * **Tiling**: Computes attention in blocks that fit in SRAM, avoiding materializing the full $n \times n$ attention matrix in HBM * **Online softmax**: Computes softmax incrementally across blocks using running statistics * **Kernel fusion**: Fuses the attention computation into a single GPU kernel, eliminating intermediate read/writes * **Recomputation**: During backward pass, recomputes attention instead of storing it, trading compute for memory Flash Attention achieves **2-4x wall-clock speedup** over standard attention and enables sequences up to 64K tokens with linear IO complexity $O(n)$ instead of $O(n^2)$. ===== Code Example ===== import torch import torch.nn.functional as F def scaled_dot_product_attention(Q, K, V, mask=None): # Scaled dot-product attention with optional causal mask d_k = Q.size(-1) scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5) if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) weights = F.softmax(scores, dim=-1) return torch.matmul(weights, V) def grouped_query_attention(Q, K, V, n_heads, n_kv_heads): # GQA: multiple query heads share fewer KV heads batch, seq_len, d_model = Q.shape d_k = d_model // n_heads heads_per_group = n_heads // n_kv_heads Q = Q.view(batch, seq_len, n_heads, d_k).transpose(1, 2) K = K.view(batch, seq_len, n_kv_heads, d_k).transpose(1, 2) V = V.view(batch, seq_len, n_kv_heads, d_k).transpose(1, 2) # Expand KV heads to match query head count K = K.repeat_interleave(heads_per_group, dim=1) V = V.repeat_interleave(heads_per_group, dim=1) out = scaled_dot_product_attention(Q, K, V) return out.transpose(1, 2).contiguous().view(batch, seq_len, d_model) ===== References ===== * [[https://arxiv.org/abs/1706.03762|Vaswani et al. - Attention Is All You Need (2017)]] * [[https://arxiv.org/abs/2205.14135|Dao et al. - FlashAttention: Fast and Memory-Efficient Exact Attention (2022)]] * [[https://arxiv.org/abs/1911.02150|Shazeer - Fast Transformer Decoding: One Write-Head is All You Need (MQA, 2019)]] * [[https://arxiv.org/abs/2305.13245|Ainslie et al. - GQA: Training Generalized Multi-Query Transformer Models (2023)]] ===== See Also ===== * [[transformer_architecture|Transformer Architecture]] * [[inference_optimization|Inference Optimization]] * [[model_context_window|Model Context Window]]