diff --git a/datagen-config-examples/web_research.yaml b/datagen-config-examples/web_research.yaml new file mode 100644 index 00000000..6275dbed --- /dev/null +++ b/datagen-config-examples/web_research.yaml @@ -0,0 +1,46 @@ +# datagen-config-examples/web_research.yaml +# +# Batch data generation config for WebResearchEnv. +# Generates tool-calling trajectories for multi-step web research tasks. +# +# Usage: +# python batch_runner.py \ +# --config datagen-config-examples/web_research.yaml \ +# --run_name web_research_v1 + +environment: web-research + +# Toolsets available to the agent during data generation +toolsets: + - web + - file + +# How many parallel workers to use +num_workers: 4 + +# Questions per batch +batch_size: 20 + +# Total trajectories to generate (comment out to run full dataset) +max_items: 500 + +# Model to use for generation (override with --model flag) +model: openrouter/nousresearch/hermes-3-llama-3.1-405b + +# System prompt additions (ephemeral — not saved to trajectories) +ephemeral_system_prompt: | + You are a highly capable research agent. When asked a factual question, + always use web_search to find current, accurate information before answering. + Cite at least 2 sources. Be concise and accurate. + +# Output directory +output_dir: data/web_research_v1 + +# Trajectory compression settings (for fitting into training token budgets) +compression: + enabled: true + target_max_tokens: 16000 + +# Eval settings +eval_every: 100 # Run eval every N trajectories +eval_size: 25 # Number of held-out questions per eval run diff --git a/environments/web_research_env.py b/environments/web_research_env.py new file mode 100644 index 00000000..e73eb45c --- /dev/null +++ b/environments/web_research_env.py @@ -0,0 +1,517 @@ +""" +WebResearchEnv — RL Environment for Multi-Step Web Research +============================================================ + +Trains models to do accurate, efficient, multi-source web research. + +Reward signals: + - Answer correctness (LLM judge, 0.0–1.0) + - Source diversity (used ≥2 distinct domains) + - Efficiency (penalizes excessive tool calls) + - Tool usage (bonus for actually using web tools) + +Dataset: FRAMES benchmark (Google, 2024) — multi-hop factual questions + HuggingFace: google/frames-benchmark + Fallback: built-in sample questions (no HF token needed) + +Usage: + # Phase 1 (OpenAI-compatible server) + python environments/web_research_env.py serve \ + --openai.base_url http://localhost:8000/v1 \ + --openai.model_name YourModel \ + --openai.server_type openai + + # With eval split + python environments/web_research_env.py serve \ + --openai.base_url http://localhost:8000/v1 \ + --openai.model_name YourModel \ + --env.eval_every 50 \ + --env.eval_size 20 + + # Standalone eval (no training server needed) + python environments/web_research_env.py eval \ + --openai.base_url http://localhost:8000/v1 \ + --openai.model_name YourModel + +Built by: github.com/jackx707 +Inspired by: GroceryMind — production Hermes agent doing live web research + across German grocery stores (firecrawl + hermes-agent) +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import random +import re +from typing import Any, Optional +from urllib.parse import urlparse + +# --------------------------------------------------------------------------- +# Optional HuggingFace datasets import +# --------------------------------------------------------------------------- +try: + from datasets import load_dataset + HF_AVAILABLE = True +except ImportError: + HF_AVAILABLE = False + +from environments.hermes_base_env import HermesAgentBaseEnv + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Fallback sample dataset (used when HuggingFace is unavailable) +# These are multi-hop questions that require real web search to answer. +# --------------------------------------------------------------------------- +SAMPLE_QUESTIONS = [ + { + "question": "What is the current population of the capital city of the country that won the 2022 FIFA World Cup?", + "answer": "Buenos Aires has approximately 3 million people in the city proper, or around 15 million in the greater metro area.", + "difficulty": "medium", + "hops": 2, + }, + { + "question": "Who is the CEO of the company that makes the most widely used open-source container orchestration platform?", + "answer": "The Linux Foundation oversees Kubernetes. CNCF (Cloud Native Computing Foundation) is the specific body — it does not have a traditional CEO but has an executive director.", + "difficulty": "medium", + "hops": 2, + }, + { + "question": "What programming language was used to write the original version of the web framework used by Instagram?", + "answer": "Django, which Instagram was built on, is written in Python.", + "difficulty": "easy", + "hops": 2, + }, + { + "question": "In what year was the university founded where the inventor of the World Wide Web currently holds a professorship?", + "answer": "Tim Berners-Lee holds a professorship at MIT (founded 1861) and the University of Southampton (founded 1952).", + "difficulty": "hard", + "hops": 3, + }, + { + "question": "What is the latest stable version of the programming language that ranks #1 on the TIOBE index as of this year?", + "answer": "Python is currently #1 on TIOBE. The latest stable version should be verified via the official python.org site.", + "difficulty": "medium", + "hops": 2, + }, + { + "question": "How many employees does the parent company of Instagram have?", + "answer": "Meta Platforms (parent of Instagram) employs approximately 70,000+ people as of recent reports.", + "difficulty": "medium", + "hops": 2, + }, + { + "question": "What is the current interest rate set by the central bank of the country where the Eiffel Tower is located?", + "answer": "The European Central Bank sets rates for France/eurozone. The current rate should be verified — it has changed frequently in 2023-2025.", + "difficulty": "hard", + "hops": 2, + }, + { + "question": "Which company acquired the startup founded by the creator of Oculus VR?", + "answer": "Palmer Luckey founded Oculus VR, which was acquired by Facebook (now Meta). He later founded Anduril Industries.", + "difficulty": "medium", + "hops": 2, + }, + { + "question": "What is the market cap of the company that owns the most popular search engine in Russia?", + "answer": "Yandex (now split into separate entities after 2024 restructuring). Current market cap should be verified via financial sources.", + "difficulty": "hard", + "hops": 2, + }, + { + "question": "What was the GDP growth rate of the country that hosted the most recent Summer Olympics?", + "answer": "Paris, France hosted the 2024 Summer Olympics. France's recent GDP growth should be verified via World Bank or IMF data.", + "difficulty": "hard", + "hops": 2, + }, +] + + +# --------------------------------------------------------------------------- +# Environment +# --------------------------------------------------------------------------- + +class WebResearchEnv(HermesAgentBaseEnv): + """ + RL environment for training multi-step web research skills. + + The model is given a factual question requiring 2-3 hops of web research + and must use web_search / web_extract tools to find and synthesize the answer. + + Reward is multi-signal: + 60% — answer correctness (LLM judge) + 20% — tool usage (did the model actually search the web?) + 20% — efficiency (penalizes >6 tool calls) + + Bonus +0.1 for source diversity (≥2 distinct domains cited). + """ + + name = "web-research" + + # Default toolsets for this environment — web + file for saving notes + default_toolsets = ["web", "file"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._items: list[dict] = [] + self._eval_items: list[dict] = [] + self._index: int = 0 + self._total_scored: int = 0 + self._total_reward: float = 0.0 + + # ------------------------------------------------------------------ + # 1. Setup — load dataset + # ------------------------------------------------------------------ + + async def setup(self) -> None: + """Load the FRAMES benchmark or fall back to built-in samples.""" + if HF_AVAILABLE: + try: + logger.info("Loading FRAMES benchmark from HuggingFace...") + ds = load_dataset("google/frames-benchmark", split="test") + self._items = [ + { + "question": row["Prompt"], + "answer": row["Answer"], + "difficulty": row.get("reasoning_types", "unknown"), + "hops": 2, + } + for row in ds + ] + # Hold out 10% for eval + eval_size = max(20, len(self._items) // 10) + random.shuffle(self._items) + self._eval_items = self._items[:eval_size] + self._items = self._items[eval_size:] + logger.info( + f"Loaded {len(self._items)} train / {len(self._eval_items)} eval items " + f"from FRAMES benchmark." + ) + return + except Exception as e: + logger.warning(f"Could not load FRAMES from HuggingFace: {e}. Using built-in samples.") + + # Fallback + random.shuffle(SAMPLE_QUESTIONS) + split = max(1, len(SAMPLE_QUESTIONS) * 8 // 10) + self._items = SAMPLE_QUESTIONS[:split] + self._eval_items = SAMPLE_QUESTIONS[split:] + logger.info( + f"Using built-in sample dataset: {len(self._items)} train / " + f"{len(self._eval_items)} eval items." + ) + + # ------------------------------------------------------------------ + # 2. get_next_item — return the next question + # ------------------------------------------------------------------ + + async def get_next_item(self) -> dict: + """Return the next item, cycling through the dataset.""" + if not self._items: + raise RuntimeError("Dataset is empty. Did you call setup()?") + item = self._items[self._index % len(self._items)] + self._index += 1 + return item + + # ------------------------------------------------------------------ + # 3. format_prompt — build the user-facing prompt + # ------------------------------------------------------------------ + + def format_prompt(self, item: dict) -> str: + """ + Format the research question as a task prompt. + Instructs the model to use web search and cite sources. + """ + return ( + f"Research the following question thoroughly using web search. " + f"You MUST search the web to find current, accurate information — " + f"do not rely solely on your training data.\n\n" + f"Question: {item['question']}\n\n" + f"Requirements:\n" + f"- Use web_search and/or web_extract tools to find information\n" + f"- Search at least 2 different sources\n" + f"- Provide a concise, accurate answer (2-4 sentences)\n" + f"- Cite the sources you used" + ) + + # ------------------------------------------------------------------ + # 4. compute_reward — multi-signal scoring + # ------------------------------------------------------------------ + + async def compute_reward( + self, + item: dict, + result: dict, + ctx: Any, # ToolContext + ) -> float: + """ + Multi-signal reward function: + + 0.6 * correctness — LLM judge comparing answer to ground truth + 0.2 * tool_used — binary: did the model use web tools? + 0.2 * efficiency — penalizes wasteful tool usage + +0.1 bonus — source diversity (≥2 distinct domains) + """ + final_response: str = result.get("final_response", "") + tools_used: list[str] = result.get("tools_used", []) + tool_call_count: int = result.get("tool_call_count", len(tools_used)) + + # ---- Signal 1: Answer correctness (LLM judge) ---------------- + correctness = await self._llm_judge( + question=item["question"], + expected=item["answer"], + model_answer=final_response, + ctx=ctx, + ) + + # ---- Signal 2: Web tool usage -------------------------------- + web_tools = {"web_search", "web_extract", "search", "firecrawl"} + tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0 + + # ---- Signal 3: Efficiency ------------------------------------ + # Ideal: 2-5 tool calls. Penalise beyond 6, hard cap at 15. + if tool_call_count <= 5: + efficiency = 1.0 + elif tool_call_count <= 10: + efficiency = 1.0 - (tool_call_count - 5) * 0.08 + else: + efficiency = max(0.0, 1.0 - (tool_call_count - 5) * 0.12) + + # ---- Bonus: Source diversity --------------------------------- + domains = self._extract_domains(final_response) + diversity_bonus = 0.1 if len(domains) >= 2 else 0.0 + + # ---- Combine ------------------------------------------------ + reward = ( + 0.6 * correctness + + 0.2 * tool_used + + 0.2 * efficiency + + diversity_bonus + ) + reward = min(1.0, max(0.0, reward)) # clamp to [0, 1] + + # Track running stats + self._total_scored += 1 + self._total_reward += reward + + logger.debug( + f"Reward breakdown — correctness={correctness:.2f}, " + f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, " + f"diversity_bonus={diversity_bonus:.1f} → total={reward:.3f}" + ) + + return reward + + # ------------------------------------------------------------------ + # 5. evaluate — run on held-out eval split + # ------------------------------------------------------------------ + + async def evaluate( + self, + *args: Any, + eval_size: Optional[int] = None, + **kwargs: Any, + ) -> dict: + """ + Run evaluation on the held-out split. + Returns a dict of metrics for logging. + """ + items = self._eval_items + if eval_size: + items = items[:eval_size] + + if not items: + logger.warning("No eval items available.") + return {} + + logger.info(f"Running eval on {len(items)} questions...") + + rewards = [] + correctness_scores = [] + + for item in items: + try: + # Run the agent on each eval question + result = await self._run_agent_on_item(item) + reward = await self.compute_reward(item, result, ctx=None) + rewards.append(reward) + + # Also track raw correctness separately + if result.get("final_response"): + correctness_scores.append( + await self._llm_judge( + question=item["question"], + expected=item["answer"], + model_answer=result["final_response"], + ctx=None, + ) + ) + except Exception as e: + logger.error(f"Eval error on item: {e}") + rewards.append(0.0) + + metrics = { + "eval/mean_reward": sum(rewards) / len(rewards) if rewards else 0.0, + "eval/mean_correctness": ( + sum(correctness_scores) / len(correctness_scores) + if correctness_scores else 0.0 + ), + "eval/n_items": len(rewards), + "train/mean_reward_so_far": ( + self._total_reward / self._total_scored + if self._total_scored > 0 else 0.0 + ), + } + + logger.info( + f"Eval complete — mean_reward={metrics['eval/mean_reward']:.3f}, " + f"mean_correctness={metrics['eval/mean_correctness']:.3f}" + ) + return metrics + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + async def _llm_judge( + self, + question: str, + expected: str, + model_answer: str, + ctx: Any, + ) -> float: + """ + Use an LLM to judge whether `model_answer` correctly addresses + `question` compared to `expected`. Returns a float in [0, 1]. + + Uses the agent's own inference client if ctx is available, + otherwise falls back to a lightweight heuristic. + """ + if not model_answer or not model_answer.strip(): + return 0.0 + + # Build judge prompt + judge_prompt = ( + "You are an impartial judge evaluating the quality of an AI research answer.\n\n" + f"Question: {question}\n\n" + f"Reference answer: {expected}\n\n" + f"Model answer: {model_answer}\n\n" + "Score the model answer on a scale from 0.0 to 1.0 where:\n" + " 1.0 = fully correct and complete\n" + " 0.7 = mostly correct with minor gaps\n" + " 0.4 = partially correct\n" + " 0.1 = mentions relevant topic but wrong or very incomplete\n" + " 0.0 = completely wrong or no answer\n\n" + "Consider: factual accuracy, completeness, and relevance.\n" + "Respond with ONLY a JSON object: {\"score\": , \"reason\": \"\"}" + ) + + # Try using ctx for inference (Phase 2 / live training) + if ctx is not None and hasattr(ctx, "chat_completion"): + try: + response = await ctx.chat_completion( + messages=[{"role": "user", "content": judge_prompt}], + max_tokens=100, + temperature=0.0, + ) + text = response.get("content", "") + parsed = self._parse_judge_json(text) + if parsed is not None: + return float(parsed) + except Exception as e: + logger.debug(f"LLM judge via ctx failed: {e}. Using heuristic.") + + # Fallback: keyword overlap heuristic + return self._heuristic_score(expected, model_answer) + + @staticmethod + def _parse_judge_json(text: str) -> Optional[float]: + """Extract the score float from LLM judge JSON response.""" + try: + # Strip markdown code fences if present + clean = re.sub(r"```(?:json)?|```", "", text).strip() + data = json.loads(clean) + score = float(data.get("score", -1)) + if 0.0 <= score <= 1.0: + return score + except Exception: + # Try regex fallback + match = re.search(r'"score"\s*:\s*([0-9.]+)', text) + if match: + score = float(match.group(1)) + if 0.0 <= score <= 1.0: + return score + return None + + @staticmethod + def _heuristic_score(expected: str, model_answer: str) -> float: + """ + Lightweight keyword overlap score as fallback when no LLM is available. + Extracts meaningful tokens and computes Jaccard similarity. + """ + stopwords = { + "the", "a", "an", "is", "are", "was", "were", "of", "in", "on", + "at", "to", "for", "with", "and", "or", "but", "it", "its", + "this", "that", "as", "by", "from", "be", "has", "have", "had", + } + + def tokenize(text: str) -> set: + tokens = re.findall(r'\b[a-zA-Z0-9]+\b', text.lower()) + return {t for t in tokens if t not in stopwords and len(t) > 2} + + expected_tokens = tokenize(expected) + answer_tokens = tokenize(model_answer) + + if not expected_tokens: + return 0.5 # Can't judge + + overlap = len(expected_tokens & answer_tokens) + union = len(expected_tokens | answer_tokens) + + jaccard = overlap / union if union > 0 else 0.0 + # Recall-weighted: reward covering expected content + recall = overlap / len(expected_tokens) + return min(1.0, 0.4 * jaccard + 0.6 * recall) + + @staticmethod + def _extract_domains(text: str) -> set: + """ + Extract unique domains from URLs cited in the response. + Used to measure source diversity. + """ + urls = re.findall(r'https?://[^\s\)>\]"\']+', text) + domains = set() + for url in urls: + try: + parsed = urlparse(url) + # Normalize: strip www. + domain = parsed.netloc.lower().lstrip("www.") + if domain: + domains.add(domain) + except Exception: + pass + return domains + + async def _run_agent_on_item(self, item: dict) -> dict: + """ + Stub for running agent during eval. In Phase 1/2, this is handled + by the Atropos framework's rollout mechanism. Provided here for + standalone eval compatibility. + """ + # In real usage, the framework calls get_next_item + format_prompt + # and runs the agent. This stub returns an empty result for safety. + return { + "final_response": "", + "tools_used": [], + "tool_call_count": 0, + } + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + WebResearchEnv.cli()