Table of Contents

Algorithm Distillation

Algorithm Distillation (AD) is a technique that applies knowledge distillation to reinforcement learning by training a transformer model on the complete learning histories of RL agents, enabling it to perform in-context reinforcement learning at inference time. Introduced by Laskin et al., 2022 at DeepMind in “In-Context Reinforcement Learning with Algorithm Distillation,” AD encodes the improvement process of a source RL algorithm into the weights of a causal transformer, so the model learns not just a policy but the algorithm for improving a policy.1) This allows the model to explore, exploit, and improve its performance over an episode purely through in-context learning, without any explicit weight updates during deployment.

In-Context Reinforcement Learning

The central idea of Algorithm Distillation is in-context reinforcement learning (ICRL): a causal transformer processes sequences of observations, actions, and rewards from RL episodes as context, and autoregressively predicts the next action conditioned on the full interaction history up to that point.

Formally, the model maximizes the likelihood of the next action given the full history:

$$\mathcal{L}(\theta) = \mathbb{E}_{\tau}\left[\sum_{t=1}^{T} \log p_\theta(a_t \mid \tau_{<t})\right]$$

where $\tau_{<t} = (o_0, a_0, r_0, o_1, a_1, r_1, \ldots, o_{t-1}, a_{t-1}, r_{t-1}, o_t)$ is the sequence of all prior observations, actions, and rewards. Through self-attention over this growing context, the model can reference its own learning progress across episodes, identifying which actions led to high rewards and adjusting its behavior accordingly, all without any gradient updates at inference time.

This is fundamentally different from standard imitation learning, which trains on expert demonstrations of a fixed policy. AD trains on the process of learning, capturing how a policy evolves from random exploration to competent behavior.

Training on Learning Histories

The training procedure for AD consists of two phases:

Phase 1 - Generate Learning Histories: A source RL algorithm (e.g., A3C, DQN, UCB) is run across a distribution of tasks, producing complete learning histories that capture the agent's trajectory from initial random behavior through progressive improvement. These histories include the full sequence of observations, actions, and rewards across all training episodes.

Phase 2 - Distill into Transformer: The causal transformer is trained via autoregressive behavioral cloning on these learning histories. The key training signal is action prediction: given the history up to time $t$, predict the action $a_t$.

A critical finding is that distilling only expert behavior fails to enable ICRL.2) The transformer must be exposed to the full learning trajectory, including early suboptimal behavior, to learn the improvement algorithm itself. The diversity of the learning process, from exploration through exploitation, is what teaches the model how to adapt.

Architecture and Implementation

AD uses a standard causal transformer architecture (similar to GPT) with the following design choices:

The transformer effectively learns a meta-policy $\pi_\theta(a_t | \tau_{<t})$ that maps interaction histories to actions, where the mapping itself implements an RL-like improvement process.

Comparison with Meta-Learning and Traditional RL

Approach Adaptation Mechanism Weight Updates at Test Time Key Advantage
Traditional RL (A3C, DQN) Gradient-based policy optimization Yes (many updates) Well-understood convergence
RL-squared ($\text{RL}^2$)3) RNN hidden state as learned RL algorithm Yes (meta-gradient updates) Adapts to new tasks
Algorithm Distillation In-context learning via attention No (fixed weights) Zero-shot adaptation, better data efficiency
Expert Imitation Learning Behavioral cloning on expert data No Simple but no improvement capability
Decision Transformer Sequence modeling on offline trajectories No Frames RL as sequence prediction4)

Key results from Laskin et al., 2022:

Experiments covered sparse-reward navigation tasks (goal-finding with delayed rewards), combinatorial domains (key-door puzzles requiring sequential subgoal completion), and pixel-based environments requiring visual understanding and sequential reasoning.

AD demonstrates that the process of learning can itself be distilled into a neural network through sequence modeling, opening a path toward foundation models that can adapt to new environments through pure in-context learning rather than explicit optimization.

Code Example: In-Context RL Trajectory

import random
 
 
class BanditEnv:
    """Multi-armed bandit environment with hidden reward probabilities."""
 
    def __init__(self, n_arms: int = 4):
        self.probs = [random.random() for _ in range(n_arms)]
 
    def pull(self, arm: int) -> float:
        return 1.0 if random.random() < self.probsarm else 0.0
 
 
def generate_learning_history(env: BanditEnv, n_episodes: int = 50) -> listdict:
    """Generate an RL learning trajectory using epsilon-greedy exploration."""
    n_arms = len(env.probs)
    counts = [0] * n_arms
    values = [0.0] * n_arms
    history = []
 
    for episode in range(n_episodes):
        epsilon = max(0.05, 1.0 - episode / (n_episodes * 0.6))
        if random.random() < epsilon:
            action = random.randint(0, n_arms - 1)
        else:
            action = max(range(n_arms), key=lambda a: valuesa)
 
        reward = env.pull(action)
        countsaction += 1
        valuesaction += (reward - valuesaction) / countsaction
        history.append({"episode": episode, "action": action, "reward": reward,
                        "epsilon": round(epsilon, 3), "values": [round(v, 3) for v in values]})
    return history
 
 
def format_as_training_sequence(history: listdict) -> str:
    """Format trajectory as a token sequence for transformer training."""
    tokens = []
    for step in history:
        tokens.append(f"<obs>ep={step['episode']}</obs>"
                      f"<act>{step['action']}</act>"
                      f"<rew>{step['reward']:.0f}</rew>")
    return " ".join(tokens)
 
 
def show_learning_progress(history: listdict, window: int = 10):
    """Show how the agent's reward improves over the trajectory."""
    for i in range(0, len(history), window):
        chunk = history[i:i + window]
        avg_reward = sum(s["reward"] for s in chunk) / len(chunk)
        actions = [s["action"] for s in chunk]
        print(f"Episodes {i:3d}-{i+window-1:3d}: avg_reward={avg_reward:.2f} actions={actions}")
 
 
env = BanditEnv(n_arms=4)
print(f"True arm probabilities: {[round(p, 3) for p in env.probs]}")
print(f"Optimal arm: {max(range(len(env.probs)), key=lambda a: env.probsa)}\n")
 
history = generate_learning_history(env, n_episodes=50)
show_learning_progress(history)
 
sequence = format_as_training_sequence(history[:5])
print(f"\nTraining sequence (first 5 steps):\n{sequence}")

See Also

References

1) , 2)
https://arxiv.org/abs/2206.11848|Laskin et al., 2022 - In-Context Reinforcement Learning with Algorithm Distillation