Files
hermes-agent/environments/web_research_env.py

518 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
WebResearchEnv — RL Environment for Multi-Step Web Research
============================================================
Trains models to do accurate, efficient, multi-source web research.
Reward signals:
- Answer correctness (LLM judge, 0.01.0)
- Source diversity (used ≥2 distinct domains)
- Efficiency (penalizes excessive tool calls)
- Tool usage (bonus for actually using web tools)
Dataset: FRAMES benchmark (Google, 2024) — multi-hop factual questions
HuggingFace: google/frames-benchmark
Fallback: built-in sample questions (no HF token needed)
Usage:
# Phase 1 (OpenAI-compatible server)
python environments/web_research_env.py serve \
--openai.base_url http://localhost:8000/v1 \
--openai.model_name YourModel \
--openai.server_type openai
# With eval split
python environments/web_research_env.py serve \
--openai.base_url http://localhost:8000/v1 \
--openai.model_name YourModel \
--env.eval_every 50 \
--env.eval_size 20
# Standalone eval (no training server needed)
python environments/web_research_env.py eval \
--openai.base_url http://localhost:8000/v1 \
--openai.model_name YourModel
Built by: github.com/jackx707
Inspired by: GroceryMind — production Hermes agent doing live web research
across German grocery stores (firecrawl + hermes-agent)
"""
from __future__ import annotations
import asyncio
import json
import logging
import random
import re
from typing import Any, Optional
from urllib.parse import urlparse
# ---------------------------------------------------------------------------
# Optional HuggingFace datasets import
# ---------------------------------------------------------------------------
try:
from datasets import load_dataset
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
from environments.hermes_base_env import HermesAgentBaseEnv
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Fallback sample dataset (used when HuggingFace is unavailable)
# These are multi-hop questions that require real web search to answer.
# ---------------------------------------------------------------------------
SAMPLE_QUESTIONS = [
{
"question": "What is the current population of the capital city of the country that won the 2022 FIFA World Cup?",
"answer": "Buenos Aires has approximately 3 million people in the city proper, or around 15 million in the greater metro area.",
"difficulty": "medium",
"hops": 2,
},
{
"question": "Who is the CEO of the company that makes the most widely used open-source container orchestration platform?",
"answer": "The Linux Foundation oversees Kubernetes. CNCF (Cloud Native Computing Foundation) is the specific body — it does not have a traditional CEO but has an executive director.",
"difficulty": "medium",
"hops": 2,
},
{
"question": "What programming language was used to write the original version of the web framework used by Instagram?",
"answer": "Django, which Instagram was built on, is written in Python.",
"difficulty": "easy",
"hops": 2,
},
{
"question": "In what year was the university founded where the inventor of the World Wide Web currently holds a professorship?",
"answer": "Tim Berners-Lee holds a professorship at MIT (founded 1861) and the University of Southampton (founded 1952).",
"difficulty": "hard",
"hops": 3,
},
{
"question": "What is the latest stable version of the programming language that ranks #1 on the TIOBE index as of this year?",
"answer": "Python is currently #1 on TIOBE. The latest stable version should be verified via the official python.org site.",
"difficulty": "medium",
"hops": 2,
},
{
"question": "How many employees does the parent company of Instagram have?",
"answer": "Meta Platforms (parent of Instagram) employs approximately 70,000+ people as of recent reports.",
"difficulty": "medium",
"hops": 2,
},
{
"question": "What is the current interest rate set by the central bank of the country where the Eiffel Tower is located?",
"answer": "The European Central Bank sets rates for France/eurozone. The current rate should be verified — it has changed frequently in 2023-2025.",
"difficulty": "hard",
"hops": 2,
},
{
"question": "Which company acquired the startup founded by the creator of Oculus VR?",
"answer": "Palmer Luckey founded Oculus VR, which was acquired by Facebook (now Meta). He later founded Anduril Industries.",
"difficulty": "medium",
"hops": 2,
},
{
"question": "What is the market cap of the company that owns the most popular search engine in Russia?",
"answer": "Yandex (now split into separate entities after 2024 restructuring). Current market cap should be verified via financial sources.",
"difficulty": "hard",
"hops": 2,
},
{
"question": "What was the GDP growth rate of the country that hosted the most recent Summer Olympics?",
"answer": "Paris, France hosted the 2024 Summer Olympics. France's recent GDP growth should be verified via World Bank or IMF data.",
"difficulty": "hard",
"hops": 2,
},
]
# ---------------------------------------------------------------------------
# Environment
# ---------------------------------------------------------------------------
class WebResearchEnv(HermesAgentBaseEnv):
"""
RL environment for training multi-step web research skills.
The model is given a factual question requiring 2-3 hops of web research
and must use web_search / web_extract tools to find and synthesize the answer.
Reward is multi-signal:
60% — answer correctness (LLM judge)
20% — tool usage (did the model actually search the web?)
20% — efficiency (penalizes >6 tool calls)
Bonus +0.1 for source diversity (≥2 distinct domains cited).
"""
name = "web-research"
# Default toolsets for this environment — web + file for saving notes
default_toolsets = ["web", "file"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._items: list[dict] = []
self._eval_items: list[dict] = []
self._index: int = 0
self._total_scored: int = 0
self._total_reward: float = 0.0
# ------------------------------------------------------------------
# 1. Setup — load dataset
# ------------------------------------------------------------------
async def setup(self) -> None:
"""Load the FRAMES benchmark or fall back to built-in samples."""
if HF_AVAILABLE:
try:
logger.info("Loading FRAMES benchmark from HuggingFace...")
ds = load_dataset("google/frames-benchmark", split="test")
self._items = [
{
"question": row["Prompt"],
"answer": row["Answer"],
"difficulty": row.get("reasoning_types", "unknown"),
"hops": 2,
}
for row in ds
]
# Hold out 10% for eval
eval_size = max(20, len(self._items) // 10)
random.shuffle(self._items)
self._eval_items = self._items[:eval_size]
self._items = self._items[eval_size:]
logger.info(
f"Loaded {len(self._items)} train / {len(self._eval_items)} eval items "
f"from FRAMES benchmark."
)
return
except Exception as e:
logger.warning(f"Could not load FRAMES from HuggingFace: {e}. Using built-in samples.")
# Fallback
random.shuffle(SAMPLE_QUESTIONS)
split = max(1, len(SAMPLE_QUESTIONS) * 8 // 10)
self._items = SAMPLE_QUESTIONS[:split]
self._eval_items = SAMPLE_QUESTIONS[split:]
logger.info(
f"Using built-in sample dataset: {len(self._items)} train / "
f"{len(self._eval_items)} eval items."
)
# ------------------------------------------------------------------
# 2. get_next_item — return the next question
# ------------------------------------------------------------------
async def get_next_item(self) -> dict:
"""Return the next item, cycling through the dataset."""
if not self._items:
raise RuntimeError("Dataset is empty. Did you call setup()?")
item = self._items[self._index % len(self._items)]
self._index += 1
return item
# ------------------------------------------------------------------
# 3. format_prompt — build the user-facing prompt
# ------------------------------------------------------------------
def format_prompt(self, item: dict) -> str:
"""
Format the research question as a task prompt.
Instructs the model to use web search and cite sources.
"""
return (
f"Research the following question thoroughly using web search. "
f"You MUST search the web to find current, accurate information — "
f"do not rely solely on your training data.\n\n"
f"Question: {item['question']}\n\n"
f"Requirements:\n"
f"- Use web_search and/or web_extract tools to find information\n"
f"- Search at least 2 different sources\n"
f"- Provide a concise, accurate answer (2-4 sentences)\n"
f"- Cite the sources you used"
)
# ------------------------------------------------------------------
# 4. compute_reward — multi-signal scoring
# ------------------------------------------------------------------
async def compute_reward(
self,
item: dict,
result: dict,
ctx: Any, # ToolContext
) -> float:
"""
Multi-signal reward function:
0.6 * correctness — LLM judge comparing answer to ground truth
0.2 * tool_used — binary: did the model use web tools?
0.2 * efficiency — penalizes wasteful tool usage
+0.1 bonus — source diversity (≥2 distinct domains)
"""
final_response: str = result.get("final_response", "")
tools_used: list[str] = result.get("tools_used", [])
tool_call_count: int = result.get("tool_call_count", len(tools_used))
# ---- Signal 1: Answer correctness (LLM judge) ----------------
correctness = await self._llm_judge(
question=item["question"],
expected=item["answer"],
model_answer=final_response,
ctx=ctx,
)
# ---- Signal 2: Web tool usage --------------------------------
web_tools = {"web_search", "web_extract", "search", "firecrawl"}
tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0
# ---- Signal 3: Efficiency ------------------------------------
# Ideal: 2-5 tool calls. Penalise beyond 6, hard cap at 15.
if tool_call_count <= 5:
efficiency = 1.0
elif tool_call_count <= 10:
efficiency = 1.0 - (tool_call_count - 5) * 0.08
else:
efficiency = max(0.0, 1.0 - (tool_call_count - 5) * 0.12)
# ---- Bonus: Source diversity ---------------------------------
domains = self._extract_domains(final_response)
diversity_bonus = 0.1 if len(domains) >= 2 else 0.0
# ---- Combine ------------------------------------------------
reward = (
0.6 * correctness
+ 0.2 * tool_used
+ 0.2 * efficiency
+ diversity_bonus
)
reward = min(1.0, max(0.0, reward)) # clamp to [0, 1]
# Track running stats
self._total_scored += 1
self._total_reward += reward
logger.debug(
f"Reward breakdown — correctness={correctness:.2f}, "
f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, "
f"diversity_bonus={diversity_bonus:.1f} → total={reward:.3f}"
)
return reward
# ------------------------------------------------------------------
# 5. evaluate — run on held-out eval split
# ------------------------------------------------------------------
async def evaluate(
self,
*args: Any,
eval_size: Optional[int] = None,
**kwargs: Any,
) -> dict:
"""
Run evaluation on the held-out split.
Returns a dict of metrics for logging.
"""
items = self._eval_items
if eval_size:
items = items[:eval_size]
if not items:
logger.warning("No eval items available.")
return {}
logger.info(f"Running eval on {len(items)} questions...")
rewards = []
correctness_scores = []
for item in items:
try:
# Run the agent on each eval question
result = await self._run_agent_on_item(item)
reward = await self.compute_reward(item, result, ctx=None)
rewards.append(reward)
# Also track raw correctness separately
if result.get("final_response"):
correctness_scores.append(
await self._llm_judge(
question=item["question"],
expected=item["answer"],
model_answer=result["final_response"],
ctx=None,
)
)
except Exception as e:
logger.error(f"Eval error on item: {e}")
rewards.append(0.0)
metrics = {
"eval/mean_reward": sum(rewards) / len(rewards) if rewards else 0.0,
"eval/mean_correctness": (
sum(correctness_scores) / len(correctness_scores)
if correctness_scores else 0.0
),
"eval/n_items": len(rewards),
"train/mean_reward_so_far": (
self._total_reward / self._total_scored
if self._total_scored > 0 else 0.0
),
}
logger.info(
f"Eval complete — mean_reward={metrics['eval/mean_reward']:.3f}, "
f"mean_correctness={metrics['eval/mean_correctness']:.3f}"
)
return metrics
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
async def _llm_judge(
self,
question: str,
expected: str,
model_answer: str,
ctx: Any,
) -> float:
"""
Use an LLM to judge whether `model_answer` correctly addresses
`question` compared to `expected`. Returns a float in [0, 1].
Uses the agent's own inference client if ctx is available,
otherwise falls back to a lightweight heuristic.
"""
if not model_answer or not model_answer.strip():
return 0.0
# Build judge prompt
judge_prompt = (
"You are an impartial judge evaluating the quality of an AI research answer.\n\n"
f"Question: {question}\n\n"
f"Reference answer: {expected}\n\n"
f"Model answer: {model_answer}\n\n"
"Score the model answer on a scale from 0.0 to 1.0 where:\n"
" 1.0 = fully correct and complete\n"
" 0.7 = mostly correct with minor gaps\n"
" 0.4 = partially correct\n"
" 0.1 = mentions relevant topic but wrong or very incomplete\n"
" 0.0 = completely wrong or no answer\n\n"
"Consider: factual accuracy, completeness, and relevance.\n"
"Respond with ONLY a JSON object: {\"score\": <float>, \"reason\": \"<one sentence>\"}"
)
# Try using ctx for inference (Phase 2 / live training)
if ctx is not None and hasattr(ctx, "chat_completion"):
try:
response = await ctx.chat_completion(
messages=[{"role": "user", "content": judge_prompt}],
max_tokens=100,
temperature=0.0,
)
text = response.get("content", "")
parsed = self._parse_judge_json(text)
if parsed is not None:
return float(parsed)
except Exception as e:
logger.debug(f"LLM judge via ctx failed: {e}. Using heuristic.")
# Fallback: keyword overlap heuristic
return self._heuristic_score(expected, model_answer)
@staticmethod
def _parse_judge_json(text: str) -> Optional[float]:
"""Extract the score float from LLM judge JSON response."""
try:
# Strip markdown code fences if present
clean = re.sub(r"```(?:json)?|```", "", text).strip()
data = json.loads(clean)
score = float(data.get("score", -1))
if 0.0 <= score <= 1.0:
return score
except Exception:
# Try regex fallback
match = re.search(r'"score"\s*:\s*([0-9.]+)', text)
if match:
score = float(match.group(1))
if 0.0 <= score <= 1.0:
return score
return None
@staticmethod
def _heuristic_score(expected: str, model_answer: str) -> float:
"""
Lightweight keyword overlap score as fallback when no LLM is available.
Extracts meaningful tokens and computes Jaccard similarity.
"""
stopwords = {
"the", "a", "an", "is", "are", "was", "were", "of", "in", "on",
"at", "to", "for", "with", "and", "or", "but", "it", "its",
"this", "that", "as", "by", "from", "be", "has", "have", "had",
}
def tokenize(text: str) -> set:
tokens = re.findall(r'\b[a-zA-Z0-9]+\b', text.lower())
return {t for t in tokens if t not in stopwords and len(t) > 2}
expected_tokens = tokenize(expected)
answer_tokens = tokenize(model_answer)
if not expected_tokens:
return 0.5 # Can't judge
overlap = len(expected_tokens & answer_tokens)
union = len(expected_tokens | answer_tokens)
jaccard = overlap / union if union > 0 else 0.0
# Recall-weighted: reward covering expected content
recall = overlap / len(expected_tokens)
return min(1.0, 0.4 * jaccard + 0.6 * recall)
@staticmethod
def _extract_domains(text: str) -> set:
"""
Extract unique domains from URLs cited in the response.
Used to measure source diversity.
"""
urls = re.findall(r'https?://[^\s\)>\]"\']+', text)
domains = set()
for url in urls:
try:
parsed = urlparse(url)
# Normalize: strip www.
domain = parsed.netloc.lower().lstrip("www.")
if domain:
domains.add(domain)
except Exception:
pass
return domains
async def _run_agent_on_item(self, item: dict) -> dict:
"""
Stub for running agent during eval. In Phase 1/2, this is handled
by the Atropos framework's rollout mechanism. Provided here for
standalone eval compatibility.
"""
# In real usage, the framework calls get_next_item + format_prompt
# and runs the agent. This stub returns an empty result for safety.
return {
"final_response": "",
"tools_used": [],
"tool_call_count": 0,
}
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
WebResearchEnv.cli()