168 lines
5.8 KiB
Python
168 lines
5.8 KiB
Python
"""Benchmark runner — executes scenarios through the heartbeat loop.
|
|
|
|
Wires each ``BenchmarkScenario`` into a ``MockWorldAdapter`` (or a
|
|
supplied adapter), runs the heartbeat for up to ``max_cycles``, and
|
|
collects ``BenchmarkMetrics``.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import subprocess
|
|
import time
|
|
from datetime import UTC, datetime
|
|
|
|
from infrastructure.world.adapters.mock import MockWorldAdapter
|
|
from infrastructure.world.benchmark.metrics import BenchmarkMetrics, ScenarioResult
|
|
from infrastructure.world.benchmark.scenarios import BenchmarkScenario
|
|
from infrastructure.world.interface import WorldInterface
|
|
from loop.heartbeat import Heartbeat
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Rough estimate: each heartbeat cycle costs ~1 unit of metabolic cost
|
|
# (gather + reason + act phases each touch the LLM router once).
|
|
_COST_PER_CYCLE = 3.0 # three phases per cycle
|
|
|
|
|
|
class BenchmarkRunner:
|
|
"""Run benchmark scenarios and collect metrics.
|
|
|
|
Parameters
|
|
----------
|
|
adapter_factory:
|
|
Optional callable that returns a ``WorldInterface`` for a given
|
|
scenario. Defaults to building a ``MockWorldAdapter`` from the
|
|
scenario's start state.
|
|
heartbeat_interval:
|
|
Seconds between heartbeat ticks (0 for immediate).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
adapter_factory=None,
|
|
heartbeat_interval: float = 0.0,
|
|
) -> None:
|
|
self._adapter_factory = adapter_factory or self._default_adapter
|
|
self._interval = heartbeat_interval
|
|
|
|
# -- public API --------------------------------------------------------
|
|
|
|
async def run(
|
|
self,
|
|
scenarios: list[BenchmarkScenario],
|
|
) -> BenchmarkMetrics:
|
|
"""Execute all *scenarios* and return aggregated metrics."""
|
|
metrics = BenchmarkMetrics(
|
|
timestamp=datetime.now(UTC).isoformat(),
|
|
commit_sha=self._git_sha(),
|
|
)
|
|
suite_start = time.monotonic()
|
|
|
|
for scenario in scenarios:
|
|
logger.info("Benchmark: starting '%s'", scenario.name)
|
|
result = await self._run_scenario(scenario)
|
|
metrics.results.append(result)
|
|
status = "PASS" if result.success else "FAIL"
|
|
logger.info(
|
|
"Benchmark: '%s' %s (%d/%d cycles, %d ms)",
|
|
scenario.name,
|
|
status,
|
|
result.cycles_used,
|
|
result.max_cycles,
|
|
result.wall_time_ms,
|
|
)
|
|
|
|
metrics.total_time_ms = int((time.monotonic() - suite_start) * 1000)
|
|
return metrics
|
|
|
|
# -- internal ----------------------------------------------------------
|
|
|
|
async def _run_scenario(self, scenario: BenchmarkScenario) -> ScenarioResult:
|
|
"""Run a single scenario through the heartbeat loop."""
|
|
result = ScenarioResult(
|
|
scenario_name=scenario.name,
|
|
max_cycles=scenario.max_cycles,
|
|
tags=list(scenario.tags),
|
|
)
|
|
|
|
adapter = self._adapter_factory(scenario)
|
|
adapter.connect()
|
|
|
|
hb = Heartbeat(world=adapter, interval=self._interval)
|
|
actions: list[dict] = []
|
|
|
|
start = time.monotonic()
|
|
try:
|
|
for cycle in range(1, scenario.max_cycles + 1):
|
|
record = await hb.run_once()
|
|
result.cycles_used = cycle
|
|
|
|
# Track LLM calls (each cycle has 3 phases that may call LLM)
|
|
result.llm_calls += 3
|
|
|
|
# Accumulate actions for goal predicate
|
|
if record.action_taken and record.action_taken != "idle":
|
|
actions.append(
|
|
{
|
|
"action": record.action_taken,
|
|
"target": record.observation.get("location", ""),
|
|
"status": record.action_status,
|
|
}
|
|
)
|
|
|
|
# Update adapter location if scenario simulates movement
|
|
current_location = self._get_current_location(adapter)
|
|
|
|
# Check goal predicate
|
|
if scenario.goal_predicate is not None:
|
|
if scenario.goal_predicate(actions, current_location):
|
|
result.success = True
|
|
break
|
|
elif cycle == scenario.max_cycles:
|
|
# No predicate — success if we survived all cycles
|
|
result.success = True
|
|
|
|
except Exception as exc:
|
|
logger.warning("Benchmark scenario '%s' crashed: %s", scenario.name, exc)
|
|
result.error = str(exc)
|
|
finally:
|
|
adapter.disconnect()
|
|
|
|
result.wall_time_ms = int((time.monotonic() - start) * 1000)
|
|
result.metabolic_cost = result.cycles_used * _COST_PER_CYCLE
|
|
return result
|
|
|
|
@staticmethod
|
|
def _default_adapter(scenario: BenchmarkScenario) -> WorldInterface:
|
|
"""Build a MockWorldAdapter from a scenario's starting state."""
|
|
return MockWorldAdapter(
|
|
location=scenario.start_location,
|
|
entities=list(scenario.entities),
|
|
events=list(scenario.events),
|
|
)
|
|
|
|
@staticmethod
|
|
def _get_current_location(adapter: WorldInterface) -> str:
|
|
"""Read the current location from the adapter."""
|
|
try:
|
|
perception = adapter.observe()
|
|
return perception.location
|
|
except Exception:
|
|
return ""
|
|
|
|
@staticmethod
|
|
def _git_sha() -> str:
|
|
"""Best-effort: return the current git commit SHA."""
|
|
try:
|
|
result = subprocess.run(
|
|
["git", "rev-parse", "--short", "HEAD"],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=5,
|
|
)
|
|
return result.stdout.strip() if result.returncode == 0 else ""
|
|
except (OSError, subprocess.TimeoutExpired):
|
|
return ""
|