====== KV Cache Compression ====== As LLMs process longer contexts, key-value (KV) cache memory becomes the dominant bottleneck during inference. KV-Distill introduces a principled distillation approach that compresses KV caches nearly losslessly, enabling dramatically extended effective context windows without the degradation typical of pruning or quantization. ===== The KV Cache Bottleneck ===== During autoregressive generation, Transformer models cache the key and value projections from all prior tokens. This KV cache: * Grows linearly with sequence length: $\text{Memory} = O(L \times d \times n_{\text{layers}} \times n_{\text{heads}})$ * Dominates GPU memory during long-context inference * Limits the effective context window regardless of the model's theoretical maximum * Creates a trade-off between context length and batch size ===== KV-Distill: Student-Teacher Cache Compression ===== **KV-Distill** (arXiv:2503.10337) treats KV cache compression as a distillation problem, training a parameter-efficient adapter that learns to sub-select the most important tokens from the cache. === Core Approach === * **Teacher**: The uncompressed KV cache from the original model * **Student**: A compressed KV cache produced by an adapted model $\text{LM}_\theta$ * **Objective**: Token-level KL divergence loss matching the output distributions: $\mathcal{L} = D_{\text{KL}}\left( P_{\text{LM}}(\cdot | \text{KV}_{\text{full}}) \| P_{\text{LM}}(\cdot | \text{KV}_{\text{compressed}}) \right)$ === Compression Pipeline === - The adapted model $\text{LM}_\theta$ encodes the input context into a full KV cache - Important tokens are identified and sub-selected via learned criteria - The unmodified base LM conditions on this compressed KV cache for generation - KL loss ensures output equivalence between compressed and full cache === Key Design Decisions === * **Question-independent**: Compression happens before any query is seen, making it reusable across multiple downstream tasks * **Parameter-efficient**: Trained as a lightweight adapter on the pretrained model * **Adaptive compression**: Learns which tokens to keep rather than using fixed-ratio rules ===== Comparison to Alternative Approaches ===== === Pruning === * Removes tokens, heads, or layers from the KV cache * Destroys direct correspondence between retained and original representations * Risks catastrophic information loss, especially for extractive tasks (needle-in-haystack) * KV-Distill's learned sub-selection preserves pre-trained capabilities via output matching === Quantization === * Reduces numerical precision of KV cache entries (e.g., FP16 to INT4) * Reduces memory per token but does not reduce sequence length * Scales poorly for very long contexts -- the number of cached tokens remains the same * KV-Distill achieves //length reduction//, which is complementary to quantization === Training-Free Methods (e.g., H2O) === * Use heuristics (attention scores, recency) to evict tokens at inference time * No training overhead but suboptimal selection criteria * KV-Distill outperforms H2O particularly on worst-case extractive tasks ===== Benchmark Results ===== ^ Metric ^ KV-Distill ^ Baselines ^ | Extractive tasks (needle-in-haystack) | Near-perfect at 90% reduction | Significant degradation | | Long-context QA | Approaches uncompressed performance | Moderate loss | | Summarization | Approaches uncompressed performance | Moderate loss | | Domain-specific (fine-tuned) | Up to 99% compression (100x ratio) | N/A | ===== Extending Effective Context Windows ===== By reducing KV cache length, KV-Distill enables: * **Longer sequences**: Models with 128K token limits can effectively process 1M+ tokens by compressing early context * **Larger batch sizes**: Freed GPU memory allows more concurrent requests * **Reduced latency**: Shorter caches mean faster attention computation: $O(L_{\text{compressed}} \times d)$ instead of $O(L_{\text{full}} \times d)$ * **Stacking with quantization**: Length reduction from KV-Distill combines with precision reduction from quantization for multiplicative savings ===== Code Example ===== import torch class KVDistillCompressor: """Simplified KV-Distill cache compression.""" def __init__(self, base_model, adapter, compression_ratio=0.1): self.base_model = base_model self.adapter = adapter # Parameter-efficient adapter self.keep_ratio = compression_ratio def compress_kv_cache(self, input_ids): """Compress KV cache via learned token selection.""" # Generate full KV cache with adapted model with torch.no_grad(): outputs = self.adapter(input_ids, use_cache=True) full_kv = outputs.past_key_values # Score token importance (learned during distillation) importance = self.adapter.score_tokens(full_kv) n_keep = int(len(importance) * self.keep_ratio) top_indices = torch.topk(importance, n_keep).indices.sort().values # Sub-select important tokens from KV cache compressed_kv = tuple( (k[:, :, top_indices, :], v[:, :, top_indices, :]) for k, v in full_kv ) return compressed_kv def generate(self, input_ids, compressed_kv, max_new_tokens=512): """Generate using compressed KV cache with unmodified base model.""" return self.base_model.generate( input_ids[:, -1:], # Only last token needed past_key_values=compressed_kv, max_new_tokens=max_new_tokens ) ===== Compression Methods Landscape ===== ^ Method ^ Reduces Length ^ Reduces Precision ^ Training Required ^ Composable ^ | KV-Distill | Yes | No | Yes (adapter) | Yes | | Pruning (H2O) | Yes | No | No | Yes | | Quantization (INT4/INT8) | No | Yes | Optional | Yes | | KV-Distill + Quantization | Yes | Yes | Yes | -- | ===== References ===== * [[https://arxiv.org/abs/2503.10337|KV-Distill: Nearly Lossless Learnable Context Compression for LLMs (arXiv:2503.10337)]] ===== See Also ===== * [[spreading_activation_memory|Spreading Activation Memory]] -- alternative approaches to managing long-context information * [[agentic_uncertainty|Agentic Uncertainty]] -- how context compression interacts with confidence degradation