#!/usr/bin/env python3 """ 1-Bit Model Tool Calling Test Suite (Issue #101). Tests whether quantized/1-bit models can handle structured tool calling. Designed to be run against any OpenAI-compatible endpoint (llama-server, Ollama). The core question: does 1-bit quantization destroy the precise JSON output required for tool calling? This suite measures it empirically. Usage: # Against local llama-server python3 benchmarks/test_bonsai_tool_calling.py \ --url http://localhost:8081/v1/chat/completions \ --model bonsai-1b # Against Ollama python3 benchmarks/test_bonsai_tool_calling.py \ --url http://localhost:11434/api/chat \ --model bonsai:latest \ --backend ollama # Dry run (validate test cases without model) python3 benchmarks/test_bonsai_tool_calling.py --dry-run """ import argparse import json import os import re import sys import time from dataclasses import dataclass, field, asdict from enum import Enum from typing import List, Dict, Optional, Tuple import requests class ToolCallCategory(Enum): """Categories of tool call complexity.""" SIMPLE_READ = "simple_read" TERMINAL_CMD = "terminal_cmd" WEB_SEARCH = "web_search" MULTI_STEP = "multi_step" NESTED_PARAMS = "nested_params" ARRAY_PARAMS = "array_params" OPTIONAL_PARAMS = "optional_params" MULTI_TOOL_SELECT = "multi_tool_select" class TestResult(Enum): PASS = "PASS" FAIL = "FAIL" PARTIAL = "PARTIAL" TIMEOUT = "TIMEOUT" ERROR = "ERROR" SKIP = "SKIP" # ── Tool schemas (hermes-compatible) ───────────────────────── TOOL_SCHEMAS = [ { "type": "function", "function": { "name": "read_file", "description": "Read a text file with line numbers.", "parameters": { "type": "object", "properties": { "path": {"type": "string", "description": "File path to read"}, "offset": {"type": "integer", "description": "Start line (1-indexed)", "default": 1}, "limit": {"type": "integer", "description": "Max lines to read", "default": 500}, }, "required": ["path"], }, }, }, { "type": "function", "function": { "name": "terminal", "description": "Execute a shell command.", "parameters": { "type": "object", "properties": { "command": {"type": "string", "description": "Shell command to execute"}, "timeout": {"type": "integer", "description": "Timeout in seconds", "default": 30}, "workdir": {"type": "string", "description": "Working directory"}, }, "required": ["command"], }, }, }, { "type": "function", "function": { "name": "web_search", "description": "Search the web for information.", "parameters": { "type": "object", "properties": { "query": {"type": "string", "description": "Search query"}, "max_results": {"type": "integer", "description": "Max results to return", "default": 5}, }, "required": ["query"], }, }, }, { "type": "function", "function": { "name": "write_file", "description": "Write content to a file, creating directories as needed.", "parameters": { "type": "object", "properties": { "path": {"type": "string", "description": "File path to write"}, "content": {"type": "string", "description": "Content to write"}, }, "required": ["path", "content"], }, }, }, { "type": "function", "function": { "name": "patch", "description": "Apply a targeted find-and-replace edit to a file.", "parameters": { "type": "object", "properties": { "path": {"type": "string", "description": "File path to edit"}, "old_string": {"type": "string", "description": "Text to find"}, "new_string": {"type": "string", "description": "Replacement text"}, "replace_all": {"type": "boolean", "description": "Replace all occurrences", "default": False}, }, "required": ["path", "old_string", "new_string"], }, }, }, ] # ── Test case definitions ──────────────────────────────────── @dataclass class ToolCallTestCase: """A single tool calling test case.""" id: str category: ToolCallCategory prompt: str tools: List[dict] expected_tool: str expected_params: Dict[str, any] param_validators: Dict[str, callable] = field(default_factory=dict) description: str = "" difficulty: int = 1 # 1-5, higher = harder TEST_CASES = [ # ── Level 1: Simple reads ────────────────────────────── ToolCallTestCase( id="simple-read-1", category=ToolCallCategory.SIMPLE_READ, prompt="Read the file at /tmp/test.txt", tools=[TOOL_SCHEMAS[0]], expected_tool="read_file", expected_params={"path": "/tmp/test.txt"}, description="Exact path, single required param", difficulty=1, ), ToolCallTestCase( id="simple-read-with-limit", category=ToolCallCategory.SIMPLE_READ, prompt="Read the first 10 lines of /var/log/system.log", tools=[TOOL_SCHEMAS[0]], expected_tool="read_file", expected_params={"path": "/var/log/system.log"}, param_validators={"limit": lambda v: isinstance(v, int) and v <= 20}, description="Required + optional param", difficulty=2, ), # ── Level 2: Terminal commands ───────────────────────── ToolCallTestCase( id="terminal-simple", category=ToolCallCategory.TERMINAL_CMD, prompt="List all files in the current directory", tools=[TOOL_SCHEMAS[1]], expected_tool="terminal", expected_params={}, param_validators={ "command": lambda v: isinstance(v, str) and any( cmd in v for cmd in ["ls", "dir", "find"] ) }, description="Generate appropriate shell command", difficulty=2, ), ToolCallTestCase( id="terminal-pipe", category=ToolCallCategory.TERMINAL_CMD, prompt="Count how many Python files are in /tmp recursively", tools=[TOOL_SCHEMAS[1]], expected_tool="terminal", expected_params={}, param_validators={ "command": lambda v: isinstance(v, str) and ( "find" in v or "ls" in v or "python" in v or ".py" in v ) }, description="Needs piped or recursive command", difficulty=3, ), # ── Level 3: Web search ──────────────────────────────── ToolCallTestCase( id="web-search-simple", category=ToolCallCategory.WEB_SEARCH, prompt="Search for the current price of Bitcoin", tools=[TOOL_SCHEMAS[2]], expected_tool="web_search", expected_params={"query": "Bitcoin price"}, param_validators={ "query": lambda v: isinstance(v, str) and len(v) > 3 and "bitcoin" in v.lower() }, description="Extract search query from natural language", difficulty=2, ), # ── Level 4: Multi-tool selection ────────────────────── ToolCallTestCase( id="multi-tool-select-read", category=ToolCallCategory.MULTI_TOOL_SELECT, prompt="Read the file at /etc/hostname", tools=TOOL_SCHEMAS[:3], # read_file, terminal, web_search expected_tool="read_file", expected_params={"path": "/etc/hostname"}, description="Choose correct tool from 3 options", difficulty=3, ), ToolCallTestCase( id="multi-tool-select-terminal", category=ToolCallCategory.MULTI_TOOL_SELECT, prompt="Check how much disk space is available", tools=TOOL_SCHEMAS[:3], expected_tool="terminal", expected_params={}, param_validators={ "command": lambda v: isinstance(v, str) and any( cmd in v for cmd in ["df", "du", "disk"] ) }, description="Choose terminal over read_file for system info", difficulty=3, ), ToolCallTestCase( id="multi-tool-select-search", category=ToolCallCategory.MULTI_TOOL_SELECT, prompt="What is the weather in Tokyo right now?", tools=TOOL_SCHEMAS[:3], expected_tool="web_search", expected_params={}, param_validators={ "query": lambda v: isinstance(v, str) and "weather" in v.lower() and "tokyo" in v.lower() }, description="Choose web_search for real-time info", difficulty=3, ), # ── Level 5: Nested/complex params ───────────────────── ToolCallTestCase( id="write-file-with-content", category=ToolCallCategory.NESTED_PARAMS, prompt="Create a file at /tmp/hello.txt with the content 'Hello, World!'", tools=[TOOL_SCHEMAS[3]], expected_tool="write_file", expected_params={"path": "/tmp/hello.txt"}, param_validators={ "content": lambda v: isinstance(v, str) and "hello" in v.lower() }, description="Two required string params", difficulty=3, ), ToolCallTestCase( id="patch-edit", category=ToolCallCategory.NESTED_PARAMS, prompt="In the file /tmp/config.yaml, replace 'debug: false' with 'debug: true'", tools=[TOOL_SCHEMAS[4]], expected_tool="patch", expected_params={"path": "/tmp/config.yaml"}, param_validators={ "old_string": lambda v: isinstance(v, str) and "debug: false" in v, "new_string": lambda v: isinstance(v, str) and "debug: true" in v, }, description="Three required params, find-and-replace", difficulty=4, ), # ── Level 6: Multi-step reasoning ────────────────────── ToolCallTestCase( id="multi-step-read-then-write", category=ToolCallCategory.MULTI_STEP, prompt="Read /tmp/source.txt and write its contents to /tmp/backup.txt", tools=[TOOL_SCHEMAS[0], TOOL_SCHEMAS[3]], # read_file + write_file expected_tool="read_file", # First step should be reading expected_params={"path": "/tmp/source.txt"}, description="Requires planning: read first, then write", difficulty=5, ), ] # ── Test runner ────────────────────────────────────────────── @dataclass class TestRunResult: """Result of running a single test case.""" test_id: str category: str difficulty: int result: str # TestResult value expected_tool: str actual_tool: str expected_params: dict actual_params: dict param_scores: Dict[str, bool] = field(default_factory=dict) response_text: str = "" latency_s: float = 0.0 tokens_per_sec: float = 0.0 error: str = "" raw_response: dict = field(default_factory=dict) def call_openai_compatible( messages: list, tools: list, url: str, model: str, timeout: int = 120, ) -> dict: """Call an OpenAI-compatible chat completions endpoint.""" payload = { "model": model, "messages": messages, "tools": tools, "tool_choice": "auto", "max_tokens": 512, "temperature": 0.0, } resp = requests.post(url, json=payload, timeout=timeout) resp.raise_for_status() return resp.json() def call_ollama( messages: list, tools: list, url: str, model: str, timeout: int = 120, ) -> dict: """Call Ollama /api/chat endpoint.""" # Convert OpenAI tool format to Ollama format ollama_tools = [] for t in tools: fn = t["function"] ollama_tools.append({ "type": "function", "function": { "name": fn["name"], "description": fn["description"], "parameters": fn["parameters"], }, }) resp = requests.post(url, json={ "model": model, "messages": messages, "tools": ollama_tools, "stream": False, }, timeout=timeout) resp.raise_for_status() data = resp.json() # Normalize to OpenAI format result = {"choices": [{"message": {}}]} msg = data.get("message", {}) result["choices"][0]["message"]["content"] = msg.get("content", "") if msg.get("tool_calls"): result["choices"][0]["message"]["tool_calls"] = msg["tool_calls"] return result def validate_tool_call( response: dict, test: ToolCallTestCase, ) -> Tuple[TestResult, str, dict, Dict[str, bool]]: """ Validate a model response against a test case. Returns: (result, actual_tool, actual_params, param_scores) """ try: choice = response["choices"][0] msg = choice["message"] except (KeyError, IndexError): return TestResult.FAIL, "", {}, {} # Check if model called a tool tool_calls = msg.get("tool_calls", []) if not tool_calls: # Model responded with text instead — check if it at least mentioned the tool content = msg.get("content", "") if test.expected_tool in content: return TestResult.PARTIAL, "text_only", {"content": content}, {} return TestResult.FAIL, "none", {}, {} tc = tool_calls[0] actual_tool = tc.get("function", {}).get("name", "") # Parse arguments try: args_str = tc.get("function", {}).get("arguments", "{}") if isinstance(args_str, str): actual_params = json.loads(args_str) else: actual_params = args_str except json.JSONDecodeError: return TestResult.FAIL, actual_tool, {}, {"json_parse": False} # Check tool name if actual_tool != test.expected_tool: return TestResult.FAIL, actual_tool, actual_params, { "tool_match": False } # Validate expected params param_scores = {"tool_match": True} all_pass = True for key, expected_val in test.expected_params.items(): if key in actual_params: if actual_params[key] == expected_val: param_scores[f"param_{key}"] = True else: param_scores[f"param_{key}"] = False all_pass = False else: param_scores[f"param_{key}"] = False all_pass = False # Run custom validators for key, validator in test.param_validators.items(): if key in actual_params: try: passed = validator(actual_params[key]) param_scores[f"validator_{key}"] = bool(passed) if not passed: all_pass = False except Exception: param_scores[f"validator_{key}"] = False all_pass = False else: param_scores[f"validator_{key}"] = False all_pass = False if all_pass and len(test.expected_params) > 0: return TestResult.PASS, actual_tool, actual_params, param_scores elif all_pass: # No expected params to check — validators passed return TestResult.PASS, actual_tool, actual_params, param_scores else: return TestResult.PARTIAL, actual_tool, actual_params, param_scores def run_test( test: ToolCallTestCase, url: str, model: str, backend: str = "openai", timeout: int = 120, ) -> TestRunResult: """Run a single test case against the model.""" messages = [{"role": "user", "content": test.prompt}] start = time.time() try: if backend == "ollama": response = call_ollama(messages, test.tools, url, model, timeout) else: response = call_openai_compatible(messages, test.tools, url, model, timeout) elapsed = time.time() - start result, actual_tool, actual_params, param_scores = validate_tool_call(response, test) # Extract text response try: text = response["choices"][0]["message"].get("content", "") except (KeyError, IndexError): text = "" return TestRunResult( test_id=test.id, category=test.category.value, difficulty=test.difficulty, result=result.value, expected_tool=test.expected_tool, actual_tool=actual_tool, expected_params=test.expected_params, actual_params=actual_params, param_scores=param_scores, response_text=text[:200], latency_s=round(elapsed, 3), raw_response=response, ) except requests.exceptions.Timeout: return TestRunResult( test_id=test.id, category=test.category.value, difficulty=test.difficulty, result=TestResult.TIMEOUT.value, expected_tool=test.expected_tool, actual_tool="", expected_params=test.expected_params, actual_params={}, error=f"Timeout after {timeout}s", ) except Exception as e: return TestRunResult( test_id=test.id, category=test.category.value, difficulty=test.difficulty, result=TestResult.ERROR.value, expected_tool=test.expected_tool, actual_tool="", expected_params=test.expected_params, actual_params={}, error=str(e)[:200], ) def run_dry_run() -> List[TestRunResult]: """Validate test cases without a model.""" results = [] for test in TEST_CASES: results.append(TestRunResult( test_id=test.id, category=test.category.value, difficulty=test.difficulty, result=TestResult.SKIP.value, expected_tool=test.expected_tool, actual_tool="(dry run)", expected_params=test.expected_params, actual_params={}, )) return results def generate_report(results: List[TestRunResult], model: str) -> str: """Generate markdown report.""" lines = [ f"# 1-Bit Model Tool Calling Test Results", f"", f"**Model:** {model}", f"**Date:** {time.strftime('%Y-%m-%d %H:%M:%S')}", f"**Test cases:** {len(results)}", f"", ] # Summary table by_result = {} for r in results: by_result[r.result] = by_result.get(r.result, 0) + 1 lines.append("## Summary") lines.append("") lines.append("| Result | Count |") lines.append("|--------|-------|") for result, count in sorted(by_result.items()): lines.append(f"| {result} | {count} |") lines.append("") pass_count = by_result.get("PASS", 0) total = len(results) pass_rate = (pass_count / total * 100) if total > 0 else 0 lines.append(f"**Pass rate: {pass_rate:.0f}%** ({pass_count}/{total})") lines.append("") # By difficulty lines.append("## Results by Difficulty") lines.append("") lines.append("| Difficulty | PASS | PARTIAL | FAIL | Other |") lines.append("|-----------|------|---------|------|-------|") for diff in range(1, 6): diff_results = [r for r in results if r.difficulty == diff] if not diff_results: continue p = sum(1 for r in diff_results if r.result == "PASS") pa = sum(1 for r in diff_results if r.result == "PARTIAL") f = sum(1 for r in diff_results if r.result in ("FAIL", "ERROR", "TIMEOUT")) o = len(diff_results) - p - pa - f lines.append(f"| {diff}/5 | {p} | {pa} | {f} | {o} |") lines.append("") # Detailed results lines.append("## Detailed Results") lines.append("") for r in results: icon = {"PASS": "✅", "PARTIAL": "⚠️", "FAIL": "❌", "ERROR": "💥", "TIMEOUT": "⏱"}.get(r.result, "❓") lines.append(f"### {icon} {r.test_id} (difficulty {r.difficulty}/5)") lines.append(f"- **Category:** {r.category}") lines.append(f"- **Expected tool:** `{r.expected_tool}`") lines.append(f"- **Actual tool:** `{r.actual_tool}`") if r.latency_s > 0: lines.append(f"- **Latency:** {r.latency_s}s") if r.param_scores: lines.append(f"- **Param scores:** {json.dumps(r.param_scores)}") if r.error: lines.append(f"- **Error:** {r.error}") lines.append("") # Viability verdict lines.append("## Viability Verdict") lines.append("") if pass_rate >= 80: lines.append("**VERDICT: VIABLE** — 1-bit model can handle tool calling for production use.") elif pass_rate >= 50: lines.append("**VERDICT: CONDITIONALLY VIABLE** — Works for simple tools, struggles with complex params. Consider for edge deployment with guardrails.") elif pass_rate >= 20: lines.append("**VERDICT: MARGINAL** — Can select correct tool sometimes, but parameter accuracy is too low for production. Investigate alternative quantization (2-bit, 3-bit).") else: lines.append("**VERDICT: NOT VIABLE** — 1-bit quantization destroys tool calling capability. Recommend minimum 3-bit quantization for tool-using models.") lines.append("") return "\n".join(lines) def main(): parser = argparse.ArgumentParser(description="Test tool calling on 1-bit models") parser.add_argument("--url", default="http://localhost:8081/v1/chat/completions", help="Model API endpoint") parser.add_argument("--model", default="bonsai-1b", help="Model name") parser.add_argument("--backend", default="openai", choices=["openai", "ollama"], help="API backend type") parser.add_argument("--timeout", type=int, default=120, help="Request timeout in seconds") parser.add_argument("--dry-run", action="store_true", help="Validate tests without model") parser.add_argument("--output", default="benchmarks/bonsai-tool-calling-results.json", help="Output file for results") parser.add_argument("--report", default="benchmarks/bonsai-tool-calling.md", help="Output file for markdown report") parser.add_argument("--test-id", help="Run a single test by ID") args = parser.parse_args() print("=" * 60) print(" 1-Bit Model Tool Calling Test Suite") print("=" * 60) if args.dry_run: print("\n[DRY RUN] Validating test cases...") results = run_dry_run() print(f" {len(results)} test cases validated") for r in results: print(f" ✓ {r.test_id} — expects {r.expected_tool} (difficulty {r.difficulty}/5)") else: print(f"\nModel: {args.model}") print(f"Endpoint: {args.url}") print(f"Backend: {args.backend}") print() tests = TEST_CASES if args.test_id: tests = [t for t in tests if t.id == args.test_id] if not tests: print(f"Test '{args.test_id}' not found") sys.exit(1) results = [] for i, test in enumerate(tests): print(f" [{i+1}/{len(tests)}] {test.id} (difficulty {test.difficulty}/5)... ", end="", flush=True) result = run_test(test, args.url, args.model, args.backend, args.timeout) results.append(result) icon = {"PASS": "✅", "PARTIAL": "⚠️", "FAIL": "❌", "ERROR": "💥", "TIMEOUT": "⏱"}.get(result.result, "❓") print(f"{icon} {result.result} ({result.latency_s}s)") # Save results os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) with open(args.output, "w") as f: json.dump([asdict(r) for r in results], f, indent=2) print(f"\nResults saved to {args.output}") # Generate report report = generate_report(results, args.model) with open(args.report, "w") as f: f.write(report) print(f"Report saved to {args.report}") # Print summary pass_count = sum(1 for r in results if r.result == "PASS") total = len(results) print(f"\n{'='*60}") print(f" Results: {pass_count}/{total} passed ({pass_count/total*100:.0f}%)") if __name__ == "__main__": main()