256 lines
9.8 KiB
Python
256 lines
9.8 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Tool Calling Test Suite for 1-Bit Models (Issue #101)
|
||
|
|
|
||
|
|
Tests whether Bonsai 1-bit models can handle tool calling at all.
|
||
|
|
Evaluates: file read, terminal execution, web search, multi-step workflows, schema parsing.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
python3 benchmarks/test_tool_calling_1bit.py --model bonsai-1bit --backend ollama
|
||
|
|
python3 benchmarks/test_tool_calling_1bit.py --results benchmarks/tool_calling_results.json
|
||
|
|
"""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import re
|
||
|
|
import sys
|
||
|
|
import time
|
||
|
|
from dataclasses import dataclass, field, asdict
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
from enum import Enum
|
||
|
|
from typing import Any, Dict, List, Optional, Tuple
|
||
|
|
|
||
|
|
import requests
|
||
|
|
|
||
|
|
|
||
|
|
class ToolCallType(Enum):
|
||
|
|
FILE_READ = "file_read"
|
||
|
|
TERMINAL_EXEC = "terminal_exec"
|
||
|
|
WEB_SEARCH = "web_search"
|
||
|
|
MULTI_STEP = "multi_step"
|
||
|
|
SCHEMA_PARSING = "schema_parsing"
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class ToolCallTest:
|
||
|
|
name: str
|
||
|
|
tool_type: ToolCallType
|
||
|
|
prompt: str
|
||
|
|
expected_tool: Optional[str]
|
||
|
|
expected_params: Dict[str, Any]
|
||
|
|
validation_fn: Optional[str] = None
|
||
|
|
difficulty: str = "easy"
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class TestResult:
|
||
|
|
test_name: str
|
||
|
|
tool_type: str
|
||
|
|
passed: bool
|
||
|
|
latency_ms: float
|
||
|
|
response_text: str
|
||
|
|
parsed_tool: Optional[str] = None
|
||
|
|
parsed_params: Optional[Dict[str, Any]] = None
|
||
|
|
error: Optional[str] = None
|
||
|
|
quality_score: float = 0.0
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class BenchmarkResult:
|
||
|
|
model: str
|
||
|
|
backend: str
|
||
|
|
timestamp: str
|
||
|
|
results: List[TestResult] = field(default_factory=list)
|
||
|
|
summary: Dict[str, Any] = field(default_factory=dict)
|
||
|
|
|
||
|
|
|
||
|
|
TOOL_CALL_TESTS = [
|
||
|
|
ToolCallTest("simple_file_read", ToolCallType.FILE_READ,
|
||
|
|
"Use read_file to read README.md. Tools: read_file(path: str)",
|
||
|
|
"read_file", {"path": "README.md"}, difficulty="easy"),
|
||
|
|
ToolCallTest("absolute_path_read", ToolCallType.FILE_READ,
|
||
|
|
"Use read_file to read /etc/hostname. Tools: read_file(path: str)",
|
||
|
|
"read_file", {"path": "/etc/hostname"}, difficulty="easy"),
|
||
|
|
ToolCallTest("simple_terminal", ToolCallType.TERMINAL_EXEC,
|
||
|
|
"Use terminal to run: echo hello world. Tools: terminal(command: str)",
|
||
|
|
"terminal", {"command": "echo hello world"}, difficulty="easy"),
|
||
|
|
ToolCallTest("terminal_ls", ToolCallType.TERMINAL_EXEC,
|
||
|
|
"Use terminal to list files. Tools: terminal(command: str)",
|
||
|
|
"terminal", {}, validation_fn="validate_ls", difficulty="medium"),
|
||
|
|
ToolCallTest("web_search", ToolCallType.WEB_SEARCH,
|
||
|
|
"Use web_search for Python. Tools: web_search(query: str)",
|
||
|
|
"web_search", {"query": "Python"}, difficulty="easy"),
|
||
|
|
ToolCallTest("read_then_analyze", ToolCallType.MULTI_STEP,
|
||
|
|
"First read README.md then analyze. Tools: read_file(path: str)",
|
||
|
|
"read_file", {"path": "README.md"}, difficulty="medium"),
|
||
|
|
ToolCallTest("nested_params", ToolCallType.SCHEMA_PARSING,
|
||
|
|
"Use complex_tool(name=test, config={verbose:true}, tags=[a,b]). Tools: complex_tool(name: str, config: dict, tags: list)",
|
||
|
|
"complex_tool", {"name": "test"}, difficulty="hard"),
|
||
|
|
ToolCallTest("optional_params", ToolCallType.SCHEMA_PARSING,
|
||
|
|
"Use search for ML with limit 5. Tools: search(query: str, limit: int=10)",
|
||
|
|
"search", {"query": "ML", "limit": 5}, difficulty="medium"),
|
||
|
|
ToolCallTest("sequential_calls", ToolCallType.MULTI_STEP,
|
||
|
|
"First run pwd, then read README.md. Tools: terminal(command: str), read_file(path: str)",
|
||
|
|
"terminal", {"command": "pwd"}, difficulty="hard"),
|
||
|
|
ToolCallTest("no_tool_needed", ToolCallType.FILE_READ,
|
||
|
|
"What is 2+2? Tools: read_file(path: str)",
|
||
|
|
None, {}, difficulty="easy"),
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
def validate_ls(params):
|
||
|
|
cmd = params.get("command", "").strip()
|
||
|
|
return cmd in ["ls", "ls -l", "ls -la", "ls -1", "dir"] or cmd.startswith("ls ")
|
||
|
|
|
||
|
|
|
||
|
|
VALIDATORS = {"validate_ls": validate_ls}
|
||
|
|
|
||
|
|
|
||
|
|
def parse_tool_call(response: str) -> Tuple[Optional[str], Optional[Dict]]:
|
||
|
|
# JSON format
|
||
|
|
patterns = [
|
||
|
|
r'"tool"\s*:\s*"([^"]+)"\s*,\s*"params"\s*:\s*({[^}]+})',
|
||
|
|
r'"name"\s*:\s*"([^"]+)"\s*,\s*"arguments"\s*:\s*({[^}]+})',
|
||
|
|
]
|
||
|
|
for pattern in patterns:
|
||
|
|
match = re.search(pattern, response, re.DOTALL)
|
||
|
|
if match:
|
||
|
|
try:
|
||
|
|
return match.group(1), json.loads(match.group(2))
|
||
|
|
except:
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Function call format
|
||
|
|
match = re.search(r'(\w+)\(([^)]+)\)', response)
|
||
|
|
if match:
|
||
|
|
tool_name = match.group(1)
|
||
|
|
params = {}
|
||
|
|
for m in re.finditer(r'(\w+)\s*=\s*"?([^",)]+)"?', match.group(2)):
|
||
|
|
params[m.group(1)] = m.group(2).strip().strip('"\'')
|
||
|
|
return tool_name, params
|
||
|
|
|
||
|
|
return None, None
|
||
|
|
|
||
|
|
|
||
|
|
def call_model(prompt: str, model: str, backend: str, url: str, timeout: int = 60) -> Tuple[str, float]:
|
||
|
|
start = time.time()
|
||
|
|
try:
|
||
|
|
if backend == "ollama":
|
||
|
|
resp = requests.post(f"{url}/api/generate", json={
|
||
|
|
"model": model, "prompt": prompt, "stream": False,
|
||
|
|
"options": {"num_predict": 256, "temperature": 0.1}
|
||
|
|
}, timeout=timeout)
|
||
|
|
resp.raise_for_status()
|
||
|
|
text = resp.json().get("response", "")
|
||
|
|
else:
|
||
|
|
text = f"ERROR: Unknown backend {backend}"
|
||
|
|
except Exception as e:
|
||
|
|
text = f"ERROR: {e}"
|
||
|
|
return text, (time.time() - start) * 1000
|
||
|
|
|
||
|
|
|
||
|
|
def run_test(test: ToolCallTest, model: str, backend: str, url: str) -> TestResult:
|
||
|
|
response, latency = call_model(test.prompt, model, backend, url)
|
||
|
|
|
||
|
|
if response.startswith("ERROR:"):
|
||
|
|
return TestResult(test.name, test.tool_type.value, False, latency, response, error=response)
|
||
|
|
|
||
|
|
parsed_tool, parsed_params = parse_tool_call(response)
|
||
|
|
passed = False
|
||
|
|
quality = 0.0
|
||
|
|
|
||
|
|
if test.expected_tool is None:
|
||
|
|
passed = parsed_tool is None
|
||
|
|
quality = 1.0 if passed else 0.0
|
||
|
|
elif parsed_tool:
|
||
|
|
tool_match = parsed_tool.lower() == test.expected_tool.lower()
|
||
|
|
if test.validation_fn and test.validation_fn in VALIDATORS:
|
||
|
|
params_match = VALIDATORS[test.validation_fn](parsed_params or {})
|
||
|
|
else:
|
||
|
|
params_match = all(
|
||
|
|
k in (parsed_params or {}) and
|
||
|
|
(str(v).lower() in str(parsed_params.get(k, "")).lower() if isinstance(v, str) else parsed_params.get(k) == v)
|
||
|
|
for k, v in test.expected_params.items()
|
||
|
|
) if test.expected_params else True
|
||
|
|
passed = tool_match and params_match
|
||
|
|
quality = (0.5 if tool_match else 0) + (0.5 if params_match else 0)
|
||
|
|
|
||
|
|
return TestResult(test.name, test.tool_type.value, passed, latency, response[:500],
|
||
|
|
parsed_tool, parsed_params, quality_score=quality)
|
||
|
|
|
||
|
|
|
||
|
|
def run_all_tests(model: str, backend: str, url: str) -> BenchmarkResult:
|
||
|
|
results = BenchmarkResult(model, backend, datetime.now(timezone.utc).isoformat())
|
||
|
|
print(f"Testing {model} ({backend})")
|
||
|
|
print("=" * 50)
|
||
|
|
|
||
|
|
for test in TOOL_CALL_TESTS:
|
||
|
|
result = run_test(test, model, backend, url)
|
||
|
|
results.results.append(result)
|
||
|
|
status = "PASS" if result.passed else "FAIL"
|
||
|
|
print(f" {status} {test.name} ({result.latency_ms:.0f}ms, q={result.quality_score:.0%})")
|
||
|
|
|
||
|
|
total = len(results.results)
|
||
|
|
passed = sum(1 for r in results.results if r.passed)
|
||
|
|
results.summary = {
|
||
|
|
"total": total, "passed": passed, "failed": total - passed,
|
||
|
|
"pass_rate": passed / total if total else 0,
|
||
|
|
"avg_latency_ms": sum(r.latency_ms for r in results.results) / total if total else 0,
|
||
|
|
"avg_quality": sum(r.quality_score for r in results.results) / total if total else 0,
|
||
|
|
}
|
||
|
|
return results
|
||
|
|
|
||
|
|
|
||
|
|
def generate_report(results: BenchmarkResult) -> str:
|
||
|
|
s = results.summary
|
||
|
|
lines = [
|
||
|
|
"# Tool Calling Test Results - 1-Bit Models", "",
|
||
|
|
f"**Model:** {results.model} ",
|
||
|
|
f"**Backend:** {results.backend} ",
|
||
|
|
f"**Timestamp:** {results.timestamp}", "",
|
||
|
|
"## Summary", "",
|
||
|
|
f"- Pass Rate: {s['passed']}/{s['total']} ({s['pass_rate']:.0%})",
|
||
|
|
f"- Avg Latency: {s['avg_latency_ms']:.0f}ms",
|
||
|
|
f"- Avg Quality: {s['avg_quality']:.0%}", "",
|
||
|
|
"## Detailed Results", "",
|
||
|
|
]
|
||
|
|
for r, t in zip(results.results, TOOL_CALL_TESTS):
|
||
|
|
lines.append(f"- {'PASS' if r.passed else 'FAIL'} {r.test_name} ({t.difficulty}, {r.latency_ms:.0f}ms)")
|
||
|
|
lines.extend(["", "## Conclusion", ""])
|
||
|
|
if s['pass_rate'] >= 0.8:
|
||
|
|
lines.append(f"**VIABLE** - {s['pass_rate']:.0%} pass rate.")
|
||
|
|
elif s['pass_rate'] >= 0.5:
|
||
|
|
lines.append(f"**MARGINAL** - {s['pass_rate']:.0%} pass rate.")
|
||
|
|
else:
|
||
|
|
lines.append(f"**NOT VIABLE** - {s['pass_rate']:.0%} pass rate.")
|
||
|
|
lines.extend(["", "### Alternatives", "- Qwen3.5 3B Q4", "- Phi-3 Mini", "- Llama 3.2 3B"])
|
||
|
|
return "\n".join(lines)
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser()
|
||
|
|
parser.add_argument("--model", default="bonsai-1bit")
|
||
|
|
parser.add_argument("--backend", default="ollama")
|
||
|
|
parser.add_argument("--url", default="http://localhost:11434")
|
||
|
|
parser.add_argument("--results", help="Save results JSON")
|
||
|
|
parser.add_argument("--report", help="Save report markdown")
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
results = run_all_tests(args.model, args.backend, args.url)
|
||
|
|
print(f"\nSUMMARY: {results.summary['passed']}/{results.summary['total']} passed")
|
||
|
|
|
||
|
|
if args.results:
|
||
|
|
os.makedirs(os.path.dirname(args.results) or ".", exist_ok=True)
|
||
|
|
with open(args.results, "w") as f:
|
||
|
|
json.dump(asdict(results), f, indent=2)
|
||
|
|
|
||
|
|
if args.report:
|
||
|
|
os.makedirs(os.path.dirname(args.report) or ".", exist_ok=True)
|
||
|
|
with open(args.report, "w") as f:
|
||
|
|
f.write(generate_report(results))
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|