196 lines
5.8 KiB
Python
196 lines
5.8 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""Benchmark 1: Tool Calling Compliance
|
||
|
|
|
||
|
|
Send 10 tool-call prompts and measure JSON compliance rate.
|
||
|
|
Target: >90% valid JSON.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import json
|
||
|
|
import re
|
||
|
|
import sys
|
||
|
|
import time
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
import requests
|
||
|
|
|
||
|
|
OLLAMA_URL = "http://localhost:11434"
|
||
|
|
|
||
|
|
TOOL_PROMPTS = [
|
||
|
|
{
|
||
|
|
"prompt": (
|
||
|
|
"Call the 'get_weather' tool to retrieve the current weather for San Francisco. "
|
||
|
|
"Return ONLY valid JSON with keys: tool, args."
|
||
|
|
),
|
||
|
|
"expected_keys": ["tool", "args"],
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"prompt": (
|
||
|
|
"Invoke the 'read_file' function with path='/etc/hosts'. "
|
||
|
|
"Return ONLY valid JSON with keys: tool, args."
|
||
|
|
),
|
||
|
|
"expected_keys": ["tool", "args"],
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"prompt": (
|
||
|
|
"Use the 'search_web' tool to look up 'latest Python release'. "
|
||
|
|
"Return ONLY valid JSON with keys: tool, args."
|
||
|
|
),
|
||
|
|
"expected_keys": ["tool", "args"],
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"prompt": (
|
||
|
|
"Call 'create_issue' with title='Fix login bug' and priority='high'. "
|
||
|
|
"Return ONLY valid JSON with keys: tool, args."
|
||
|
|
),
|
||
|
|
"expected_keys": ["tool", "args"],
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"prompt": (
|
||
|
|
"Execute the 'list_directory' tool for path='/home/user/projects'. "
|
||
|
|
"Return ONLY valid JSON with keys: tool, args."
|
||
|
|
),
|
||
|
|
"expected_keys": ["tool", "args"],
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"prompt": (
|
||
|
|
"Call 'send_notification' with message='Deploy complete' and channel='slack'. "
|
||
|
|
"Return ONLY valid JSON with keys: tool, args."
|
||
|
|
),
|
||
|
|
"expected_keys": ["tool", "args"],
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"prompt": (
|
||
|
|
"Invoke 'database_query' with sql='SELECT COUNT(*) FROM users'. "
|
||
|
|
"Return ONLY valid JSON with keys: tool, args."
|
||
|
|
),
|
||
|
|
"expected_keys": ["tool", "args"],
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"prompt": (
|
||
|
|
"Use the 'get_git_log' tool with limit=10 and branch='main'. "
|
||
|
|
"Return ONLY valid JSON with keys: tool, args."
|
||
|
|
),
|
||
|
|
"expected_keys": ["tool", "args"],
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"prompt": (
|
||
|
|
"Call 'schedule_task' with cron='0 9 * * MON-FRI' and task='generate_report'. "
|
||
|
|
"Return ONLY valid JSON with keys: tool, args."
|
||
|
|
),
|
||
|
|
"expected_keys": ["tool", "args"],
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"prompt": (
|
||
|
|
"Invoke 'resize_image' with url='https://example.com/photo.jpg', "
|
||
|
|
"width=800, height=600. "
|
||
|
|
"Return ONLY valid JSON with keys: tool, args."
|
||
|
|
),
|
||
|
|
"expected_keys": ["tool", "args"],
|
||
|
|
},
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
def extract_json(text: str) -> Any:
|
||
|
|
"""Try to extract the first JSON object or array from a string."""
|
||
|
|
# Try direct parse first
|
||
|
|
text = text.strip()
|
||
|
|
try:
|
||
|
|
return json.loads(text)
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
# Try to find JSON block in markdown fences
|
||
|
|
fence_match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
|
||
|
|
if fence_match:
|
||
|
|
try:
|
||
|
|
return json.loads(fence_match.group(1))
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
# Try to find first { ... }
|
||
|
|
brace_match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)?\}", text, re.DOTALL)
|
||
|
|
if brace_match:
|
||
|
|
try:
|
||
|
|
return json.loads(brace_match.group(0))
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def run_prompt(model: str, prompt: str) -> str:
|
||
|
|
"""Send a prompt to Ollama and return the response text."""
|
||
|
|
payload = {
|
||
|
|
"model": model,
|
||
|
|
"prompt": prompt,
|
||
|
|
"stream": False,
|
||
|
|
"options": {"temperature": 0.1, "num_predict": 256},
|
||
|
|
}
|
||
|
|
resp = requests.post(f"{OLLAMA_URL}/api/generate", json=payload, timeout=120)
|
||
|
|
resp.raise_for_status()
|
||
|
|
return resp.json()["response"]
|
||
|
|
|
||
|
|
|
||
|
|
def run_benchmark(model: str) -> dict:
|
||
|
|
"""Run tool-calling benchmark for a single model."""
|
||
|
|
results = []
|
||
|
|
total_time = 0.0
|
||
|
|
|
||
|
|
for i, case in enumerate(TOOL_PROMPTS, 1):
|
||
|
|
start = time.time()
|
||
|
|
try:
|
||
|
|
raw = run_prompt(model, case["prompt"])
|
||
|
|
elapsed = time.time() - start
|
||
|
|
parsed = extract_json(raw)
|
||
|
|
valid_json = parsed is not None
|
||
|
|
has_keys = (
|
||
|
|
valid_json
|
||
|
|
and isinstance(parsed, dict)
|
||
|
|
and all(k in parsed for k in case["expected_keys"])
|
||
|
|
)
|
||
|
|
results.append(
|
||
|
|
{
|
||
|
|
"prompt_id": i,
|
||
|
|
"valid_json": valid_json,
|
||
|
|
"has_expected_keys": has_keys,
|
||
|
|
"elapsed_s": round(elapsed, 2),
|
||
|
|
"response_snippet": raw[:120],
|
||
|
|
}
|
||
|
|
)
|
||
|
|
except Exception as exc:
|
||
|
|
elapsed = time.time() - start
|
||
|
|
results.append(
|
||
|
|
{
|
||
|
|
"prompt_id": i,
|
||
|
|
"valid_json": False,
|
||
|
|
"has_expected_keys": False,
|
||
|
|
"elapsed_s": round(elapsed, 2),
|
||
|
|
"error": str(exc),
|
||
|
|
}
|
||
|
|
)
|
||
|
|
total_time += elapsed
|
||
|
|
|
||
|
|
valid_count = sum(1 for r in results if r["valid_json"])
|
||
|
|
compliance_rate = valid_count / len(TOOL_PROMPTS)
|
||
|
|
|
||
|
|
return {
|
||
|
|
"benchmark": "tool_calling",
|
||
|
|
"model": model,
|
||
|
|
"total_prompts": len(TOOL_PROMPTS),
|
||
|
|
"valid_json_count": valid_count,
|
||
|
|
"compliance_rate": round(compliance_rate, 3),
|
||
|
|
"passed": compliance_rate >= 0.90,
|
||
|
|
"total_time_s": round(total_time, 2),
|
||
|
|
"results": results,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
model = sys.argv[1] if len(sys.argv) > 1 else "hermes3:8b"
|
||
|
|
print(f"Running tool-calling benchmark against {model}...")
|
||
|
|
result = run_benchmark(model)
|
||
|
|
print(json.dumps(result, indent=2))
|
||
|
|
sys.exit(0 if result["passed"] else 1)
|