First Atropos environment to populate distill_token_ids / distill_logprobs on ScoredDataGroup, enabling on-policy distillation training. Based on OpenClaw-RL (Princeton, arXiv:2603.10165): - Extracts hindsight hints from next-state signals (tool results, errors) - Uses LLM judge with majority voting for hint extraction - Scores student tokens under hint-enhanced distribution via get_logprobs - Packages teacher's top-K predictions as distillation targets Architecture: - AgenticOPDEnv extends HermesAgentBaseEnv - Overrides collect_trajectories to add OPD pipeline after standard rollouts - Uses Atropos's built-in get_logprobs (VLLM prompt_logprobs) for teacher scoring - No external servers needed — same VLLM backend handles both rollouts and scoring Task: Coding problems with test verification (8 built-in tasks, HF dataset support) Reward: correctness (0.7) + efficiency (0.15) + tool usage (0.15) OPD: Per-turn hint extraction → enhanced prompt → teacher top-K logprobs Configurable: opd_enabled, distill_topk, prm_votes, hint truncation length Metrics: opd/mean_hints_per_rollout, opd/mean_turns_scored, opd/hint_rate
1214 lines
50 KiB
Python
1214 lines
50 KiB
Python
"""
|
|
AgenticOPDEnv — On-Policy Distillation for Agentic Tool-Calling Tasks
|
|
=====================================================================
|
|
|
|
First Atropos environment to populate the distill_token_ids / distill_logprobs
|
|
fields on ScoredDataGroup, enabling on-policy distillation (OPD) training.
|
|
|
|
Key idea (from OpenClaw-RL, Princeton 2026):
|
|
Every time an agent receives a next-state signal (tool result, error trace,
|
|
test verdict), that signal contains hindsight information about how the
|
|
agent's PREVIOUS response could have been better. This environment:
|
|
|
|
1. Runs standard agentic rollouts (tool-calling agent loop)
|
|
2. Walks the conversation to find (assistant_turn, next_state) pairs
|
|
3. Uses an LLM judge to extract "hints" from next-state signals
|
|
4. Builds an enhanced prompt (original context + hint)
|
|
5. Scores the student's response tokens under the enhanced distribution
|
|
using VLLM's prompt_logprobs (via Atropos's get_logprobs API)
|
|
6. Packages the teacher's top-K predictions as distill_token_ids /
|
|
distill_logprobs on the ScoredDataGroup
|
|
|
|
The trainer then computes per-token advantages:
|
|
A_t = teacher_logprob(token_t) - student_logprob(token_t)
|
|
Positive → teacher approves this token (upweight)
|
|
Negative → teacher disapproves (downweight)
|
|
|
|
This gives dense, token-level training signal from every tool interaction,
|
|
instead of just a scalar reward at the end of the trajectory.
|
|
|
|
Task: Coding tasks with test verification (rich next-state signals from
|
|
test results, error messages, terminal output). Falls back to built-in
|
|
coding problems if no HuggingFace dataset is configured.
|
|
|
|
Requirements:
|
|
- VLLM backend (server_type: vllm) — needed for prompt logprob scoring
|
|
- Phase 2 mode (ManagedServer) — needed for token-level tracking
|
|
|
|
Usage:
|
|
# Process mode (offline data generation with OPD)
|
|
python environments/agentic_opd_env.py process \\
|
|
--env.total_steps 10 --env.group_size 2 \\
|
|
--env.data_path_to_save_groups output.jsonl \\
|
|
--openai.base_url http://localhost:8000/v1 \\
|
|
--openai.model_name Qwen/Qwen3-4B
|
|
|
|
# Serve mode (connected to Atropos trainer)
|
|
python environments/agentic_opd_env.py serve \\
|
|
--openai.base_url http://localhost:8000/v1 \\
|
|
--openai.model_name Qwen/Qwen3-4B
|
|
|
|
# Evaluate mode
|
|
python environments/agentic_opd_env.py evaluate \\
|
|
--env.eval_size 10 \\
|
|
--openai.base_url http://localhost:8000/v1 \\
|
|
--openai.model_name Qwen/Qwen3-4B
|
|
|
|
Reference: Wang et al., "OpenClaw-RL: Train Any Agent Simply by Talking"
|
|
arXiv:2603.10165, March 2026
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import copy
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import re
|
|
import sys
|
|
import time
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
|
|
from pydantic import Field
|
|
|
|
# Ensure hermes-agent root is on path
|
|
_repo_root = Path(__file__).resolve().parent.parent
|
|
if str(_repo_root) not in sys.path:
|
|
sys.path.insert(0, str(_repo_root))
|
|
|
|
from atroposlib.envs.base import ScoredDataGroup, ScoredDataItem
|
|
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
|
from atroposlib.type_definitions import Item
|
|
|
|
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
|
from environments.agent_loop import AgentResult, HermesAgentLoop
|
|
from environments.tool_context import ToolContext
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
# Built-in coding tasks (fallback when no HF dataset is configured)
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
|
|
BUILTIN_CODING_TASKS = [
|
|
{
|
|
"task": "Write a Python function `fizzbuzz(n)` that returns a list of strings from 1 to n. "
|
|
"For multiples of 3 return 'Fizz', for multiples of 5 return 'Buzz', "
|
|
"for multiples of both return 'FizzBuzz', otherwise the number as a string.",
|
|
"test_code": (
|
|
"from solution import fizzbuzz\n"
|
|
"assert fizzbuzz(15) == ['1','2','Fizz','4','Buzz','Fizz','7','8','Fizz','Buzz','11','Fizz','13','14','FizzBuzz']\n"
|
|
"assert fizzbuzz(1) == ['1']\n"
|
|
"assert fizzbuzz(0) == []\n"
|
|
"print('All tests passed!')\n"
|
|
),
|
|
"difficulty": "easy",
|
|
},
|
|
{
|
|
"task": "Write a Python function `is_palindrome(s)` that checks if a string is a palindrome, "
|
|
"ignoring case and non-alphanumeric characters. Return True or False.",
|
|
"test_code": (
|
|
"from solution import is_palindrome\n"
|
|
"assert is_palindrome('A man, a plan, a canal: Panama') == True\n"
|
|
"assert is_palindrome('race a car') == False\n"
|
|
"assert is_palindrome('') == True\n"
|
|
"assert is_palindrome('Was it a car or a cat I saw?') == True\n"
|
|
"print('All tests passed!')\n"
|
|
),
|
|
"difficulty": "easy",
|
|
},
|
|
{
|
|
"task": "Write a Python function `two_sum(nums, target)` that returns the indices of the two "
|
|
"numbers in `nums` that add up to `target`. Assume exactly one solution exists. "
|
|
"Return a list of two indices [i, j] where i < j.",
|
|
"test_code": (
|
|
"from solution import two_sum\n"
|
|
"assert two_sum([2, 7, 11, 15], 9) == [0, 1]\n"
|
|
"assert two_sum([3, 2, 4], 6) == [1, 2]\n"
|
|
"assert two_sum([3, 3], 6) == [0, 1]\n"
|
|
"print('All tests passed!')\n"
|
|
),
|
|
"difficulty": "easy",
|
|
},
|
|
{
|
|
"task": "Write a Python function `flatten(lst)` that takes an arbitrarily nested list and "
|
|
"returns a flat list of all elements. For example, flatten([1, [2, [3, 4], 5]]) "
|
|
"should return [1, 2, 3, 4, 5].",
|
|
"test_code": (
|
|
"from solution import flatten\n"
|
|
"assert flatten([1, [2, [3, 4], 5]]) == [1, 2, 3, 4, 5]\n"
|
|
"assert flatten([]) == []\n"
|
|
"assert flatten([1, 2, 3]) == [1, 2, 3]\n"
|
|
"assert flatten([[[[1]]]]) == [1]\n"
|
|
"assert flatten([1, [2], [[3]], [[[4]]]]) == [1, 2, 3, 4]\n"
|
|
"print('All tests passed!')\n"
|
|
),
|
|
"difficulty": "medium",
|
|
},
|
|
{
|
|
"task": "Write a Python function `longest_common_prefix(strs)` that finds the longest "
|
|
"common prefix string amongst a list of strings. If there is no common prefix, "
|
|
"return an empty string.",
|
|
"test_code": (
|
|
"from solution import longest_common_prefix\n"
|
|
"assert longest_common_prefix(['flower', 'flow', 'flight']) == 'fl'\n"
|
|
"assert longest_common_prefix(['dog', 'racecar', 'car']) == ''\n"
|
|
"assert longest_common_prefix(['interspecies', 'interstellar', 'interstate']) == 'inters'\n"
|
|
"assert longest_common_prefix(['a']) == 'a'\n"
|
|
"assert longest_common_prefix([]) == ''\n"
|
|
"print('All tests passed!')\n"
|
|
),
|
|
"difficulty": "easy",
|
|
},
|
|
{
|
|
"task": "Write a Python function `group_anagrams(strs)` that groups anagrams together. "
|
|
"Return a list of lists, where each inner list contains strings that are anagrams of "
|
|
"each other. The order of groups and strings within groups does not matter.",
|
|
"test_code": (
|
|
"from solution import group_anagrams\n"
|
|
"result = group_anagrams(['eat', 'tea', 'tan', 'ate', 'nat', 'bat'])\n"
|
|
"result_sorted = sorted([sorted(g) for g in result])\n"
|
|
"assert result_sorted == [['ate', 'eat', 'tea'], ['bat'], ['nat', 'tan']]\n"
|
|
"assert group_anagrams([]) == []\n"
|
|
"assert group_anagrams(['a']) == [['a']]\n"
|
|
"print('All tests passed!')\n"
|
|
),
|
|
"difficulty": "medium",
|
|
},
|
|
{
|
|
"task": "Write a Python function `valid_parentheses(s)` that determines if a string "
|
|
"containing just '(', ')', '{', '}', '[' and ']' is valid. A string is valid if "
|
|
"open brackets are closed by the same type and in the correct order.",
|
|
"test_code": (
|
|
"from solution import valid_parentheses\n"
|
|
"assert valid_parentheses('()') == True\n"
|
|
"assert valid_parentheses('()[]{}') == True\n"
|
|
"assert valid_parentheses('(]') == False\n"
|
|
"assert valid_parentheses('([)]') == False\n"
|
|
"assert valid_parentheses('{[]}') == True\n"
|
|
"assert valid_parentheses('') == True\n"
|
|
"print('All tests passed!')\n"
|
|
),
|
|
"difficulty": "easy",
|
|
},
|
|
{
|
|
"task": "Write a Python function `merge_intervals(intervals)` that merges overlapping "
|
|
"intervals. Each interval is a list [start, end]. Return the merged intervals sorted "
|
|
"by start time.",
|
|
"test_code": (
|
|
"from solution import merge_intervals\n"
|
|
"assert merge_intervals([[1,3],[2,6],[8,10],[15,18]]) == [[1,6],[8,10],[15,18]]\n"
|
|
"assert merge_intervals([[1,4],[4,5]]) == [[1,5]]\n"
|
|
"assert merge_intervals([[1,4],[0,4]]) == [[0,4]]\n"
|
|
"assert merge_intervals([]) == []\n"
|
|
"assert merge_intervals([[1,2]]) == [[1,2]]\n"
|
|
"print('All tests passed!')\n"
|
|
),
|
|
"difficulty": "medium",
|
|
},
|
|
]
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
# Hint extraction prompts (adapted from OpenClaw-RL)
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
|
|
_HINT_JUDGE_SYSTEM = (
|
|
"You are a process reward model used for hindsight hint extraction.\n"
|
|
"You are given:\n"
|
|
"1) The assistant response at turn t.\n"
|
|
"2) The next state at turn t+1, along with its **role**.\n\n"
|
|
"## Understanding the next state's role\n"
|
|
"- role='user': A reply from the user (follow-up, correction, new request, etc.).\n"
|
|
"- role='tool': The return value of a tool the assistant invoked. "
|
|
"This content was NOT available before the assistant's action — "
|
|
"it exists BECAUSE the assistant called the tool. "
|
|
"A successful, non-error tool output generally means the assistant's "
|
|
"action was appropriate; do NOT treat it as information the assistant "
|
|
"should have already known.\n\n"
|
|
"Your goal is to decide whether the next state reveals useful hindsight information\n"
|
|
"that could have helped improve the assistant response at turn t.\n\n"
|
|
"Output format rules (strict):\n"
|
|
"- You MUST include exactly one final decision token: \\boxed{1} or \\boxed{-1}.\n"
|
|
"- If and only if decision is \\boxed{1}, provide a concise, information-dense hint in 1-3 sentences,\n"
|
|
" wrapped between [HINT_START] and [HINT_END].\n"
|
|
"- If decision is \\boxed{-1}, do not provide a hint block.\n"
|
|
"- Hint must be concrete and actionable for improving the previous response."
|
|
)
|
|
|
|
_BOXED_RE = re.compile(r"\\boxed\{(-?\d+)\}")
|
|
_HINT_RE = re.compile(r"\[HINT_START\](.*?)\[HINT_END\]", re.DOTALL)
|
|
|
|
|
|
def _build_hint_judge_messages(
|
|
response_text: str, next_state_text: str, next_state_role: str = "tool"
|
|
) -> list[dict]:
|
|
"""Build messages for the hint extraction judge."""
|
|
user = (
|
|
f"## Assistant response (turn t)\n{response_text}\n\n"
|
|
f"## Next state (turn t+1) [role: {next_state_role}]\n{next_state_text}\n\n"
|
|
"Now output your decision and (if positive) the hint in the required format."
|
|
)
|
|
return [
|
|
{"role": "system", "content": _HINT_JUDGE_SYSTEM},
|
|
{"role": "user", "content": user},
|
|
]
|
|
|
|
|
|
def _parse_hint_result(text: str) -> tuple[int | None, str]:
|
|
"""Parse the judge's boxed decision and hint text."""
|
|
boxed = _BOXED_RE.findall(text)
|
|
score = int(boxed[-1]) if boxed else None
|
|
if score not in (1, -1):
|
|
score = None
|
|
hint_matches = _HINT_RE.findall(text)
|
|
hint = hint_matches[-1].strip() if hint_matches else ""
|
|
return score, hint
|
|
|
|
|
|
def _select_best_hint(votes: list[dict]) -> dict | None:
|
|
"""Select the best hint from majority-voted judge results."""
|
|
good = [
|
|
v
|
|
for v in votes
|
|
if v.get("score") == 1
|
|
and isinstance(v.get("hint"), str)
|
|
and len(v["hint"].strip()) > 10
|
|
]
|
|
if not good:
|
|
return None
|
|
return max(good, key=lambda v: len(v["hint"].strip()))
|
|
|
|
|
|
def _append_hint_to_messages(messages: list[dict], hint: str) -> list[dict]:
|
|
"""Clone messages and append hint to the last user message."""
|
|
cloned = copy.deepcopy(messages)
|
|
if not cloned:
|
|
return [{"role": "user", "content": f"[user's hint / instruction]\n{hint}"}]
|
|
|
|
# Find last user message
|
|
target_idx = None
|
|
for i in range(len(cloned) - 1, -1, -1):
|
|
if cloned[i].get("role") == "user":
|
|
target_idx = i
|
|
break
|
|
if target_idx is None:
|
|
target_idx = len(cloned) - 1
|
|
|
|
content = cloned[target_idx].get("content", "")
|
|
if isinstance(content, list):
|
|
content = " ".join(
|
|
c.get("text", "") if isinstance(c, dict) else str(c) for c in content
|
|
)
|
|
suffix = f"\n\n[user's hint / instruction]\n{hint.strip()}"
|
|
cloned[target_idx]["content"] = (content + suffix).strip()
|
|
return cloned
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
# Configuration
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
|
|
|
|
class AgenticOPDConfig(HermesAgentEnvConfig):
|
|
"""Configuration for the agentic OPD environment."""
|
|
|
|
# --- OPD settings ---
|
|
opd_enabled: bool = Field(
|
|
default=True,
|
|
description="Enable on-policy distillation pipeline. When disabled, "
|
|
"the environment behaves like a standard agentic env (no distill fields).",
|
|
)
|
|
distill_topk: int = Field(
|
|
default=50,
|
|
description="Number of top-K teacher logprobs per position for distillation.",
|
|
)
|
|
prm_votes: int = Field(
|
|
default=3,
|
|
description="Number of independent judge queries for majority-voted hint extraction.",
|
|
)
|
|
hint_max_next_state_chars: int = Field(
|
|
default=4000,
|
|
description="Maximum characters of next-state text to include in the hint judge prompt. "
|
|
"Tool results can be very long — truncating prevents judge context overflow.",
|
|
)
|
|
|
|
# --- Reward settings ---
|
|
correctness_weight: float = Field(
|
|
default=0.7,
|
|
description="Weight for test pass/fail in reward.",
|
|
)
|
|
efficiency_weight: float = Field(
|
|
default=0.15,
|
|
description="Weight for efficiency (fewer turns = better).",
|
|
)
|
|
tool_usage_weight: float = Field(
|
|
default=0.15,
|
|
description="Weight for appropriate tool usage signal.",
|
|
)
|
|
|
|
# --- Dataset ---
|
|
dataset_name: Optional[str] = Field(
|
|
default=None,
|
|
description="HuggingFace dataset with coding tasks. "
|
|
"Expected fields: 'task' (problem description) and 'test_code' (pytest/assert tests). "
|
|
"Falls back to built-in tasks if not set or unavailable.",
|
|
)
|
|
|
|
# --- Eval ---
|
|
eval_size: int = Field(
|
|
default=10,
|
|
description="Number of held-out items for evaluation.",
|
|
)
|
|
eval_split_ratio: float = Field(
|
|
default=0.15,
|
|
description="Fraction of dataset to hold out for evaluation.",
|
|
)
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
# Environment
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
|
|
|
|
class AgenticOPDEnv(HermesAgentBaseEnv):
|
|
"""
|
|
RL environment with on-policy distillation from next-state signals.
|
|
|
|
Runs coding tasks where the agent writes code and runs tests.
|
|
Tool results (test pass/fail, error traces) serve as next-state signals
|
|
for hint extraction and teacher logprob scoring.
|
|
|
|
This is the first Atropos environment to populate distill_token_ids
|
|
and distill_logprobs on ScoredDataGroup for OPD training.
|
|
"""
|
|
|
|
name = "agentic-opd"
|
|
env_config_cls = AgenticOPDConfig
|
|
|
|
# Default toolsets: terminal for running code, file for writing it
|
|
default_toolsets = ["terminal", "file"]
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[AgenticOPDConfig, List[APIServerConfig]]:
|
|
"""Default configuration."""
|
|
env_config = AgenticOPDConfig(
|
|
# Toolsets
|
|
enabled_toolsets=["terminal", "file"],
|
|
# Agent loop
|
|
max_agent_turns=15,
|
|
agent_temperature=1.0,
|
|
system_prompt=(
|
|
"You are a skilled Python programmer. When given a coding task:\n"
|
|
"1. Write the solution to a file called 'solution.py'\n"
|
|
"2. Write the test code to a file called 'test_solution.py'\n"
|
|
"3. Run the tests with: python test_solution.py\n"
|
|
"4. If tests fail, read the error output carefully, fix your code, and re-run\n"
|
|
"5. Once all tests pass, report success\n\n"
|
|
"Be efficient — write clean code and fix errors methodically."
|
|
),
|
|
# OPD
|
|
opd_enabled=True,
|
|
distill_topk=50,
|
|
prm_votes=3,
|
|
# Training
|
|
group_size=4,
|
|
total_steps=500,
|
|
steps_per_eval=50,
|
|
use_wandb=True,
|
|
wandb_name="agentic-opd",
|
|
)
|
|
|
|
server_configs = [
|
|
APIServerConfig(
|
|
base_url="http://localhost:8000/v1",
|
|
model_name="Qwen/Qwen3-4B",
|
|
server_type="vllm",
|
|
)
|
|
]
|
|
|
|
return env_config, server_configs
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._items: list[dict] = []
|
|
self._eval_items: list[dict] = []
|
|
self._index: int = 0
|
|
|
|
# Metric buffers
|
|
self._reward_buffer: list[float] = []
|
|
self._correctness_buffer: list[float] = []
|
|
self._efficiency_buffer: list[float] = []
|
|
self._tool_usage_buffer: list[float] = []
|
|
self._hints_extracted_buffer: list[int] = []
|
|
self._opd_turns_scored_buffer: list[int] = []
|
|
|
|
# ═══════════════════════════════════════════════════════════════════
|
|
# 1. setup — load dataset
|
|
# ═══════════════════════════════════════════════════════════════════
|
|
|
|
async def setup(self) -> None:
|
|
"""Load coding tasks from HuggingFace or use built-in set."""
|
|
if self.config.dataset_name:
|
|
try:
|
|
from datasets import load_dataset
|
|
|
|
logger.info(
|
|
"Loading dataset '%s'...", self.config.dataset_name
|
|
)
|
|
ds = load_dataset(
|
|
self.config.dataset_name, split=self.config.dataset_split
|
|
)
|
|
task_field = self.config.prompt_field
|
|
self._items = [
|
|
{
|
|
"task": row.get(task_field, row.get("task", "")),
|
|
"test_code": row.get("test_code", row.get("tests", "")),
|
|
"difficulty": row.get("difficulty", "unknown"),
|
|
}
|
|
for row in ds
|
|
if row.get(task_field, row.get("task", ""))
|
|
]
|
|
if self._items:
|
|
random.shuffle(self._items)
|
|
eval_size = max(
|
|
self.config.eval_size,
|
|
int(len(self._items) * self.config.eval_split_ratio),
|
|
)
|
|
self._eval_items = self._items[:eval_size]
|
|
self._items = self._items[eval_size:]
|
|
logger.info(
|
|
"Loaded %d train / %d eval items from '%s'",
|
|
len(self._items),
|
|
len(self._eval_items),
|
|
self.config.dataset_name,
|
|
)
|
|
return
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Could not load dataset '%s': %s. Using built-in tasks.",
|
|
self.config.dataset_name,
|
|
e,
|
|
)
|
|
|
|
# Fallback to built-in tasks
|
|
items = copy.deepcopy(BUILTIN_CODING_TASKS)
|
|
random.shuffle(items)
|
|
split = max(1, len(items) * 85 // 100)
|
|
self._items = items[:split]
|
|
self._eval_items = items[split:]
|
|
logger.info(
|
|
"Using built-in coding tasks: %d train / %d eval items",
|
|
len(self._items),
|
|
len(self._eval_items),
|
|
)
|
|
|
|
# ═══════════════════════════════════════════════════════════════════
|
|
# 2. get_next_item
|
|
# ═══════════════════════════════════════════════════════════════════
|
|
|
|
async def get_next_item(self) -> dict:
|
|
"""Return the next coding task, cycling through the dataset."""
|
|
if not self._items:
|
|
raise RuntimeError("Dataset is empty. Did you call setup()?")
|
|
item = self._items[self._index % len(self._items)]
|
|
self._index += 1
|
|
return item
|
|
|
|
# ═══════════════════════════════════════════════════════════════════
|
|
# 3. format_prompt
|
|
# ═══════════════════════════════════════════════════════════════════
|
|
|
|
def format_prompt(self, item: dict) -> str:
|
|
"""Format the coding task as a user prompt."""
|
|
prompt = (
|
|
f"Solve the following coding task.\n\n"
|
|
f"## Task\n{item['task']}\n\n"
|
|
)
|
|
if item.get("test_code"):
|
|
prompt += (
|
|
f"## Tests\nThe following test code will be used to verify your solution:\n"
|
|
f"```python\n{item['test_code']}```\n\n"
|
|
)
|
|
prompt += (
|
|
"## Instructions\n"
|
|
"1. Write your solution to `solution.py`\n"
|
|
"2. Write the test code to `test_solution.py`\n"
|
|
"3. Run `python test_solution.py` to verify\n"
|
|
"4. Fix any failures and re-run until all tests pass\n"
|
|
)
|
|
return prompt
|
|
|
|
# ═══════════════════════════════════════════════════════════════════
|
|
# 4. compute_reward
|
|
# ═══════════════════════════════════════════════════════════════════
|
|
|
|
async def compute_reward(
|
|
self,
|
|
item: dict,
|
|
result: AgentResult,
|
|
ctx: ToolContext,
|
|
) -> float:
|
|
"""
|
|
Multi-signal reward:
|
|
- correctness (0.7): Did the tests pass?
|
|
- efficiency (0.15): Fewer turns = better
|
|
- tool_usage (0.15): Did the agent actually write + run code?
|
|
"""
|
|
cfg = self.config
|
|
|
|
# ---- Signal 1: Test correctness ----
|
|
# Check if test_solution.py exists and passes in the agent's sandbox
|
|
correctness = 0.0
|
|
try:
|
|
test_result = ctx.terminal("python test_solution.py 2>&1", timeout=30)
|
|
output = test_result.get("output", "")
|
|
exit_code = test_result.get("exit_code", 1)
|
|
if exit_code == 0 and "passed" in output.lower():
|
|
correctness = 1.0
|
|
elif exit_code == 0:
|
|
correctness = 0.8 # Ran without error but no explicit "passed"
|
|
elif "assert" in output.lower() and "error" in output.lower():
|
|
correctness = 0.2 # Partial — code runs but assertions fail
|
|
else:
|
|
correctness = 0.1 # Code errors out entirely
|
|
except Exception as e:
|
|
logger.debug("Test execution failed in reward: %s", e)
|
|
correctness = 0.0
|
|
|
|
# ---- Signal 2: Efficiency ----
|
|
max_turns = cfg.max_agent_turns
|
|
turns_used = result.turns_used
|
|
if turns_used <= 3:
|
|
efficiency = 1.0
|
|
elif turns_used <= max_turns // 2:
|
|
efficiency = 0.8
|
|
elif turns_used <= max_turns * 3 // 4:
|
|
efficiency = 0.5
|
|
else:
|
|
efficiency = 0.2
|
|
|
|
# ---- Signal 3: Tool usage ----
|
|
tools_used = set()
|
|
for msg in result.messages:
|
|
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
for tc in msg["tool_calls"]:
|
|
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
|
|
name = fn.get("name", "")
|
|
if name:
|
|
tools_used.add(name)
|
|
|
|
# Good: used both terminal and file tools
|
|
if "terminal" in tools_used and ("write_file" in tools_used or "patch" in tools_used):
|
|
tool_usage = 1.0
|
|
elif "terminal" in tools_used:
|
|
tool_usage = 0.6
|
|
elif tools_used:
|
|
tool_usage = 0.3
|
|
else:
|
|
tool_usage = 0.0
|
|
|
|
# ---- Combine ----
|
|
reward = (
|
|
cfg.correctness_weight * correctness
|
|
+ cfg.efficiency_weight * efficiency
|
|
+ cfg.tool_usage_weight * tool_usage
|
|
)
|
|
reward = min(1.0, max(0.0, reward))
|
|
|
|
# Track metrics
|
|
self._reward_buffer.append(reward)
|
|
self._correctness_buffer.append(correctness)
|
|
self._efficiency_buffer.append(efficiency)
|
|
self._tool_usage_buffer.append(tool_usage)
|
|
|
|
logger.debug(
|
|
"Reward: correctness=%.2f, efficiency=%.2f, tool_usage=%.2f → %.3f",
|
|
correctness,
|
|
efficiency,
|
|
tool_usage,
|
|
reward,
|
|
)
|
|
return reward
|
|
|
|
# ═══════════════════════════════════════════════════════════════════
|
|
# 5. collect_trajectories — OPD pipeline
|
|
# ═══════════════════════════════════════════════════════════════════
|
|
|
|
async def collect_trajectories(
|
|
self, item: Item
|
|
) -> Tuple[
|
|
Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]],
|
|
List[Item],
|
|
]:
|
|
"""
|
|
Override collect_trajectories to add the OPD pipeline.
|
|
|
|
1. Run standard rollouts via super() → ScoredDataGroup with tokens/masks/scores
|
|
2. For each rollout, extract hints from next-state signals
|
|
3. Score student tokens under enhanced (hint-augmented) distribution
|
|
4. Add distill_token_ids / distill_logprobs to the ScoredDataGroup
|
|
"""
|
|
# Step 1: Run standard rollouts
|
|
scored_group, backlog = await super().collect_trajectories(item)
|
|
|
|
# Step 2: OPD pipeline (only if enabled and we have VLLM server)
|
|
if (
|
|
self.config.opd_enabled
|
|
and scored_group is not None
|
|
and isinstance(scored_group, dict)
|
|
and self._use_managed_server()
|
|
):
|
|
await self._apply_opd_pipeline(scored_group)
|
|
|
|
return scored_group, backlog
|
|
|
|
async def _apply_opd_pipeline(self, group: ScoredDataGroup) -> None:
|
|
"""
|
|
Apply on-policy distillation to each rollout in the group.
|
|
|
|
For each rollout's messages:
|
|
1. Find (assistant, next_state) turn pairs
|
|
2. Extract hints via LLM judge with majority voting
|
|
3. Build enhanced prompt (original + hint)
|
|
4. Score student tokens under enhanced distribution via get_logprobs
|
|
5. Add distill_token_ids / distill_logprobs to the group
|
|
"""
|
|
messages_list = group.get("messages", [])
|
|
tokens_list = group.get("tokens", [])
|
|
|
|
if not messages_list or not tokens_list:
|
|
logger.debug("OPD: No messages or tokens to process")
|
|
return
|
|
|
|
all_distill_token_ids: List[Optional[List[List[int]]]] = []
|
|
all_distill_logprobs: List[Optional[List[List[float]]]] = []
|
|
|
|
for seq_idx, (messages, student_tokens) in enumerate(
|
|
zip(messages_list, tokens_list)
|
|
):
|
|
try:
|
|
distill_ids, distill_lps = await self._opd_for_sequence(
|
|
messages, student_tokens
|
|
)
|
|
all_distill_token_ids.append(distill_ids)
|
|
all_distill_logprobs.append(distill_lps)
|
|
except Exception as e:
|
|
logger.warning(
|
|
"OPD failed for sequence %d: %s", seq_idx, e
|
|
)
|
|
all_distill_token_ids.append(None)
|
|
all_distill_logprobs.append(None)
|
|
|
|
# Only set distill fields if at least one sequence succeeded
|
|
any_succeeded = any(d is not None for d in all_distill_token_ids)
|
|
if any_succeeded:
|
|
# Replace None entries with zero-padded arrays matching token length
|
|
for i in range(len(all_distill_token_ids)):
|
|
if all_distill_token_ids[i] is None and i < len(tokens_list):
|
|
seq_len = len(tokens_list[i])
|
|
k = self.config.distill_topk
|
|
all_distill_token_ids[i] = [[0] * k] * seq_len
|
|
all_distill_logprobs[i] = [[0.0] * k] * seq_len
|
|
|
|
group["distill_token_ids"] = all_distill_token_ids
|
|
group["distill_logprobs"] = all_distill_logprobs
|
|
logger.info(
|
|
"OPD: Set distill fields on %d/%d sequences",
|
|
sum(1 for d in all_distill_token_ids if d is not None),
|
|
len(all_distill_token_ids),
|
|
)
|
|
|
|
async def _opd_for_sequence(
|
|
self, messages: List[Dict], student_tokens: List[int]
|
|
) -> Tuple[List[List[int]], List[List[float]]]:
|
|
"""
|
|
Run OPD for a single rollout sequence.
|
|
|
|
1. Walk conversation to find (assistant, next_state) pairs
|
|
2. Extract hints from next-state signals
|
|
3. For each hint-augmented turn, score student tokens via get_logprobs
|
|
4. Merge per-turn teacher logprobs into a full-sequence distill array
|
|
|
|
Returns:
|
|
(distill_token_ids, distill_logprobs) each of shape [seq_len][top_k]
|
|
"""
|
|
k = self.config.distill_topk
|
|
seq_len = len(student_tokens)
|
|
|
|
# Initialize with zeros (no distill info = neutral)
|
|
distill_token_ids: List[List[int]] = [[0] * k for _ in range(seq_len)]
|
|
distill_logprobs: List[List[float]] = [[0.0] * k for _ in range(seq_len)]
|
|
|
|
# Find (assistant, next_state) turn pairs
|
|
turn_pairs = self._extract_turn_pairs(messages)
|
|
if not turn_pairs:
|
|
return distill_token_ids, distill_logprobs
|
|
|
|
hints_extracted = 0
|
|
turns_scored = 0
|
|
|
|
for pair in turn_pairs:
|
|
try:
|
|
hint = await self._extract_hint(
|
|
pair["assistant_text"],
|
|
pair["next_state_text"],
|
|
pair["next_state_role"],
|
|
)
|
|
if not hint:
|
|
continue
|
|
|
|
hints_extracted += 1
|
|
|
|
# Build enhanced prompt with hint
|
|
enhanced_messages = _append_hint_to_messages(
|
|
pair["context_messages"], hint
|
|
)
|
|
|
|
# Tokenize the enhanced prompt
|
|
if not self.tokenizer:
|
|
logger.warning("OPD: No tokenizer available, skipping scoring")
|
|
continue
|
|
|
|
enhanced_prompt = self.tokenizer.apply_chat_template(
|
|
enhanced_messages,
|
|
tokenize=False,
|
|
add_generation_prompt=True,
|
|
)
|
|
|
|
# Tokenize the assistant response to score
|
|
response_text = pair["assistant_text"]
|
|
enhanced_full_text = enhanced_prompt + response_text
|
|
enhanced_ids = self.tokenizer(
|
|
enhanced_full_text, add_special_tokens=False
|
|
)["input_ids"]
|
|
|
|
response_ids = self.tokenizer(
|
|
response_text, add_special_tokens=False
|
|
)["input_ids"]
|
|
response_len = len(response_ids)
|
|
|
|
if response_len == 0:
|
|
continue
|
|
|
|
# Score via get_logprobs — teacher scoring the student's tokens
|
|
# under the enhanced (hint-augmented) distribution
|
|
try:
|
|
logprob_result = await self.server.get_logprobs(
|
|
input_ids=enhanced_ids,
|
|
top_k=k,
|
|
split="eval", # Use eval semaphore to not block training
|
|
)
|
|
except Exception as e:
|
|
logger.debug("get_logprobs failed: %s", e)
|
|
continue
|
|
|
|
teacher_topk_ids = logprob_result.get("prompt_topk_token_ids", [])
|
|
teacher_topk_lps = logprob_result.get("prompt_topk_logprobs", [])
|
|
|
|
if not teacher_topk_ids:
|
|
continue
|
|
|
|
# Extract only the response positions (last response_len entries)
|
|
if len(teacher_topk_ids) >= response_len:
|
|
resp_topk_ids = teacher_topk_ids[-response_len:]
|
|
resp_topk_lps = teacher_topk_lps[-response_len:]
|
|
else:
|
|
# Pad from the left if the response was shorter than expected
|
|
pad_len = response_len - len(teacher_topk_ids)
|
|
resp_topk_ids = [[0] * k] * pad_len + teacher_topk_ids
|
|
resp_topk_lps = [[0.0] * k] * pad_len + teacher_topk_lps
|
|
|
|
# Map these back to the student's full sequence positions
|
|
# Find where this assistant turn's tokens appear in the full sequence
|
|
turn_start = self._find_token_span(
|
|
student_tokens, response_ids
|
|
)
|
|
if turn_start is not None:
|
|
for j in range(min(response_len, seq_len - turn_start)):
|
|
pos = turn_start + j
|
|
if pos < seq_len and j < len(resp_topk_ids):
|
|
# Pad/truncate to exactly k entries
|
|
ids = resp_topk_ids[j][:k]
|
|
lps = resp_topk_lps[j][:k]
|
|
while len(ids) < k:
|
|
ids.append(0)
|
|
lps.append(0.0)
|
|
distill_token_ids[pos] = ids
|
|
distill_logprobs[pos] = lps
|
|
turns_scored += 1
|
|
|
|
except Exception as e:
|
|
logger.debug("OPD turn processing failed: %s", e)
|
|
continue
|
|
|
|
# Track OPD metrics
|
|
self._hints_extracted_buffer.append(hints_extracted)
|
|
self._opd_turns_scored_buffer.append(turns_scored)
|
|
|
|
logger.debug(
|
|
"OPD sequence: %d turn pairs, %d hints extracted, %d turns scored",
|
|
len(turn_pairs),
|
|
hints_extracted,
|
|
turns_scored,
|
|
)
|
|
return distill_token_ids, distill_logprobs
|
|
|
|
def _extract_turn_pairs(
|
|
self, messages: List[Dict]
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Walk conversation messages to find (assistant, next_state) pairs.
|
|
|
|
A "turn pair" is an assistant message with content (the response)
|
|
followed by one or more tool results or a user reply (the next state).
|
|
|
|
Returns list of dicts:
|
|
{
|
|
"context_messages": messages up to (not including) the assistant turn,
|
|
"assistant_text": the assistant's response text,
|
|
"next_state_text": the next state content (tool result or user reply),
|
|
"next_state_role": "tool" or "user",
|
|
}
|
|
"""
|
|
pairs = []
|
|
i = 0
|
|
while i < len(messages):
|
|
msg = messages[i]
|
|
if msg.get("role") == "assistant" and msg.get("content"):
|
|
# Found an assistant message with content
|
|
assistant_text = msg["content"]
|
|
context = messages[:i] # Everything before this turn
|
|
|
|
# Look ahead for next state
|
|
j = i + 1
|
|
# Skip tool_calls-only assistant messages and collect tool results
|
|
next_states = []
|
|
while j < len(messages):
|
|
next_msg = messages[j]
|
|
if next_msg.get("role") == "tool":
|
|
next_states.append(next_msg)
|
|
j += 1
|
|
elif next_msg.get("role") == "user":
|
|
next_states.append(next_msg)
|
|
break
|
|
else:
|
|
break
|
|
|
|
if next_states:
|
|
# Combine all next-state content
|
|
next_text_parts = []
|
|
next_role = next_states[0].get("role", "tool")
|
|
for ns in next_states:
|
|
content = ns.get("content", "")
|
|
if content:
|
|
# Truncate very long tool outputs
|
|
max_chars = self.config.hint_max_next_state_chars
|
|
if len(content) > max_chars:
|
|
content = content[:max_chars] + "\n...[truncated]"
|
|
next_text_parts.append(content)
|
|
|
|
next_text = "\n---\n".join(next_text_parts)
|
|
if next_text.strip():
|
|
pairs.append(
|
|
{
|
|
"context_messages": context,
|
|
"assistant_text": assistant_text,
|
|
"next_state_text": next_text,
|
|
"next_state_role": next_role,
|
|
}
|
|
)
|
|
i += 1
|
|
return pairs
|
|
|
|
async def _extract_hint(
|
|
self,
|
|
assistant_text: str,
|
|
next_state_text: str,
|
|
next_state_role: str,
|
|
) -> Optional[str]:
|
|
"""
|
|
Extract a hindsight hint from a next-state signal using majority-voted LLM judge.
|
|
|
|
Returns the hint string if the judge votes positively, None otherwise.
|
|
"""
|
|
judge_messages = _build_hint_judge_messages(
|
|
response_text=assistant_text,
|
|
next_state_text=next_state_text,
|
|
next_state_role=next_state_role,
|
|
)
|
|
|
|
# Majority voting across multiple judge queries
|
|
votes = []
|
|
tasks = []
|
|
for _ in range(self.config.prm_votes):
|
|
tasks.append(
|
|
self.server.chat_completion(
|
|
messages=judge_messages,
|
|
n=1,
|
|
max_tokens=500,
|
|
temperature=0.7,
|
|
split="eval",
|
|
)
|
|
)
|
|
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
for result in results:
|
|
if isinstance(result, Exception):
|
|
logger.debug("Hint judge call failed: %s", result)
|
|
votes.append({"score": None, "hint": ""})
|
|
continue
|
|
try:
|
|
text = result.choices[0].message.content or ""
|
|
score, hint = _parse_hint_result(text)
|
|
votes.append({"score": score, "hint": hint})
|
|
except Exception as e:
|
|
logger.debug("Hint parse failed: %s", e)
|
|
votes.append({"score": None, "hint": ""})
|
|
|
|
selected = _select_best_hint(votes)
|
|
if selected is None:
|
|
return None
|
|
return selected["hint"]
|
|
|
|
@staticmethod
|
|
def _find_token_span(
|
|
full_tokens: List[int], sub_tokens: List[int]
|
|
) -> Optional[int]:
|
|
"""
|
|
Find where sub_tokens appears in full_tokens.
|
|
Returns the start index, or None if not found.
|
|
|
|
Uses a sliding window search. For long sequences, searches
|
|
from the end since assistant responses are typically at the end.
|
|
"""
|
|
if not sub_tokens or not full_tokens:
|
|
return None
|
|
sub_len = len(sub_tokens)
|
|
full_len = len(full_tokens)
|
|
if sub_len > full_len:
|
|
return None
|
|
|
|
# Search backwards (assistant responses are usually near the end)
|
|
for i in range(full_len - sub_len, -1, -1):
|
|
if full_tokens[i : i + sub_len] == sub_tokens:
|
|
return i
|
|
return None
|
|
|
|
# ═══════════════════════════════════════════════════════════════════
|
|
# 6. evaluate
|
|
# ═══════════════════════════════════════════════════════════════════
|
|
|
|
async def evaluate(self, *args, **kwargs) -> None:
|
|
"""
|
|
Evaluate on held-out coding tasks using the full agent loop.
|
|
No OPD during eval — just standard agentic evaluation.
|
|
"""
|
|
if not self._eval_items:
|
|
logger.warning("No eval items available.")
|
|
return
|
|
|
|
eval_size = min(self.config.eval_size, len(self._eval_items))
|
|
eval_items = self._eval_items[:eval_size]
|
|
|
|
logger.info("Running eval on %d coding tasks...", len(eval_items))
|
|
start_time = time.time()
|
|
samples = []
|
|
|
|
tools, valid_names = self._resolve_tools_for_group()
|
|
|
|
for i, item in enumerate(eval_items):
|
|
task_id = str(uuid.uuid4())
|
|
logger.info(
|
|
"Eval [%d/%d]: %s...", i + 1, len(eval_items), item["task"][:60]
|
|
)
|
|
|
|
try:
|
|
messages: List[Dict[str, Any]] = []
|
|
if self.config.system_prompt:
|
|
messages.append(
|
|
{"role": "system", "content": self.config.system_prompt}
|
|
)
|
|
messages.append(
|
|
{"role": "user", "content": self.format_prompt(item)}
|
|
)
|
|
|
|
agent = HermesAgentLoop(
|
|
server=self.server,
|
|
tool_schemas=tools,
|
|
valid_tool_names=valid_names,
|
|
max_turns=self.config.max_agent_turns,
|
|
task_id=task_id,
|
|
temperature=0.0,
|
|
max_tokens=self.config.max_token_length,
|
|
extra_body=self.config.extra_body,
|
|
)
|
|
result = await agent.run(messages)
|
|
|
|
# Compute reward (track buffer lengths to rollback eval pollution)
|
|
buf_len = len(self._correctness_buffer)
|
|
ctx = ToolContext(task_id)
|
|
try:
|
|
reward = await self.compute_reward(item, result, ctx)
|
|
finally:
|
|
ctx.cleanup()
|
|
|
|
# Extract correctness and rollback training buffers
|
|
correctness = (
|
|
self._correctness_buffer[buf_len]
|
|
if len(self._correctness_buffer) > buf_len
|
|
else 0.0
|
|
)
|
|
for buf in (
|
|
self._reward_buffer,
|
|
self._correctness_buffer,
|
|
self._efficiency_buffer,
|
|
self._tool_usage_buffer,
|
|
):
|
|
if len(buf) > buf_len:
|
|
buf.pop()
|
|
|
|
# Also rollback OPD buffers if they were touched
|
|
for buf in (
|
|
self._hints_extracted_buffer,
|
|
self._opd_turns_scored_buffer,
|
|
):
|
|
if len(buf) > buf_len:
|
|
buf.pop()
|
|
|
|
# Extract final response
|
|
final_response = ""
|
|
for msg in reversed(result.messages):
|
|
if (
|
|
msg.get("role") == "assistant"
|
|
and msg.get("content")
|
|
and not final_response
|
|
):
|
|
final_response = msg["content"]
|
|
break
|
|
|
|
samples.append(
|
|
{
|
|
"prompt": item["task"][:200],
|
|
"response": final_response[:500],
|
|
"correctness": correctness,
|
|
"reward": reward,
|
|
"turns": result.turns_used,
|
|
}
|
|
)
|
|
|
|
logger.info(
|
|
" → correctness=%.2f, reward=%.3f, turns=%d",
|
|
correctness,
|
|
reward,
|
|
result.turns_used,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error("Eval error: %s", e)
|
|
samples.append(
|
|
{
|
|
"prompt": item["task"][:200],
|
|
"response": f"ERROR: {e}",
|
|
"correctness": 0.0,
|
|
"reward": 0.0,
|
|
"turns": 0,
|
|
}
|
|
)
|
|
|
|
end_time = time.time()
|
|
|
|
correctness_scores = [s["correctness"] for s in samples]
|
|
rewards = [s["reward"] for s in samples]
|
|
n = len(samples)
|
|
|
|
eval_metrics = {
|
|
"eval/mean_correctness": sum(correctness_scores) / n if n else 0.0,
|
|
"eval/mean_reward": sum(rewards) / n if n else 0.0,
|
|
"eval/pass_rate": (
|
|
sum(1 for c in correctness_scores if c >= 0.8) / n if n else 0.0
|
|
),
|
|
"eval/n_items": n,
|
|
}
|
|
|
|
logger.info(
|
|
"Eval complete — correctness=%.3f, reward=%.3f, pass_rate=%.0f%%",
|
|
eval_metrics["eval/mean_correctness"],
|
|
eval_metrics["eval/mean_reward"],
|
|
eval_metrics["eval/pass_rate"] * 100,
|
|
)
|
|
|
|
await self.evaluate_log(
|
|
metrics=eval_metrics,
|
|
samples=samples,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
)
|
|
|
|
# ═══════════════════════════════════════════════════════════════════
|
|
# 7. wandb_log — custom OPD metrics
|
|
# ═══════════════════════════════════════════════════════════════════
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None:
|
|
"""Log reward breakdown and OPD-specific metrics to wandb."""
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
if self._reward_buffer:
|
|
n = len(self._reward_buffer)
|
|
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
|
|
wandb_metrics["train/mean_correctness"] = (
|
|
sum(self._correctness_buffer) / n
|
|
)
|
|
wandb_metrics["train/mean_efficiency"] = (
|
|
sum(self._efficiency_buffer) / n
|
|
)
|
|
wandb_metrics["train/mean_tool_usage"] = (
|
|
sum(self._tool_usage_buffer) / n
|
|
)
|
|
wandb_metrics["train/pass_rate"] = (
|
|
sum(1 for c in self._correctness_buffer if c >= 0.8) / n
|
|
)
|
|
wandb_metrics["train/total_rollouts"] = n
|
|
|
|
self._reward_buffer.clear()
|
|
self._correctness_buffer.clear()
|
|
self._efficiency_buffer.clear()
|
|
self._tool_usage_buffer.clear()
|
|
|
|
# OPD-specific metrics
|
|
if self._hints_extracted_buffer:
|
|
n = len(self._hints_extracted_buffer)
|
|
wandb_metrics["opd/mean_hints_per_rollout"] = (
|
|
sum(self._hints_extracted_buffer) / n
|
|
)
|
|
wandb_metrics["opd/mean_turns_scored"] = (
|
|
sum(self._opd_turns_scored_buffer) / n
|
|
)
|
|
wandb_metrics["opd/hint_rate"] = (
|
|
sum(1 for h in self._hints_extracted_buffer if h > 0) / n
|
|
)
|
|
wandb_metrics["opd/total_hints"] = sum(self._hints_extracted_buffer)
|
|
wandb_metrics["opd/total_scored_turns"] = sum(
|
|
self._opd_turns_scored_buffer
|
|
)
|
|
|
|
self._hints_extracted_buffer.clear()
|
|
self._opd_turns_scored_buffer.clear()
|
|
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
# Entry point
|
|
# ═══════════════════════════════════════════════════════════════════════
|
|
|
|
if __name__ == "__main__":
|
|
AgenticOPDEnv.cli()
|