feat: add WebResearchEnv RL environment for multi-step web research

This commit is contained in:
jackx707
2026-03-05 14:34:36 +00:00
parent ada3713e77
commit 15561ec425
2 changed files with 563 additions and 0 deletions

View File

@@ -0,0 +1,46 @@
# datagen-config-examples/web_research.yaml
#
# Batch data generation config for WebResearchEnv.
# Generates tool-calling trajectories for multi-step web research tasks.
#
# Usage:
# python batch_runner.py \
# --config datagen-config-examples/web_research.yaml \
# --run_name web_research_v1
environment: web-research
# Toolsets available to the agent during data generation
toolsets:
- web
- file
# How many parallel workers to use
num_workers: 4
# Questions per batch
batch_size: 20
# Total trajectories to generate (comment out to run full dataset)
max_items: 500
# Model to use for generation (override with --model flag)
model: openrouter/nousresearch/hermes-3-llama-3.1-405b
# System prompt additions (ephemeral — not saved to trajectories)
ephemeral_system_prompt: |
You are a highly capable research agent. When asked a factual question,
always use web_search to find current, accurate information before answering.
Cite at least 2 sources. Be concise and accurate.
# Output directory
output_dir: data/web_research_v1
# Trajectory compression settings (for fitting into training token budgets)
compression:
enabled: true
target_max_tokens: 16000
# Eval settings
eval_every: 100 # Run eval every N trajectories
eval_size: 25 # Number of held-out questions per eval run

View File

