fix: eliminate double LLM judge call and eval buffer pollution

evaluate() was calling _llm_judge twice per item (once via
compute_reward, once directly) — double the API cost for no benefit.
Now extracts correctness from compute_reward's buffer instead.

Also: compute_reward appends to training metric buffers during eval,
which would pollute wandb training charts. Now rolls back buffer
entries added during eval so training metrics stay clean.
This commit is contained in:
teknium1
2026-03-09 20:57:46 -07:00
parent bf8350ac18
commit 975fd86dc4

View File

@@ -475,14 +475,7 @@ class WebResearchEnv(HermesAgentBaseEnv):
)
result = await agent.run(messages)
# Extract final response and compute reward
ctx = ToolContext(task_id)
try:
reward = await self.compute_reward(item, result, ctx)
finally:
ctx.cleanup()
# Extract final response for logging
# Extract final response and tool usage from messages
final_response = ""
tool_call_count = 0
for msg in reversed(result.messages):
@@ -491,12 +484,32 @@ class WebResearchEnv(HermesAgentBaseEnv):
if msg.get("role") == "assistant" and msg.get("tool_calls"):
tool_call_count += len(msg["tool_calls"])
# Score correctness separately for the metric
correctness = await self._llm_judge(
question=item["question"],
expected=item["answer"],
model_answer=final_response,
# Compute reward (includes LLM judge for correctness)
# Temporarily save buffer lengths so we can extract the
# correctness score without calling judge twice, and avoid
# polluting training metric buffers with eval data.
buf_len = len(self._correctness_buffer)
ctx = ToolContext(task_id)
try:
reward = await self.compute_reward(item, result, ctx)
finally:
ctx.cleanup()
# Extract correctness from the buffer (compute_reward appended it)
# then remove eval entries from training buffers
correctness = (
self._correctness_buffer[buf_len]
if len(self._correctness_buffer) > buf_len
else 0.0
)
# Roll back buffers to avoid polluting training metrics
for buf in (
self._reward_buffer, self._correctness_buffer,
self._tool_usage_buffer, self._efficiency_buffer,
self._diversity_buffer,
):
if len(buf) > buf_len:
buf.pop()
samples.append({
"prompt": item["question"],