Merge: WebResearchEnv evaluate() with full agent loop + tools
This commit is contained in:
@@ -425,8 +425,16 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""Run evaluation on the held-out split using the agent loop."""
|
||||
"""Run evaluation on the held-out split using the full agent loop with tools.
|
||||
|
||||
Each eval item runs through the same agent loop as training —
|
||||
the model can use web_search, web_extract, etc. to research answers.
|
||||
This measures actual agentic research capability, not just knowledge.
|
||||
"""
|
||||
import time
|
||||
import uuid
|
||||
from environments.agent_loop import HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
items = self._eval_items
|
||||
if not items:
|
||||
@@ -436,43 +444,75 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||
eval_size = min(self.config.eval_size, len(items))
|
||||
eval_items = items[:eval_size]
|
||||
|
||||
logger.info(f"Running eval on {len(eval_items)} questions...")
|
||||
logger.info(f"Running eval on {len(eval_items)} questions (with agent loop + tools)...")
|
||||
start_time = time.time()
|
||||
samples = []
|
||||
|
||||
for item in eval_items:
|
||||
# Resolve tools once for all eval items
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
|
||||
for i, item in enumerate(eval_items):
|
||||
task_id = str(uuid.uuid4())
|
||||
logger.info(f"Eval [{i+1}/{len(eval_items)}]: {item['question'][:80]}...")
|
||||
|
||||
try:
|
||||
# Use the base env's agent loop for eval (same as training)
|
||||
prompt = self.format_prompt(item)
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": self.config.system_prompt or ""},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
n=1,
|
||||
# Build messages
|
||||
messages: List[Dict[str, Any]] = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": self.format_prompt(item)})
|
||||
|
||||
# Run the full agent loop with tools
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=0.0, # Deterministic for eval
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
response_content = (
|
||||
completion.choices[0].message.content if completion.choices else ""
|
||||
)
|
||||
# Extract final response and compute reward
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
# Score the response
|
||||
# Extract final response for logging
|
||||
final_response = ""
|
||||
tool_call_count = 0
|
||||
for msg in reversed(result.messages):
|
||||
if msg.get("role") == "assistant" and msg.get("content") and not final_response:
|
||||
final_response = msg["content"]
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
tool_call_count += len(msg["tool_calls"])
|
||||
|
||||
# Score correctness separately for the metric
|
||||
correctness = await self._llm_judge(
|
||||
question=item["question"],
|
||||
expected=item["answer"],
|
||||
model_answer=response_content,
|
||||
model_answer=final_response,
|
||||
)
|
||||
|
||||
samples.append({
|
||||
"prompt": item["question"],
|
||||
"response": response_content,
|
||||
"response": final_response[:500],
|
||||
"expected": item["answer"],
|
||||
"correctness": correctness,
|
||||
"reward": reward,
|
||||
"tool_calls": tool_call_count,
|
||||
"turns": result.turns_used,
|
||||
})
|
||||
|
||||
logger.info(
|
||||
f" → correctness={correctness:.2f}, reward={reward:.3f}, "
|
||||
f"tools={tool_call_count}, turns={result.turns_used}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Eval error on item: {e}")
|
||||
samples.append({
|
||||
@@ -480,20 +520,33 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||
"response": f"ERROR: {e}",
|
||||
"expected": item["answer"],
|
||||
"correctness": 0.0,
|
||||
"reward": 0.0,
|
||||
"tool_calls": 0,
|
||||
"turns": 0,
|
||||
})
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Compute metrics
|
||||
# Compute aggregate metrics
|
||||
correctness_scores = [s["correctness"] for s in samples]
|
||||
rewards = [s["reward"] for s in samples]
|
||||
tool_counts = [s["tool_calls"] for s in samples]
|
||||
n = len(samples)
|
||||
|
||||
eval_metrics = {
|
||||
"eval/mean_correctness": (
|
||||
sum(correctness_scores) / len(correctness_scores)
|
||||
if correctness_scores else 0.0
|
||||
),
|
||||
"eval/n_items": len(samples),
|
||||
"eval/mean_correctness": sum(correctness_scores) / n if n else 0.0,
|
||||
"eval/mean_reward": sum(rewards) / n if n else 0.0,
|
||||
"eval/mean_tool_calls": sum(tool_counts) / n if n else 0.0,
|
||||
"eval/tool_usage_rate": sum(1 for t in tool_counts if t > 0) / n if n else 0.0,
|
||||
"eval/n_items": n,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Eval complete — correctness={eval_metrics['eval/mean_correctness']:.3f}, "
|
||||
f"reward={eval_metrics['eval/mean_reward']:.3f}, "
|
||||
f"tool_usage={eval_metrics['eval/tool_usage_rate']:.0%}"
|
||||
)
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
|
||||
Reference in New Issue
Block a user