====== Agent Distillation: Compressing Agent Systems into Smaller Models ======
Agent distillation is the process of transferring the full task-solving behavior of large LLM agent systems — including reasoning chains, tool usage patterns, and multi-step decision trajectories — into smaller, deployable models. Unlike general model distillation which uses flat token-level supervision to mimic outputs, agent distillation explicitly handles the **compositional structure** of agent trajectories, segmenting them into reasoning and action components for fine-grained alignment. This enables models as small as 0.5B-3.8B parameters to achieve performance competitive with models 4-10x larger.
===== Distinction from General Model Distillation =====
^ Aspect ^ Agent Distillation ^ General Model Distillation ^
| Supervision level | Span-level ([REASON] vs [ACT] masks) | Token-level (flat KL divergence) |
| Focus | Structured trajectories, reasoning-action fidelity | Overall probability distributions |
| Data structure | Segmented reasoning chains + tool calls | Input-output pairs |
| Key advantage | Preserves agent coherence (rationale leads to action) | Simpler, faster for non-agentic tasks |
| Limitation | Requires trajectory parsing and segmentation | Ignores agent structure, degrades reasoning |
The critical difference is that agent distillation models the **causal dependency** between reasoning and actions — a thought process leads to a specific tool call, which produces an observation that informs the next reasoning step. Flat distillation loses this structure.
===== Trajectory Distillation =====
Trajectory distillation trains small models to imitate complete agent trajectories from teacher agents:
* **Data generation**: Large teacher model (e.g., GPT-4, Qwen-32B) executes tasks while logging full trajectories — every reasoning step, tool invocation, observation, and decision
* **Trajectory segmentation**: Each trajectory is parsed into typed spans: [REASON], [ACT], [OBSERVE]
* **Supervised training**: Small student model is fine-tuned on segmented trajectories using span-specific loss functions
* **Curriculum training**: Tasks are ordered by complexity, starting with simple single-step actions and progressing to multi-step chains
Key approaches include:
* **AgentTrek** — Synthesizes web agent trajectories from publicly available tutorials through a three-stage pipeline: harvest tutorial texts, transform into task specs, execute in real environments with VLM verification. Achieves state-of-the-art on WebArena and ScreenSpot benchmarks.
* **Agent Distillation (Kang et al. 2025)** — Introduces "first-thought prefix" prompting to improve teacher trajectory quality, and self-consistent action generation for student robustness. Models at 0.5B achieve performance competitive with 1.5B CoT-distilled models.
===== Reasoning Chain Distillation =====
Reasoning chain distillation specifically supervises the internal thought process:
* Uses separate loss functions for reasoning spans vs action spans
* **KL divergence on [REASON] tokens** aligns multi-step rationale between teacher and student
* **Cross-entropy on [ACT] tokens** ensures correct tool calls
* Masking prevents cross-span interference during gradient computation
* Gradient projection techniques keep reasoning and action learning in separate subspaces
===== Structured Agent Distillation (SAD) =====
SAD is the first span-level distillation framework specifically for ReAct-style agents:
* Segments ReAct trajectories into Thought/Action/Observation spans
* Applies segment-specific losses with learnable weights
* Uses curriculum training progressing from simple to complex trajectories
* Achieves strong results on ALFWorld, HotPotQA-ReAct, and WebShop benchmarks
* Outperforms token-level baselines by 8-15% on planning tasks
===== Code Example =====
# Agent distillation: trajectory segmentation and span-level training
from dataclasses import dataclass
from typing import List
import torch
import torch.nn.functional as F
@dataclass
class TrajectorySpan:
type: str # "reason", "action", "observation"
tokens: List[int]
start_idx: int
end_idx: int
def segment_trajectory(trajectory_tokens, span_markers):
"""Parse agent trajectory into typed spans."""
spans = []
for marker in span_markers:
spans.append(TrajectorySpan(
type=marker["type"],
tokens=trajectory_tokens[marker["start"]:marker["end"]],
start_idx=marker["start"],
end_idx=marker["end"],
))
return spans
def compute_span_loss(student_logits, teacher_logits, spans, labels):
"""Compute span-specific losses for agent distillation."""
total_loss = 0.0
for span in spans:
s, e = span.start_idx, span.end_idx
student_span = student_logits[:, s:e, :]
if span.type == "reason":
# KL divergence for reasoning alignment
teacher_span = F.softmax(teacher_logits[:, s:e, :], dim=-1)
total_loss += F.kl_div(
F.log_softmax(student_span, dim=-1),
teacher_span, reduction="batchmean"
)
elif span.type == "action":
# Cross-entropy for action correctness
total_loss += F.cross_entropy(
student_span.view(-1, student_span.size(-1)),
labels[:, s:e].view(-1)
)
# Observation spans are masked (not supervised)
return total_loss
===== Phi-3 Mini: Reasoning Parity at 3.8B =====
Microsoft's Phi-3 Mini demonstrates that aggressive distillation combined with high-quality data curation can achieve reasoning parity with much larger models:
* **3.8B parameters** — fraction of the size of GPT-4 or Llama-65B
* Trained on curated synthetic data generated by larger teacher models
* Chain-of-thought distillation from teacher reasoning traces
* Achieves competitive scores on reasoning benchmarks (MMLU, GSM8K, HumanEval)
* Key insight: **data quality matters more than model size** — carefully distilled reasoning traces from strong teachers enable small models to punch above their weight
===== Scale Efficiency Results =====
^ Student Size ^ Method ^ Performance vs Next-Tier CoT ^
| 0.5B | Agent Distillation | Matches 1.5B CoT-distilled |
| 1.5B | Agent Distillation | Matches 3B CoT-distilled |
| 3B | Agent Distillation | Matches 7B CoT-distilled |
| 3.8B (Phi-3 Mini) | Reasoning Distillation | Competitive with 7-13B models |
The consistent pattern shows agent distillation provides roughly a **4x model size efficiency gain** over standard chain-of-thought distillation.
===== Loss Function =====
The combined agent distillation objective:
\mathcal{L}_{AD} = \alpha \mathcal{L}_{KL}^{reason} + \beta \mathcal{L}_{CE}^{action} + \gamma \mathcal{L}_{curriculum}
where \alpha, \beta are learnable span weights, \mathcal{L}_{KL}^{reason} is KL divergence on reasoning spans, \mathcal{L}_{CE}^{action} is cross-entropy on action spans, and \mathcal{L}_{curriculum} is a complexity-weighted scheduling term.
===== References =====
* [[https://arxiv.org/abs/2505.13820|"Structured Agent Distillation for ReAct Agents" (SAD)]]
* [[https://arxiv.org/abs/2505.17612|Kang et al. "Distilling LLM Agent into Small Models with Retrieval and Code Tools" (arXiv:2505.17612)]]
* [[https://arxiv.org/abs/2412.09605|Xu et al. "AgentTrek: Agent Trajectory Synthesis via Guiding Replay with Web Tutorials" (arXiv:2412.09605)]]
* [[https://arxiv.org/abs/2310.05915|Chen et al. "FireAct: Toward Language Agent Fine-tuning"]]
* [[https://arxiv.org/abs/2404.14219|Microsoft "Phi-3 Technical Report"]]
* [[https://github.com/Nardien/agent-distillation|Agent Distillation GitHub Repository]]
===== See Also =====
* [[swe_agent|SWE-agent — Agent system whose trajectories can be distilled]]
* [[gorilla|Gorilla — Specialized model training for API calling]]
* [[self_play_agents|Self-Play Agents — Training agents through self-generated data]]
* [[metagpt|MetaGPT — Multi-agent framework producing distillable trajectories]]