Merge: WebResearchEnv Atropos standards compliance

This commit is contained in:
teknium1
2026-03-09 17:45:57 -07:00

View File

@@ -16,21 +16,18 @@ Dataset: FRAMES benchmark (Google, 2024) — multi-hop factual questions
Usage:
# Phase 1 (OpenAI-compatible server)
python environments/web_research_env.py serve \
--openai.base_url http://localhost:8000/v1 \
--openai.model_name YourModel \
python environments/web_research_env.py serve \\
--openai.base_url http://localhost:8000/v1 \\
--openai.model_name YourModel \\
--openai.server_type openai
# With eval split
python environments/web_research_env.py serve \
--openai.base_url http://localhost:8000/v1 \
--openai.model_name YourModel \
--env.eval_every 50 \
--env.eval_size 20
# Process mode (offline data generation)
python environments/web_research_env.py process \\
--env.data_path_to_save_groups data/web_research.jsonl
# Standalone eval (no training server needed)
python environments/web_research_env.py eval \
--openai.base_url http://localhost:8000/v1 \
# Standalone eval
python environments/web_research_env.py evaluate \\
--openai.base_url http://localhost:8000/v1 \\
--openai.model_name YourModel
Built by: github.com/jackx707
@@ -43,11 +40,21 @@ from __future__ import annotations
import asyncio
import json
import logging
import os
import random
import re
from typing import Any, Optional
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse
from pydantic import Field
# Ensure hermes-agent root is on path
_repo_root = Path(__file__).resolve().parent.parent
if str(_repo_root) not in sys.path:
sys.path.insert(0, str(_repo_root))
# ---------------------------------------------------------------------------
# Optional HuggingFace datasets import
# ---------------------------------------------------------------------------
@@ -57,13 +64,19 @@ try:
except ImportError:
HF_AVAILABLE = False
from environments.hermes_base_env import HermesAgentBaseEnv
from atroposlib.envs.base import ScoredDataGroup
from atroposlib.envs.server_handling.server_manager import APIServerConfig
from atroposlib.type_definitions import Item
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
from environments.agent_loop import AgentResult
from environments.tool_context import ToolContext
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Fallback sample dataset (used when HuggingFace is unavailable)
# These are multi-hop questions that require real web search to answer.
# Multi-hop questions requiring real web search to answer.
# ---------------------------------------------------------------------------
SAMPLE_QUESTIONS = [
{
@@ -129,6 +142,58 @@ SAMPLE_QUESTIONS = [
]
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
class WebResearchEnvConfig(HermesAgentEnvConfig):
"""Configuration for the web research RL environment."""
# Reward weights
correctness_weight: float = Field(
default=0.6,
description="Weight for answer correctness in reward (LLM judge score).",
)
tool_usage_weight: float = Field(
default=0.2,
description="Weight for tool usage signal (did the model actually use web tools?).",
)
efficiency_weight: float = Field(
default=0.2,
description="Weight for efficiency signal (penalizes excessive tool calls).",
)
diversity_bonus: float = Field(
default=0.1,
description="Bonus reward for citing ≥2 distinct domains.",
)
# Efficiency thresholds
efficient_max_calls: int = Field(
default=5,
description="Maximum tool calls before efficiency penalty begins.",
)
heavy_penalty_calls: int = Field(
default=10,
description="Tool call count where efficiency penalty steepens.",
)
# Eval
eval_size: int = Field(
default=20,
description="Number of held-out items for evaluation.",
)
eval_split_ratio: float = Field(
default=0.1,
description="Fraction of dataset to hold out for evaluation (0.01.0).",
)
# Dataset
dataset_name: str = Field(
default="google/frames-benchmark",
description="HuggingFace dataset name for research questions.",
)
# ---------------------------------------------------------------------------
# Environment
# ---------------------------------------------------------------------------
@@ -143,23 +208,60 @@ class WebResearchEnv(HermesAgentBaseEnv):
Reward is multi-signal:
60% — answer correctness (LLM judge)
20% — tool usage (did the model actually search the web?)
20% — efficiency (penalizes >6 tool calls)
20% — efficiency (penalizes >5 tool calls)
Bonus +0.1 for source diversity (≥2 distinct domains cited).
"""
name = "web-research"
env_config_cls = WebResearchEnvConfig
# Default toolsets for this environment — web + file for saving notes
default_toolsets = ["web", "file"]
@classmethod
def config_init(cls) -> Tuple[WebResearchEnvConfig, List[APIServerConfig]]:
"""Default configuration for the web research environment."""
env_config = WebResearchEnvConfig(
enabled_toolsets=["web", "file"],
max_agent_turns=15,
agent_temperature=1.0,
system_prompt=(
"You are a highly capable research agent. When asked a factual question, "
"always use web_search to find current, accurate information before answering. "
"Cite at least 2 sources. Be concise and accurate."
),
group_size=4,
total_steps=1000,
steps_per_eval=100,
use_wandb=True,
wandb_name="web-research",
)
server_configs = [
APIServerConfig(
base_url="https://openrouter.ai/api/v1",
model_name="anthropic/claude-sonnet-4.5",
server_type="openai",
api_key=os.getenv("OPENROUTER_API_KEY", ""),
health_check=False,
)
]
return env_config, server_configs
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._items: list[dict] = []
self._eval_items: list[dict] = []
self._index: int = 0
self._total_scored: int = 0
self._total_reward: float = 0.0
# Metrics tracking for wandb
self._reward_buffer: list[float] = []
self._correctness_buffer: list[float] = []
self._tool_usage_buffer: list[float] = []
self._efficiency_buffer: list[float] = []
self._diversity_buffer: list[float] = []
# ------------------------------------------------------------------
# 1. Setup — load dataset
@@ -170,7 +272,7 @@ class WebResearchEnv(HermesAgentBaseEnv):
if HF_AVAILABLE:
try:
logger.info("Loading FRAMES benchmark from HuggingFace...")
ds = load_dataset("google/frames-benchmark", split="test")
ds = load_dataset(self.config.dataset_name, split="test")
self._items = [
{
"question": row["Prompt"],
@@ -180,8 +282,11 @@ class WebResearchEnv(HermesAgentBaseEnv):
}
for row in ds
]
# Hold out 10% for eval
eval_size = max(20, len(self._items) // 10)
# Hold out for eval
eval_size = max(
self.config.eval_size,
int(len(self._items) * self.config.eval_split_ratio),
)
random.shuffle(self._items)
self._eval_items = self._items[:eval_size]
self._items = self._items[eval_size:]
@@ -220,10 +325,7 @@ class WebResearchEnv(HermesAgentBaseEnv):
# ------------------------------------------------------------------
def format_prompt(self, item: dict) -> str:
"""
Format the research question as a task prompt.
Instructs the model to use web search and cite sources.
"""
"""Format the research question as a task prompt."""
return (
f"Research the following question thoroughly using web search. "
f"You MUST search the web to find current, accurate information — "
@@ -243,27 +345,30 @@ class WebResearchEnv(HermesAgentBaseEnv):
async def compute_reward(
self,
item: dict,
result: dict,
ctx: Any, # ToolContext
result: AgentResult,
ctx: ToolContext,
) -> float:
"""
Multi-signal reward function:
0.6 * correctness — LLM judge comparing answer to ground truth
0.2 * tool_used — binary: did the model use web tools?
0.2 * efficiency — penalizes wasteful tool usage
+0.1 bonus — source diversity (≥2 distinct domains)
correctness_weight * correctness — LLM judge comparing answer to ground truth
tool_usage_weight * tool_used — binary: did the model use web tools?
efficiency_weight * efficiency — penalizes wasteful tool usage
+ diversity_bonus — source diversity (≥2 distinct domains)
"""
final_response: str = result.get("final_response", "")
tools_used: list[str] = result.get("tools_used", [])
tool_call_count: int = result.get("tool_call_count", len(tools_used))
final_response: str = result.final_response or ""
tools_used: list[str] = [
tc.tool_name for tc in (result.tool_calls or [])
] if hasattr(result, "tool_calls") and result.tool_calls else []
tool_call_count: int = result.turns_used or len(tools_used)
cfg = self.config
# ---- Signal 1: Answer correctness (LLM judge) ----------------
correctness = await self._llm_judge(
question=item["question"],
expected=item["answer"],
model_answer=final_response,
ctx=ctx,
)
# ---- Signal 2: Web tool usage --------------------------------
@@ -271,35 +376,37 @@ class WebResearchEnv(HermesAgentBaseEnv):
tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0
# ---- Signal 3: Efficiency ------------------------------------
# Ideal: 2-5 tool calls. Penalise beyond 6, hard cap at 15.
if tool_call_count <= 5:
if tool_call_count <= cfg.efficient_max_calls:
efficiency = 1.0
elif tool_call_count <= 10:
efficiency = 1.0 - (tool_call_count - 5) * 0.08
elif tool_call_count <= cfg.heavy_penalty_calls:
efficiency = 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.08
else:
efficiency = max(0.0, 1.0 - (tool_call_count - 5) * 0.12)
efficiency = max(0.0, 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.12)
# ---- Bonus: Source diversity ---------------------------------
domains = self._extract_domains(final_response)
diversity_bonus = 0.1 if len(domains) >= 2 else 0.0
diversity = cfg.diversity_bonus if len(domains) >= 2 else 0.0
# ---- Combine ------------------------------------------------
reward = (
0.6 * correctness
+ 0.2 * tool_used
+ 0.2 * efficiency
+ diversity_bonus
cfg.correctness_weight * correctness
+ cfg.tool_usage_weight * tool_used
+ cfg.efficiency_weight * efficiency
+ diversity
)
reward = min(1.0, max(0.0, reward)) # clamp to [0, 1]
# Track running stats
self._total_scored += 1
self._total_reward += reward
# Track for wandb
self._reward_buffer.append(reward)
self._correctness_buffer.append(correctness)
self._tool_usage_buffer.append(tool_used)
self._efficiency_buffer.append(efficiency)
self._diversity_buffer.append(diversity)
logger.debug(
f"Reward breakdown — correctness={correctness:.2f}, "
f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, "
f"diversity_bonus={diversity_bonus:.1f} → total={reward:.3f}"
f"diversity={diversity:.1f} → total={reward:.3f}"
)
return reward
@@ -308,68 +415,117 @@ class WebResearchEnv(HermesAgentBaseEnv):
# 5. evaluate — run on held-out eval split
# ------------------------------------------------------------------
async def evaluate(
self,
*args: Any,
eval_size: Optional[int] = None,
**kwargs: Any,
) -> dict:
"""
Run evaluation on the held-out split.
Returns a dict of metrics for logging.
"""
items = self._eval_items
if eval_size:
items = items[:eval_size]
async def evaluate(self, *args, **kwargs) -> None:
"""Run evaluation on the held-out split using the agent loop."""
import time
items = self._eval_items
if not items:
logger.warning("No eval items available.")
return {}
return
logger.info(f"Running eval on {len(items)} questions...")
eval_size = min(self.config.eval_size, len(items))
eval_items = items[:eval_size]
rewards = []
correctness_scores = []
logger.info(f"Running eval on {len(eval_items)} questions...")
start_time = time.time()
samples = []
for item in items:
for item in eval_items:
try:
# Run the agent on each eval question
result = await self._run_agent_on_item(item)
reward = await self.compute_reward(item, result, ctx=None)
rewards.append(reward)
# Use the base env's agent loop for eval (same as training)
prompt = self.format_prompt(item)
completion = await self.server.chat_completion(
messages=[
{"role": "system", "content": self.config.system_prompt or ""},
{"role": "user", "content": prompt},
],
n=1,
max_tokens=self.config.max_token_length,
temperature=0.0,
split="eval",
)
response_content = (
completion.choices[0].message.content if completion.choices else ""
)
# Score the response
correctness = await self._llm_judge(
question=item["question"],
expected=item["answer"],
model_answer=response_content,
)
samples.append({
"prompt": item["question"],
"response": response_content,
"expected": item["answer"],
"correctness": correctness,
})
# Also track raw correctness separately
if result.get("final_response"):
correctness_scores.append(
await self._llm_judge(
question=item["question"],
expected=item["answer"],
model_answer=result["final_response"],
ctx=None,
)
)
except Exception as e:
logger.error(f"Eval error on item: {e}")
rewards.append(0.0)
samples.append({
"prompt": item["question"],
"response": f"ERROR: {e}",
"expected": item["answer"],
"correctness": 0.0,
})
metrics = {
"eval/mean_reward": sum(rewards) / len(rewards) if rewards else 0.0,
end_time = time.time()
# Compute metrics
correctness_scores = [s["correctness"] for s in samples]
eval_metrics = {
"eval/mean_correctness": (
sum(correctness_scores) / len(correctness_scores)
if correctness_scores else 0.0
),
"eval/n_items": len(rewards),
"train/mean_reward_so_far": (
self._total_reward / self._total_scored
if self._total_scored > 0 else 0.0
),
"eval/n_items": len(samples),
}
logger.info(
f"Eval complete — mean_reward={metrics['eval/mean_reward']:.3f}, "
f"mean_correctness={metrics['eval/mean_correctness']:.3f}"
await self.evaluate_log(
metrics=eval_metrics,
samples=samples,
start_time=start_time,
end_time=end_time,
)
return metrics
# ------------------------------------------------------------------
# 6. wandb_log — custom metrics
# ------------------------------------------------------------------
async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None:
"""Log reward breakdown metrics to wandb."""
if wandb_metrics is None:
wandb_metrics = {}
if self._reward_buffer:
n = len(self._reward_buffer)
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
wandb_metrics["train/mean_correctness"] = sum(self._correctness_buffer) / n
wandb_metrics["train/mean_tool_usage"] = sum(self._tool_usage_buffer) / n
wandb_metrics["train/mean_efficiency"] = sum(self._efficiency_buffer) / n
wandb_metrics["train/mean_diversity"] = sum(self._diversity_buffer) / n
wandb_metrics["train/total_rollouts"] = n
# Accuracy buckets
wandb_metrics["train/correct_rate"] = (
sum(1 for c in self._correctness_buffer if c >= 0.7) / n
)
wandb_metrics["train/tool_usage_rate"] = (
sum(1 for t in self._tool_usage_buffer if t > 0) / n
)
# Clear buffers
self._reward_buffer.clear()
self._correctness_buffer.clear()
self._tool_usage_buffer.clear()
self._efficiency_buffer.clear()
self._diversity_buffer.clear()
await super().wandb_log(wandb_metrics)
# ------------------------------------------------------------------
# Private helpers
@@ -380,19 +536,14 @@ class WebResearchEnv(HermesAgentBaseEnv):
question: str,
expected: str,
model_answer: str,
ctx: Any,
) -> float:
"""
Use an LLM to judge whether `model_answer` correctly addresses
`question` compared to `expected`. Returns a float in [0, 1].
Uses the agent's own inference client if ctx is available,
otherwise falls back to a lightweight heuristic.
Use the server's LLM to judge answer correctness.
Falls back to keyword heuristic if LLM call fails.
"""
if not model_answer or not model_answer.strip():
return 0.0
# Build judge prompt
judge_prompt = (
"You are an impartial judge evaluating the quality of an AI research answer.\n\n"
f"Question: {question}\n\n"
@@ -405,39 +556,36 @@ class WebResearchEnv(HermesAgentBaseEnv):
" 0.1 = mentions relevant topic but wrong or very incomplete\n"
" 0.0 = completely wrong or no answer\n\n"
"Consider: factual accuracy, completeness, and relevance.\n"
"Respond with ONLY a JSON object: {\"score\": <float>, \"reason\": \"<one sentence>\"}"
'Respond with ONLY a JSON object: {"score": <float>, "reason": "<one sentence>"}'
)
# Try using ctx for inference (Phase 2 / live training)
if ctx is not None and hasattr(ctx, "chat_completion"):
try:
response = await ctx.chat_completion(
messages=[{"role": "user", "content": judge_prompt}],
max_tokens=100,
temperature=0.0,
)
text = response.get("content", "")
parsed = self._parse_judge_json(text)
if parsed is not None:
return float(parsed)
except Exception as e:
logger.debug(f"LLM judge via ctx failed: {e}. Using heuristic.")
try:
response = await self.server.chat_completion(
messages=[{"role": "user", "content": judge_prompt}],
n=1,
max_tokens=150,
temperature=0.0,
split="eval",
)
text = response.choices[0].message.content if response.choices else ""
parsed = self._parse_judge_json(text)
if parsed is not None:
return float(parsed)
except Exception as e:
logger.debug(f"LLM judge failed: {e}. Using heuristic.")
# Fallback: keyword overlap heuristic
return self._heuristic_score(expected, model_answer)
@staticmethod
def _parse_judge_json(text: str) -> Optional[float]:
"""Extract the score float from LLM judge JSON response."""
try:
# Strip markdown code fences if present
clean = re.sub(r"```(?:json)?|```", "", text).strip()
data = json.loads(clean)
score = float(data.get("score", -1))
if 0.0 <= score <= 1.0:
return score
except Exception:
# Try regex fallback
match = re.search(r'"score"\s*:\s*([0-9.]+)', text)
if match:
score = float(match.group(1))
@@ -447,10 +595,7 @@ class WebResearchEnv(HermesAgentBaseEnv):
@staticmethod
def _heuristic_score(expected: str, model_answer: str) -> float:
"""
Lightweight keyword overlap score as fallback when no LLM is available.
Extracts meaningful tokens and computes Jaccard similarity.
"""
"""Lightweight keyword overlap score as fallback."""
stopwords = {
"the", "a", "an", "is", "are", "was", "were", "of", "in", "on",
"at", "to", "for", "with", "and", "or", "but", "it", "its",
@@ -458,35 +603,30 @@ class WebResearchEnv(HermesAgentBaseEnv):
}
def tokenize(text: str) -> set:
tokens = re.findall(r'\b[a-zA-Z0-9]+\b', text.lower())
tokens = re.findall(r'\b\w+\b', text.lower())
return {t for t in tokens if t not in stopwords and len(t) > 2}
expected_tokens = tokenize(expected)
answer_tokens = tokenize(model_answer)
if not expected_tokens:
return 0.5 # Can't judge
return 0.5
overlap = len(expected_tokens & answer_tokens)
union = len(expected_tokens | answer_tokens)
jaccard = overlap / union if union > 0 else 0.0
# Recall-weighted: reward covering expected content
recall = overlap / len(expected_tokens)
return min(1.0, 0.4 * jaccard + 0.6 * recall)
@staticmethod
def _extract_domains(text: str) -> set:
"""
Extract unique domains from URLs cited in the response.
Used to measure source diversity.
"""
"""Extract unique domains from URLs cited in the response."""
urls = re.findall(r'https?://[^\s\)>\]"\']+', text)
domains = set()
for url in urls:
try:
parsed = urlparse(url)
# Normalize: strip www.
domain = parsed.netloc.lower().lstrip("www.")
if domain:
domains.add(domain)
@@ -494,20 +634,6 @@ class WebResearchEnv(HermesAgentBaseEnv):
pass
return domains
async def _run_agent_on_item(self, item: dict) -> dict:
"""
Stub for running agent during eval. In Phase 1/2, this is handled
by the Atropos framework's rollout mechanism. Provided here for
standalone eval compatibility.
"""
# In real usage, the framework calls get_next_item + format_prompt
# and runs the agent. This stub returns an empty result for safety.
return {
"final_response": "",
"tools_used": [],
"tool_call_count": 0,
}
# ---------------------------------------------------------------------------
# Entry point