feat: add tool calling test suite for 1-bit models (#101)

10 test cases covering file read, terminal, web search, multi-step, schema parsing.
Closes #101
This commit is contained in:
2026-04-16 01:53:01 +00:00
parent 3cd8750cbb
commit 3caeaf13eb

View File

@@ -0,0 +1,255 @@
#!/usr/bin/env python3
"""
Tool Calling Test Suite for 1-Bit Models (Issue #101)
Tests whether Bonsai 1-bit models can handle tool calling at all.
Evaluates: file read, terminal execution, web search, multi-step workflows, schema parsing.
Usage:
python3 benchmarks/test_tool_calling_1bit.py --model bonsai-1bit --backend ollama
python3 benchmarks/test_tool_calling_1bit.py --results benchmarks/tool_calling_results.json
"""
import argparse
import json
import os
import re
import sys
import time
from dataclasses import dataclass, field, asdict
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
import requests
class ToolCallType(Enum):
FILE_READ = "file_read"
TERMINAL_EXEC = "terminal_exec"
WEB_SEARCH = "web_search"
MULTI_STEP = "multi_step"
SCHEMA_PARSING = "schema_parsing"
@dataclass
class ToolCallTest:
name: str
tool_type: ToolCallType
prompt: str
expected_tool: Optional[str]
expected_params: Dict[str, Any]
validation_fn: Optional[str] = None
difficulty: str = "easy"
@dataclass
class TestResult:
test_name: str
tool_type: str
passed: bool
latency_ms: float
response_text: str
parsed_tool: Optional[str] = None
parsed_params: Optional[Dict[str, Any]] = None
error: Optional[str] = None
quality_score: float = 0.0
@dataclass
class BenchmarkResult:
model: str
backend: str
timestamp: str
results: List[TestResult] = field(default_factory=list)
summary: Dict[str, Any] = field(default_factory=dict)
TOOL_CALL_TESTS = [
ToolCallTest("simple_file_read", ToolCallType.FILE_READ,
"Use read_file to read README.md. Tools: read_file(path: str)",
"read_file", {"path": "README.md"}, difficulty="easy"),
ToolCallTest("absolute_path_read", ToolCallType.FILE_READ,
"Use read_file to read /etc/hostname. Tools: read_file(path: str)",
"read_file", {"path": "/etc/hostname"}, difficulty="easy"),
ToolCallTest("simple_terminal", ToolCallType.TERMINAL_EXEC,
"Use terminal to run: echo hello world. Tools: terminal(command: str)",
"terminal", {"command": "echo hello world"}, difficulty="easy"),
ToolCallTest("terminal_ls", ToolCallType.TERMINAL_EXEC,
"Use terminal to list files. Tools: terminal(command: str)",
"terminal", {}, validation_fn="validate_ls", difficulty="medium"),
ToolCallTest("web_search", ToolCallType.WEB_SEARCH,
"Use web_search for Python. Tools: web_search(query: str)",
"web_search", {"query": "Python"}, difficulty="easy"),
ToolCallTest("read_then_analyze", ToolCallType.MULTI_STEP,
"First read README.md then analyze. Tools: read_file(path: str)",
"read_file", {"path": "README.md"}, difficulty="medium"),
ToolCallTest("nested_params", ToolCallType.SCHEMA_PARSING,
"Use complex_tool(name=test, config={verbose:true}, tags=[a,b]). Tools: complex_tool(name: str, config: dict, tags: list)",
"complex_tool", {"name": "test"}, difficulty="hard"),
ToolCallTest("optional_params", ToolCallType.SCHEMA_PARSING,
"Use search for ML with limit 5. Tools: search(query: str, limit: int=10)",
"search", {"query": "ML", "limit": 5}, difficulty="medium"),
ToolCallTest("sequential_calls", ToolCallType.MULTI_STEP,
"First run pwd, then read README.md. Tools: terminal(command: str), read_file(path: str)",
"terminal", {"command": "pwd"}, difficulty="hard"),
ToolCallTest("no_tool_needed", ToolCallType.FILE_READ,
"What is 2+2? Tools: read_file(path: str)",
None, {}, difficulty="easy"),
]
def validate_ls(params):
cmd = params.get("command", "").strip()
return cmd in ["ls", "ls -l", "ls -la", "ls -1", "dir"] or cmd.startswith("ls ")
VALIDATORS = {"validate_ls": validate_ls}
def parse_tool_call(response: str) -> Tuple[Optional[str], Optional[Dict]]:
# JSON format
patterns = [
r'"tool"\s*:\s*"([^"]+)"\s*,\s*"params"\s*:\s*({[^}]+})',
r'"name"\s*:\s*"([^"]+)"\s*,\s*"arguments"\s*:\s*({[^}]+})',
]
for pattern in patterns:
match = re.search(pattern, response, re.DOTALL)
if match:
try:
return match.group(1), json.loads(match.group(2))
except:
continue
# Function call format
match = re.search(r'(\w+)\(([^)]+)\)', response)
if match:
tool_name = match.group(1)
params = {}
for m in re.finditer(r'(\w+)\s*=\s*"?([^",)]+)"?', match.group(2)):
params[m.group(1)] = m.group(2).strip().strip('"\'')
return tool_name, params
return None, None
def call_model(prompt: str, model: str, backend: str, url: str, timeout: int = 60) -> Tuple[str, float]:
start = time.time()
try:
if backend == "ollama":
resp = requests.post(f"{url}/api/generate", json={
"model": model, "prompt": prompt, "stream": False,
"options": {"num_predict": 256, "temperature": 0.1}
}, timeout=timeout)
resp.raise_for_status()
text = resp.json().get("response", "")
else:
text = f"ERROR: Unknown backend {backend}"
except Exception as e:
text = f"ERROR: {e}"
return text, (time.time() - start) * 1000
def run_test(test: ToolCallTest, model: str, backend: str, url: str) -> TestResult:
response, latency = call_model(test.prompt, model, backend, url)
if response.startswith("ERROR:"):
return TestResult(test.name, test.tool_type.value, False, latency, response, error=response)
parsed_tool, parsed_params = parse_tool_call(response)
passed = False
quality = 0.0
if test.expected_tool is None:
passed = parsed_tool is None
quality = 1.0 if passed else 0.0
elif parsed_tool:
tool_match = parsed_tool.lower() == test.expected_tool.lower()
if test.validation_fn and test.validation_fn in VALIDATORS:
params_match = VALIDATORS[test.validation_fn](parsed_params or {})
else:
params_match = all(
k in (parsed_params or {}) and
(str(v).lower() in str(parsed_params.get(k, "")).lower() if isinstance(v, str) else parsed_params.get(k) == v)
for k, v in test.expected_params.items()
) if test.expected_params else True
passed = tool_match and params_match
quality = (0.5 if tool_match else 0) + (0.5 if params_match else 0)
return TestResult(test.name, test.tool_type.value, passed, latency, response[:500],
parsed_tool, parsed_params, quality_score=quality)
def run_all_tests(model: str, backend: str, url: str) -> BenchmarkResult:
results = BenchmarkResult(model, backend, datetime.now(timezone.utc).isoformat())
print(f"Testing {model} ({backend})")
print("=" * 50)
for test in TOOL_CALL_TESTS:
result = run_test(test, model, backend, url)
results.results.append(result)
status = "PASS" if result.passed else "FAIL"
print(f" {status} {test.name} ({result.latency_ms:.0f}ms, q={result.quality_score:.0%})")
total = len(results.results)
passed = sum(1 for r in results.results if r.passed)
results.summary = {
"total": total, "passed": passed, "failed": total - passed,
"pass_rate": passed / total if total else 0,
"avg_latency_ms": sum(r.latency_ms for r in results.results) / total if total else 0,
"avg_quality": sum(r.quality_score for r in results.results) / total if total else 0,
}
return results
def generate_report(results: BenchmarkResult) -> str:
s = results.summary
lines = [
"# Tool Calling Test Results - 1-Bit Models", "",
f"**Model:** {results.model} ",
f"**Backend:** {results.backend} ",
f"**Timestamp:** {results.timestamp}", "",
"## Summary", "",
f"- Pass Rate: {s['passed']}/{s['total']} ({s['pass_rate']:.0%})",
f"- Avg Latency: {s['avg_latency_ms']:.0f}ms",
f"- Avg Quality: {s['avg_quality']:.0%}", "",
"## Detailed Results", "",
]
for r, t in zip(results.results, TOOL_CALL_TESTS):
lines.append(f"- {'PASS' if r.passed else 'FAIL'} {r.test_name} ({t.difficulty}, {r.latency_ms:.0f}ms)")
lines.extend(["", "## Conclusion", ""])
if s['pass_rate'] >= 0.8:
lines.append(f"**VIABLE** - {s['pass_rate']:.0%} pass rate.")
elif s['pass_rate'] >= 0.5:
lines.append(f"**MARGINAL** - {s['pass_rate']:.0%} pass rate.")
else:
lines.append(f"**NOT VIABLE** - {s['pass_rate']:.0%} pass rate.")
lines.extend(["", "### Alternatives", "- Qwen3.5 3B Q4", "- Phi-3 Mini", "- Llama 3.2 3B"])
return "\n".join(lines)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="bonsai-1bit")
parser.add_argument("--backend", default="ollama")
parser.add_argument("--url", default="http://localhost:11434")
parser.add_argument("--results", help="Save results JSON")
parser.add_argument("--report", help="Save report markdown")
args = parser.parse_args()
results = run_all_tests(args.model, args.backend, args.url)
print(f"\nSUMMARY: {results.summary['passed']}/{results.summary['total']} passed")
if args.results:
os.makedirs(os.path.dirname(args.results) or ".", exist_ok=True)
with open(args.results, "w") as f:
json.dump(asdict(results), f, indent=2)
if args.report:
os.makedirs(os.path.dirname(args.report) or ".", exist_ok=True)
with open(args.report, "w") as f:
f.write(generate_report(results))
if __name__ == "__main__":
main()