diff --git a/environments/web_research_env.py b/environments/web_research_env.py index e73eb45c6..a868cd034 100644 --- a/environments/web_research_env.py +++ b/environments/web_research_env.py @@ -16,21 +16,18 @@ Dataset: FRAMES benchmark (Google, 2024) — multi-hop factual questions Usage: # Phase 1 (OpenAI-compatible server) - python environments/web_research_env.py serve \ - --openai.base_url http://localhost:8000/v1 \ - --openai.model_name YourModel \ + python environments/web_research_env.py serve \\ + --openai.base_url http://localhost:8000/v1 \\ + --openai.model_name YourModel \\ --openai.server_type openai - # With eval split - python environments/web_research_env.py serve \ - --openai.base_url http://localhost:8000/v1 \ - --openai.model_name YourModel \ - --env.eval_every 50 \ - --env.eval_size 20 + # Process mode (offline data generation) + python environments/web_research_env.py process \\ + --env.data_path_to_save_groups data/web_research.jsonl - # Standalone eval (no training server needed) - python environments/web_research_env.py eval \ - --openai.base_url http://localhost:8000/v1 \ + # Standalone eval + python environments/web_research_env.py evaluate \\ + --openai.base_url http://localhost:8000/v1 \\ --openai.model_name YourModel Built by: github.com/jackx707 @@ -43,11 +40,21 @@ from __future__ import annotations import asyncio import json import logging +import os import random import re -from typing import Any, Optional +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple from urllib.parse import urlparse +from pydantic import Field + +# Ensure hermes-agent root is on path +_repo_root = Path(__file__).resolve().parent.parent +if str(_repo_root) not in sys.path: + sys.path.insert(0, str(_repo_root)) + # --------------------------------------------------------------------------- # Optional HuggingFace datasets import # --------------------------------------------------------------------------- @@ -57,13 +64,19 @@ try: except ImportError: HF_AVAILABLE = False -from environments.hermes_base_env import HermesAgentBaseEnv +from atroposlib.envs.base import ScoredDataGroup +from atroposlib.envs.server_handling.server_manager import APIServerConfig +from atroposlib.type_definitions import Item + +from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig +from environments.agent_loop import AgentResult +from environments.tool_context import ToolContext logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Fallback sample dataset (used when HuggingFace is unavailable) -# These are multi-hop questions that require real web search to answer. +# Multi-hop questions requiring real web search to answer. # --------------------------------------------------------------------------- SAMPLE_QUESTIONS = [ { @@ -129,6 +142,58 @@ SAMPLE_QUESTIONS = [ ] +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +class WebResearchEnvConfig(HermesAgentEnvConfig): + """Configuration for the web research RL environment.""" + + # Reward weights + correctness_weight: float = Field( + default=0.6, + description="Weight for answer correctness in reward (LLM judge score).", + ) + tool_usage_weight: float = Field( + default=0.2, + description="Weight for tool usage signal (did the model actually use web tools?).", + ) + efficiency_weight: float = Field( + default=0.2, + description="Weight for efficiency signal (penalizes excessive tool calls).", + ) + diversity_bonus: float = Field( + default=0.1, + description="Bonus reward for citing ≥2 distinct domains.", + ) + + # Efficiency thresholds + efficient_max_calls: int = Field( + default=5, + description="Maximum tool calls before efficiency penalty begins.", + ) + heavy_penalty_calls: int = Field( + default=10, + description="Tool call count where efficiency penalty steepens.", + ) + + # Eval + eval_size: int = Field( + default=20, + description="Number of held-out items for evaluation.", + ) + eval_split_ratio: float = Field( + default=0.1, + description="Fraction of dataset to hold out for evaluation (0.0–1.0).", + ) + + # Dataset + dataset_name: str = Field( + default="google/frames-benchmark", + description="HuggingFace dataset name for research questions.", + ) + + # --------------------------------------------------------------------------- # Environment # --------------------------------------------------------------------------- @@ -143,23 +208,60 @@ class WebResearchEnv(HermesAgentBaseEnv): Reward is multi-signal: 60% — answer correctness (LLM judge) 20% — tool usage (did the model actually search the web?) - 20% — efficiency (penalizes >6 tool calls) + 20% — efficiency (penalizes >5 tool calls) Bonus +0.1 for source diversity (≥2 distinct domains cited). """ name = "web-research" + env_config_cls = WebResearchEnvConfig # Default toolsets for this environment — web + file for saving notes default_toolsets = ["web", "file"] + @classmethod + def config_init(cls) -> Tuple[WebResearchEnvConfig, List[APIServerConfig]]: + """Default configuration for the web research environment.""" + env_config = WebResearchEnvConfig( + enabled_toolsets=["web", "file"], + max_agent_turns=15, + agent_temperature=1.0, + system_prompt=( + "You are a highly capable research agent. When asked a factual question, " + "always use web_search to find current, accurate information before answering. " + "Cite at least 2 sources. Be concise and accurate." + ), + group_size=4, + total_steps=1000, + steps_per_eval=100, + use_wandb=True, + wandb_name="web-research", + ) + + server_configs = [ + APIServerConfig( + base_url="https://openrouter.ai/api/v1", + model_name="anthropic/claude-sonnet-4.5", + server_type="openai", + api_key=os.getenv("OPENROUTER_API_KEY", ""), + health_check=False, + ) + ] + + return env_config, server_configs + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._items: list[dict] = [] self._eval_items: list[dict] = [] self._index: int = 0 - self._total_scored: int = 0 - self._total_reward: float = 0.0 + + # Metrics tracking for wandb + self._reward_buffer: list[float] = [] + self._correctness_buffer: list[float] = [] + self._tool_usage_buffer: list[float] = [] + self._efficiency_buffer: list[float] = [] + self._diversity_buffer: list[float] = [] # ------------------------------------------------------------------ # 1. Setup — load dataset @@ -170,7 +272,7 @@ class WebResearchEnv(HermesAgentBaseEnv): if HF_AVAILABLE: try: logger.info("Loading FRAMES benchmark from HuggingFace...") - ds = load_dataset("google/frames-benchmark", split="test") + ds = load_dataset(self.config.dataset_name, split="test") self._items = [ { "question": row["Prompt"], @@ -180,8 +282,11 @@ class WebResearchEnv(HermesAgentBaseEnv): } for row in ds ] - # Hold out 10% for eval - eval_size = max(20, len(self._items) // 10) + # Hold out for eval + eval_size = max( + self.config.eval_size, + int(len(self._items) * self.config.eval_split_ratio), + ) random.shuffle(self._items) self._eval_items = self._items[:eval_size] self._items = self._items[eval_size:] @@ -220,10 +325,7 @@ class WebResearchEnv(HermesAgentBaseEnv): # ------------------------------------------------------------------ def format_prompt(self, item: dict) -> str: - """ - Format the research question as a task prompt. - Instructs the model to use web search and cite sources. - """ + """Format the research question as a task prompt.""" return ( f"Research the following question thoroughly using web search. " f"You MUST search the web to find current, accurate information — " @@ -243,27 +345,30 @@ class WebResearchEnv(HermesAgentBaseEnv): async def compute_reward( self, item: dict, - result: dict, - ctx: Any, # ToolContext + result: AgentResult, + ctx: ToolContext, ) -> float: """ Multi-signal reward function: - 0.6 * correctness — LLM judge comparing answer to ground truth - 0.2 * tool_used — binary: did the model use web tools? - 0.2 * efficiency — penalizes wasteful tool usage - +0.1 bonus — source diversity (≥2 distinct domains) + correctness_weight * correctness — LLM judge comparing answer to ground truth + tool_usage_weight * tool_used — binary: did the model use web tools? + efficiency_weight * efficiency — penalizes wasteful tool usage + + diversity_bonus — source diversity (≥2 distinct domains) """ - final_response: str = result.get("final_response", "") - tools_used: list[str] = result.get("tools_used", []) - tool_call_count: int = result.get("tool_call_count", len(tools_used)) + final_response: str = result.final_response or "" + tools_used: list[str] = [ + tc.tool_name for tc in (result.tool_calls or []) + ] if hasattr(result, "tool_calls") and result.tool_calls else [] + tool_call_count: int = result.turns_used or len(tools_used) + + cfg = self.config # ---- Signal 1: Answer correctness (LLM judge) ---------------- correctness = await self._llm_judge( question=item["question"], expected=item["answer"], model_answer=final_response, - ctx=ctx, ) # ---- Signal 2: Web tool usage -------------------------------- @@ -271,35 +376,37 @@ class WebResearchEnv(HermesAgentBaseEnv): tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0 # ---- Signal 3: Efficiency ------------------------------------ - # Ideal: 2-5 tool calls. Penalise beyond 6, hard cap at 15. - if tool_call_count <= 5: + if tool_call_count <= cfg.efficient_max_calls: efficiency = 1.0 - elif tool_call_count <= 10: - efficiency = 1.0 - (tool_call_count - 5) * 0.08 + elif tool_call_count <= cfg.heavy_penalty_calls: + efficiency = 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.08 else: - efficiency = max(0.0, 1.0 - (tool_call_count - 5) * 0.12) + efficiency = max(0.0, 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.12) # ---- Bonus: Source diversity --------------------------------- domains = self._extract_domains(final_response) - diversity_bonus = 0.1 if len(domains) >= 2 else 0.0 + diversity = cfg.diversity_bonus if len(domains) >= 2 else 0.0 # ---- Combine ------------------------------------------------ reward = ( - 0.6 * correctness - + 0.2 * tool_used - + 0.2 * efficiency - + diversity_bonus + cfg.correctness_weight * correctness + + cfg.tool_usage_weight * tool_used + + cfg.efficiency_weight * efficiency + + diversity ) reward = min(1.0, max(0.0, reward)) # clamp to [0, 1] - # Track running stats - self._total_scored += 1 - self._total_reward += reward + # Track for wandb + self._reward_buffer.append(reward) + self._correctness_buffer.append(correctness) + self._tool_usage_buffer.append(tool_used) + self._efficiency_buffer.append(efficiency) + self._diversity_buffer.append(diversity) logger.debug( f"Reward breakdown — correctness={correctness:.2f}, " f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, " - f"diversity_bonus={diversity_bonus:.1f} → total={reward:.3f}" + f"diversity={diversity:.1f} → total={reward:.3f}" ) return reward @@ -308,68 +415,117 @@ class WebResearchEnv(HermesAgentBaseEnv): # 5. evaluate — run on held-out eval split # ------------------------------------------------------------------ - async def evaluate( - self, - *args: Any, - eval_size: Optional[int] = None, - **kwargs: Any, - ) -> dict: - """ - Run evaluation on the held-out split. - Returns a dict of metrics for logging. - """ - items = self._eval_items - if eval_size: - items = items[:eval_size] + async def evaluate(self, *args, **kwargs) -> None: + """Run evaluation on the held-out split using the agent loop.""" + import time + items = self._eval_items if not items: logger.warning("No eval items available.") - return {} + return - logger.info(f"Running eval on {len(items)} questions...") + eval_size = min(self.config.eval_size, len(items)) + eval_items = items[:eval_size] - rewards = [] - correctness_scores = [] + logger.info(f"Running eval on {len(eval_items)} questions...") + start_time = time.time() + samples = [] - for item in items: + for item in eval_items: try: - # Run the agent on each eval question - result = await self._run_agent_on_item(item) - reward = await self.compute_reward(item, result, ctx=None) - rewards.append(reward) + # Use the base env's agent loop for eval (same as training) + prompt = self.format_prompt(item) + completion = await self.server.chat_completion( + messages=[ + {"role": "system", "content": self.config.system_prompt or ""}, + {"role": "user", "content": prompt}, + ], + n=1, + max_tokens=self.config.max_token_length, + temperature=0.0, + split="eval", + ) + + response_content = ( + completion.choices[0].message.content if completion.choices else "" + ) + + # Score the response + correctness = await self._llm_judge( + question=item["question"], + expected=item["answer"], + model_answer=response_content, + ) + + samples.append({ + "prompt": item["question"], + "response": response_content, + "expected": item["answer"], + "correctness": correctness, + }) - # Also track raw correctness separately - if result.get("final_response"): - correctness_scores.append( - await self._llm_judge( - question=item["question"], - expected=item["answer"], - model_answer=result["final_response"], - ctx=None, - ) - ) except Exception as e: logger.error(f"Eval error on item: {e}") - rewards.append(0.0) + samples.append({ + "prompt": item["question"], + "response": f"ERROR: {e}", + "expected": item["answer"], + "correctness": 0.0, + }) - metrics = { - "eval/mean_reward": sum(rewards) / len(rewards) if rewards else 0.0, + end_time = time.time() + + # Compute metrics + correctness_scores = [s["correctness"] for s in samples] + eval_metrics = { "eval/mean_correctness": ( sum(correctness_scores) / len(correctness_scores) if correctness_scores else 0.0 ), - "eval/n_items": len(rewards), - "train/mean_reward_so_far": ( - self._total_reward / self._total_scored - if self._total_scored > 0 else 0.0 - ), + "eval/n_items": len(samples), } - logger.info( - f"Eval complete — mean_reward={metrics['eval/mean_reward']:.3f}, " - f"mean_correctness={metrics['eval/mean_correctness']:.3f}" + await self.evaluate_log( + metrics=eval_metrics, + samples=samples, + start_time=start_time, + end_time=end_time, ) - return metrics + + # ------------------------------------------------------------------ + # 6. wandb_log — custom metrics + # ------------------------------------------------------------------ + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None: + """Log reward breakdown metrics to wandb.""" + if wandb_metrics is None: + wandb_metrics = {} + + if self._reward_buffer: + n = len(self._reward_buffer) + wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n + wandb_metrics["train/mean_correctness"] = sum(self._correctness_buffer) / n + wandb_metrics["train/mean_tool_usage"] = sum(self._tool_usage_buffer) / n + wandb_metrics["train/mean_efficiency"] = sum(self._efficiency_buffer) / n + wandb_metrics["train/mean_diversity"] = sum(self._diversity_buffer) / n + wandb_metrics["train/total_rollouts"] = n + + # Accuracy buckets + wandb_metrics["train/correct_rate"] = ( + sum(1 for c in self._correctness_buffer if c >= 0.7) / n + ) + wandb_metrics["train/tool_usage_rate"] = ( + sum(1 for t in self._tool_usage_buffer if t > 0) / n + ) + + # Clear buffers + self._reward_buffer.clear() + self._correctness_buffer.clear() + self._tool_usage_buffer.clear() + self._efficiency_buffer.clear() + self._diversity_buffer.clear() + + await super().wandb_log(wandb_metrics) # ------------------------------------------------------------------ # Private helpers @@ -380,19 +536,14 @@ class WebResearchEnv(HermesAgentBaseEnv): question: str, expected: str, model_answer: str, - ctx: Any, ) -> float: """ - Use an LLM to judge whether `model_answer` correctly addresses - `question` compared to `expected`. Returns a float in [0, 1]. - - Uses the agent's own inference client if ctx is available, - otherwise falls back to a lightweight heuristic. + Use the server's LLM to judge answer correctness. + Falls back to keyword heuristic if LLM call fails. """ if not model_answer or not model_answer.strip(): return 0.0 - # Build judge prompt judge_prompt = ( "You are an impartial judge evaluating the quality of an AI research answer.\n\n" f"Question: {question}\n\n" @@ -405,39 +556,36 @@ class WebResearchEnv(HermesAgentBaseEnv): " 0.1 = mentions relevant topic but wrong or very incomplete\n" " 0.0 = completely wrong or no answer\n\n" "Consider: factual accuracy, completeness, and relevance.\n" - "Respond with ONLY a JSON object: {\"score\": , \"reason\": \"\"}" + 'Respond with ONLY a JSON object: {"score": , "reason": ""}' ) - # Try using ctx for inference (Phase 2 / live training) - if ctx is not None and hasattr(ctx, "chat_completion"): - try: - response = await ctx.chat_completion( - messages=[{"role": "user", "content": judge_prompt}], - max_tokens=100, - temperature=0.0, - ) - text = response.get("content", "") - parsed = self._parse_judge_json(text) - if parsed is not None: - return float(parsed) - except Exception as e: - logger.debug(f"LLM judge via ctx failed: {e}. Using heuristic.") + try: + response = await self.server.chat_completion( + messages=[{"role": "user", "content": judge_prompt}], + n=1, + max_tokens=150, + temperature=0.0, + split="eval", + ) + text = response.choices[0].message.content if response.choices else "" + parsed = self._parse_judge_json(text) + if parsed is not None: + return float(parsed) + except Exception as e: + logger.debug(f"LLM judge failed: {e}. Using heuristic.") - # Fallback: keyword overlap heuristic return self._heuristic_score(expected, model_answer) @staticmethod def _parse_judge_json(text: str) -> Optional[float]: """Extract the score float from LLM judge JSON response.""" try: - # Strip markdown code fences if present clean = re.sub(r"```(?:json)?|```", "", text).strip() data = json.loads(clean) score = float(data.get("score", -1)) if 0.0 <= score <= 1.0: return score except Exception: - # Try regex fallback match = re.search(r'"score"\s*:\s*([0-9.]+)', text) if match: score = float(match.group(1)) @@ -447,10 +595,7 @@ class WebResearchEnv(HermesAgentBaseEnv): @staticmethod def _heuristic_score(expected: str, model_answer: str) -> float: - """ - Lightweight keyword overlap score as fallback when no LLM is available. - Extracts meaningful tokens and computes Jaccard similarity. - """ + """Lightweight keyword overlap score as fallback.""" stopwords = { "the", "a", "an", "is", "are", "was", "were", "of", "in", "on", "at", "to", "for", "with", "and", "or", "but", "it", "its", @@ -458,35 +603,30 @@ class WebResearchEnv(HermesAgentBaseEnv): } def tokenize(text: str) -> set: - tokens = re.findall(r'\b[a-zA-Z0-9]+\b', text.lower()) + tokens = re.findall(r'\b\w+\b', text.lower()) return {t for t in tokens if t not in stopwords and len(t) > 2} expected_tokens = tokenize(expected) answer_tokens = tokenize(model_answer) if not expected_tokens: - return 0.5 # Can't judge + return 0.5 overlap = len(expected_tokens & answer_tokens) union = len(expected_tokens | answer_tokens) jaccard = overlap / union if union > 0 else 0.0 - # Recall-weighted: reward covering expected content recall = overlap / len(expected_tokens) return min(1.0, 0.4 * jaccard + 0.6 * recall) @staticmethod def _extract_domains(text: str) -> set: - """ - Extract unique domains from URLs cited in the response. - Used to measure source diversity. - """ + """Extract unique domains from URLs cited in the response.""" urls = re.findall(r'https?://[^\s\)>\]"\']+', text) domains = set() for url in urls: try: parsed = urlparse(url) - # Normalize: strip www. domain = parsed.netloc.lower().lstrip("www.") if domain: domains.add(domain) @@ -494,20 +634,6 @@ class WebResearchEnv(HermesAgentBaseEnv): pass return domains - async def _run_agent_on_item(self, item: dict) -> dict: - """ - Stub for running agent during eval. In Phase 1/2, this is handled - by the Atropos framework's rollout mechanism. Provided here for - standalone eval compatibility. - """ - # In real usage, the framework calls get_next_item + format_prompt - # and runs the agent. This stub returns an empty result for safety. - return { - "final_response": "", - "tools_used": [], - "tool_call_count": 0, - } - # --------------------------------------------------------------------------- # Entry point