Table of Contents

Agent Distillation: Compressing Agent Systems into Smaller Models

Agent distillation is the process of transferring the full task-solving behavior of large LLM agent systems — including reasoning chains, tool usage patterns, and multi-step decision trajectories — into smaller, deployable models. Unlike general model distillation which uses flat token-level supervision to mimic outputs, agent distillation explicitly handles the compositional structure of agent trajectories, segmenting them into reasoning and action components for fine-grained alignment. This enables models as small as 0.5B-3.8B parameters to achieve performance competitive with models 4-10x larger.

Distinction from General Model Distillation

Aspect Agent Distillation General Model Distillation
Supervision level Span-level ([REASON] vs [ACT] masks) Token-level (flat KL divergence)
Focus Structured trajectories, reasoning-action fidelity Overall probability distributions
Data structure Segmented reasoning chains + tool calls Input-output pairs
Key advantage Preserves agent coherence (rationale leads to action) Simpler, faster for non-agentic tasks
Limitation Requires trajectory parsing and segmentation Ignores agent structure, degrades reasoning

The critical difference is that agent distillation models the causal dependency between reasoning and actions — a thought process leads to a specific tool call, which produces an observation that informs the next reasoning step. Flat distillation loses this structure.

Trajectory Distillation

Trajectory distillation trains small models to imitate complete agent trajectories from teacher agents:

Key approaches include:

Reasoning Chain Distillation

Reasoning chain distillation specifically supervises the internal thought process:

Structured Agent Distillation (SAD)

SAD is the first span-level distillation framework specifically for ReAct-style agents:

Code Example

# Agent distillation: trajectory segmentation and span-level training
from dataclasses import dataclass
from typing import List
import torch
import torch.nn.functional as F
 
@dataclass
class TrajectorySpan:
    type: str        # "reason", "action", "observation"
    tokens: List[int]
    start_idx: int
    end_idx: int
 
def segment_trajectory(trajectory_tokens, span_markers):
    """Parse agent trajectory into typed spans."""
    spans = []
    for marker in span_markers:
        spans.append(TrajectorySpan(
            type=marker["type"],
            tokens=trajectory_tokens[marker["start"]:marker["end"]],
            start_idx=marker["start"],
            end_idx=marker["end"],
        ))
    return spans
 
def compute_span_loss(student_logits, teacher_logits, spans, labels):
    """Compute span-specific losses for agent distillation."""
    total_loss = 0.0
    for span in spans:
        s, e = span.start_idx, span.end_idx
        student_span = student_logits[:, s:e, :]
        if span.type == "reason":
            # KL divergence for reasoning alignment
            teacher_span = F.softmax(teacher_logits[:, s:e, :], dim=-1)
            total_loss += F.kl_div(
                F.log_softmax(student_span, dim=-1),
                teacher_span, reduction="batchmean"
            )
        elif span.type == "action":
            # Cross-entropy for action correctness
            total_loss += F.cross_entropy(
                student_span.view(-1, student_span.size(-1)),
                labels[:, s:e].view(-1)
            )
        # Observation spans are masked (not supervised)
    return total_loss

Phi-3 Mini: Reasoning Parity at 3.8B

Microsoft's Phi-3 Mini demonstrates that aggressive distillation combined with high-quality data curation can achieve reasoning parity with much larger models:

Scale Efficiency Results

Student Size Method Performance vs Next-Tier CoT
0.5B Agent Distillation Matches 1.5B CoT-distilled
1.5B Agent Distillation Matches 3B CoT-distilled
3B Agent Distillation Matches 7B CoT-distilled
3.8B (Phi-3 Mini) Reasoning Distillation Competitive with 7-13B models

The consistent pattern shows agent distillation provides roughly a 4x model size efficiency gain over standard chain-of-thought distillation.

Loss Function

The combined agent distillation objective:

<latex>\mathcal{L}_{AD} = \alpha \mathcal{L}_{KL}^{reason} + \beta \mathcal{L}_{CE}^{action} + \gamma \mathcal{L}_{curriculum}</latex>

where <latex>\alpha, \beta</latex> are learnable span weights, <latex>\mathcal{L}_{KL}^{reason}</latex> is KL divergence on reasoning spans, <latex>\mathcal{L}_{CE}^{action}</latex> is cross-entropy on action spans, and <latex>\mathcal{L}_{curriculum}</latex> is a complexity-weighted scheduling term.

References

See Also