Compare commits

...

1 Commits

Author SHA1 Message Date
Alexander Whitestone
4bb12e05ef bench: Gemma 4 tool calling benchmark — 100 prompts (#796)
Some checks are pending
Contributor Attribution Check / check-attribution (pull_request) Waiting to run
Docker Build and Publish / build-and-push (pull_request) Waiting to run
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Waiting to run
Tests / test (pull_request) Waiting to run
Tests / e2e (pull_request) Waiting to run
Benchmark script comparing Gemma 4 vs mimo-v2-pro on tool calling.

100 prompts across 6 categories:
- File operations (20): read, write, search
- Terminal commands (20): system info, process management
- Web search (15): documentation, comparisons
- Code execution (15): calculations, parsing
- Parallel tool calls (10): concurrent operations
- Edge cases (20): complex, ambiguous prompts

Metrics:
- Schema parse success rate
- Tool execution success rate
- Argument validity rate
- Average latency
- Token cost

Usage:
  python3 benchmarks/tool_call_benchmark.py --model1 gemma3:27b --model2 xiaomi/mimo-v2-pro
  python3 benchmarks/tool_call_benchmark.py --model1 gemma3:27b --limit 10

Closes #796
2026-04-16 01:05:29 -04:00

461
benchmarks/tool_call_benchmark.py Executable file
View File

@@ -0,0 +1,461 @@
#!/usr/bin/env python3
"""
tool_call_benchmark.py — Benchmark Gemma 4 tool calling vs mimo-v2-pro.
Runs 100 diverse tool calling prompts through each model and compares:
- Schema parse success rate
- Tool execution success rate
- Parallel tool call success rate
- Average latency
- Token cost per call
Usage:
python3 benchmarks/tool_call_benchmark.py --model1 gemma3:27b --model2 xiaomi/mimo-v2-pro
python3 benchmarks/tool_call_benchmark.py --model1 gemma3:27b --limit 10 # quick test
python3 benchmarks/tool_call_benchmark.py --output benchmarks/results.json
Requires:
- Ollama running locally (or --endpoint for remote)
- Models pulled: ollama pull gemma3:27b, etc.
"""
import json
import os
import sys
import time
import urllib.request
import urllib.error
from datetime import datetime, timezone
from pathlib import Path
from dataclasses import dataclass, field, asdict
from typing import List, Dict, Optional
ENDPOINT = os.environ.get("OPENAI_BASE_URL", "http://localhost:11434/v1")
API_KEY = os.environ.get("OPENAI_API_KEY", "ollama")
# ── Tool schemas (subset for benchmarking) ──────────────────────────────
TOOL_SCHEMAS = [
{
"type": "function",
"function": {
"name": "read_file",
"description": "Read a text file",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "File path"},
"offset": {"type": "integer", "description": "Start line"},
"limit": {"type": "integer", "description": "Max lines"}
},
"required": ["path"]
}
}
},
{
"type": "function",
"function": {
"name": "terminal",
"description": "Execute a shell command",
"parameters": {
"type": "object",
"properties": {
"command": {"type": "string", "description": "Shell command"}
},
"required": ["command"]
}
}
},
{
"type": "function",
"function": {
"name": "write_file",
"description": "Write content to a file",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string"},
"content": {"type": "string"}
},
"required": ["path", "content"]
}
}
},
{
"type": "function",
"function": {
"name": "search_files",
"description": "Search for content in files",
"parameters": {
"type": "object",
"properties": {
"pattern": {"type": "string"},
"path": {"type": "string"}
},
"required": ["pattern"]
}
}
},
{
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string"}
},
"required": ["query"]
}
}
},
{
"type": "function",
"function": {
"name": "execute_code",
"description": "Execute Python code",
"parameters": {
"type": "object",
"properties": {
"code": {"type": "string"}
},
"required": ["code"]
}
}
},
]
SYSTEM_PROMPT = "You are a helpful assistant with access to tools. Use tools when needed."
# ── Test prompts (100 diverse tool calling scenarios) ────────────────────
TEST_PROMPTS = [
# File operations (20)
("Read the README.md file", "read_file", "file_ops"),
("Show me the contents of config.yaml", "read_file", "file_ops"),
("Read lines 10-20 of main.py", "read_file", "file_ops"),
("Open the package.json", "read_file", "file_ops"),
("Read the .gitignore file", "read_file", "file_ops"),
("Save this to notes.txt: meeting at 3pm", "write_file", "file_ops"),
("Create a new file hello.py with print hello", "write_file", "file_ops"),
("Write the config to settings.json", "write_file", "file_ops"),
("Save the output to results.txt", "write_file", "file_ops"),
("Create TODO.md with my tasks", "write_file", "file_ops"),
("Search for 'import os' in the codebase", "search_files", "file_ops"),
("Find all Python files mentioning 'error'", "search_files", "file_ops"),
("Search for TODO comments", "search_files", "file_ops"),
("Find where 'authenticate' is defined", "search_files", "file_ops"),
("Look for any hardcoded API keys", "search_files", "file_ops"),
("Read the Makefile", "read_file", "file_ops"),
("Show me the Dockerfile", "read_file", "file_ops"),
("Read the docker-compose.yml", "read_file", "file_ops"),
("Save the function to utils.py", "write_file", "file_ops"),
("Create a backup of config.yaml", "write_file", "file_ops"),
# Terminal commands (20)
("List all files in the current directory", "terminal", "terminal"),
("Show disk usage", "terminal", "terminal"),
("Check what processes are running", "terminal", "terminal"),
("Show the git log", "terminal", "terminal"),
("Check the Python version", "terminal", "terminal"),
("Run ls -la in the home directory", "terminal", "terminal"),
("Show the current date and time", "terminal", "terminal"),
("Check network connectivity with ping", "terminal", "terminal"),
("Show environment variables", "terminal", "terminal"),
("List running docker containers", "terminal", "terminal"),
("Check system memory usage", "terminal", "terminal"),
("Show the crontab", "terminal", "terminal"),
("Check the firewall status", "terminal", "terminal"),
("Show recent log entries", "terminal", "terminal"),
("Check disk free space", "terminal", "terminal"),
("Run a system update check", "terminal", "terminal"),
("Show open network connections", "terminal", "terminal"),
("Check the timezone", "terminal", "terminal"),
("List tmux sessions", "terminal", "terminal"),
("Check systemd service status", "terminal", "terminal"),
# Web search (15)
("Search for Python asyncio documentation", "web_search", "web"),
("Look up the latest GPT-4 pricing", "web_search", "web"),
("Find information about Gemma 4 benchmarks", "web_search", "web"),
("Search for Rust vs Go performance comparison", "web_search", "web"),
("Look up Docker best practices", "web_search", "web"),
("Search for Kubernetes deployment tutorials", "web_search", "web"),
("Find the latest AI safety research papers", "web_search", "web"),
("Search for SQLite vs PostgreSQL comparison", "web_search", "web"),
("Look up Linux kernel tuning parameters", "web_search", "web"),
("Search for WebSocket protocol specification", "web_search", "web"),
("Find information about Matrix protocol federation", "web_search", "web"),
("Search for MCP protocol documentation", "web_search", "web"),
("Look up A2A agent protocol spec", "web_search", "web"),
("Search for quantization methods for LLMs", "web_search", "web"),
("Find information about GRPO training", "web_search", "web"),
# Code execution (15)
("Calculate the factorial of 20", "execute_code", "code"),
("Parse this JSON and extract keys", "execute_code", "code"),
("Sort a list of numbers", "execute_code", "code"),
("Calculate the fibonacci sequence", "execute_code", "code"),
("Convert a CSV to JSON", "execute_code", "code"),
("Parse an email address", "execute_code", "code"),
("Calculate elapsed time between dates", "execute_code", "code"),
("Generate a random password", "execute_code", "code"),
("Hash a string with SHA256", "execute_code", "code"),
("Parse a URL into components", "execute_code", "code"),
("Calculate statistics on a dataset", "execute_code", "code"),
("Convert epoch timestamp to human readable", "execute_code", "code"),
("Validate an IPv4 address", "execute_code", "code"),
("Calculate the distance between coordinates", "execute_code", "code"),
("Generate a UUID", "execute_code", "code"),
# Parallel tool calls (10)
("Read config.yaml and show git status at the same time", "read_file|terminal", "parallel"),
("Check disk usage and memory usage simultaneously", "terminal|terminal", "parallel"),
("Read two files at once: README and CHANGELOG", "read_file|read_file", "parallel"),
("Search for imports in both Python and JS files", "search_files|search_files", "parallel"),
("Check git log and disk space in parallel", "terminal|terminal", "parallel"),
("Read the Makefile and Dockerfile together", "read_file|read_file", "parallel"),
("Search for TODO and FIXME at the same time", "search_files|search_files", "parallel"),
("List files and check Python version simultaneously", "terminal|terminal", "parallel"),
("Read package.json and requirements.txt together", "read_file|read_file", "parallel"),
("Check system time and uptime in parallel", "terminal|terminal", "parallel"),
]
@dataclass
class BenchmarkResult:
model: str
prompt: str
expected_tool: str
category: str
success: bool = False
tool_called: str = ""
args_valid: bool = False
latency_ms: float = 0.0
prompt_tokens: int = 0
completion_tokens: int = 0
error: str = ""
def call_model(model: str, prompt: str) -> dict:
"""Call a model with tool schemas and return the response."""
url = f"{ENDPOINT}/chat/completions"
data = {
"model": model,
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
"tools": TOOL_SCHEMAS,
"max_tokens": 512,
"temperature": 0.0,
}
body = json.dumps(data).encode()
req = urllib.request.Request(url, data=body, headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}",
}, method="POST")
start = time.time()
try:
with urllib.request.urlopen(req, timeout=60) as resp:
result = json.loads(resp.read())
elapsed = time.time() - start
return {"response": result, "elapsed": elapsed, "error": None}
except Exception as e:
elapsed = time.time() - start
return {"response": None, "elapsed": elapsed, "error": str(e)}
def evaluate_response(result: dict, expected_tool: str) -> BenchmarkResult:
"""Evaluate a model response against expectations."""
resp = result.get("response")
error = result.get("error", "")
elapsed = result.get("elapsed", 0)
br = BenchmarkResult(
model="",
prompt="",
expected_tool=expected_tool,
category="",
latency_ms=round(elapsed * 1000, 1),
error=error or "",
)
if not resp:
br.success = False
return br
usage = resp.get("usage", {})
br.prompt_tokens = usage.get("prompt_tokens", 0)
br.completion_tokens = usage.get("completion_tokens", 0)
choice = resp.get("choices", [{}])[0]
message = choice.get("message", {})
tool_calls = message.get("tool_calls", [])
if not tool_calls:
br.success = False
br.error = "no_tool_calls"
return br
# Check first tool call
tc = tool_calls[0]
fn = tc.get("function", {})
br.tool_called = fn.get("name", "")
# Parse args
args_str = fn.get("arguments", "{}")
try:
json.loads(args_str)
br.args_valid = True
except json.JSONDecodeError:
# Try normalization
try:
import re
fixed = re.sub(r',\s*([}\]])', r'\1', args_str.strip())
json.loads(fixed)
br.args_valid = True
except:
br.args_valid = False
# Success = tool called matches expected (or contains it for parallel)
expected = expected_tool.split("|")[0] # primary expected tool
br.success = br.tool_called == expected and br.args_valid
return br
def run_benchmark(model: str, prompts: list, limit: int = None) -> List[BenchmarkResult]:
"""Run benchmark against a model."""
if limit:
prompts = prompts[:limit]
results = []
for i, (prompt, expected_tool, category) in enumerate(prompts):
print(f" [{i+1}/{len(prompts)}] {model}: {prompt[:50]}...", end=" ", flush=True)
raw = call_model(model, prompt)
br = evaluate_response(raw, expected_tool)
br.model = model
br.prompt = prompt
br.category = category
status = "OK" if br.success else f"FAIL({br.error or br.tool_called})"
print(f"{status} {br.latency_ms}ms")
results.append(br)
return results
def generate_report(results: List[BenchmarkResult]) -> str:
"""Generate markdown benchmark report."""
by_model = {}
for r in results:
if r.model not in by_model:
by_model[r.model] = []
by_model[r.model].append(r)
lines = [
"# Gemma 4 Tool Calling Benchmark",
f"",
f"**Date:** {datetime.now().strftime('%Y-%m-%d %H:%M')}",
f"**Prompts:** {len(results) // len(by_model)} per model",
f"",
]
# Summary table
lines.append("| Metric | " + " | ".join(by_model.keys()) + " |")
lines.append("|--------|" + "|".join(["--------"] * len(by_model)) + "|")
metrics = ["schema_parse", "tool_execution", "avg_latency_ms", "total_prompt_tokens"]
for metric in ["success_rate", "args_valid_rate", "avg_latency_ms", "total_prompt_tokens"]:
vals = []
for model, rs in by_model.items():
if metric == "success_rate":
v = sum(1 for r in rs if r.success) / len(rs) * 100
vals.append(f"{v:.1f}%")
elif metric == "args_valid_rate":
v = sum(1 for r in rs if r.args_valid) / len(rs) * 100
vals.append(f"{v:.1f}%")
elif metric == "avg_latency_ms":
v = sum(r.latency_ms for r in rs) / len(rs)
vals.append(f"{v:.0f}ms")
elif metric == "total_prompt_tokens":
v = sum(r.prompt_tokens for r in rs)
vals.append(f"{v:,}")
label = metric.replace("_", " ").title()
lines.append(f"| {label} | " + " | ".join(vals) + " |")
lines.append("")
# By category
lines.append("## By Category")
lines.append("")
lines.append("| Category | " + " | ".join(f"{m} success" for m in by_model.keys()) + " |")
lines.append("|----------|" + "|".join(["--------"] * len(by_model)) + "|")
categories = sorted(set(r.category for r in results))
for cat in categories:
vals = []
for model, rs in by_model.items():
cat_results = [r for r in rs if r.category == cat]
if cat_results:
v = sum(1 for r in cat_results if r.success) / len(cat_results) * 100
vals.append(f"{v:.0f}%")
else:
vals.append("N/A")
lines.append(f"| {cat} | " + " | ".join(vals) + " |")
return "\n".join(lines)
def main():
import argparse
parser = argparse.ArgumentParser(description="Tool calling benchmark")
parser.add_argument("--model1", default="gemma3:27b")
parser.add_argument("--model2", default="xiaomi/mimo-v2-pro")
parser.add_argument("--endpoint", default=ENDPOINT)
parser.add_argument("--limit", type=int, default=None)
parser.add_argument("--output", default=None)
parser.add_argument("--markdown", action="store_true")
args = parser.parse_args()
global ENDPOINT
ENDPOINT = args.endpoint
prompts = TEST_PROMPTS
if args.limit:
prompts = prompts[:args.limit]
print(f"Benchmark: {args.model1} vs {args.model2}")
print(f"Prompts: {len(prompts)}")
print()
print(f"--- {args.model1} ---")
results1 = run_benchmark(args.model1, prompts)
print(f"\n--- {args.model2} ---")
results2 = run_benchmark(args.model2, prompts)
all_results = results1 + results2
report = generate_report(all_results)
print(f"\n{report}")
if args.output:
with open(args.output, "w") as f:
json.dump([r.__dict__ for r in all_results], f, indent=2, default=str)
print(f"\nResults saved to {args.output}")
# Save markdown report
report_path = f"benchmarks/gemma4-tool-calling-{datetime.now().strftime('%Y-%m-%d')}.md"
Path("benchmarks").mkdir(exist_ok=True)
with open(report_path, "w") as f:
f.write(report)
print(f"Report saved to {report_path}")
if __name__ == "__main__":
main()