====== Agent Distillation: Compressing Agent Systems into Smaller Models ====== Agent distillation is the process of transferring the full task-solving behavior of large LLM agent systems — including reasoning chains, tool usage patterns, and multi-step decision trajectories — into smaller, deployable models. Unlike general model distillation which uses flat token-level supervision to mimic outputs, agent distillation explicitly handles the **compositional structure** of agent trajectories, segmenting them into reasoning and action components for fine-grained alignment. This enables models as small as 0.5B-3.8B parameters to achieve performance competitive with models 4-10x larger. ===== Distinction from General Model Distillation ===== ^ Aspect ^ Agent Distillation ^ General Model Distillation ^ | Supervision level | Span-level ([REASON] vs [ACT] masks) | Token-level (flat KL divergence) | | Focus | Structured trajectories, reasoning-action fidelity | Overall probability distributions | | Data structure | Segmented reasoning chains + tool calls | Input-output pairs | | Key advantage | Preserves agent coherence (rationale leads to action) | Simpler, faster for non-agentic tasks | | Limitation | Requires trajectory parsing and segmentation | Ignores agent structure, degrades reasoning | The critical difference is that agent distillation models the **causal dependency** between reasoning and actions — a thought process leads to a specific tool call, which produces an observation that informs the next reasoning step. Flat distillation loses this structure. ===== Trajectory Distillation ===== Trajectory distillation trains small models to imitate complete agent trajectories from teacher agents: * **Data generation**: Large teacher model (e.g., GPT-4, Qwen-32B) executes tasks while logging full trajectories — every reasoning step, tool invocation, observation, and decision * **Trajectory segmentation**: Each trajectory is parsed into typed spans: [REASON], [ACT], [OBSERVE] * **Supervised training**: Small student model is fine-tuned on segmented trajectories using span-specific loss functions * **Curriculum training**: Tasks are ordered by complexity, starting with simple single-step actions and progressing to multi-step chains Key approaches include: * **AgentTrek** — Synthesizes web agent trajectories from publicly available tutorials through a three-stage pipeline: harvest tutorial texts, transform into task specs, execute in real environments with VLM verification. Achieves state-of-the-art on WebArena and ScreenSpot benchmarks. * **Agent Distillation (Kang et al. 2025)** — Introduces "first-thought prefix" prompting to improve teacher trajectory quality, and self-consistent action generation for student robustness. Models at 0.5B achieve performance competitive with 1.5B CoT-distilled models. ===== Reasoning Chain Distillation ===== Reasoning chain distillation specifically supervises the internal thought process: * Uses separate loss functions for reasoning spans vs action spans * **KL divergence on [REASON] tokens** aligns multi-step rationale between teacher and student * **Cross-entropy on [ACT] tokens** ensures correct tool calls * Masking prevents cross-span interference during gradient computation * Gradient projection techniques keep reasoning and action learning in separate subspaces ===== Structured Agent Distillation (SAD) ===== SAD is the first span-level distillation framework specifically for ReAct-style agents: * Segments ReAct trajectories into Thought/Action/Observation spans * Applies segment-specific losses with learnable weights * Uses curriculum training progressing from simple to complex trajectories * Achieves strong results on ALFWorld, HotPotQA-ReAct, and WebShop benchmarks * Outperforms token-level baselines by 8-15% on planning tasks ===== Code Example ===== # Agent distillation: trajectory segmentation and span-level training from dataclasses import dataclass from typing import List import torch import torch.nn.functional as F @dataclass class TrajectorySpan: type: str # "reason", "action", "observation" tokens: List[int] start_idx: int end_idx: int def segment_trajectory(trajectory_tokens, span_markers): """Parse agent trajectory into typed spans.""" spans = [] for marker in span_markers: spans.append(TrajectorySpan( type=marker["type"], tokens=trajectory_tokens[marker["start"]:marker["end"]], start_idx=marker["start"], end_idx=marker["end"], )) return spans def compute_span_loss(student_logits, teacher_logits, spans, labels): """Compute span-specific losses for agent distillation.""" total_loss = 0.0 for span in spans: s, e = span.start_idx, span.end_idx student_span = student_logits[:, s:e, :] if span.type == "reason": # KL divergence for reasoning alignment teacher_span = F.softmax(teacher_logits[:, s:e, :], dim=-1) total_loss += F.kl_div( F.log_softmax(student_span, dim=-1), teacher_span, reduction="batchmean" ) elif span.type == "action": # Cross-entropy for action correctness total_loss += F.cross_entropy( student_span.view(-1, student_span.size(-1)), labels[:, s:e].view(-1) ) # Observation spans are masked (not supervised) return total_loss ===== Phi-3 Mini: Reasoning Parity at 3.8B ===== Microsoft's Phi-3 Mini demonstrates that aggressive distillation combined with high-quality data curation can achieve reasoning parity with much larger models: * **3.8B parameters** — fraction of the size of GPT-4 or Llama-65B * Trained on curated synthetic data generated by larger teacher models * Chain-of-thought distillation from teacher reasoning traces * Achieves competitive scores on reasoning benchmarks (MMLU, GSM8K, HumanEval) * Key insight: **data quality matters more than model size** — carefully distilled reasoning traces from strong teachers enable small models to punch above their weight ===== Scale Efficiency Results ===== ^ Student Size ^ Method ^ Performance vs Next-Tier CoT ^ | 0.5B | Agent Distillation | Matches 1.5B CoT-distilled | | 1.5B | Agent Distillation | Matches 3B CoT-distilled | | 3B | Agent Distillation | Matches 7B CoT-distilled | | 3.8B (Phi-3 Mini) | Reasoning Distillation | Competitive with 7-13B models | The consistent pattern shows agent distillation provides roughly a **4x model size efficiency gain** over standard chain-of-thought distillation. ===== Loss Function ===== The combined agent distillation objective: \mathcal{L}_{AD} = \alpha \mathcal{L}_{KL}^{reason} + \beta \mathcal{L}_{CE}^{action} + \gamma \mathcal{L}_{curriculum} where \alpha, \beta are learnable span weights, \mathcal{L}_{KL}^{reason} is KL divergence on reasoning spans, \mathcal{L}_{CE}^{action} is cross-entropy on action spans, and \mathcal{L}_{curriculum} is a complexity-weighted scheduling term. ===== References ===== * [[https://arxiv.org/abs/2505.13820|"Structured Agent Distillation for ReAct Agents" (SAD)]] * [[https://arxiv.org/abs/2505.17612|Kang et al. "Distilling LLM Agent into Small Models with Retrieval and Code Tools" (arXiv:2505.17612)]] * [[https://arxiv.org/abs/2412.09605|Xu et al. "AgentTrek: Agent Trajectory Synthesis via Guiding Replay with Web Tutorials" (arXiv:2412.09605)]] * [[https://arxiv.org/abs/2310.05915|Chen et al. "FireAct: Toward Language Agent Fine-tuning"]] * [[https://arxiv.org/abs/2404.14219|Microsoft "Phi-3 Technical Report"]] * [[https://github.com/Nardien/agent-distillation|Agent Distillation GitHub Repository]] ===== See Also ===== * [[swe_agent|SWE-agent — Agent system whose trajectories can be distilled]] * [[gorilla|Gorilla — Specialized model training for API calling]] * [[self_play_agents|Self-Play Agents — Training agents through self-generated data]] * [[metagpt|MetaGPT — Multi-agent framework producing distillable trajectories]]