feat: add WebResearchEnv RL environment for multi-step web research
This commit is contained in:
46
datagen-config-examples/web_research.yaml
Normal file
46
datagen-config-examples/web_research.yaml
Normal file
@@ -0,0 +1,46 @@
|
||||
# datagen-config-examples/web_research.yaml
|
||||
#
|
||||
# Batch data generation config for WebResearchEnv.
|
||||
# Generates tool-calling trajectories for multi-step web research tasks.
|
||||
#
|
||||
# Usage:
|
||||
# python batch_runner.py \
|
||||
# --config datagen-config-examples/web_research.yaml \
|
||||
# --run_name web_research_v1
|
||||
|
||||
environment: web-research
|
||||
|
||||
# Toolsets available to the agent during data generation
|
||||
toolsets:
|
||||
- web
|
||||
- file
|
||||
|
||||
# How many parallel workers to use
|
||||
num_workers: 4
|
||||
|
||||
# Questions per batch
|
||||
batch_size: 20
|
||||
|
||||
# Total trajectories to generate (comment out to run full dataset)
|
||||
max_items: 500
|
||||
|
||||
# Model to use for generation (override with --model flag)
|
||||
model: openrouter/nousresearch/hermes-3-llama-3.1-405b
|
||||
|
||||
# System prompt additions (ephemeral — not saved to trajectories)
|
||||
ephemeral_system_prompt: |
|
||||
You are a highly capable research agent. When asked a factual question,
|
||||
always use web_search to find current, accurate information before answering.
|
||||
Cite at least 2 sources. Be concise and accurate.
|
||||
|
||||
# Output directory
|
||||
output_dir: data/web_research_v1
|
||||
|
||||
# Trajectory compression settings (for fitting into training token budgets)
|
||||
compression:
|
||||
enabled: true
|
||||
target_max_tokens: 16000
|
||||
|
||||
# Eval settings
|
||||
eval_every: 100 # Run eval every N trajectories
|
||||
eval_size: 25 # Number of held-out questions per eval run
|
||||
517
environments/web_research_env.py
Normal file
517
environments/web_research_env.py
Normal file
@@ -0,0 +1,517 @@
|
||||
"""
|
||||
WebResearchEnv — RL Environment for Multi-Step Web Research
|
||||
============================================================
|
||||
|
||||
Trains models to do accurate, efficient, multi-source web research.
|
||||
|
||||
Reward signals:
|
||||
- Answer correctness (LLM judge, 0.0–1.0)
|
||||
- Source diversity (used ≥2 distinct domains)
|
||||
- Efficiency (penalizes excessive tool calls)
|
||||
- Tool usage (bonus for actually using web tools)
|
||||
|
||||
Dataset: FRAMES benchmark (Google, 2024) — multi-hop factual questions
|
||||
HuggingFace: google/frames-benchmark
|
||||
Fallback: built-in sample questions (no HF token needed)
|
||||
|
||||
Usage:
|
||||
# Phase 1 (OpenAI-compatible server)
|
||||
python environments/web_research_env.py serve \
|
||||
--openai.base_url http://localhost:8000/v1 \
|
||||
--openai.model_name YourModel \
|
||||
--openai.server_type openai
|
||||
|
||||
# With eval split
|
||||
python environments/web_research_env.py serve \
|
||||
--openai.base_url http://localhost:8000/v1 \
|
||||
--openai.model_name YourModel \
|
||||
--env.eval_every 50 \
|
||||
--env.eval_size 20
|
||||
|
||||
# Standalone eval (no training server needed)
|
||||
python environments/web_research_env.py eval \
|
||||
--openai.base_url http://localhost:8000/v1 \
|
||||
--openai.model_name YourModel
|
||||
|
||||
Built by: github.com/jackx707
|
||||
Inspired by: GroceryMind — production Hermes agent doing live web research
|
||||
across German grocery stores (firecrawl + hermes-agent)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Optional HuggingFace datasets import
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
HF_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_AVAILABLE = False
|
||||
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fallback sample dataset (used when HuggingFace is unavailable)
|
||||
# These are multi-hop questions that require real web search to answer.
|
||||
# ---------------------------------------------------------------------------
|
||||
SAMPLE_QUESTIONS = [
|
||||
{
|
||||
"question": "What is the current population of the capital city of the country that won the 2022 FIFA World Cup?",
|
||||
"answer": "Buenos Aires has approximately 3 million people in the city proper, or around 15 million in the greater metro area.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "Who is the CEO of the company that makes the most widely used open-source container orchestration platform?",
|
||||
"answer": "The Linux Foundation oversees Kubernetes. CNCF (Cloud Native Computing Foundation) is the specific body — it does not have a traditional CEO but has an executive director.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What programming language was used to write the original version of the web framework used by Instagram?",
|
||||
"answer": "Django, which Instagram was built on, is written in Python.",
|
||||
"difficulty": "easy",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "In what year was the university founded where the inventor of the World Wide Web currently holds a professorship?",
|
||||
"answer": "Tim Berners-Lee holds a professorship at MIT (founded 1861) and the University of Southampton (founded 1952).",
|
||||
"difficulty": "hard",
|
||||
"hops": 3,
|
||||
},
|
||||
{
|
||||
"question": "What is the latest stable version of the programming language that ranks #1 on the TIOBE index as of this year?",
|
||||
"answer": "Python is currently #1 on TIOBE. The latest stable version should be verified via the official python.org site.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "How many employees does the parent company of Instagram have?",
|
||||
"answer": "Meta Platforms (parent of Instagram) employs approximately 70,000+ people as of recent reports.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What is the current interest rate set by the central bank of the country where the Eiffel Tower is located?",
|
||||
"answer": "The European Central Bank sets rates for France/eurozone. The current rate should be verified — it has changed frequently in 2023-2025.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "Which company acquired the startup founded by the creator of Oculus VR?",
|
||||
"answer": "Palmer Luckey founded Oculus VR, which was acquired by Facebook (now Meta). He later founded Anduril Industries.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What is the market cap of the company that owns the most popular search engine in Russia?",
|
||||
"answer": "Yandex (now split into separate entities after 2024 restructuring). Current market cap should be verified via financial sources.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What was the GDP growth rate of the country that hosted the most recent Summer Olympics?",
|
||||
"answer": "Paris, France hosted the 2024 Summer Olympics. France's recent GDP growth should be verified via World Bank or IMF data.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Environment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WebResearchEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
RL environment for training multi-step web research skills.
|
||||
|
||||
The model is given a factual question requiring 2-3 hops of web research
|
||||
and must use web_search / web_extract tools to find and synthesize the answer.
|
||||
|
||||
Reward is multi-signal:
|
||||
60% — answer correctness (LLM judge)
|
||||
20% — tool usage (did the model actually search the web?)
|
||||
20% — efficiency (penalizes >6 tool calls)
|
||||
|
||||
Bonus +0.1 for source diversity (≥2 distinct domains cited).
|
||||
"""
|
||||
|
||||
name = "web-research"
|
||||
|
||||
# Default toolsets for this environment — web + file for saving notes
|
||||
default_toolsets = ["web", "file"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._items: list[dict] = []
|
||||
self._eval_items: list[dict] = []
|
||||
self._index: int = 0
|
||||
self._total_scored: int = 0
|
||||
self._total_reward: float = 0.0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Setup — load dataset
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def setup(self) -> None:
|
||||
"""Load the FRAMES benchmark or fall back to built-in samples."""
|
||||
if HF_AVAILABLE:
|
||||
try:
|
||||
logger.info("Loading FRAMES benchmark from HuggingFace...")
|
||||
ds = load_dataset("google/frames-benchmark", split="test")
|
||||
self._items = [
|
||||
{
|
||||
"question": row["Prompt"],
|
||||
"answer": row["Answer"],
|
||||
"difficulty": row.get("reasoning_types", "unknown"),
|
||||
"hops": 2,
|
||||
}
|
||||
for row in ds
|
||||
]
|
||||
# Hold out 10% for eval
|
||||
eval_size = max(20, len(self._items) // 10)
|
||||
random.shuffle(self._items)
|
||||
self._eval_items = self._items[:eval_size]
|
||||
self._items = self._items[eval_size:]
|
||||
logger.info(
|
||||
f"Loaded {len(self._items)} train / {len(self._eval_items)} eval items "
|
||||
f"from FRAMES benchmark."
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load FRAMES from HuggingFace: {e}. Using built-in samples.")
|
||||
|
||||
# Fallback
|
||||
random.shuffle(SAMPLE_QUESTIONS)
|
||||
split = max(1, len(SAMPLE_QUESTIONS) * 8 // 10)
|
||||
self._items = SAMPLE_QUESTIONS[:split]
|
||||
self._eval_items = SAMPLE_QUESTIONS[split:]
|
||||
logger.info(
|
||||
f"Using built-in sample dataset: {len(self._items)} train / "
|
||||
f"{len(self._eval_items)} eval items."
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. get_next_item — return the next question
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_next_item(self) -> dict:
|
||||
"""Return the next item, cycling through the dataset."""
|
||||
if not self._items:
|
||||
raise RuntimeError("Dataset is empty. Did you call setup()?")
|
||||
item = self._items[self._index % len(self._items)]
|
||||
self._index += 1
|
||||
return item
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. format_prompt — build the user-facing prompt
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def format_prompt(self, item: dict) -> str:
|
||||
"""
|
||||
Format the research question as a task prompt.
|
||||
Instructs the model to use web search and cite sources.
|
||||
"""
|
||||
return (
|
||||
f"Research the following question thoroughly using web search. "
|
||||
f"You MUST search the web to find current, accurate information — "
|
||||
f"do not rely solely on your training data.\n\n"
|
||||
f"Question: {item['question']}\n\n"
|
||||
f"Requirements:\n"
|
||||
f"- Use web_search and/or web_extract tools to find information\n"
|
||||
f"- Search at least 2 different sources\n"
|
||||
f"- Provide a concise, accurate answer (2-4 sentences)\n"
|
||||
f"- Cite the sources you used"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. compute_reward — multi-signal scoring
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def compute_reward(
|
||||
self,
|
||||
item: dict,
|
||||
result: dict,
|
||||
ctx: Any, # ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Multi-signal reward function:
|
||||
|
||||
0.6 * correctness — LLM judge comparing answer to ground truth
|
||||
0.2 * tool_used — binary: did the model use web tools?
|
||||
0.2 * efficiency — penalizes wasteful tool usage
|
||||
+0.1 bonus — source diversity (≥2 distinct domains)
|
||||
"""
|
||||
final_response: str = result.get("final_response", "")
|
||||
tools_used: list[str] = result.get("tools_used", [])
|
||||
tool_call_count: int = result.get("tool_call_count", len(tools_used))
|
||||
|
||||
# ---- Signal 1: Answer correctness (LLM judge) ----------------
|
||||
correctness = await self._llm_judge(
|
||||
question=item["question"],
|
||||
expected=item["answer"],
|
||||
model_answer=final_response,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
# ---- Signal 2: Web tool usage --------------------------------
|
||||
web_tools = {"web_search", "web_extract", "search", "firecrawl"}
|
||||
tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0
|
||||
|
||||
# ---- Signal 3: Efficiency ------------------------------------
|
||||
# Ideal: 2-5 tool calls. Penalise beyond 6, hard cap at 15.
|
||||
if tool_call_count <= 5:
|
||||
efficiency = 1.0
|
||||
elif tool_call_count <= 10:
|
||||
efficiency = 1.0 - (tool_call_count - 5) * 0.08
|
||||
else:
|
||||
efficiency = max(0.0, 1.0 - (tool_call_count - 5) * 0.12)
|
||||
|
||||
# ---- Bonus: Source diversity ---------------------------------
|
||||
domains = self._extract_domains(final_response)
|
||||
diversity_bonus = 0.1 if len(domains) >= 2 else 0.0
|
||||
|
||||
# ---- Combine ------------------------------------------------
|
||||
reward = (
|
||||
0.6 * correctness
|
||||
+ 0.2 * tool_used
|
||||
+ 0.2 * efficiency
|
||||
+ diversity_bonus
|
||||
)
|
||||
reward = min(1.0, max(0.0, reward)) # clamp to [0, 1]
|
||||
|
||||
# Track running stats
|
||||
self._total_scored += 1
|
||||
self._total_reward += reward
|
||||
|
||||
logger.debug(
|
||||
f"Reward breakdown — correctness={correctness:.2f}, "
|
||||
f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, "
|
||||
f"diversity_bonus={diversity_bonus:.1f} → total={reward:.3f}"
|
||||
)
|
||||
|
||||
return reward
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. evaluate — run on held-out eval split
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def evaluate(
|
||||
self,
|
||||
*args: Any,
|
||||
eval_size: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""
|
||||
Run evaluation on the held-out split.
|
||||
Returns a dict of metrics for logging.
|
||||
"""
|
||||
items = self._eval_items
|
||||
if eval_size:
|
||||
items = items[:eval_size]
|
||||
|
||||
if not items:
|
||||
logger.warning("No eval items available.")
|
||||
return {}
|
||||
|
||||
logger.info(f"Running eval on {len(items)} questions...")
|
||||
|
||||
rewards = []
|
||||
correctness_scores = []
|
||||
|
||||
for item in items:
|
||||
try:
|
||||
# Run the agent on each eval question
|
||||
result = await self._run_agent_on_item(item)
|
||||
reward = await self.compute_reward(item, result, ctx=None)
|
||||
rewards.append(reward)
|
||||
|
||||
# Also track raw correctness separately
|
||||
if result.get("final_response"):
|
||||
correctness_scores.append(
|
||||
await self._llm_judge(
|
||||
question=item["question"],
|
||||
expected=item["answer"],
|
||||
model_answer=result["final_response"],
|
||||
ctx=None,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Eval error on item: {e}")
|
||||
rewards.append(0.0)
|
||||
|
||||
metrics = {
|
||||
"eval/mean_reward": sum(rewards) / len(rewards) if rewards else 0.0,
|
||||
"eval/mean_correctness": (
|
||||
sum(correctness_scores) / len(correctness_scores)
|
||||
if correctness_scores else 0.0
|
||||
),
|
||||
"eval/n_items": len(rewards),
|
||||
"train/mean_reward_so_far": (
|
||||
self._total_reward / self._total_scored
|
||||
if self._total_scored > 0 else 0.0
|
||||
),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Eval complete — mean_reward={metrics['eval/mean_reward']:.3f}, "
|
||||
f"mean_correctness={metrics['eval/mean_correctness']:.3f}"
|
||||
)
|
||||
return metrics
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _llm_judge(
|
||||
self,
|
||||
question: str,
|
||||
expected: str,
|
||||
model_answer: str,
|
||||
ctx: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Use an LLM to judge whether `model_answer` correctly addresses
|
||||
`question` compared to `expected`. Returns a float in [0, 1].
|
||||
|
||||
Uses the agent's own inference client if ctx is available,
|
||||
otherwise falls back to a lightweight heuristic.
|
||||
"""
|
||||
if not model_answer or not model_answer.strip():
|
||||
return 0.0
|
||||
|
||||
# Build judge prompt
|
||||
judge_prompt = (
|
||||
"You are an impartial judge evaluating the quality of an AI research answer.\n\n"
|
||||
f"Question: {question}\n\n"
|
||||
f"Reference answer: {expected}\n\n"
|
||||
f"Model answer: {model_answer}\n\n"
|
||||
"Score the model answer on a scale from 0.0 to 1.0 where:\n"
|
||||
" 1.0 = fully correct and complete\n"
|
||||
" 0.7 = mostly correct with minor gaps\n"
|
||||
" 0.4 = partially correct\n"
|
||||
" 0.1 = mentions relevant topic but wrong or very incomplete\n"
|
||||
" 0.0 = completely wrong or no answer\n\n"
|
||||
"Consider: factual accuracy, completeness, and relevance.\n"
|
||||
"Respond with ONLY a JSON object: {\"score\": <float>, \"reason\": \"<one sentence>\"}"
|
||||
)
|
||||
|
||||
# Try using ctx for inference (Phase 2 / live training)
|
||||
if ctx is not None and hasattr(ctx, "chat_completion"):
|
||||
try:
|
||||
response = await ctx.chat_completion(
|
||||
messages=[{"role": "user", "content": judge_prompt}],
|
||||
max_tokens=100,
|
||||
temperature=0.0,
|
||||
)
|
||||
text = response.get("content", "")
|
||||
parsed = self._parse_judge_json(text)
|
||||
if parsed is not None:
|
||||
return float(parsed)
|
||||
except Exception as e:
|
||||
logger.debug(f"LLM judge via ctx failed: {e}. Using heuristic.")
|
||||
|
||||
# Fallback: keyword overlap heuristic
|
||||
return self._heuristic_score(expected, model_answer)
|
||||
|
||||
@staticmethod
|
||||
def _parse_judge_json(text: str) -> Optional[float]:
|
||||
"""Extract the score float from LLM judge JSON response."""
|
||||
try:
|
||||
# Strip markdown code fences if present
|
||||
clean = re.sub(r"```(?:json)?|```", "", text).strip()
|
||||
data = json.loads(clean)
|
||||
score = float(data.get("score", -1))
|
||||
if 0.0 <= score <= 1.0:
|
||||
return score
|
||||
except Exception:
|
||||
# Try regex fallback
|
||||
match = re.search(r'"score"\s*:\s*([0-9.]+)', text)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
if 0.0 <= score <= 1.0:
|
||||
return score
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _heuristic_score(expected: str, model_answer: str) -> float:
|
||||
"""
|
||||
Lightweight keyword overlap score as fallback when no LLM is available.
|
||||
Extracts meaningful tokens and computes Jaccard similarity.
|
||||
"""
|
||||
stopwords = {
|
||||
"the", "a", "an", "is", "are", "was", "were", "of", "in", "on",
|
||||
"at", "to", "for", "with", "and", "or", "but", "it", "its",
|
||||
"this", "that", "as", "by", "from", "be", "has", "have", "had",
|
||||
}
|
||||
|
||||
def tokenize(text: str) -> set:
|
||||
tokens = re.findall(r'\b[a-zA-Z0-9]+\b', text.lower())
|
||||
return {t for t in tokens if t not in stopwords and len(t) > 2}
|
||||
|
||||
expected_tokens = tokenize(expected)
|
||||
answer_tokens = tokenize(model_answer)
|
||||
|
||||
if not expected_tokens:
|
||||
return 0.5 # Can't judge
|
||||
|
||||
overlap = len(expected_tokens & answer_tokens)
|
||||
union = len(expected_tokens | answer_tokens)
|
||||
|
||||
jaccard = overlap / union if union > 0 else 0.0
|
||||
# Recall-weighted: reward covering expected content
|
||||
recall = overlap / len(expected_tokens)
|
||||
return min(1.0, 0.4 * jaccard + 0.6 * recall)
|
||||
|
||||
@staticmethod
|
||||
def _extract_domains(text: str) -> set:
|
||||
"""
|
||||
Extract unique domains from URLs cited in the response.
|
||||
Used to measure source diversity.
|
||||
"""
|
||||
urls = re.findall(r'https?://[^\s\)>\]"\']+', text)
|
||||
domains = set()
|
||||
for url in urls:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
# Normalize: strip www.
|
||||
domain = parsed.netloc.lower().lstrip("www.")
|
||||
if domain:
|
||||
domains.add(domain)
|
||||
except Exception:
|
||||
pass
|
||||
return domains
|
||||
|
||||
async def _run_agent_on_item(self, item: dict) -> dict:
|
||||
"""
|
||||
Stub for running agent during eval. In Phase 1/2, this is handled
|
||||
by the Atropos framework's rollout mechanism. Provided here for
|
||||
standalone eval compatibility.
|
||||
"""
|
||||
# In real usage, the framework calls get_next_item + format_prompt
|
||||
# and runs the agent. This stub returns an empty result for safety.
|
||||
return {
|
||||
"final_response": "",
|
||||
"tools_used": [],
|
||||
"tool_call_count": 0,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
WebResearchEnv.cli()
|
||||
Reference in New Issue
Block a user