@@ -0,0 +1,517 @@
"""
WebResearchEnv — RL Environment for Multi-Step Web Research
============================================================
Trains models to do accurate, efficient, multi-source web research.
Reward signals:
- Answer correctness (LLM judge, 0.01.0)
- Source diversity (used ≥2 distinct domains)
- Efficiency (penalizes excessive tool calls)
- Tool usage (bonus for actually using web tools)
Dataset: FRAMES benchmark (Google, 2024) — multi-hop factual questions
HuggingFace: google/frames-benchmark
Fallback: built-in sample questions (no HF token needed)
Usage:
# Phase 1 (OpenAI-compatible server)
python environments/web_research_env.py serve \
--openai.base_url http://localhost:8000/v1 \
--openai.model_name YourModel \
--openai.server_type openai
# With eval split
python environments/web_research_env.py serve \
--openai.base_url http://localhost:8000/v1 \
--openai.model_name YourModel \
--env.eval_every 50 \
--env.eval_size 20
# Standalone eval (no training server needed)
python environments/web_research_env.py eval \
--openai.base_url http://localhost:8000/v1 \
--openai.model_name YourModel
Built by: github.com/jackx707
Inspired by: GroceryMind — production Hermes agent doing live web research
across German grocery stores (firecrawl + hermes-agent)
"""
from __future__ import annotations
import asyncio
import json
import logging
import random
import re
from typing import Any, Optional
from urllib.parse import urlparse
# ---------------------------------------------------------------------------
# Optional HuggingFace datasets import
# ---------------------------------------------------------------------------
try:
from datasets import load_dataset
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
from environments.hermes_base_env import HermesAgentBaseEnv
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Fallback sample dataset (used when HuggingFace is unavailable)
# These are multi-hop questions that require real web search to answer.
# ---------------------------------------------------------------------------
SAMPLE_QUESTIONS = [
{
"question": "What is the current population of the capital city of the country that won the 2022 FIFA World Cup?",
"answer": "Buenos Aires has approximately 3 million people in the city proper, or around 15 million in the greater metro area.",
"difficulty": "medium",
"hops": 2,
},
{
"question": "Who is the CEO of the company that makes the most widely used open-source container orchestration platform?",
"answer": "The Linux Foundation oversees Kubernetes. CNCF (Cloud Native Computing Foundation) is the specific body — it does not have a traditional CEO but has an executive director.",
"difficulty": "medium",
"hops": 2,
},
{
"question": "What programming language was used to write the original version of the web framework used by Instagram?",
"answer": "Django, which Instagram was built on, is written in Python.",
"difficulty": "easy",
"hops": 2,
},
{
"question": "In what year was the university founded where the inventor of the World Wide Web currently holds a professorship?",
"answer": "Tim Berners-Lee holds a professorship at MIT (founded 1861) and the University of Southampton (founded 1952).",
"difficulty": "hard",
"hops": 3,
},
{
"question": "What is the latest stable version of the programming language that ranks #1 on the TIOBE index as of this year?",
"answer": "Python is currently #1 on TIOBE. The latest stable version should be verified via the official python.org site.",
"difficulty": "medium",
"hops": 2,
},
{
"question": "How many employees does the parent company of Instagram have?",
"answer": "Meta Platforms (parent of Instagram) employs approximately 70,000+ people as of recent reports.",
"difficulty": "medium",
"hops": 2,
},
{
"question": "What is the current interest rate set by the central bank of the country where the Eiffel Tower is located?",
"answer": "The European Central Bank sets rates for France/eurozone. The current rate should be verified — it has changed frequently in 2023-2025.",
"difficulty": "hard",
"hops": 2,
},
{
"question": "Which company acquired the startup founded by the creator of Oculus VR?",
"answer": "Palmer Luckey founded Oculus VR, which was acquired by Facebook (now Meta). He later founded Anduril Industries.",
"difficulty": "medium",
"hops": 2,
},
{
"question": "What is the market cap of the company that owns the most popular search engine in Russia?",
"answer": "Yandex (now split into separate entities after 2024 restructuring). Current market cap should be verified via financial sources.",
"difficulty": "hard",
"hops": 2,
},
{
"question": "What was the GDP growth rate of the country that hosted the most recent Summer Olympics?",
"answer": "Paris, France hosted the 2024 Summer Olympics. France's recent GDP growth should be verified via World Bank or IMF data.",
"difficulty": "hard",
"hops": 2,
},
]
# ---------------------------------------------------------------------------
# Environment
# ---------------------------------------------------------------------------
class WebResearchEnv(HermesAgentBaseEnv):
"""
RL environment for training multi-step web research skills.
The model is given a factual question requiring 2-3 hops of web research
and must use web_search / web_extract tools to find and synthesize the answer.
Reward is multi-signal:
60% — answer correctness (LLM judge)
20% — tool usage (did the model actually search the web?)
20% — efficiency (penalizes >6 tool calls)
Bonus +0.1 for source diversity (≥2 distinct domains cited).
"""
name = "web-research"
# Default toolsets for this environment — web + file for saving notes
default_toolsets = ["web", "file"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._items: list[dict] = []
self._eval_items: list[dict] = []
self._index: int = 0
self._total_scored: int = 0
self._total_reward: float = 0.0
# ------------------------------------------------------------------
# 1. Setup — load dataset
# ------------------------------------------------------------------
async def setup(self) -> None:
"""Load the FRAMES benchmark or fall back to built-in samples."""
if HF_AVAILABLE:
try:
logger.info("Loading FRAMES benchmark from HuggingFace...")
ds = load_dataset("google/frames-benchmark", split="test")
self._items = [
{
"question": row["Prompt"],
"answer": row["Answer"],
"difficulty": row.get("reasoning_types", "unknown"),
"hops": 2,
}
for row in ds
]
# Hold out 10% for eval
eval_size = max(20, len(self._items) // 10)
random.shuffle(self._items)
self._eval_items = self._items[:eval_size]
self._items = self._items[eval_size:]
logger.info(
f"Loaded {len(self._items)} train / {len(self._eval_items)} eval items "
f"from FRAMES benchmark."
)
return
except Exception as e:
logger.warning(f"Could not load FRAMES from HuggingFace: {e}. Using built-in samples.")
# Fallback
random.shuffle(SAMPLE_QUESTIONS)
split = max(1, len(SAMPLE_QUESTIONS) * 8 // 10)
self._items = SAMPLE_QUESTIONS[:split]
self._eval_items = SAMPLE_QUESTIONS[split:]
logger.info(
f"Using built-in sample dataset: {len(self._items)} train / "
f"{len(self._eval_items)} eval items."
)
# ------------------------------------------------------------------
# 2. get_next_item — return the next question
# ------------------------------------------------------------------
async def get_next_item(self) -> dict:
"""Return the next item, cycling through the dataset."""
if not self._items:
raise RuntimeError("Dataset is empty. Did you call setup()?")
item = self._items[self._index % len(self._items)]
self._index += 1
return item
# ------------------------------------------------------------------
# 3. format_prompt — build the user-facing prompt
# ------------------------------------------------------------------
def format_prompt(self, item: dict) -> str:
"""
Format the research question as a task prompt.
Instructs the model to use web search and cite sources.
"""
return (
f"Research the following question thoroughly using web search. "
f"You MUST search the web to find current, accurate information — "
f"do not rely solely on your training data.\n\n"
f"Question: {item['question']}\n\n"
f"Requirements:\n"
f"- Use web_search and/or web_extract tools to find information\n"
f"- Search at least 2 different sources\n"
f"- Provide a concise, accurate answer (2-4 sentences)\n"
f"- Cite the sources you used"
)
# ------------------------------------------------------------------
# 4. compute_reward — multi-signal scoring
# ------------------------------------------------------------------
async def compute_reward(
self,
item: dict,
result: dict,
ctx: Any, # ToolContext
) -> float:
"""
Multi-signal reward function:
0.6 * correctness — LLM judge comparing answer to ground truth
0.2 * tool_used — binary: did the model use web tools?
0.2 * efficiency — penalizes wasteful tool usage
+0.1 bonus — source diversity (≥2 distinct domains)
"""
final_response: str = result.get("final_response", "")
tools_used: list[str] = result.get("tools_used", [])
tool_call_count: int = result.get("tool_call_count", len(tools_used))
# ---- Signal 1: Answer correctness (LLM judge) ----------------
correctness = await self._llm_judge(
question=item["question"],
expected=item["answer"],
model_answer=final_response,
ctx=ctx,
)
# ---- Signal 2: Web tool usage --------------------------------
web_tools = {"web_search", "web_extract", "search", "firecrawl"}
tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0
# ---- Signal 3: Efficiency ------------------------------------
# Ideal: 2-5 tool calls. Penalise beyond 6, hard cap at 15.
if tool_call_count <= 5:
efficiency = 1.0
elif tool_call_count <= 10:
efficiency = 1.0 - (tool_call_count - 5) * 0.08
else:
efficiency = max(0.0, 1.0 - (tool_call_count - 5) * 0.12)
# ---- Bonus: Source diversity ---------------------------------
domains = self._extract_domains(final_response)
diversity_bonus = 0.1 if len(domains) >= 2 else 0.0
# ---- Combine ------------------------------------------------
reward = (
0.6 * correctness
+ 0.2 * tool_used
+ 0.2 * efficiency
+ diversity_bonus
)
reward = min(1.0, max(0.0, reward)) # clamp to [0, 1]
# Track running stats
self._total_scored += 1
self._total_reward += reward
logger.debug(
f"Reward breakdown — correctness={correctness:.2f}, "
f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, "
f"diversity_bonus={diversity_bonus:.1f} → total={reward:.3f}"
)
return reward
# ------------------------------------------------------------------
# 5. evaluate — run on held-out eval split
# ------------------------------------------------------------------
async def evaluate(
self,
*args: Any,
eval_size: Optional[int] = None,
**kwargs: Any,
) -> dict:
"""
Run evaluation on the held-out split.
Returns a dict of metrics for logging.
"""
items = self._eval_items
if eval_size:
items = items[:eval_size]
if not items:
logger.warning("No eval items available.")
return {}
logger.info(f"Running eval on {len(items)} questions...")
rewards = []
correctness_scores = []
for item in items:
try:
# Run the agent on each eval question
result = await self._run_agent_on_item(item)
reward = await self.compute_reward(item, result, ctx=None)
rewards.append(reward)
# Also track raw correctness separately
if result.get("final_response"):
correctness_scores.append(
await self._llm_judge(
question=item["question"],
expected=item["answer"],
model_answer=result["final_response"],
ctx=None,
)
)
except Exception as e:
logger.error(f"Eval error on item: {e}")
rewards.append(0.0)
metrics = {
"eval/mean_reward": sum(rewards) / len(rewards) if rewards else 0.0,
"eval/mean_correctness": (
sum(correctness_scores) / len(correctness_scores)
if correctness_scores else 0.0
),
"eval/n_items": len(rewards),
"train/mean_reward_so_far": (
self._total_reward / self._total_scored
if self._total_scored > 0 else 0.0
),
}
logger.info(
f"Eval complete — mean_reward={metrics['eval/mean_reward']:.3f}, "
f"mean_correctness={metrics['eval/mean_correctness']:.3f}"
)
return metrics
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
async def _llm_judge(
self,
question: str,
expected: str,
model_answer: str,
ctx: Any,
) -> float:
"""
Use an LLM to judge whether `model_answer` correctly addresses
`question` compared to `expected`. Returns a float in [0, 1].
Uses the agent's own inference client if ctx is available,
otherwise falls back to a lightweight heuristic.
"""
if not model_answer or not model_answer.strip():
return 0.0
# Build judge prompt
judge_prompt = (
"You are an impartial judge evaluating the quality of an AI research answer.\n\n"
f"Question: {question}\n\n"
f"Reference answer: {expected}\n\n"
f"Model answer: {model_answer}\n\n"
"Score the model answer on a scale from 0.0 to 1.0 where:\n"
" 1.0 = fully correct and complete\n"
" 0.7 = mostly correct with minor gaps\n"
" 0.4 = partially correct\n"
" 0.1 = mentions relevant topic but wrong or very incomplete\n"
" 0.0 = completely wrong or no answer\n\n"
"Consider: factual accuracy, completeness, and relevance.\n"
"Respond with ONLY a JSON object: {\"score\": <float>, \"reason\": \"<one sentence>\"}"
)
# Try using ctx for inference (Phase 2 / live training)
if ctx is not None and hasattr(ctx, "chat_completion"):
try:
response = await ctx.chat_completion(
messages=[{"role": "user", "content": judge_prompt}],
max_tokens=100,
temperature=0.0,
)
text = response.get("content", "")
parsed = self._parse_judge_json(text)
if parsed is not None:
return float(parsed)
except Exception as e:
logger.debug(f"LLM judge via ctx failed: {e}. Using heuristic.")
# Fallback: keyword overlap heuristic
return self._heuristic_score(expected, model_answer)
@staticmethod
def _parse_judge_json(text: str) -> Optional[float]:
"""Extract the score float from LLM judge JSON response."""
try:
# Strip markdown code fences if present
clean = re.sub(r"```(?:json)?|```", "", text).strip()
data = json.loads(clean)
score = float(data.get("score", -1))
if 0.0 <= score <= 1.0:
return score
except Exception:
# Try regex fallback
match = re.search(r'"score"\s*:\s*([0-9.]+)', text)
if match:
score = float(match.group(1))
if 0.0 <= score <= 1.0:
return score
return None
@staticmethod
def _heuristic_score(expected: str, model_answer: str) -> float:
"""
Lightweight keyword overlap score as fallback when no LLM is available.
Extracts meaningful tokens and computes Jaccard similarity.
"""
stopwords = {
"the", "a", "an", "is", "are", "was", "were", "of", "in", "on",
"at", "to", "for", "with", "and", "or", "but", "it", "its",
"this", "that", "as", "by", "from", "be", "has", "have", "had",
}
def tokenize(text: str) -> set:
tokens = re.findall(r'\b[a-zA-Z0-9]+\b', text.lower())
return {t for t in tokens if t not in stopwords and len(t) > 2}
expected_tokens = tokenize(expected)
answer_tokens = tokenize(model_answer)
if not expected_tokens:
return 0.5 # Can't judge
overlap = len(expected_tokens & answer_tokens)
union = len(expected_tokens | answer_tokens)
jaccard = overlap / union if union > 0 else 0.0
# Recall-weighted: reward covering expected content
recall = overlap / len(expected_tokens)
return min(1.0, 0.4 * jaccard + 0.6 * recall)
@staticmethod
def _extract_domains(text: str) -> set:
"""
Extract unique domains from URLs cited in the response.
Used to measure source diversity.
"""
urls = re.findall(r'https?://[^\s\)>\]"\']+', text)
domains = set()
for url in urls:
try:
parsed = urlparse(url)
# Normalize: strip www.
domain = parsed.netloc.lower().lstrip("www.")
if domain:
domains.add(domain)
except Exception:
pass
return domains
async def _run_agent_on_item(self, item: dict) -> dict:
"""
Stub for running agent during eval. In Phase 1/2, this is handled
by the Atropos framework's rollout mechanism. Provided here for
standalone eval compatibility.
"""
# In real usage, the framework calls get_next_item + format_prompt
# and runs the agent. This stub returns an empty result for safety.
return {
"final_response": "",
"tools_used": [],
"tool_call_count": 0,
}
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
WebResearchEnv.cli()