#!/usr/bin/env python3 """ Tool Calling Test Suite for 1-Bit Models (Issue #101) Tests whether Bonsai 1-bit models can handle tool calling at all. Evaluates: file read, terminal execution, web search, multi-step workflows, schema parsing. Usage: python3 benchmarks/test_tool_calling_1bit.py --model bonsai-1bit --backend ollama python3 benchmarks/test_tool_calling_1bit.py --results benchmarks/tool_calling_results.json """ import argparse import json import os import re import sys import time from dataclasses import dataclass, field, asdict from datetime import datetime, timezone from enum import Enum from typing import Any, Dict, List, Optional, Tuple import requests class ToolCallType(Enum): FILE_READ = "file_read" TERMINAL_EXEC = "terminal_exec" WEB_SEARCH = "web_search" MULTI_STEP = "multi_step" SCHEMA_PARSING = "schema_parsing" @dataclass class ToolCallTest: name: str tool_type: ToolCallType prompt: str expected_tool: Optional[str] expected_params: Dict[str, Any] validation_fn: Optional[str] = None difficulty: str = "easy" @dataclass class TestResult: test_name: str tool_type: str passed: bool latency_ms: float response_text: str parsed_tool: Optional[str] = None parsed_params: Optional[Dict[str, Any]] = None error: Optional[str] = None quality_score: float = 0.0 @dataclass class BenchmarkResult: model: str backend: str timestamp: str results: List[TestResult] = field(default_factory=list) summary: Dict[str, Any] = field(default_factory=dict) TOOL_CALL_TESTS = [ ToolCallTest("simple_file_read", ToolCallType.FILE_READ, "Use read_file to read README.md. Tools: read_file(path: str)", "read_file", {"path": "README.md"}, difficulty="easy"), ToolCallTest("absolute_path_read", ToolCallType.FILE_READ, "Use read_file to read /etc/hostname. Tools: read_file(path: str)", "read_file", {"path": "/etc/hostname"}, difficulty="easy"), ToolCallTest("simple_terminal", ToolCallType.TERMINAL_EXEC, "Use terminal to run: echo hello world. Tools: terminal(command: str)", "terminal", {"command": "echo hello world"}, difficulty="easy"), ToolCallTest("terminal_ls", ToolCallType.TERMINAL_EXEC, "Use terminal to list files. Tools: terminal(command: str)", "terminal", {}, validation_fn="validate_ls", difficulty="medium"), ToolCallTest("web_search", ToolCallType.WEB_SEARCH, "Use web_search for Python. Tools: web_search(query: str)", "web_search", {"query": "Python"}, difficulty="easy"), ToolCallTest("read_then_analyze", ToolCallType.MULTI_STEP, "First read README.md then analyze. Tools: read_file(path: str)", "read_file", {"path": "README.md"}, difficulty="medium"), ToolCallTest("nested_params", ToolCallType.SCHEMA_PARSING, "Use complex_tool(name=test, config={verbose:true}, tags=[a,b]). Tools: complex_tool(name: str, config: dict, tags: list)", "complex_tool", {"name": "test"}, difficulty="hard"), ToolCallTest("optional_params", ToolCallType.SCHEMA_PARSING, "Use search for ML with limit 5. Tools: search(query: str, limit: int=10)", "search", {"query": "ML", "limit": 5}, difficulty="medium"), ToolCallTest("sequential_calls", ToolCallType.MULTI_STEP, "First run pwd, then read README.md. Tools: terminal(command: str), read_file(path: str)", "terminal", {"command": "pwd"}, difficulty="hard"), ToolCallTest("no_tool_needed", ToolCallType.FILE_READ, "What is 2+2? Tools: read_file(path: str)", None, {}, difficulty="easy"), ] def validate_ls(params): cmd = params.get("command", "").strip() return cmd in ["ls", "ls -l", "ls -la", "ls -1", "dir"] or cmd.startswith("ls ") VALIDATORS = {"validate_ls": validate_ls} def parse_tool_call(response: str) -> Tuple[Optional[str], Optional[Dict]]: # JSON format patterns = [ r'"tool"\s*:\s*"([^"]+)"\s*,\s*"params"\s*:\s*({[^}]+})', r'"name"\s*:\s*"([^"]+)"\s*,\s*"arguments"\s*:\s*({[^}]+})', ] for pattern in patterns: match = re.search(pattern, response, re.DOTALL) if match: try: return match.group(1), json.loads(match.group(2)) except: continue # Function call format match = re.search(r'(\w+)\(([^)]+)\)', response) if match: tool_name = match.group(1) params = {} for m in re.finditer(r'(\w+)\s*=\s*"?([^",)]+)"?', match.group(2)): params[m.group(1)] = m.group(2).strip().strip('"\'') return tool_name, params return None, None def call_model(prompt: str, model: str, backend: str, url: str, timeout: int = 60) -> Tuple[str, float]: start = time.time() try: if backend == "ollama": resp = requests.post(f"{url}/api/generate", json={ "model": model, "prompt": prompt, "stream": False, "options": {"num_predict": 256, "temperature": 0.1} }, timeout=timeout) resp.raise_for_status() text = resp.json().get("response", "") else: text = f"ERROR: Unknown backend {backend}" except Exception as e: text = f"ERROR: {e}" return text, (time.time() - start) * 1000 def run_test(test: ToolCallTest, model: str, backend: str, url: str) -> TestResult: response, latency = call_model(test.prompt, model, backend, url) if response.startswith("ERROR:"): return TestResult(test.name, test.tool_type.value, False, latency, response, error=response) parsed_tool, parsed_params = parse_tool_call(response) passed = False quality = 0.0 if test.expected_tool is None: passed = parsed_tool is None quality = 1.0 if passed else 0.0 elif parsed_tool: tool_match = parsed_tool.lower() == test.expected_tool.lower() if test.validation_fn and test.validation_fn in VALIDATORS: params_match = VALIDATORS[test.validation_fn](parsed_params or {}) else: params_match = all( k in (parsed_params or {}) and (str(v).lower() in str(parsed_params.get(k, "")).lower() if isinstance(v, str) else parsed_params.get(k) == v) for k, v in test.expected_params.items() ) if test.expected_params else True passed = tool_match and params_match quality = (0.5 if tool_match else 0) + (0.5 if params_match else 0) return TestResult(test.name, test.tool_type.value, passed, latency, response[:500], parsed_tool, parsed_params, quality_score=quality) def run_all_tests(model: str, backend: str, url: str) -> BenchmarkResult: results = BenchmarkResult(model, backend, datetime.now(timezone.utc).isoformat()) print(f"Testing {model} ({backend})") print("=" * 50) for test in TOOL_CALL_TESTS: result = run_test(test, model, backend, url) results.results.append(result) status = "PASS" if result.passed else "FAIL" print(f" {status} {test.name} ({result.latency_ms:.0f}ms, q={result.quality_score:.0%})") total = len(results.results) passed = sum(1 for r in results.results if r.passed) results.summary = { "total": total, "passed": passed, "failed": total - passed, "pass_rate": passed / total if total else 0, "avg_latency_ms": sum(r.latency_ms for r in results.results) / total if total else 0, "avg_quality": sum(r.quality_score for r in results.results) / total if total else 0, } return results def generate_report(results: BenchmarkResult) -> str: s = results.summary lines = [ "# Tool Calling Test Results - 1-Bit Models", "", f"**Model:** {results.model} ", f"**Backend:** {results.backend} ", f"**Timestamp:** {results.timestamp}", "", "## Summary", "", f"- Pass Rate: {s['passed']}/{s['total']} ({s['pass_rate']:.0%})", f"- Avg Latency: {s['avg_latency_ms']:.0f}ms", f"- Avg Quality: {s['avg_quality']:.0%}", "", "## Detailed Results", "", ] for r, t in zip(results.results, TOOL_CALL_TESTS): lines.append(f"- {'PASS' if r.passed else 'FAIL'} {r.test_name} ({t.difficulty}, {r.latency_ms:.0f}ms)") lines.extend(["", "## Conclusion", ""]) if s['pass_rate'] >= 0.8: lines.append(f"**VIABLE** - {s['pass_rate']:.0%} pass rate.") elif s['pass_rate'] >= 0.5: lines.append(f"**MARGINAL** - {s['pass_rate']:.0%} pass rate.") else: lines.append(f"**NOT VIABLE** - {s['pass_rate']:.0%} pass rate.") lines.extend(["", "### Alternatives", "- Qwen3.5 3B Q4", "- Phi-3 Mini", "- Llama 3.2 3B"]) return "\n".join(lines) def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", default="bonsai-1bit") parser.add_argument("--backend", default="ollama") parser.add_argument("--url", default="http://localhost:11434") parser.add_argument("--results", help="Save results JSON") parser.add_argument("--report", help="Save report markdown") args = parser.parse_args() results = run_all_tests(args.model, args.backend, args.url) print(f"\nSUMMARY: {results.summary['passed']}/{results.summary['total']} passed") if args.results: os.makedirs(os.path.dirname(args.results) or ".", exist_ok=True) with open(args.results, "w") as f: json.dump(asdict(results), f, indent=2) if args.report: os.makedirs(os.path.dirname(args.report) or ".", exist_ok=True) with open(args.report, "w") as f: f.write(generate_report(results)) if __name__ == "__main__": main()