Core Concepts
Reasoning
Memory & Retrieval
Agent Types
Design Patterns
Training & Alignment
Frameworks
Tools
Safety & Security
Evaluation
Meta
Core Concepts
Reasoning
Memory & Retrieval
Agent Types
Design Patterns
Training & Alignment
Frameworks
Tools
Safety & Security
Evaluation
Meta
As LLMs process longer contexts, key-value (KV) cache memory becomes the dominant bottleneck during inference. KV-Distill introduces a principled distillation approach that compresses KV caches nearly losslessly, enabling dramatically extended effective context windows without the degradation typical of pruning or quantization.
During autoregressive generation, Transformer models cache the key and value projections from all prior tokens. This KV cache:
KV-Distill (arXiv:2503.10337) treats KV cache compression as a distillation problem, training a parameter-efficient adapter that learns to sub-select the most important tokens from the cache.
$\mathcal{L} = D_{\text{KL}}\left( P_{\text{LM}}(\cdot | \text{KV}_{\text{full}}) \| P_{\text{LM}}(\cdot | \text{KV}_{\text{compressed}}) \right)$
| Metric | KV-Distill | Baselines |
|---|---|---|
| Extractive tasks (needle-in-haystack) | Near-perfect at 90% reduction | Significant degradation |
| Long-context QA | Approaches uncompressed performance | Moderate loss |
| Summarization | Approaches uncompressed performance | Moderate loss |
| Domain-specific (fine-tuned) | Up to 99% compression (100x ratio) | N/A |
By reducing KV cache length, KV-Distill enables:
import torch class KVDistillCompressor: """Simplified KV-Distill cache compression.""" def __init__(self, base_model, adapter, compression_ratio=0.1): self.base_model = base_model self.adapter = adapter # Parameter-efficient adapter self.keep_ratio = compression_ratio def compress_kv_cache(self, input_ids): """Compress KV cache via learned token selection.""" # Generate full KV cache with adapted model with torch.no_grad(): outputs = self.adapter(input_ids, use_cache=True) full_kv = outputs.past_key_values # Score token importance (learned during distillation) importance = self.adapter.score_tokens(full_kv) n_keep = int(len(importance) * self.keep_ratio) top_indices = torch.topk(importance, n_keep).indices.sort().values # Sub-select important tokens from KV cache compressed_kv = tuple( (k[:, :, top_indices, :], v[:, :, top_indices, :]) for k, v in full_kv ) return compressed_kv def generate(self, input_ids, compressed_kv, max_new_tokens=512): """Generate using compressed KV cache with unmodified base model.""" return self.base_model.generate( input_ids[:, -1:], # Only last token needed past_key_values=compressed_kv, max_new_tokens=max_new_tokens )
| Method | Reduces Length | Reduces Precision | Training Required | Composable |
|---|---|---|---|---|
| KV-Distill | Yes | No | Yes (adapter) | Yes |
| Pruning (H2O) | Yes | No | No | Yes |
| Quantization (INT4/INT8) | No | Yes | Optional | Yes |
| KV-Distill + Quantization | Yes | Yes | Yes | – |