Merge: WebResearchEnv Atropos standards compliance
This commit is contained in:
@@ -16,21 +16,18 @@ Dataset: FRAMES benchmark (Google, 2024) — multi-hop factual questions
|
||||
|
||||
Usage:
|
||||
# Phase 1 (OpenAI-compatible server)
|
||||
python environments/web_research_env.py serve \
|
||||
--openai.base_url http://localhost:8000/v1 \
|
||||
--openai.model_name YourModel \
|
||||
python environments/web_research_env.py serve \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel \\
|
||||
--openai.server_type openai
|
||||
|
||||
# With eval split
|
||||
python environments/web_research_env.py serve \
|
||||
--openai.base_url http://localhost:8000/v1 \
|
||||
--openai.model_name YourModel \
|
||||
--env.eval_every 50 \
|
||||
--env.eval_size 20
|
||||
# Process mode (offline data generation)
|
||||
python environments/web_research_env.py process \\
|
||||
--env.data_path_to_save_groups data/web_research.jsonl
|
||||
|
||||
# Standalone eval (no training server needed)
|
||||
python environments/web_research_env.py eval \
|
||||
--openai.base_url http://localhost:8000/v1 \
|
||||
# Standalone eval
|
||||
python environments/web_research_env.py evaluate \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel
|
||||
|
||||
Built by: github.com/jackx707
|
||||
@@ -43,11 +40,21 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
# Ensure hermes-agent root is on path
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Optional HuggingFace datasets import
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -57,13 +64,19 @@ try:
|
||||
except ImportError:
|
||||
HF_AVAILABLE = False
|
||||
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fallback sample dataset (used when HuggingFace is unavailable)
|
||||
# These are multi-hop questions that require real web search to answer.
|
||||
# Multi-hop questions requiring real web search to answer.
|
||||
# ---------------------------------------------------------------------------
|
||||
SAMPLE_QUESTIONS = [
|
||||
{
|
||||
@@ -129,6 +142,58 @@ SAMPLE_QUESTIONS = [
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WebResearchEnvConfig(HermesAgentEnvConfig):
|
||||
"""Configuration for the web research RL environment."""
|
||||
|
||||
# Reward weights
|
||||
correctness_weight: float = Field(
|
||||
default=0.6,
|
||||
description="Weight for answer correctness in reward (LLM judge score).",
|
||||
)
|
||||
tool_usage_weight: float = Field(
|
||||
default=0.2,
|
||||
description="Weight for tool usage signal (did the model actually use web tools?).",
|
||||
)
|
||||
efficiency_weight: float = Field(
|
||||
default=0.2,
|
||||
description="Weight for efficiency signal (penalizes excessive tool calls).",
|
||||
)
|
||||
diversity_bonus: float = Field(
|
||||
default=0.1,
|
||||
description="Bonus reward for citing ≥2 distinct domains.",
|
||||
)
|
||||
|
||||
# Efficiency thresholds
|
||||
efficient_max_calls: int = Field(
|
||||
default=5,
|
||||
description="Maximum tool calls before efficiency penalty begins.",
|
||||
)
|
||||
heavy_penalty_calls: int = Field(
|
||||
default=10,
|
||||
description="Tool call count where efficiency penalty steepens.",
|
||||
)
|
||||
|
||||
# Eval
|
||||
eval_size: int = Field(
|
||||
default=20,
|
||||
description="Number of held-out items for evaluation.",
|
||||
)
|
||||
eval_split_ratio: float = Field(
|
||||
default=0.1,
|
||||
description="Fraction of dataset to hold out for evaluation (0.0–1.0).",
|
||||
)
|
||||
|
||||
# Dataset
|
||||
dataset_name: str = Field(
|
||||
default="google/frames-benchmark",
|
||||
description="HuggingFace dataset name for research questions.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Environment
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -143,23 +208,60 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||
Reward is multi-signal:
|
||||
60% — answer correctness (LLM judge)
|
||||
20% — tool usage (did the model actually search the web?)
|
||||
20% — efficiency (penalizes >6 tool calls)
|
||||
20% — efficiency (penalizes >5 tool calls)
|
||||
|
||||
Bonus +0.1 for source diversity (≥2 distinct domains cited).
|
||||
"""
|
||||
|
||||
name = "web-research"
|
||||
env_config_cls = WebResearchEnvConfig
|
||||
|
||||
# Default toolsets for this environment — web + file for saving notes
|
||||
default_toolsets = ["web", "file"]
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[WebResearchEnvConfig, List[APIServerConfig]]:
|
||||
"""Default configuration for the web research environment."""
|
||||
env_config = WebResearchEnvConfig(
|
||||
enabled_toolsets=["web", "file"],
|
||||
max_agent_turns=15,
|
||||
agent_temperature=1.0,
|
||||
system_prompt=(
|
||||
"You are a highly capable research agent. When asked a factual question, "
|
||||
"always use web_search to find current, accurate information before answering. "
|
||||
"Cite at least 2 sources. Be concise and accurate."
|
||||
),
|
||||
group_size=4,
|
||||
total_steps=1000,
|
||||
steps_per_eval=100,
|
||||
use_wandb=True,
|
||||
wandb_name="web-research",
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4.5",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._items: list[dict] = []
|
||||
self._eval_items: list[dict] = []
|
||||
self._index: int = 0
|
||||
self._total_scored: int = 0
|
||||
self._total_reward: float = 0.0
|
||||
|
||||
# Metrics tracking for wandb
|
||||
self._reward_buffer: list[float] = []
|
||||
self._correctness_buffer: list[float] = []
|
||||
self._tool_usage_buffer: list[float] = []
|
||||
self._efficiency_buffer: list[float] = []
|
||||
self._diversity_buffer: list[float] = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Setup — load dataset
|
||||
@@ -170,7 +272,7 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||
if HF_AVAILABLE:
|
||||
try:
|
||||
logger.info("Loading FRAMES benchmark from HuggingFace...")
|
||||
ds = load_dataset("google/frames-benchmark", split="test")
|
||||
ds = load_dataset(self.config.dataset_name, split="test")
|
||||
self._items = [
|
||||
{
|
||||
"question": row["Prompt"],
|
||||
@@ -180,8 +282,11 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||
}
|
||||
for row in ds
|
||||
]
|
||||
# Hold out 10% for eval
|
||||
eval_size = max(20, len(self._items) // 10)
|
||||
# Hold out for eval
|
||||
eval_size = max(
|
||||
self.config.eval_size,
|
||||
int(len(self._items) * self.config.eval_split_ratio),
|
||||
)
|
||||
random.shuffle(self._items)
|
||||
self._eval_items = self._items[:eval_size]
|
||||
self._items = self._items[eval_size:]
|
||||
@@ -220,10 +325,7 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def format_prompt(self, item: dict) -> str:
|
||||
"""
|
||||
Format the research question as a task prompt.
|
||||
Instructs the model to use web search and cite sources.
|
||||
"""
|
||||
"""Format the research question as a task prompt."""
|
||||
return (
|
||||
f"Research the following question thoroughly using web search. "
|
||||
f"You MUST search the web to find current, accurate information — "
|
||||
@@ -243,27 +345,30 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||
async def compute_reward(
|
||||
self,
|
||||
item: dict,
|
||||
result: dict,
|
||||
ctx: Any, # ToolContext
|
||||
result: AgentResult,
|
||||
ctx: ToolContext,
|
||||
) -> float:
|
||||
"""
|
||||
Multi-signal reward function:
|
||||
|
||||
0.6 * correctness — LLM judge comparing answer to ground truth
|
||||
0.2 * tool_used — binary: did the model use web tools?
|
||||
0.2 * efficiency — penalizes wasteful tool usage
|
||||
+0.1 bonus — source diversity (≥2 distinct domains)
|
||||
correctness_weight * correctness — LLM judge comparing answer to ground truth
|
||||
tool_usage_weight * tool_used — binary: did the model use web tools?
|
||||
efficiency_weight * efficiency — penalizes wasteful tool usage
|
||||
+ diversity_bonus — source diversity (≥2 distinct domains)
|
||||
"""
|
||||
final_response: str = result.get("final_response", "")
|
||||
tools_used: list[str] = result.get("tools_used", [])
|
||||
tool_call_count: int = result.get("tool_call_count", len(tools_used))
|
||||
final_response: str = result.final_response or ""
|
||||
tools_used: list[str] = [
|
||||
tc.tool_name for tc in (result.tool_calls or [])
|
||||
] if hasattr(result, "tool_calls") and result.tool_calls else []
|
||||
tool_call_count: int = result.turns_used or len(tools_used)
|
||||
|
||||
cfg = self.config
|
||||
|
||||
# ---- Signal 1: Answer correctness (LLM judge) ----------------
|
||||
correctness = await self._llm_judge(
|
||||
question=item["question"],
|
||||
expected=item["answer"],
|
||||
model_answer=final_response,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
# ---- Signal 2: Web tool usage --------------------------------
|
||||
@@ -271,35 +376,37 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||
tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0
|
||||
|
||||
# ---- Signal 3: Efficiency ------------------------------------
|
||||
# Ideal: 2-5 tool calls. Penalise beyond 6, hard cap at 15.
|
||||
if tool_call_count <= 5:
|
||||
if tool_call_count <= cfg.efficient_max_calls:
|
||||
efficiency = 1.0
|
||||
elif tool_call_count <= 10:
|
||||
efficiency = 1.0 - (tool_call_count - 5) * 0.08
|
||||
elif tool_call_count <= cfg.heavy_penalty_calls:
|
||||
efficiency = 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.08
|
||||
else:
|
||||
efficiency = max(0.0, 1.0 - (tool_call_count - 5) * 0.12)
|
||||
efficiency = max(0.0, 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.12)
|
||||
|
||||
# ---- Bonus: Source diversity ---------------------------------
|
||||
domains = self._extract_domains(final_response)
|
||||
diversity_bonus = 0.1 if len(domains) >= 2 else 0.0
|
||||
diversity = cfg.diversity_bonus if len(domains) >= 2 else 0.0
|
||||
|
||||
# ---- Combine ------------------------------------------------
|
||||
reward = (
|
||||
0.6 * correctness
|
||||
+ 0.2 * tool_used
|
||||
+ 0.2 * efficiency
|
||||
+ diversity_bonus
|
||||
cfg.correctness_weight * correctness
|
||||
+ cfg.tool_usage_weight * tool_used
|
||||
+ cfg.efficiency_weight * efficiency
|
||||
+ diversity
|
||||
)
|
||||
reward = min(1.0, max(0.0, reward)) # clamp to [0, 1]
|
||||
|
||||
# Track running stats
|
||||
self._total_scored += 1
|
||||
self._total_reward += reward
|
||||
# Track for wandb
|
||||
self._reward_buffer.append(reward)
|
||||
self._correctness_buffer.append(correctness)
|
||||
self._tool_usage_buffer.append(tool_used)
|
||||
self._efficiency_buffer.append(efficiency)
|
||||
self._diversity_buffer.append(diversity)
|
||||
|
||||
logger.debug(
|
||||
f"Reward breakdown — correctness={correctness:.2f}, "
|
||||
f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, "
|
||||
f"diversity_bonus={diversity_bonus:.1f} → total={reward:.3f}"
|
||||
f"diversity={diversity:.1f} → total={reward:.3f}"
|
||||
)
|
||||
|
||||
return reward
|
||||
@@ -308,68 +415,117 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||
# 5. evaluate — run on held-out eval split
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def evaluate(
|
||||
self,
|
||||
*args: Any,
|
||||
eval_size: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""
|
||||
Run evaluation on the held-out split.
|
||||
Returns a dict of metrics for logging.
|
||||
"""
|
||||
items = self._eval_items
|
||||
if eval_size:
|
||||
items = items[:eval_size]
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""Run evaluation on the held-out split using the agent loop."""
|
||||
import time
|
||||
|
||||
items = self._eval_items
|
||||
if not items:
|
||||
logger.warning("No eval items available.")
|
||||
return {}
|
||||
return
|
||||
|
||||
logger.info(f"Running eval on {len(items)} questions...")
|
||||
eval_size = min(self.config.eval_size, len(items))
|
||||
eval_items = items[:eval_size]
|
||||
|
||||
rewards = []
|
||||
correctness_scores = []
|
||||
logger.info(f"Running eval on {len(eval_items)} questions...")
|
||||
start_time = time.time()
|
||||
samples = []
|
||||
|
||||
for item in items:
|
||||
for item in eval_items:
|
||||
try:
|
||||
# Run the agent on each eval question
|
||||
result = await self._run_agent_on_item(item)
|
||||
reward = await self.compute_reward(item, result, ctx=None)
|
||||
rewards.append(reward)
|
||||
# Use the base env's agent loop for eval (same as training)
|
||||
prompt = self.format_prompt(item)
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": self.config.system_prompt or ""},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
response_content = (
|
||||
completion.choices[0].message.content if completion.choices else ""
|
||||
)
|
||||
|
||||
# Score the response
|
||||
correctness = await self._llm_judge(
|
||||
question=item["question"],
|
||||
expected=item["answer"],
|
||||
model_answer=response_content,
|
||||
)
|
||||
|
||||
samples.append({
|
||||
"prompt": item["question"],
|
||||
"response": response_content,
|
||||
"expected": item["answer"],
|
||||
"correctness": correctness,
|
||||
})
|
||||
|
||||
# Also track raw correctness separately
|
||||
if result.get("final_response"):
|
||||
correctness_scores.append(
|
||||
await self._llm_judge(
|
||||
question=item["question"],
|
||||
expected=item["answer"],
|
||||
model_answer=result["final_response"],
|
||||
ctx=None,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Eval error on item: {e}")
|
||||
rewards.append(0.0)
|
||||
samples.append({
|
||||
"prompt": item["question"],
|
||||
"response": f"ERROR: {e}",
|
||||
"expected": item["answer"],
|
||||
"correctness": 0.0,
|
||||
})
|
||||
|
||||
metrics = {
|
||||
"eval/mean_reward": sum(rewards) / len(rewards) if rewards else 0.0,
|
||||
end_time = time.time()
|
||||
|
||||
# Compute metrics
|
||||
correctness_scores = [s["correctness"] for s in samples]
|
||||
eval_metrics = {
|
||||
"eval/mean_correctness": (
|
||||
sum(correctness_scores) / len(correctness_scores)
|
||||
if correctness_scores else 0.0
|
||||
),
|
||||
"eval/n_items": len(rewards),
|
||||
"train/mean_reward_so_far": (
|
||||
self._total_reward / self._total_scored
|
||||
if self._total_scored > 0 else 0.0
|
||||
),
|
||||
"eval/n_items": len(samples),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Eval complete — mean_reward={metrics['eval/mean_reward']:.3f}, "
|
||||
f"mean_correctness={metrics['eval/mean_correctness']:.3f}"
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
return metrics
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 6. wandb_log — custom metrics
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None:
|
||||
"""Log reward breakdown metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self._reward_buffer:
|
||||
n = len(self._reward_buffer)
|
||||
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
|
||||
wandb_metrics["train/mean_correctness"] = sum(self._correctness_buffer) / n
|
||||
wandb_metrics["train/mean_tool_usage"] = sum(self._tool_usage_buffer) / n
|
||||
wandb_metrics["train/mean_efficiency"] = sum(self._efficiency_buffer) / n
|
||||
wandb_metrics["train/mean_diversity"] = sum(self._diversity_buffer) / n
|
||||
wandb_metrics["train/total_rollouts"] = n
|
||||
|
||||
# Accuracy buckets
|
||||
wandb_metrics["train/correct_rate"] = (
|
||||
sum(1 for c in self._correctness_buffer if c >= 0.7) / n
|
||||
)
|
||||
wandb_metrics["train/tool_usage_rate"] = (
|
||||
sum(1 for t in self._tool_usage_buffer if t > 0) / n
|
||||
)
|
||||
|
||||
# Clear buffers
|
||||
self._reward_buffer.clear()
|
||||
self._correctness_buffer.clear()
|
||||
self._tool_usage_buffer.clear()
|
||||
self._efficiency_buffer.clear()
|
||||
self._diversity_buffer.clear()
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
@@ -380,19 +536,14 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||
question: str,
|
||||
expected: str,
|
||||
model_answer: str,
|
||||
ctx: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Use an LLM to judge whether `model_answer` correctly addresses
|
||||
`question` compared to `expected`. Returns a float in [0, 1].
|
||||
|
||||
Uses the agent's own inference client if ctx is available,
|
||||
otherwise falls back to a lightweight heuristic.
|
||||
Use the server's LLM to judge answer correctness.
|
||||
Falls back to keyword heuristic if LLM call fails.
|
||||
"""
|
||||
if not model_answer or not model_answer.strip():
|
||||
return 0.0
|
||||
|
||||
# Build judge prompt
|
||||
judge_prompt = (
|
||||
"You are an impartial judge evaluating the quality of an AI research answer.\n\n"
|
||||
f"Question: {question}\n\n"
|
||||
@@ -405,39 +556,36 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||
" 0.1 = mentions relevant topic but wrong or very incomplete\n"
|
||||
" 0.0 = completely wrong or no answer\n\n"
|
||||
"Consider: factual accuracy, completeness, and relevance.\n"
|
||||
"Respond with ONLY a JSON object: {\"score\": <float>, \"reason\": \"<one sentence>\"}"
|
||||
'Respond with ONLY a JSON object: {"score": <float>, "reason": "<one sentence>"}'
|
||||
)
|
||||
|
||||
# Try using ctx for inference (Phase 2 / live training)
|
||||
if ctx is not None and hasattr(ctx, "chat_completion"):
|
||||
try:
|
||||
response = await ctx.chat_completion(
|
||||
messages=[{"role": "user", "content": judge_prompt}],
|
||||
max_tokens=100,
|
||||
temperature=0.0,
|
||||
)
|
||||
text = response.get("content", "")
|
||||
parsed = self._parse_judge_json(text)
|
||||
if parsed is not None:
|
||||
return float(parsed)
|
||||
except Exception as e:
|
||||
logger.debug(f"LLM judge via ctx failed: {e}. Using heuristic.")
|
||||
try:
|
||||
response = await self.server.chat_completion(
|
||||
messages=[{"role": "user", "content": judge_prompt}],
|
||||
n=1,
|
||||
max_tokens=150,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
text = response.choices[0].message.content if response.choices else ""
|
||||
parsed = self._parse_judge_json(text)
|
||||
if parsed is not None:
|
||||
return float(parsed)
|
||||
except Exception as e:
|
||||
logger.debug(f"LLM judge failed: {e}. Using heuristic.")
|
||||
|
||||
# Fallback: keyword overlap heuristic
|
||||
return self._heuristic_score(expected, model_answer)
|
||||
|
||||
@staticmethod
|
||||
def _parse_judge_json(text: str) -> Optional[float]:
|
||||
"""Extract the score float from LLM judge JSON response."""
|
||||
try:
|
||||
# Strip markdown code fences if present
|
||||
clean = re.sub(r"```(?:json)?|```", "", text).strip()
|
||||
data = json.loads(clean)
|
||||
score = float(data.get("score", -1))
|
||||
if 0.0 <= score <= 1.0:
|
||||
return score
|
||||
except Exception:
|
||||
# Try regex fallback
|
||||
match = re.search(r'"score"\s*:\s*([0-9.]+)', text)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
@@ -447,10 +595,7 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||
|
||||
@staticmethod
|
||||
def _heuristic_score(expected: str, model_answer: str) -> float:
|
||||
"""
|
||||
Lightweight keyword overlap score as fallback when no LLM is available.
|
||||
Extracts meaningful tokens and computes Jaccard similarity.
|
||||
"""
|
||||
"""Lightweight keyword overlap score as fallback."""
|
||||
stopwords = {
|
||||
"the", "a", "an", "is", "are", "was", "were", "of", "in", "on",
|
||||
"at", "to", "for", "with", "and", "or", "but", "it", "its",
|
||||
@@ -458,35 +603,30 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||
}
|
||||
|
||||
def tokenize(text: str) -> set:
|
||||
tokens = re.findall(r'\b[a-zA-Z0-9]+\b', text.lower())
|
||||
tokens = re.findall(r'\b\w+\b', text.lower())
|
||||
return {t for t in tokens if t not in stopwords and len(t) > 2}
|
||||
|
||||
expected_tokens = tokenize(expected)
|
||||
answer_tokens = tokenize(model_answer)
|
||||
|
||||
if not expected_tokens:
|
||||
return 0.5 # Can't judge
|
||||
return 0.5
|
||||
|
||||
overlap = len(expected_tokens & answer_tokens)
|
||||
union = len(expected_tokens | answer_tokens)
|
||||
|
||||
jaccard = overlap / union if union > 0 else 0.0
|
||||
# Recall-weighted: reward covering expected content
|
||||
recall = overlap / len(expected_tokens)
|
||||
return min(1.0, 0.4 * jaccard + 0.6 * recall)
|
||||
|
||||
@staticmethod
|
||||
def _extract_domains(text: str) -> set:
|
||||
"""
|
||||
Extract unique domains from URLs cited in the response.
|
||||
Used to measure source diversity.
|
||||
"""
|
||||
"""Extract unique domains from URLs cited in the response."""
|
||||
urls = re.findall(r'https?://[^\s\)>\]"\']+', text)
|
||||
domains = set()
|
||||
for url in urls:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
# Normalize: strip www.
|
||||
domain = parsed.netloc.lower().lstrip("www.")
|
||||
if domain:
|
||||
domains.add(domain)
|
||||
@@ -494,20 +634,6 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||
pass
|
||||
return domains
|
||||
|
||||
async def _run_agent_on_item(self, item: dict) -> dict:
|
||||
"""
|
||||
Stub for running agent during eval. In Phase 1/2, this is handled
|
||||
by the Atropos framework's rollout mechanism. Provided here for
|
||||
standalone eval compatibility.
|
||||
"""
|
||||
# In real usage, the framework calls get_next_item + format_prompt
|
||||
# and runs the agent. This stub returns an empty result for safety.
|
||||
return {
|
||||
"final_response": "",
|
||||
"tools_used": [],
|
||||
"tool_call_count": 0,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entry point
|
||||
|
||||
Reference in New Issue
Block a user