From 1a5f31d6317899fd27fd5f3745721c9fbe4b2e01 Mon Sep 17 00:00:00 2001 From: teknium1 Date: Fri, 13 Mar 2026 02:45:08 -0700 Subject: [PATCH] feat: add agentic on-policy distillation (OPD) environment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit First Atropos environment to populate distill_token_ids / distill_logprobs on ScoredDataGroup, enabling on-policy distillation training. Based on OpenClaw-RL (Princeton, arXiv:2603.10165): - Extracts hindsight hints from next-state signals (tool results, errors) - Uses LLM judge with majority voting for hint extraction - Scores student tokens under hint-enhanced distribution via get_logprobs - Packages teacher's top-K predictions as distillation targets Architecture: - AgenticOPDEnv extends HermesAgentBaseEnv - Overrides collect_trajectories to add OPD pipeline after standard rollouts - Uses Atropos's built-in get_logprobs (VLLM prompt_logprobs) for teacher scoring - No external servers needed — same VLLM backend handles both rollouts and scoring Task: Coding problems with test verification (8 built-in tasks, HF dataset support) Reward: correctness (0.7) + efficiency (0.15) + tool usage (0.15) OPD: Per-turn hint extraction → enhanced prompt → teacher top-K logprobs Configurable: opd_enabled, distill_topk, prm_votes, hint truncation length Metrics: opd/mean_hints_per_rollout, opd/mean_turns_scored, opd/hint_rate --- environments/agentic_opd_env.py | 1213 +++++++++++++++++++++++++++++++ 1 file changed, 1213 insertions(+) create mode 100644 environments/agentic_opd_env.py diff --git a/environments/agentic_opd_env.py b/environments/agentic_opd_env.py new file mode 100644 index 000000000..b96271237 --- /dev/null +++ b/environments/agentic_opd_env.py @@ -0,0 +1,1213 @@ +""" +AgenticOPDEnv — On-Policy Distillation for Agentic Tool-Calling Tasks +===================================================================== + +First Atropos environment to populate the distill_token_ids / distill_logprobs +fields on ScoredDataGroup, enabling on-policy distillation (OPD) training. + +Key idea (from OpenClaw-RL, Princeton 2026): + Every time an agent receives a next-state signal (tool result, error trace, + test verdict), that signal contains hindsight information about how the + agent's PREVIOUS response could have been better. This environment: + + 1. Runs standard agentic rollouts (tool-calling agent loop) + 2. Walks the conversation to find (assistant_turn, next_state) pairs + 3. Uses an LLM judge to extract "hints" from next-state signals + 4. Builds an enhanced prompt (original context + hint) + 5. Scores the student's response tokens under the enhanced distribution + using VLLM's prompt_logprobs (via Atropos's get_logprobs API) + 6. Packages the teacher's top-K predictions as distill_token_ids / + distill_logprobs on the ScoredDataGroup + +The trainer then computes per-token advantages: + A_t = teacher_logprob(token_t) - student_logprob(token_t) + Positive → teacher approves this token (upweight) + Negative → teacher disapproves (downweight) + +This gives dense, token-level training signal from every tool interaction, +instead of just a scalar reward at the end of the trajectory. + +Task: Coding tasks with test verification (rich next-state signals from +test results, error messages, terminal output). Falls back to built-in +coding problems if no HuggingFace dataset is configured. + +Requirements: + - VLLM backend (server_type: vllm) — needed for prompt logprob scoring + - Phase 2 mode (ManagedServer) — needed for token-level tracking + +Usage: + # Process mode (offline data generation with OPD) + python environments/agentic_opd_env.py process \\ + --env.total_steps 10 --env.group_size 2 \\ + --env.data_path_to_save_groups output.jsonl \\ + --openai.base_url http://localhost:8000/v1 \\ + --openai.model_name Qwen/Qwen3-4B + + # Serve mode (connected to Atropos trainer) + python environments/agentic_opd_env.py serve \\ + --openai.base_url http://localhost:8000/v1 \\ + --openai.model_name Qwen/Qwen3-4B + + # Evaluate mode + python environments/agentic_opd_env.py evaluate \\ + --env.eval_size 10 \\ + --openai.base_url http://localhost:8000/v1 \\ + --openai.model_name Qwen/Qwen3-4B + +Reference: Wang et al., "OpenClaw-RL: Train Any Agent Simply by Talking" + arXiv:2603.10165, March 2026 +""" + +from __future__ import annotations + +import asyncio +import copy +import json +import logging +import os +import random +import re +import sys +import time +import uuid +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +from pydantic import Field + +# Ensure hermes-agent root is on path +_repo_root = Path(__file__).resolve().parent.parent +if str(_repo_root) not in sys.path: + sys.path.insert(0, str(_repo_root)) + +from atroposlib.envs.base import ScoredDataGroup, ScoredDataItem +from atroposlib.envs.server_handling.server_manager import APIServerConfig +from atroposlib.type_definitions import Item + +from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig +from environments.agent_loop import AgentResult, HermesAgentLoop +from environments.tool_context import ToolContext + +logger = logging.getLogger(__name__) + + +# ═══════════════════════════════════════════════════════════════════════ +# Built-in coding tasks (fallback when no HF dataset is configured) +# ═══════════════════════════════════════════════════════════════════════ + +BUILTIN_CODING_TASKS = [ + { + "task": "Write a Python function `fizzbuzz(n)` that returns a list of strings from 1 to n. " + "For multiples of 3 return 'Fizz', for multiples of 5 return 'Buzz', " + "for multiples of both return 'FizzBuzz', otherwise the number as a string.", + "test_code": ( + "from solution import fizzbuzz\n" + "assert fizzbuzz(15) == ['1','2','Fizz','4','Buzz','Fizz','7','8','Fizz','Buzz','11','Fizz','13','14','FizzBuzz']\n" + "assert fizzbuzz(1) == ['1']\n" + "assert fizzbuzz(0) == []\n" + "print('All tests passed!')\n" + ), + "difficulty": "easy", + }, + { + "task": "Write a Python function `is_palindrome(s)` that checks if a string is a palindrome, " + "ignoring case and non-alphanumeric characters. Return True or False.", + "test_code": ( + "from solution import is_palindrome\n" + "assert is_palindrome('A man, a plan, a canal: Panama') == True\n" + "assert is_palindrome('race a car') == False\n" + "assert is_palindrome('') == True\n" + "assert is_palindrome('Was it a car or a cat I saw?') == True\n" + "print('All tests passed!')\n" + ), + "difficulty": "easy", + }, + { + "task": "Write a Python function `two_sum(nums, target)` that returns the indices of the two " + "numbers in `nums` that add up to `target`. Assume exactly one solution exists. " + "Return a list of two indices [i, j] where i < j.", + "test_code": ( + "from solution import two_sum\n" + "assert two_sum([2, 7, 11, 15], 9) == [0, 1]\n" + "assert two_sum([3, 2, 4], 6) == [1, 2]\n" + "assert two_sum([3, 3], 6) == [0, 1]\n" + "print('All tests passed!')\n" + ), + "difficulty": "easy", + }, + { + "task": "Write a Python function `flatten(lst)` that takes an arbitrarily nested list and " + "returns a flat list of all elements. For example, flatten([1, [2, [3, 4], 5]]) " + "should return [1, 2, 3, 4, 5].", + "test_code": ( + "from solution import flatten\n" + "assert flatten([1, [2, [3, 4], 5]]) == [1, 2, 3, 4, 5]\n" + "assert flatten([]) == []\n" + "assert flatten([1, 2, 3]) == [1, 2, 3]\n" + "assert flatten([[[[1]]]]) == [1]\n" + "assert flatten([1, [2], [[3]], [[[4]]]]) == [1, 2, 3, 4]\n" + "print('All tests passed!')\n" + ), + "difficulty": "medium", + }, + { + "task": "Write a Python function `longest_common_prefix(strs)` that finds the longest " + "common prefix string amongst a list of strings. If there is no common prefix, " + "return an empty string.", + "test_code": ( + "from solution import longest_common_prefix\n" + "assert longest_common_prefix(['flower', 'flow', 'flight']) == 'fl'\n" + "assert longest_common_prefix(['dog', 'racecar', 'car']) == ''\n" + "assert longest_common_prefix(['interspecies', 'interstellar', 'interstate']) == 'inters'\n" + "assert longest_common_prefix(['a']) == 'a'\n" + "assert longest_common_prefix([]) == ''\n" + "print('All tests passed!')\n" + ), + "difficulty": "easy", + }, + { + "task": "Write a Python function `group_anagrams(strs)` that groups anagrams together. " + "Return a list of lists, where each inner list contains strings that are anagrams of " + "each other. The order of groups and strings within groups does not matter.", + "test_code": ( + "from solution import group_anagrams\n" + "result = group_anagrams(['eat', 'tea', 'tan', 'ate', 'nat', 'bat'])\n" + "result_sorted = sorted([sorted(g) for g in result])\n" + "assert result_sorted == [['ate', 'eat', 'tea'], ['bat'], ['nat', 'tan']]\n" + "assert group_anagrams([]) == []\n" + "assert group_anagrams(['a']) == [['a']]\n" + "print('All tests passed!')\n" + ), + "difficulty": "medium", + }, + { + "task": "Write a Python function `valid_parentheses(s)` that determines if a string " + "containing just '(', ')', '{', '}', '[' and ']' is valid. A string is valid if " + "open brackets are closed by the same type and in the correct order.", + "test_code": ( + "from solution import valid_parentheses\n" + "assert valid_parentheses('()') == True\n" + "assert valid_parentheses('()[]{}') == True\n" + "assert valid_parentheses('(]') == False\n" + "assert valid_parentheses('([)]') == False\n" + "assert valid_parentheses('{[]}') == True\n" + "assert valid_parentheses('') == True\n" + "print('All tests passed!')\n" + ), + "difficulty": "easy", + }, + { + "task": "Write a Python function `merge_intervals(intervals)` that merges overlapping " + "intervals. Each interval is a list [start, end]. Return the merged intervals sorted " + "by start time.", + "test_code": ( + "from solution import merge_intervals\n" + "assert merge_intervals([[1,3],[2,6],[8,10],[15,18]]) == [[1,6],[8,10],[15,18]]\n" + "assert merge_intervals([[1,4],[4,5]]) == [[1,5]]\n" + "assert merge_intervals([[1,4],[0,4]]) == [[0,4]]\n" + "assert merge_intervals([]) == []\n" + "assert merge_intervals([[1,2]]) == [[1,2]]\n" + "print('All tests passed!')\n" + ), + "difficulty": "medium", + }, +] + + +# ═══════════════════════════════════════════════════════════════════════ +# Hint extraction prompts (adapted from OpenClaw-RL) +# ═══════════════════════════════════════════════════════════════════════ + +_HINT_JUDGE_SYSTEM = ( + "You are a process reward model used for hindsight hint extraction.\n" + "You are given:\n" + "1) The assistant response at turn t.\n" + "2) The next state at turn t+1, along with its **role**.\n\n" + "## Understanding the next state's role\n" + "- role='user': A reply from the user (follow-up, correction, new request, etc.).\n" + "- role='tool': The return value of a tool the assistant invoked. " + "This content was NOT available before the assistant's action — " + "it exists BECAUSE the assistant called the tool. " + "A successful, non-error tool output generally means the assistant's " + "action was appropriate; do NOT treat it as information the assistant " + "should have already known.\n\n" + "Your goal is to decide whether the next state reveals useful hindsight information\n" + "that could have helped improve the assistant response at turn t.\n\n" + "Output format rules (strict):\n" + "- You MUST include exactly one final decision token: \\boxed{1} or \\boxed{-1}.\n" + "- If and only if decision is \\boxed{1}, provide a concise, information-dense hint in 1-3 sentences,\n" + " wrapped between [HINT_START] and [HINT_END].\n" + "- If decision is \\boxed{-1}, do not provide a hint block.\n" + "- Hint must be concrete and actionable for improving the previous response." +) + +_BOXED_RE = re.compile(r"\\boxed\{(-?\d+)\}") +_HINT_RE = re.compile(r"\[HINT_START\](.*?)\[HINT_END\]", re.DOTALL) + + +def _build_hint_judge_messages( + response_text: str, next_state_text: str, next_state_role: str = "tool" +) -> list[dict]: + """Build messages for the hint extraction judge.""" + user = ( + f"## Assistant response (turn t)\n{response_text}\n\n" + f"## Next state (turn t+1) [role: {next_state_role}]\n{next_state_text}\n\n" + "Now output your decision and (if positive) the hint in the required format." + ) + return [ + {"role": "system", "content": _HINT_JUDGE_SYSTEM}, + {"role": "user", "content": user}, + ] + + +def _parse_hint_result(text: str) -> tuple[int | None, str]: + """Parse the judge's boxed decision and hint text.""" + boxed = _BOXED_RE.findall(text) + score = int(boxed[-1]) if boxed else None + if score not in (1, -1): + score = None + hint_matches = _HINT_RE.findall(text) + hint = hint_matches[-1].strip() if hint_matches else "" + return score, hint + + +def _select_best_hint(votes: list[dict]) -> dict | None: + """Select the best hint from majority-voted judge results.""" + good = [ + v + for v in votes + if v.get("score") == 1 + and isinstance(v.get("hint"), str) + and len(v["hint"].strip()) > 10 + ] + if not good: + return None + return max(good, key=lambda v: len(v["hint"].strip())) + + +def _append_hint_to_messages(messages: list[dict], hint: str) -> list[dict]: + """Clone messages and append hint to the last user message.""" + cloned = copy.deepcopy(messages) + if not cloned: + return [{"role": "user", "content": f"[user's hint / instruction]\n{hint}"}] + + # Find last user message + target_idx = None + for i in range(len(cloned) - 1, -1, -1): + if cloned[i].get("role") == "user": + target_idx = i + break + if target_idx is None: + target_idx = len(cloned) - 1 + + content = cloned[target_idx].get("content", "") + if isinstance(content, list): + content = " ".join( + c.get("text", "") if isinstance(c, dict) else str(c) for c in content + ) + suffix = f"\n\n[user's hint / instruction]\n{hint.strip()}" + cloned[target_idx]["content"] = (content + suffix).strip() + return cloned + + +# ═══════════════════════════════════════════════════════════════════════ +# Configuration +# ═══════════════════════════════════════════════════════════════════════ + + +class AgenticOPDConfig(HermesAgentEnvConfig): + """Configuration for the agentic OPD environment.""" + + # --- OPD settings --- + opd_enabled: bool = Field( + default=True, + description="Enable on-policy distillation pipeline. When disabled, " + "the environment behaves like a standard agentic env (no distill fields).", + ) + distill_topk: int = Field( + default=50, + description="Number of top-K teacher logprobs per position for distillation.", + ) + prm_votes: int = Field( + default=3, + description="Number of independent judge queries for majority-voted hint extraction.", + ) + hint_max_next_state_chars: int = Field( + default=4000, + description="Maximum characters of next-state text to include in the hint judge prompt. " + "Tool results can be very long — truncating prevents judge context overflow.", + ) + + # --- Reward settings --- + correctness_weight: float = Field( + default=0.7, + description="Weight for test pass/fail in reward.", + ) + efficiency_weight: float = Field( + default=0.15, + description="Weight for efficiency (fewer turns = better).", + ) + tool_usage_weight: float = Field( + default=0.15, + description="Weight for appropriate tool usage signal.", + ) + + # --- Dataset --- + dataset_name: Optional[str] = Field( + default=None, + description="HuggingFace dataset with coding tasks. " + "Expected fields: 'task' (problem description) and 'test_code' (pytest/assert tests). " + "Falls back to built-in tasks if not set or unavailable.", + ) + + # --- Eval --- + eval_size: int = Field( + default=10, + description="Number of held-out items for evaluation.", + ) + eval_split_ratio: float = Field( + default=0.15, + description="Fraction of dataset to hold out for evaluation.", + ) + + +# ═══════════════════════════════════════════════════════════════════════ +# Environment +# ═══════════════════════════════════════════════════════════════════════ + + +class AgenticOPDEnv(HermesAgentBaseEnv): + """ + RL environment with on-policy distillation from next-state signals. + + Runs coding tasks where the agent writes code and runs tests. + Tool results (test pass/fail, error traces) serve as next-state signals + for hint extraction and teacher logprob scoring. + + This is the first Atropos environment to populate distill_token_ids + and distill_logprobs on ScoredDataGroup for OPD training. + """ + + name = "agentic-opd" + env_config_cls = AgenticOPDConfig + + # Default toolsets: terminal for running code, file for writing it + default_toolsets = ["terminal", "file"] + + @classmethod + def config_init(cls) -> Tuple[AgenticOPDConfig, List[APIServerConfig]]: + """Default configuration.""" + env_config = AgenticOPDConfig( + # Toolsets + enabled_toolsets=["terminal", "file"], + # Agent loop + max_agent_turns=15, + agent_temperature=1.0, + system_prompt=( + "You are a skilled Python programmer. When given a coding task:\n" + "1. Write the solution to a file called 'solution.py'\n" + "2. Write the test code to a file called 'test_solution.py'\n" + "3. Run the tests with: python test_solution.py\n" + "4. If tests fail, read the error output carefully, fix your code, and re-run\n" + "5. Once all tests pass, report success\n\n" + "Be efficient — write clean code and fix errors methodically." + ), + # OPD + opd_enabled=True, + distill_topk=50, + prm_votes=3, + # Training + group_size=4, + total_steps=500, + steps_per_eval=50, + use_wandb=True, + wandb_name="agentic-opd", + ) + + server_configs = [ + APIServerConfig( + base_url="http://localhost:8000/v1", + model_name="Qwen/Qwen3-4B", + server_type="vllm", + ) + ] + + return env_config, server_configs + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._items: list[dict] = [] + self._eval_items: list[dict] = [] + self._index: int = 0 + + # Metric buffers + self._reward_buffer: list[float] = [] + self._correctness_buffer: list[float] = [] + self._efficiency_buffer: list[float] = [] + self._tool_usage_buffer: list[float] = [] + self._hints_extracted_buffer: list[int] = [] + self._opd_turns_scored_buffer: list[int] = [] + + # ═══════════════════════════════════════════════════════════════════ + # 1. setup — load dataset + # ═══════════════════════════════════════════════════════════════════ + + async def setup(self) -> None: + """Load coding tasks from HuggingFace or use built-in set.""" + if self.config.dataset_name: + try: + from datasets import load_dataset + + logger.info( + "Loading dataset '%s'...", self.config.dataset_name + ) + ds = load_dataset( + self.config.dataset_name, split=self.config.dataset_split + ) + task_field = self.config.prompt_field + self._items = [ + { + "task": row.get(task_field, row.get("task", "")), + "test_code": row.get("test_code", row.get("tests", "")), + "difficulty": row.get("difficulty", "unknown"), + } + for row in ds + if row.get(task_field, row.get("task", "")) + ] + if self._items: + random.shuffle(self._items) + eval_size = max( + self.config.eval_size, + int(len(self._items) * self.config.eval_split_ratio), + ) + self._eval_items = self._items[:eval_size] + self._items = self._items[eval_size:] + logger.info( + "Loaded %d train / %d eval items from '%s'", + len(self._items), + len(self._eval_items), + self.config.dataset_name, + ) + return + except Exception as e: + logger.warning( + "Could not load dataset '%s': %s. Using built-in tasks.", + self.config.dataset_name, + e, + ) + + # Fallback to built-in tasks + items = copy.deepcopy(BUILTIN_CODING_TASKS) + random.shuffle(items) + split = max(1, len(items) * 85 // 100) + self._items = items[:split] + self._eval_items = items[split:] + logger.info( + "Using built-in coding tasks: %d train / %d eval items", + len(self._items), + len(self._eval_items), + ) + + # ═══════════════════════════════════════════════════════════════════ + # 2. get_next_item + # ═══════════════════════════════════════════════════════════════════ + + async def get_next_item(self) -> dict: + """Return the next coding task, cycling through the dataset.""" + if not self._items: + raise RuntimeError("Dataset is empty. Did you call setup()?") + item = self._items[self._index % len(self._items)] + self._index += 1 + return item + + # ═══════════════════════════════════════════════════════════════════ + # 3. format_prompt + # ═══════════════════════════════════════════════════════════════════ + + def format_prompt(self, item: dict) -> str: + """Format the coding task as a user prompt.""" + prompt = ( + f"Solve the following coding task.\n\n" + f"## Task\n{item['task']}\n\n" + ) + if item.get("test_code"): + prompt += ( + f"## Tests\nThe following test code will be used to verify your solution:\n" + f"```python\n{item['test_code']}```\n\n" + ) + prompt += ( + "## Instructions\n" + "1. Write your solution to `solution.py`\n" + "2. Write the test code to `test_solution.py`\n" + "3. Run `python test_solution.py` to verify\n" + "4. Fix any failures and re-run until all tests pass\n" + ) + return prompt + + # ═══════════════════════════════════════════════════════════════════ + # 4. compute_reward + # ═══════════════════════════════════════════════════════════════════ + + async def compute_reward( + self, + item: dict, + result: AgentResult, + ctx: ToolContext, + ) -> float: + """ + Multi-signal reward: + - correctness (0.7): Did the tests pass? + - efficiency (0.15): Fewer turns = better + - tool_usage (0.15): Did the agent actually write + run code? + """ + cfg = self.config + + # ---- Signal 1: Test correctness ---- + # Check if test_solution.py exists and passes in the agent's sandbox + correctness = 0.0 + try: + test_result = ctx.terminal("python test_solution.py 2>&1", timeout=30) + output = test_result.get("output", "") + exit_code = test_result.get("exit_code", 1) + if exit_code == 0 and "passed" in output.lower(): + correctness = 1.0 + elif exit_code == 0: + correctness = 0.8 # Ran without error but no explicit "passed" + elif "assert" in output.lower() and "error" in output.lower(): + correctness = 0.2 # Partial — code runs but assertions fail + else: + correctness = 0.1 # Code errors out entirely + except Exception as e: + logger.debug("Test execution failed in reward: %s", e) + correctness = 0.0 + + # ---- Signal 2: Efficiency ---- + max_turns = cfg.max_agent_turns + turns_used = result.turns_used + if turns_used <= 3: + efficiency = 1.0 + elif turns_used <= max_turns // 2: + efficiency = 0.8 + elif turns_used <= max_turns * 3 // 4: + efficiency = 0.5 + else: + efficiency = 0.2 + + # ---- Signal 3: Tool usage ---- + tools_used = set() + for msg in result.messages: + if msg.get("role") == "assistant" and msg.get("tool_calls"): + for tc in msg["tool_calls"]: + fn = tc.get("function", {}) if isinstance(tc, dict) else {} + name = fn.get("name", "") + if name: + tools_used.add(name) + + # Good: used both terminal and file tools + if "terminal" in tools_used and ("write_file" in tools_used or "patch" in tools_used): + tool_usage = 1.0 + elif "terminal" in tools_used: + tool_usage = 0.6 + elif tools_used: + tool_usage = 0.3 + else: + tool_usage = 0.0 + + # ---- Combine ---- + reward = ( + cfg.correctness_weight * correctness + + cfg.efficiency_weight * efficiency + + cfg.tool_usage_weight * tool_usage + ) + reward = min(1.0, max(0.0, reward)) + + # Track metrics + self._reward_buffer.append(reward) + self._correctness_buffer.append(correctness) + self._efficiency_buffer.append(efficiency) + self._tool_usage_buffer.append(tool_usage) + + logger.debug( + "Reward: correctness=%.2f, efficiency=%.2f, tool_usage=%.2f → %.3f", + correctness, + efficiency, + tool_usage, + reward, + ) + return reward + + # ═══════════════════════════════════════════════════════════════════ + # 5. collect_trajectories — OPD pipeline + # ═══════════════════════════════════════════════════════════════════ + + async def collect_trajectories( + self, item: Item + ) -> Tuple[ + Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]], + List[Item], + ]: + """ + Override collect_trajectories to add the OPD pipeline. + + 1. Run standard rollouts via super() → ScoredDataGroup with tokens/masks/scores + 2. For each rollout, extract hints from next-state signals + 3. Score student tokens under enhanced (hint-augmented) distribution + 4. Add distill_token_ids / distill_logprobs to the ScoredDataGroup + """ + # Step 1: Run standard rollouts + scored_group, backlog = await super().collect_trajectories(item) + + # Step 2: OPD pipeline (only if enabled and we have VLLM server) + if ( + self.config.opd_enabled + and scored_group is not None + and isinstance(scored_group, dict) + and self._use_managed_server() + ): + await self._apply_opd_pipeline(scored_group) + + return scored_group, backlog + + async def _apply_opd_pipeline(self, group: ScoredDataGroup) -> None: + """ + Apply on-policy distillation to each rollout in the group. + + For each rollout's messages: + 1. Find (assistant, next_state) turn pairs + 2. Extract hints via LLM judge with majority voting + 3. Build enhanced prompt (original + hint) + 4. Score student tokens under enhanced distribution via get_logprobs + 5. Add distill_token_ids / distill_logprobs to the group + """ + messages_list = group.get("messages", []) + tokens_list = group.get("tokens", []) + + if not messages_list or not tokens_list: + logger.debug("OPD: No messages or tokens to process") + return + + all_distill_token_ids: List[Optional[List[List[int]]]] = [] + all_distill_logprobs: List[Optional[List[List[float]]]] = [] + + for seq_idx, (messages, student_tokens) in enumerate( + zip(messages_list, tokens_list) + ): + try: + distill_ids, distill_lps = await self._opd_for_sequence( + messages, student_tokens + ) + all_distill_token_ids.append(distill_ids) + all_distill_logprobs.append(distill_lps) + except Exception as e: + logger.warning( + "OPD failed for sequence %d: %s", seq_idx, e + ) + all_distill_token_ids.append(None) + all_distill_logprobs.append(None) + + # Only set distill fields if at least one sequence succeeded + any_succeeded = any(d is not None for d in all_distill_token_ids) + if any_succeeded: + # Replace None entries with zero-padded arrays matching token length + for i in range(len(all_distill_token_ids)): + if all_distill_token_ids[i] is None and i < len(tokens_list): + seq_len = len(tokens_list[i]) + k = self.config.distill_topk + all_distill_token_ids[i] = [[0] * k] * seq_len + all_distill_logprobs[i] = [[0.0] * k] * seq_len + + group["distill_token_ids"] = all_distill_token_ids + group["distill_logprobs"] = all_distill_logprobs + logger.info( + "OPD: Set distill fields on %d/%d sequences", + sum(1 for d in all_distill_token_ids if d is not None), + len(all_distill_token_ids), + ) + + async def _opd_for_sequence( + self, messages: List[Dict], student_tokens: List[int] + ) -> Tuple[List[List[int]], List[List[float]]]: + """ + Run OPD for a single rollout sequence. + + 1. Walk conversation to find (assistant, next_state) pairs + 2. Extract hints from next-state signals + 3. For each hint-augmented turn, score student tokens via get_logprobs + 4. Merge per-turn teacher logprobs into a full-sequence distill array + + Returns: + (distill_token_ids, distill_logprobs) each of shape [seq_len][top_k] + """ + k = self.config.distill_topk + seq_len = len(student_tokens) + + # Initialize with zeros (no distill info = neutral) + distill_token_ids: List[List[int]] = [[0] * k for _ in range(seq_len)] + distill_logprobs: List[List[float]] = [[0.0] * k for _ in range(seq_len)] + + # Find (assistant, next_state) turn pairs + turn_pairs = self._extract_turn_pairs(messages) + if not turn_pairs: + return distill_token_ids, distill_logprobs + + hints_extracted = 0 + turns_scored = 0 + + for pair in turn_pairs: + try: + hint = await self._extract_hint( + pair["assistant_text"], + pair["next_state_text"], + pair["next_state_role"], + ) + if not hint: + continue + + hints_extracted += 1 + + # Build enhanced prompt with hint + enhanced_messages = _append_hint_to_messages( + pair["context_messages"], hint + ) + + # Tokenize the enhanced prompt + if not self.tokenizer: + logger.warning("OPD: No tokenizer available, skipping scoring") + continue + + enhanced_prompt = self.tokenizer.apply_chat_template( + enhanced_messages, + tokenize=False, + add_generation_prompt=True, + ) + + # Tokenize the assistant response to score + response_text = pair["assistant_text"] + enhanced_full_text = enhanced_prompt + response_text + enhanced_ids = self.tokenizer( + enhanced_full_text, add_special_tokens=False + )["input_ids"] + + response_ids = self.tokenizer( + response_text, add_special_tokens=False + )["input_ids"] + response_len = len(response_ids) + + if response_len == 0: + continue + + # Score via get_logprobs — teacher scoring the student's tokens + # under the enhanced (hint-augmented) distribution + try: + logprob_result = await self.server.get_logprobs( + input_ids=enhanced_ids, + top_k=k, + split="eval", # Use eval semaphore to not block training + ) + except Exception as e: + logger.debug("get_logprobs failed: %s", e) + continue + + teacher_topk_ids = logprob_result.get("prompt_topk_token_ids", []) + teacher_topk_lps = logprob_result.get("prompt_topk_logprobs", []) + + if not teacher_topk_ids: + continue + + # Extract only the response positions (last response_len entries) + if len(teacher_topk_ids) >= response_len: + resp_topk_ids = teacher_topk_ids[-response_len:] + resp_topk_lps = teacher_topk_lps[-response_len:] + else: + # Pad from the left if the response was shorter than expected + pad_len = response_len - len(teacher_topk_ids) + resp_topk_ids = [[0] * k] * pad_len + teacher_topk_ids + resp_topk_lps = [[0.0] * k] * pad_len + teacher_topk_lps + + # Map these back to the student's full sequence positions + # Find where this assistant turn's tokens appear in the full sequence + turn_start = self._find_token_span( + student_tokens, response_ids + ) + if turn_start is not None: + for j in range(min(response_len, seq_len - turn_start)): + pos = turn_start + j + if pos < seq_len and j < len(resp_topk_ids): + # Pad/truncate to exactly k entries + ids = resp_topk_ids[j][:k] + lps = resp_topk_lps[j][:k] + while len(ids) < k: + ids.append(0) + lps.append(0.0) + distill_token_ids[pos] = ids + distill_logprobs[pos] = lps + turns_scored += 1 + + except Exception as e: + logger.debug("OPD turn processing failed: %s", e) + continue + + # Track OPD metrics + self._hints_extracted_buffer.append(hints_extracted) + self._opd_turns_scored_buffer.append(turns_scored) + + logger.debug( + "OPD sequence: %d turn pairs, %d hints extracted, %d turns scored", + len(turn_pairs), + hints_extracted, + turns_scored, + ) + return distill_token_ids, distill_logprobs + + def _extract_turn_pairs( + self, messages: List[Dict] + ) -> List[Dict[str, Any]]: + """ + Walk conversation messages to find (assistant, next_state) pairs. + + A "turn pair" is an assistant message with content (the response) + followed by one or more tool results or a user reply (the next state). + + Returns list of dicts: + { + "context_messages": messages up to (not including) the assistant turn, + "assistant_text": the assistant's response text, + "next_state_text": the next state content (tool result or user reply), + "next_state_role": "tool" or "user", + } + """ + pairs = [] + i = 0 + while i < len(messages): + msg = messages[i] + if msg.get("role") == "assistant" and msg.get("content"): + # Found an assistant message with content + assistant_text = msg["content"] + context = messages[:i] # Everything before this turn + + # Look ahead for next state + j = i + 1 + # Skip tool_calls-only assistant messages and collect tool results + next_states = [] + while j < len(messages): + next_msg = messages[j] + if next_msg.get("role") == "tool": + next_states.append(next_msg) + j += 1 + elif next_msg.get("role") == "user": + next_states.append(next_msg) + break + else: + break + + if next_states: + # Combine all next-state content + next_text_parts = [] + next_role = next_states[0].get("role", "tool") + for ns in next_states: + content = ns.get("content", "") + if content: + # Truncate very long tool outputs + max_chars = self.config.hint_max_next_state_chars + if len(content) > max_chars: + content = content[:max_chars] + "\n...[truncated]" + next_text_parts.append(content) + + next_text = "\n---\n".join(next_text_parts) + if next_text.strip(): + pairs.append( + { + "context_messages": context, + "assistant_text": assistant_text, + "next_state_text": next_text, + "next_state_role": next_role, + } + ) + i += 1 + return pairs + + async def _extract_hint( + self, + assistant_text: str, + next_state_text: str, + next_state_role: str, + ) -> Optional[str]: + """ + Extract a hindsight hint from a next-state signal using majority-voted LLM judge. + + Returns the hint string if the judge votes positively, None otherwise. + """ + judge_messages = _build_hint_judge_messages( + response_text=assistant_text, + next_state_text=next_state_text, + next_state_role=next_state_role, + ) + + # Majority voting across multiple judge queries + votes = [] + tasks = [] + for _ in range(self.config.prm_votes): + tasks.append( + self.server.chat_completion( + messages=judge_messages, + n=1, + max_tokens=500, + temperature=0.7, + split="eval", + ) + ) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + for result in results: + if isinstance(result, Exception): + logger.debug("Hint judge call failed: %s", result) + votes.append({"score": None, "hint": ""}) + continue + try: + text = result.choices[0].message.content or "" + score, hint = _parse_hint_result(text) + votes.append({"score": score, "hint": hint}) + except Exception as e: + logger.debug("Hint parse failed: %s", e) + votes.append({"score": None, "hint": ""}) + + selected = _select_best_hint(votes) + if selected is None: + return None + return selected["hint"] + + @staticmethod + def _find_token_span( + full_tokens: List[int], sub_tokens: List[int] + ) -> Optional[int]: + """ + Find where sub_tokens appears in full_tokens. + Returns the start index, or None if not found. + + Uses a sliding window search. For long sequences, searches + from the end since assistant responses are typically at the end. + """ + if not sub_tokens or not full_tokens: + return None + sub_len = len(sub_tokens) + full_len = len(full_tokens) + if sub_len > full_len: + return None + + # Search backwards (assistant responses are usually near the end) + for i in range(full_len - sub_len, -1, -1): + if full_tokens[i : i + sub_len] == sub_tokens: + return i + return None + + # ═══════════════════════════════════════════════════════════════════ + # 6. evaluate + # ═══════════════════════════════════════════════════════════════════ + + async def evaluate(self, *args, **kwargs) -> None: + """ + Evaluate on held-out coding tasks using the full agent loop. + No OPD during eval — just standard agentic evaluation. + """ + if not self._eval_items: + logger.warning("No eval items available.") + return + + eval_size = min(self.config.eval_size, len(self._eval_items)) + eval_items = self._eval_items[:eval_size] + + logger.info("Running eval on %d coding tasks...", len(eval_items)) + start_time = time.time() + samples = [] + + tools, valid_names = self._resolve_tools_for_group() + + for i, item in enumerate(eval_items): + task_id = str(uuid.uuid4()) + logger.info( + "Eval [%d/%d]: %s...", i + 1, len(eval_items), item["task"][:60] + ) + + try: + messages: List[Dict[str, Any]] = [] + if self.config.system_prompt: + messages.append( + {"role": "system", "content": self.config.system_prompt} + ) + messages.append( + {"role": "user", "content": self.format_prompt(item)} + ) + + agent = HermesAgentLoop( + server=self.server, + tool_schemas=tools, + valid_tool_names=valid_names, + max_turns=self.config.max_agent_turns, + task_id=task_id, + temperature=0.0, + max_tokens=self.config.max_token_length, + extra_body=self.config.extra_body, + ) + result = await agent.run(messages) + + # Compute reward (track buffer lengths to rollback eval pollution) + buf_len = len(self._correctness_buffer) + ctx = ToolContext(task_id) + try: + reward = await self.compute_reward(item, result, ctx) + finally: + ctx.cleanup() + + # Extract correctness and rollback training buffers + correctness = ( + self._correctness_buffer[buf_len] + if len(self._correctness_buffer) > buf_len + else 0.0 + ) + for buf in ( + self._reward_buffer, + self._correctness_buffer, + self._efficiency_buffer, + self._tool_usage_buffer, + ): + if len(buf) > buf_len: + buf.pop() + + # Also rollback OPD buffers if they were touched + for buf in ( + self._hints_extracted_buffer, + self._opd_turns_scored_buffer, + ): + if len(buf) > buf_len: + buf.pop() + + # Extract final response + final_response = "" + for msg in reversed(result.messages): + if ( + msg.get("role") == "assistant" + and msg.get("content") + and not final_response + ): + final_response = msg["content"] + break + + samples.append( + { + "prompt": item["task"][:200], + "response": final_response[:500], + "correctness": correctness, + "reward": reward, + "turns": result.turns_used, + } + ) + + logger.info( + " → correctness=%.2f, reward=%.3f, turns=%d", + correctness, + reward, + result.turns_used, + ) + + except Exception as e: + logger.error("Eval error: %s", e) + samples.append( + { + "prompt": item["task"][:200], + "response": f"ERROR: {e}", + "correctness": 0.0, + "reward": 0.0, + "turns": 0, + } + ) + + end_time = time.time() + + correctness_scores = [s["correctness"] for s in samples] + rewards = [s["reward"] for s in samples] + n = len(samples) + + eval_metrics = { + "eval/mean_correctness": sum(correctness_scores) / n if n else 0.0, + "eval/mean_reward": sum(rewards) / n if n else 0.0, + "eval/pass_rate": ( + sum(1 for c in correctness_scores if c >= 0.8) / n if n else 0.0 + ), + "eval/n_items": n, + } + + logger.info( + "Eval complete — correctness=%.3f, reward=%.3f, pass_rate=%.0f%%", + eval_metrics["eval/mean_correctness"], + eval_metrics["eval/mean_reward"], + eval_metrics["eval/pass_rate"] * 100, + ) + + await self.evaluate_log( + metrics=eval_metrics, + samples=samples, + start_time=start_time, + end_time=end_time, + ) + + # ═══════════════════════════════════════════════════════════════════ + # 7. wandb_log — custom OPD metrics + # ═══════════════════════════════════════════════════════════════════ + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None: + """Log reward breakdown and OPD-specific metrics to wandb.""" + if wandb_metrics is None: + wandb_metrics = {} + + if self._reward_buffer: + n = len(self._reward_buffer) + wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n + wandb_metrics["train/mean_correctness"] = ( + sum(self._correctness_buffer) / n + ) + wandb_metrics["train/mean_efficiency"] = ( + sum(self._efficiency_buffer) / n + ) + wandb_metrics["train/mean_tool_usage"] = ( + sum(self._tool_usage_buffer) / n + ) + wandb_metrics["train/pass_rate"] = ( + sum(1 for c in self._correctness_buffer if c >= 0.8) / n + ) + wandb_metrics["train/total_rollouts"] = n + + self._reward_buffer.clear() + self._correctness_buffer.clear() + self._efficiency_buffer.clear() + self._tool_usage_buffer.clear() + + # OPD-specific metrics + if self._hints_extracted_buffer: + n = len(self._hints_extracted_buffer) + wandb_metrics["opd/mean_hints_per_rollout"] = ( + sum(self._hints_extracted_buffer) / n + ) + wandb_metrics["opd/mean_turns_scored"] = ( + sum(self._opd_turns_scored_buffer) / n + ) + wandb_metrics["opd/hint_rate"] = ( + sum(1 for h in self._hints_extracted_buffer if h > 0) / n + ) + wandb_metrics["opd/total_hints"] = sum(self._hints_extracted_buffer) + wandb_metrics["opd/total_scored_turns"] = sum( + self._opd_turns_scored_buffer + ) + + self._hints_extracted_buffer.clear() + self._opd_turns_scored_buffer.clear() + + await super().wandb_log(wandb_metrics) + + +# ═══════════════════════════════════════════════════════════════════════ +# Entry point +# ═══════════════════════════════════════════════════════════════════════ + +if __name__ == "__main__": + AgenticOPDEnv.cli()