155 lines
5.0 KiB
Python
155 lines
5.0 KiB
Python
#!/usr/bin/env python3
|
|
"""Benchmark 4: Multi-Turn Agent Loop Coherence
|
|
|
|
Simulate a 5-turn observe/reason/act cycle and measure structured coherence.
|
|
Each turn must return valid JSON with required fields.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import re
|
|
import sys
|
|
import time
|
|
|
|
import requests
|
|
|
|
OLLAMA_URL = "http://localhost:11434"
|
|
|
|
SYSTEM_PROMPT = """\
|
|
You are an autonomous AI agent. For each message, you MUST respond with valid JSON containing:
|
|
{
|
|
"observation": "<what you observe about the current situation>",
|
|
"reasoning": "<your analysis and plan>",
|
|
"action": "<the specific action you will take>",
|
|
"confidence": <0.0-1.0>
|
|
}
|
|
Respond ONLY with the JSON object. No other text.
|
|
"""
|
|
|
|
TURNS = [
|
|
"You are monitoring a web server. CPU usage just spiked to 95%. What do you observe, reason, and do?",
|
|
"Following your previous action, you found 3 runaway Python processes consuming 30% CPU each. Continue.",
|
|
"You killed the top 2 processes. CPU is now at 45%. A new alert: disk I/O is at 98%. Continue.",
|
|
"You traced the disk I/O to a log rotation script that's stuck. You terminated it. Disk I/O dropped to 20%. Final status check: all metrics are now nominal. Continue.",
|
|
"The incident is resolved. Write a brief post-mortem summary as your final action.",
|
|
]
|
|
|
|
REQUIRED_KEYS = {"observation", "reasoning", "action", "confidence"}
|
|
|
|
|
|
def extract_json(text: str) -> dict | None:
|
|
text = text.strip()
|
|
try:
|
|
return json.loads(text)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
fence_match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
|
|
if fence_match:
|
|
try:
|
|
return json.loads(fence_match.group(1))
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Try to find { ... } block
|
|
brace_match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)?\}", text, re.DOTALL)
|
|
if brace_match:
|
|
try:
|
|
return json.loads(brace_match.group(0))
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
return None
|
|
|
|
|
|
def run_multi_turn(model: str) -> dict:
|
|
"""Run the multi-turn coherence benchmark."""
|
|
conversation = []
|
|
turn_results = []
|
|
total_time = 0.0
|
|
|
|
# Build system + turn messages using chat endpoint
|
|
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
|
|
|
for i, turn_prompt in enumerate(TURNS, 1):
|
|
messages.append({"role": "user", "content": turn_prompt})
|
|
start = time.time()
|
|
|
|
try:
|
|
payload = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"stream": False,
|
|
"options": {"temperature": 0.1, "num_predict": 512},
|
|
}
|
|
resp = requests.post(f"{OLLAMA_URL}/api/chat", json=payload, timeout=120)
|
|
resp.raise_for_status()
|
|
raw = resp.json()["message"]["content"]
|
|
except Exception as exc:
|
|
elapsed = time.time() - start
|
|
turn_results.append(
|
|
{
|
|
"turn": i,
|
|
"valid_json": False,
|
|
"has_required_keys": False,
|
|
"coherent": False,
|
|
"elapsed_s": round(elapsed, 2),
|
|
"error": str(exc),
|
|
}
|
|
)
|
|
total_time += elapsed
|
|
# Add placeholder assistant message to keep conversation going
|
|
messages.append({"role": "assistant", "content": "{}"})
|
|
continue
|
|
|
|
elapsed = time.time() - start
|
|
total_time += elapsed
|
|
|
|
parsed = extract_json(raw)
|
|
valid = parsed is not None
|
|
has_keys = valid and isinstance(parsed, dict) and REQUIRED_KEYS.issubset(parsed.keys())
|
|
confidence_valid = (
|
|
has_keys
|
|
and isinstance(parsed.get("confidence"), (int, float))
|
|
and 0.0 <= parsed["confidence"] <= 1.0
|
|
)
|
|
coherent = has_keys and confidence_valid
|
|
|
|
turn_results.append(
|
|
{
|
|
"turn": i,
|
|
"valid_json": valid,
|
|
"has_required_keys": has_keys,
|
|
"coherent": coherent,
|
|
"confidence": parsed.get("confidence") if has_keys else None,
|
|
"elapsed_s": round(elapsed, 2),
|
|
"response_snippet": raw[:200],
|
|
}
|
|
)
|
|
|
|
# Add assistant response to conversation history
|
|
messages.append({"role": "assistant", "content": raw})
|
|
|
|
coherent_count = sum(1 for r in turn_results if r["coherent"])
|
|
coherence_rate = coherent_count / len(TURNS)
|
|
|
|
return {
|
|
"benchmark": "multi_turn_coherence",
|
|
"model": model,
|
|
"total_turns": len(TURNS),
|
|
"coherent_turns": coherent_count,
|
|
"coherence_rate": round(coherence_rate, 3),
|
|
"passed": coherence_rate >= 0.80,
|
|
"total_time_s": round(total_time, 2),
|
|
"turns": turn_results,
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model = sys.argv[1] if len(sys.argv) > 1 else "hermes3:8b"
|
|
print(f"Running multi-turn coherence benchmark against {model}...")
|
|
result = run_multi_turn(model)
|
|
print(json.dumps(result, indent=2))
|
|
sys.exit(0 if result["passed"] else 1)
|