#!/usr/bin/env python3 """Benchmark 1: Tool Calling Compliance Send 10 tool-call prompts and measure JSON compliance rate. Target: >90% valid JSON. """ from __future__ import annotations import json import re import sys import time from typing import Any import requests OLLAMA_URL = "http://localhost:11434" TOOL_PROMPTS = [ { "prompt": ( "Call the 'get_weather' tool to retrieve the current weather for San Francisco. " "Return ONLY valid JSON with keys: tool, args." ), "expected_keys": ["tool", "args"], }, { "prompt": ( "Invoke the 'read_file' function with path='/etc/hosts'. " "Return ONLY valid JSON with keys: tool, args." ), "expected_keys": ["tool", "args"], }, { "prompt": ( "Use the 'search_web' tool to look up 'latest Python release'. " "Return ONLY valid JSON with keys: tool, args." ), "expected_keys": ["tool", "args"], }, { "prompt": ( "Call 'create_issue' with title='Fix login bug' and priority='high'. " "Return ONLY valid JSON with keys: tool, args." ), "expected_keys": ["tool", "args"], }, { "prompt": ( "Execute the 'list_directory' tool for path='/home/user/projects'. " "Return ONLY valid JSON with keys: tool, args." ), "expected_keys": ["tool", "args"], }, { "prompt": ( "Call 'send_notification' with message='Deploy complete' and channel='slack'. " "Return ONLY valid JSON with keys: tool, args." ), "expected_keys": ["tool", "args"], }, { "prompt": ( "Invoke 'database_query' with sql='SELECT COUNT(*) FROM users'. " "Return ONLY valid JSON with keys: tool, args." ), "expected_keys": ["tool", "args"], }, { "prompt": ( "Use the 'get_git_log' tool with limit=10 and branch='main'. " "Return ONLY valid JSON with keys: tool, args." ), "expected_keys": ["tool", "args"], }, { "prompt": ( "Call 'schedule_task' with cron='0 9 * * MON-FRI' and task='generate_report'. " "Return ONLY valid JSON with keys: tool, args." ), "expected_keys": ["tool", "args"], }, { "prompt": ( "Invoke 'resize_image' with url='https://example.com/photo.jpg', " "width=800, height=600. " "Return ONLY valid JSON with keys: tool, args." ), "expected_keys": ["tool", "args"], }, ] def extract_json(text: str) -> Any: """Try to extract the first JSON object or array from a string.""" # Try direct parse first text = text.strip() try: return json.loads(text) except json.JSONDecodeError: pass # Try to find JSON block in markdown fences fence_match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL) if fence_match: try: return json.loads(fence_match.group(1)) except json.JSONDecodeError: pass # Try to find first { ... } brace_match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)?\}", text, re.DOTALL) if brace_match: try: return json.loads(brace_match.group(0)) except json.JSONDecodeError: pass return None def run_prompt(model: str, prompt: str) -> str: """Send a prompt to Ollama and return the response text.""" payload = { "model": model, "prompt": prompt, "stream": False, "options": {"temperature": 0.1, "num_predict": 256}, } resp = requests.post(f"{OLLAMA_URL}/api/generate", json=payload, timeout=120) resp.raise_for_status() return resp.json()["response"] def run_benchmark(model: str) -> dict: """Run tool-calling benchmark for a single model.""" results = [] total_time = 0.0 for i, case in enumerate(TOOL_PROMPTS, 1): start = time.time() try: raw = run_prompt(model, case["prompt"]) elapsed = time.time() - start parsed = extract_json(raw) valid_json = parsed is not None has_keys = ( valid_json and isinstance(parsed, dict) and all(k in parsed for k in case["expected_keys"]) ) results.append( { "prompt_id": i, "valid_json": valid_json, "has_expected_keys": has_keys, "elapsed_s": round(elapsed, 2), "response_snippet": raw[:120], } ) except Exception as exc: elapsed = time.time() - start results.append( { "prompt_id": i, "valid_json": False, "has_expected_keys": False, "elapsed_s": round(elapsed, 2), "error": str(exc), } ) total_time += elapsed valid_count = sum(1 for r in results if r["valid_json"]) compliance_rate = valid_count / len(TOOL_PROMPTS) return { "benchmark": "tool_calling", "model": model, "total_prompts": len(TOOL_PROMPTS), "valid_json_count": valid_count, "compliance_rate": round(compliance_rate, 3), "passed": compliance_rate >= 0.90, "total_time_s": round(total_time, 2), "results": results, } if __name__ == "__main__": model = sys.argv[1] if len(sys.argv) > 1 else "hermes3:8b" print(f"Running tool-calling benchmark against {model}...") result = run_benchmark(model) print(json.dumps(result, indent=2)) sys.exit(0 if result["passed"] else 1)