feat: add tool calling test suite for 1-bit models (#101)
10 test cases covering file read, terminal, web search, multi-step, schema parsing. Closes #101
This commit is contained in:
255
benchmarks/test_tool_calling_1bit.py
Normal file
255
benchmarks/test_tool_calling_1bit.py
Normal file
@@ -0,0 +1,255 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tool Calling Test Suite for 1-Bit Models (Issue #101)
|
||||
|
||||
Tests whether Bonsai 1-bit models can handle tool calling at all.
|
||||
Evaluates: file read, terminal execution, web search, multi-step workflows, schema parsing.
|
||||
|
||||
Usage:
|
||||
python3 benchmarks/test_tool_calling_1bit.py --model bonsai-1bit --backend ollama
|
||||
python3 benchmarks/test_tool_calling_1bit.py --results benchmarks/tool_calling_results.json
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class ToolCallType(Enum):
|
||||
FILE_READ = "file_read"
|
||||
TERMINAL_EXEC = "terminal_exec"
|
||||
WEB_SEARCH = "web_search"
|
||||
MULTI_STEP = "multi_step"
|
||||
SCHEMA_PARSING = "schema_parsing"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallTest:
|
||||
name: str
|
||||
tool_type: ToolCallType
|
||||
prompt: str
|
||||
expected_tool: Optional[str]
|
||||
expected_params: Dict[str, Any]
|
||||
validation_fn: Optional[str] = None
|
||||
difficulty: str = "easy"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
test_name: str
|
||||
tool_type: str
|
||||
passed: bool
|
||||
latency_ms: float
|
||||
response_text: str
|
||||
parsed_tool: Optional[str] = None
|
||||
parsed_params: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
quality_score: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
model: str
|
||||
backend: str
|
||||
timestamp: str
|
||||
results: List[TestResult] = field(default_factory=list)
|
||||
summary: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
TOOL_CALL_TESTS = [
|
||||
ToolCallTest("simple_file_read", ToolCallType.FILE_READ,
|
||||
"Use read_file to read README.md. Tools: read_file(path: str)",
|
||||
"read_file", {"path": "README.md"}, difficulty="easy"),
|
||||
ToolCallTest("absolute_path_read", ToolCallType.FILE_READ,
|
||||
"Use read_file to read /etc/hostname. Tools: read_file(path: str)",
|
||||
"read_file", {"path": "/etc/hostname"}, difficulty="easy"),
|
||||
ToolCallTest("simple_terminal", ToolCallType.TERMINAL_EXEC,
|
||||
"Use terminal to run: echo hello world. Tools: terminal(command: str)",
|
||||
"terminal", {"command": "echo hello world"}, difficulty="easy"),
|
||||
ToolCallTest("terminal_ls", ToolCallType.TERMINAL_EXEC,
|
||||
"Use terminal to list files. Tools: terminal(command: str)",
|
||||
"terminal", {}, validation_fn="validate_ls", difficulty="medium"),
|
||||
ToolCallTest("web_search", ToolCallType.WEB_SEARCH,
|
||||
"Use web_search for Python. Tools: web_search(query: str)",
|
||||
"web_search", {"query": "Python"}, difficulty="easy"),
|
||||
ToolCallTest("read_then_analyze", ToolCallType.MULTI_STEP,
|
||||
"First read README.md then analyze. Tools: read_file(path: str)",
|
||||
"read_file", {"path": "README.md"}, difficulty="medium"),
|
||||
ToolCallTest("nested_params", ToolCallType.SCHEMA_PARSING,
|
||||
"Use complex_tool(name=test, config={verbose:true}, tags=[a,b]). Tools: complex_tool(name: str, config: dict, tags: list)",
|
||||
"complex_tool", {"name": "test"}, difficulty="hard"),
|
||||
ToolCallTest("optional_params", ToolCallType.SCHEMA_PARSING,
|
||||
"Use search for ML with limit 5. Tools: search(query: str, limit: int=10)",
|
||||
"search", {"query": "ML", "limit": 5}, difficulty="medium"),
|
||||
ToolCallTest("sequential_calls", ToolCallType.MULTI_STEP,
|
||||
"First run pwd, then read README.md. Tools: terminal(command: str), read_file(path: str)",
|
||||
"terminal", {"command": "pwd"}, difficulty="hard"),
|
||||
ToolCallTest("no_tool_needed", ToolCallType.FILE_READ,
|
||||
"What is 2+2? Tools: read_file(path: str)",
|
||||
None, {}, difficulty="easy"),
|
||||
]
|
||||
|
||||
|
||||
def validate_ls(params):
|
||||
cmd = params.get("command", "").strip()
|
||||
return cmd in ["ls", "ls -l", "ls -la", "ls -1", "dir"] or cmd.startswith("ls ")
|
||||
|
||||
|
||||
VALIDATORS = {"validate_ls": validate_ls}
|
||||
|
||||
|
||||
def parse_tool_call(response: str) -> Tuple[Optional[str], Optional[Dict]]:
|
||||
# JSON format
|
||||
patterns = [
|
||||
r'"tool"\s*:\s*"([^"]+)"\s*,\s*"params"\s*:\s*({[^}]+})',
|
||||
r'"name"\s*:\s*"([^"]+)"\s*,\s*"arguments"\s*:\s*({[^}]+})',
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, response, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
return match.group(1), json.loads(match.group(2))
|
||||
except:
|
||||
continue
|
||||
|
||||
# Function call format
|
||||
match = re.search(r'(\w+)\(([^)]+)\)', response)
|
||||
if match:
|
||||
tool_name = match.group(1)
|
||||
params = {}
|
||||
for m in re.finditer(r'(\w+)\s*=\s*"?([^",)]+)"?', match.group(2)):
|
||||
params[m.group(1)] = m.group(2).strip().strip('"\'')
|
||||
return tool_name, params
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def call_model(prompt: str, model: str, backend: str, url: str, timeout: int = 60) -> Tuple[str, float]:
|
||||
start = time.time()
|
||||
try:
|
||||
if backend == "ollama":
|
||||
resp = requests.post(f"{url}/api/generate", json={
|
||||
"model": model, "prompt": prompt, "stream": False,
|
||||
"options": {"num_predict": 256, "temperature": 0.1}
|
||||
}, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
text = resp.json().get("response", "")
|
||||
else:
|
||||
text = f"ERROR: Unknown backend {backend}"
|
||||
except Exception as e:
|
||||
text = f"ERROR: {e}"
|
||||
return text, (time.time() - start) * 1000
|
||||
|
||||
|
||||
def run_test(test: ToolCallTest, model: str, backend: str, url: str) -> TestResult:
|
||||
response, latency = call_model(test.prompt, model, backend, url)
|
||||
|
||||
if response.startswith("ERROR:"):
|
||||
return TestResult(test.name, test.tool_type.value, False, latency, response, error=response)
|
||||
|
||||
parsed_tool, parsed_params = parse_tool_call(response)
|
||||
passed = False
|
||||
quality = 0.0
|
||||
|
||||
if test.expected_tool is None:
|
||||
passed = parsed_tool is None
|
||||
quality = 1.0 if passed else 0.0
|
||||
elif parsed_tool:
|
||||
tool_match = parsed_tool.lower() == test.expected_tool.lower()
|
||||
if test.validation_fn and test.validation_fn in VALIDATORS:
|
||||
params_match = VALIDATORS[test.validation_fn](parsed_params or {})
|
||||
else:
|
||||
params_match = all(
|
||||
k in (parsed_params or {}) and
|
||||
(str(v).lower() in str(parsed_params.get(k, "")).lower() if isinstance(v, str) else parsed_params.get(k) == v)
|
||||
for k, v in test.expected_params.items()
|
||||
) if test.expected_params else True
|
||||
passed = tool_match and params_match
|
||||
quality = (0.5 if tool_match else 0) + (0.5 if params_match else 0)
|
||||
|
||||
return TestResult(test.name, test.tool_type.value, passed, latency, response[:500],
|
||||
parsed_tool, parsed_params, quality_score=quality)
|
||||
|
||||
|
||||
def run_all_tests(model: str, backend: str, url: str) -> BenchmarkResult:
|
||||
results = BenchmarkResult(model, backend, datetime.now(timezone.utc).isoformat())
|
||||
print(f"Testing {model} ({backend})")
|
||||
print("=" * 50)
|
||||
|
||||
for test in TOOL_CALL_TESTS:
|
||||
result = run_test(test, model, backend, url)
|
||||
results.results.append(result)
|
||||
status = "PASS" if result.passed else "FAIL"
|
||||
print(f" {status} {test.name} ({result.latency_ms:.0f}ms, q={result.quality_score:.0%})")
|
||||
|
||||
total = len(results.results)
|
||||
passed = sum(1 for r in results.results if r.passed)
|
||||
results.summary = {
|
||||
"total": total, "passed": passed, "failed": total - passed,
|
||||
"pass_rate": passed / total if total else 0,
|
||||
"avg_latency_ms": sum(r.latency_ms for r in results.results) / total if total else 0,
|
||||
"avg_quality": sum(r.quality_score for r in results.results) / total if total else 0,
|
||||
}
|
||||
return results
|
||||
|
||||
|
||||
def generate_report(results: BenchmarkResult) -> str:
|
||||
s = results.summary
|
||||
lines = [
|
||||
"# Tool Calling Test Results - 1-Bit Models", "",
|
||||
f"**Model:** {results.model} ",
|
||||
f"**Backend:** {results.backend} ",
|
||||
f"**Timestamp:** {results.timestamp}", "",
|
||||
"## Summary", "",
|
||||
f"- Pass Rate: {s['passed']}/{s['total']} ({s['pass_rate']:.0%})",
|
||||
f"- Avg Latency: {s['avg_latency_ms']:.0f}ms",
|
||||
f"- Avg Quality: {s['avg_quality']:.0%}", "",
|
||||
"## Detailed Results", "",
|
||||
]
|
||||
for r, t in zip(results.results, TOOL_CALL_TESTS):
|
||||
lines.append(f"- {'PASS' if r.passed else 'FAIL'} {r.test_name} ({t.difficulty}, {r.latency_ms:.0f}ms)")
|
||||
lines.extend(["", "## Conclusion", ""])
|
||||
if s['pass_rate'] >= 0.8:
|
||||
lines.append(f"**VIABLE** - {s['pass_rate']:.0%} pass rate.")
|
||||
elif s['pass_rate'] >= 0.5:
|
||||
lines.append(f"**MARGINAL** - {s['pass_rate']:.0%} pass rate.")
|
||||
else:
|
||||
lines.append(f"**NOT VIABLE** - {s['pass_rate']:.0%} pass rate.")
|
||||
lines.extend(["", "### Alternatives", "- Qwen3.5 3B Q4", "- Phi-3 Mini", "- Llama 3.2 3B"])
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", default="bonsai-1bit")
|
||||
parser.add_argument("--backend", default="ollama")
|
||||
parser.add_argument("--url", default="http://localhost:11434")
|
||||
parser.add_argument("--results", help="Save results JSON")
|
||||
parser.add_argument("--report", help="Save report markdown")
|
||||
args = parser.parse_args()
|
||||
|
||||
results = run_all_tests(args.model, args.backend, args.url)
|
||||
print(f"\nSUMMARY: {results.summary['passed']}/{results.summary['total']} passed")
|
||||
|
||||
if args.results:
|
||||
os.makedirs(os.path.dirname(args.results) or ".", exist_ok=True)
|
||||
with open(args.results, "w") as f:
|
||||
json.dump(asdict(results), f, indent=2)
|
||||
|
||||
if args.report:
|
||||
os.makedirs(os.path.dirname(args.report) or ".", exist_ok=True)
|
||||
with open(args.report, "w") as f:
|
||||
f.write(generate_report(results))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user