Table of Contents

RAP: Reasoning via Planning with LLM as World Model

RAP (Reasoning via Planning) is a framework introduced by Hao et al. (2023) that repurposes a large language model to serve as both a world model and a reasoning agent, guided by Monte Carlo Tree Search (MCTS) to explore high-reward reasoning paths.1) With 925 citations, it demonstrates that deliberate planning significantly outperforms autoregressive chain-of-thought reasoning2) across diverse tasks by enabling strategic exploration, lookahead, and backtracking.

arXiv:2305.14992

Dual Role of the LLM

RAP assigns two complementary roles to the same LLM via task-specific prompting:3)

World Model

The LLM predicts the next state $s_{t+1}$ given current state $s_t$ and action $a_t$:

$$s_{t+1} = \text{LLM}_{\text{world}}(s_t, a_t)$$

This enables the agent to simulate future outcomes without actually executing actions, providing lookahead capability analogous to a mental model.

Reasoning Agent

The LLM generates candidate actions at each state:

$$a_t \sim \pi_{\text{LLM}}(a \mid s_t)$$

Actions and states are flexibly defined per task – for plan generation, states are world configurations and actions are operations; for math, states are partial solutions and actions are reasoning steps.

Monte Carlo Tree Search for Reasoning

MCTS builds a search tree over the reasoning space through four iterative phases:4)

  1. Selection: Traverse the tree from root using UCB1 to balance exploration and exploitation:

$$\text{UCB1}(s, a) = Q(s, a) + c \sqrt{\frac{\ln N(s)}{N(s, a)}}$$

where $Q(s,a)$ is the estimated action value, $N(s)$ is the visit count for state $s$, and $c$ is the exploration constant.

  1. Expansion: The agent-LLM proposes new actions at leaf nodes
  2. Simulation: The world-model-LLM rolls out future states to compute rewards
  3. Backpropagation: Update value estimates along the traversed path

The final answer is selected from the highest-reward complete reasoning trace, optionally aggregated via majority vote across multiple MCTS runs5).

System Architecture

graph TD A[Input Problem] --> B[Initialize Root State] B --> C[MCTS Iteration] C --> D[Selection via UCB1] D --> E[Leaf Node] E --> F{Fully Expanded?} F -- No --> G[Expansion: Agent-LLM Proposes Actions] G --> H[New Child Nodes] F -- Yes --> I[Select Best Child] H --> J[Simulation: World-Model Rollout] I --> J J --> K[Compute Reward] K --> L[Backpropagation: Update Q-values] L --> M{Budget Exhausted?} M -- No --> C M -- Yes --> N[Extract Best Reasoning Trace] N --> O[Optional: Majority Vote Aggregation] O --> P[Final Answer]

Code Example

# Simplified RAP with MCTS for LLM reasoning
import math
 
class RAPNode:
    def __init__(self, state, parent=None):
        self.state = state
        self.parent = parent
        self.children = {}
        self.visits = 0
        self.total_reward = 0.0
 
class RAP:
    def __init__(self, llm, exploration_weight=1.41, max_depth=10):
        self.llm = llm
        self.c = exploration_weight
        self.max_depth = max_depth
 
    def ucb1(self, node, child):
        if child.visits == 0:
            return float("inf")
        exploit = child.total_reward / child.visits
        explore = self.c * math.sqrt(math.log(node.visits) / child.visits)
        return exploit + explore
 
    def select(self, node):
        while node.children:
            node = max(node.children.values(), key=lambda c: self.ucb1(node, c))
        return node
 
    def expand(self, node):
        actions = self.llm.generate_actions(node.state)
        for action in actions:
            next_state = self.llm.world_model(node.state, action)
            node.children[action] = RAPNode(next_state, parent=node)
 
    def simulate(self, node):
        state = node.state
        for _ in range(self.max_depth):
            action = self.llm.generate_actions(state)[0]
            state = self.llm.world_model(state, action)
            if is_terminal(state):
                break
        return self.llm.compute_reward(state)
 
    def search(self, problem, num_iterations=100):
        root = RAPNode(state=problem)
        for _ in range(num_iterations):
            leaf = self.select(root)
            self.expand(leaf)
            child = list(leaf.children.values())[0]
            reward = self.simulate(child)
            self._backpropagate(child, reward)
        return self._best_trace(root)

Key Results

Benchmark Task RAP Improvement
Blocksworld Embodied plan generation Substantial gains over CoT
GSM8K Grade-school math Higher accuracy via trace ensembling
Logical Reasoning Hypothesis verification Outperforms with detailed proofs

See Also

References