From 3caeaf13eb26be9d721516abd4fd9a677148fbfe Mon Sep 17 00:00:00 2001 From: Alexander Whitestone Date: Thu, 16 Apr 2026 01:53:01 +0000 Subject: [PATCH] feat: add tool calling test suite for 1-bit models (#101) 10 test cases covering file read, terminal, web search, multi-step, schema parsing. Closes #101 --- benchmarks/test_tool_calling_1bit.py | 255 +++++++++++++++++++++++++++ 1 file changed, 255 insertions(+) create mode 100644 benchmarks/test_tool_calling_1bit.py diff --git a/benchmarks/test_tool_calling_1bit.py b/benchmarks/test_tool_calling_1bit.py new file mode 100644 index 00000000..72ee2aa7 --- /dev/null +++ b/benchmarks/test_tool_calling_1bit.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +""" +Tool Calling Test Suite for 1-Bit Models (Issue #101) + +Tests whether Bonsai 1-bit models can handle tool calling at all. +Evaluates: file read, terminal execution, web search, multi-step workflows, schema parsing. + +Usage: + python3 benchmarks/test_tool_calling_1bit.py --model bonsai-1bit --backend ollama + python3 benchmarks/test_tool_calling_1bit.py --results benchmarks/tool_calling_results.json +""" + +import argparse +import json +import os +import re +import sys +import time +from dataclasses import dataclass, field, asdict +from datetime import datetime, timezone +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +import requests + + +class ToolCallType(Enum): + FILE_READ = "file_read" + TERMINAL_EXEC = "terminal_exec" + WEB_SEARCH = "web_search" + MULTI_STEP = "multi_step" + SCHEMA_PARSING = "schema_parsing" + + +@dataclass +class ToolCallTest: + name: str + tool_type: ToolCallType + prompt: str + expected_tool: Optional[str] + expected_params: Dict[str, Any] + validation_fn: Optional[str] = None + difficulty: str = "easy" + + +@dataclass +class TestResult: + test_name: str + tool_type: str + passed: bool + latency_ms: float + response_text: str + parsed_tool: Optional[str] = None + parsed_params: Optional[Dict[str, Any]] = None + error: Optional[str] = None + quality_score: float = 0.0 + + +@dataclass +class BenchmarkResult: + model: str + backend: str + timestamp: str + results: List[TestResult] = field(default_factory=list) + summary: Dict[str, Any] = field(default_factory=dict) + + +TOOL_CALL_TESTS = [ + ToolCallTest("simple_file_read", ToolCallType.FILE_READ, + "Use read_file to read README.md. Tools: read_file(path: str)", + "read_file", {"path": "README.md"}, difficulty="easy"), + ToolCallTest("absolute_path_read", ToolCallType.FILE_READ, + "Use read_file to read /etc/hostname. Tools: read_file(path: str)", + "read_file", {"path": "/etc/hostname"}, difficulty="easy"), + ToolCallTest("simple_terminal", ToolCallType.TERMINAL_EXEC, + "Use terminal to run: echo hello world. Tools: terminal(command: str)", + "terminal", {"command": "echo hello world"}, difficulty="easy"), + ToolCallTest("terminal_ls", ToolCallType.TERMINAL_EXEC, + "Use terminal to list files. Tools: terminal(command: str)", + "terminal", {}, validation_fn="validate_ls", difficulty="medium"), + ToolCallTest("web_search", ToolCallType.WEB_SEARCH, + "Use web_search for Python. Tools: web_search(query: str)", + "web_search", {"query": "Python"}, difficulty="easy"), + ToolCallTest("read_then_analyze", ToolCallType.MULTI_STEP, + "First read README.md then analyze. Tools: read_file(path: str)", + "read_file", {"path": "README.md"}, difficulty="medium"), + ToolCallTest("nested_params", ToolCallType.SCHEMA_PARSING, + "Use complex_tool(name=test, config={verbose:true}, tags=[a,b]). Tools: complex_tool(name: str, config: dict, tags: list)", + "complex_tool", {"name": "test"}, difficulty="hard"), + ToolCallTest("optional_params", ToolCallType.SCHEMA_PARSING, + "Use search for ML with limit 5. Tools: search(query: str, limit: int=10)", + "search", {"query": "ML", "limit": 5}, difficulty="medium"), + ToolCallTest("sequential_calls", ToolCallType.MULTI_STEP, + "First run pwd, then read README.md. Tools: terminal(command: str), read_file(path: str)", + "terminal", {"command": "pwd"}, difficulty="hard"), + ToolCallTest("no_tool_needed", ToolCallType.FILE_READ, + "What is 2+2? Tools: read_file(path: str)", + None, {}, difficulty="easy"), +] + + +def validate_ls(params): + cmd = params.get("command", "").strip() + return cmd in ["ls", "ls -l", "ls -la", "ls -1", "dir"] or cmd.startswith("ls ") + + +VALIDATORS = {"validate_ls": validate_ls} + + +def parse_tool_call(response: str) -> Tuple[Optional[str], Optional[Dict]]: + # JSON format + patterns = [ + r'"tool"\s*:\s*"([^"]+)"\s*,\s*"params"\s*:\s*({[^}]+})', + r'"name"\s*:\s*"([^"]+)"\s*,\s*"arguments"\s*:\s*({[^}]+})', + ] + for pattern in patterns: + match = re.search(pattern, response, re.DOTALL) + if match: + try: + return match.group(1), json.loads(match.group(2)) + except: + continue + + # Function call format + match = re.search(r'(\w+)\(([^)]+)\)', response) + if match: + tool_name = match.group(1) + params = {} + for m in re.finditer(r'(\w+)\s*=\s*"?([^",)]+)"?', match.group(2)): + params[m.group(1)] = m.group(2).strip().strip('"\'') + return tool_name, params + + return None, None + + +def call_model(prompt: str, model: str, backend: str, url: str, timeout: int = 60) -> Tuple[str, float]: + start = time.time() + try: + if backend == "ollama": + resp = requests.post(f"{url}/api/generate", json={ + "model": model, "prompt": prompt, "stream": False, + "options": {"num_predict": 256, "temperature": 0.1} + }, timeout=timeout) + resp.raise_for_status() + text = resp.json().get("response", "") + else: + text = f"ERROR: Unknown backend {backend}" + except Exception as e: + text = f"ERROR: {e}" + return text, (time.time() - start) * 1000 + + +def run_test(test: ToolCallTest, model: str, backend: str, url: str) -> TestResult: + response, latency = call_model(test.prompt, model, backend, url) + + if response.startswith("ERROR:"): + return TestResult(test.name, test.tool_type.value, False, latency, response, error=response) + + parsed_tool, parsed_params = parse_tool_call(response) + passed = False + quality = 0.0 + + if test.expected_tool is None: + passed = parsed_tool is None + quality = 1.0 if passed else 0.0 + elif parsed_tool: + tool_match = parsed_tool.lower() == test.expected_tool.lower() + if test.validation_fn and test.validation_fn in VALIDATORS: + params_match = VALIDATORS[test.validation_fn](parsed_params or {}) + else: + params_match = all( + k in (parsed_params or {}) and + (str(v).lower() in str(parsed_params.get(k, "")).lower() if isinstance(v, str) else parsed_params.get(k) == v) + for k, v in test.expected_params.items() + ) if test.expected_params else True + passed = tool_match and params_match + quality = (0.5 if tool_match else 0) + (0.5 if params_match else 0) + + return TestResult(test.name, test.tool_type.value, passed, latency, response[:500], + parsed_tool, parsed_params, quality_score=quality) + + +def run_all_tests(model: str, backend: str, url: str) -> BenchmarkResult: + results = BenchmarkResult(model, backend, datetime.now(timezone.utc).isoformat()) + print(f"Testing {model} ({backend})") + print("=" * 50) + + for test in TOOL_CALL_TESTS: + result = run_test(test, model, backend, url) + results.results.append(result) + status = "PASS" if result.passed else "FAIL" + print(f" {status} {test.name} ({result.latency_ms:.0f}ms, q={result.quality_score:.0%})") + + total = len(results.results) + passed = sum(1 for r in results.results if r.passed) + results.summary = { + "total": total, "passed": passed, "failed": total - passed, + "pass_rate": passed / total if total else 0, + "avg_latency_ms": sum(r.latency_ms for r in results.results) / total if total else 0, + "avg_quality": sum(r.quality_score for r in results.results) / total if total else 0, + } + return results + + +def generate_report(results: BenchmarkResult) -> str: + s = results.summary + lines = [ + "# Tool Calling Test Results - 1-Bit Models", "", + f"**Model:** {results.model} ", + f"**Backend:** {results.backend} ", + f"**Timestamp:** {results.timestamp}", "", + "## Summary", "", + f"- Pass Rate: {s['passed']}/{s['total']} ({s['pass_rate']:.0%})", + f"- Avg Latency: {s['avg_latency_ms']:.0f}ms", + f"- Avg Quality: {s['avg_quality']:.0%}", "", + "## Detailed Results", "", + ] + for r, t in zip(results.results, TOOL_CALL_TESTS): + lines.append(f"- {'PASS' if r.passed else 'FAIL'} {r.test_name} ({t.difficulty}, {r.latency_ms:.0f}ms)") + lines.extend(["", "## Conclusion", ""]) + if s['pass_rate'] >= 0.8: + lines.append(f"**VIABLE** - {s['pass_rate']:.0%} pass rate.") + elif s['pass_rate'] >= 0.5: + lines.append(f"**MARGINAL** - {s['pass_rate']:.0%} pass rate.") + else: + lines.append(f"**NOT VIABLE** - {s['pass_rate']:.0%} pass rate.") + lines.extend(["", "### Alternatives", "- Qwen3.5 3B Q4", "- Phi-3 Mini", "- Llama 3.2 3B"]) + return "\n".join(lines) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="bonsai-1bit") + parser.add_argument("--backend", default="ollama") + parser.add_argument("--url", default="http://localhost:11434") + parser.add_argument("--results", help="Save results JSON") + parser.add_argument("--report", help="Save report markdown") + args = parser.parse_args() + + results = run_all_tests(args.model, args.backend, args.url) + print(f"\nSUMMARY: {results.summary['passed']}/{results.summary['total']} passed") + + if args.results: + os.makedirs(os.path.dirname(args.results) or ".", exist_ok=True) + with open(args.results, "w") as f: + json.dump(asdict(results), f, indent=2) + + if args.report: + os.makedirs(os.path.dirname(args.report) or ".", exist_ok=True) + with open(args.report, "w") as f: + f.write(generate_report(results)) + + +if __name__ == "__main__": + main()