From 442c4dbcc708e8556cfdc4de4cf52bc8d4ebe45f Mon Sep 17 00:00:00 2001 From: Alexander Whitestone Date: Thu, 16 Apr 2026 01:58:46 +0000 Subject: [PATCH] =?UTF-8?q?test:=20tool=20calling=20on=201-bit=20models=20?= =?UTF-8?q?=E2=80=94=20test=20suite=20+=20harness=20(closes=20#101)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- benchmarks/test_bonsai_tool_calling.py | 709 +++++++++++++++++++++++++ 1 file changed, 709 insertions(+) create mode 100644 benchmarks/test_bonsai_tool_calling.py diff --git a/benchmarks/test_bonsai_tool_calling.py b/benchmarks/test_bonsai_tool_calling.py new file mode 100644 index 00000000..064de138 --- /dev/null +++ b/benchmarks/test_bonsai_tool_calling.py @@ -0,0 +1,709 @@ +#!/usr/bin/env python3 +""" +1-Bit Model Tool Calling Test Suite (Issue #101). + +Tests whether quantized/1-bit models can handle structured tool calling. +Designed to be run against any OpenAI-compatible endpoint (llama-server, Ollama). + +The core question: does 1-bit quantization destroy the precise JSON output +required for tool calling? This suite measures it empirically. + +Usage: + # Against local llama-server + python3 benchmarks/test_bonsai_tool_calling.py \ + --url http://localhost:8081/v1/chat/completions \ + --model bonsai-1b + + # Against Ollama + python3 benchmarks/test_bonsai_tool_calling.py \ + --url http://localhost:11434/api/chat \ + --model bonsai:latest \ + --backend ollama + + # Dry run (validate test cases without model) + python3 benchmarks/test_bonsai_tool_calling.py --dry-run +""" +import argparse +import json +import os +import re +import sys +import time +from dataclasses import dataclass, field, asdict +from enum import Enum +from typing import List, Dict, Optional, Tuple + +import requests + + +class ToolCallCategory(Enum): + """Categories of tool call complexity.""" + SIMPLE_READ = "simple_read" + TERMINAL_CMD = "terminal_cmd" + WEB_SEARCH = "web_search" + MULTI_STEP = "multi_step" + NESTED_PARAMS = "nested_params" + ARRAY_PARAMS = "array_params" + OPTIONAL_PARAMS = "optional_params" + MULTI_TOOL_SELECT = "multi_tool_select" + + +class TestResult(Enum): + PASS = "PASS" + FAIL = "FAIL" + PARTIAL = "PARTIAL" + TIMEOUT = "TIMEOUT" + ERROR = "ERROR" + SKIP = "SKIP" + + +# ── Tool schemas (hermes-compatible) ───────────────────────── + +TOOL_SCHEMAS = [ + { + "type": "function", + "function": { + "name": "read_file", + "description": "Read a text file with line numbers.", + "parameters": { + "type": "object", + "properties": { + "path": {"type": "string", "description": "File path to read"}, + "offset": {"type": "integer", "description": "Start line (1-indexed)", "default": 1}, + "limit": {"type": "integer", "description": "Max lines to read", "default": 500}, + }, + "required": ["path"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "terminal", + "description": "Execute a shell command.", + "parameters": { + "type": "object", + "properties": { + "command": {"type": "string", "description": "Shell command to execute"}, + "timeout": {"type": "integer", "description": "Timeout in seconds", "default": 30}, + "workdir": {"type": "string", "description": "Working directory"}, + }, + "required": ["command"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "web_search", + "description": "Search the web for information.", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"}, + "max_results": {"type": "integer", "description": "Max results to return", "default": 5}, + }, + "required": ["query"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "write_file", + "description": "Write content to a file, creating directories as needed.", + "parameters": { + "type": "object", + "properties": { + "path": {"type": "string", "description": "File path to write"}, + "content": {"type": "string", "description": "Content to write"}, + }, + "required": ["path", "content"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "patch", + "description": "Apply a targeted find-and-replace edit to a file.", + "parameters": { + "type": "object", + "properties": { + "path": {"type": "string", "description": "File path to edit"}, + "old_string": {"type": "string", "description": "Text to find"}, + "new_string": {"type": "string", "description": "Replacement text"}, + "replace_all": {"type": "boolean", "description": "Replace all occurrences", "default": False}, + }, + "required": ["path", "old_string", "new_string"], + }, + }, + }, +] + + +# ── Test case definitions ──────────────────────────────────── + +@dataclass +class ToolCallTestCase: + """A single tool calling test case.""" + id: str + category: ToolCallCategory + prompt: str + tools: List[dict] + expected_tool: str + expected_params: Dict[str, any] + param_validators: Dict[str, callable] = field(default_factory=dict) + description: str = "" + difficulty: int = 1 # 1-5, higher = harder + + +TEST_CASES = [ + # ── Level 1: Simple reads ────────────────────────────── + ToolCallTestCase( + id="simple-read-1", + category=ToolCallCategory.SIMPLE_READ, + prompt="Read the file at /tmp/test.txt", + tools=[TOOL_SCHEMAS[0]], + expected_tool="read_file", + expected_params={"path": "/tmp/test.txt"}, + description="Exact path, single required param", + difficulty=1, + ), + ToolCallTestCase( + id="simple-read-with-limit", + category=ToolCallCategory.SIMPLE_READ, + prompt="Read the first 10 lines of /var/log/system.log", + tools=[TOOL_SCHEMAS[0]], + expected_tool="read_file", + expected_params={"path": "/var/log/system.log"}, + param_validators={"limit": lambda v: isinstance(v, int) and v <= 20}, + description="Required + optional param", + difficulty=2, + ), + + # ── Level 2: Terminal commands ───────────────────────── + ToolCallTestCase( + id="terminal-simple", + category=ToolCallCategory.TERMINAL_CMD, + prompt="List all files in the current directory", + tools=[TOOL_SCHEMAS[1]], + expected_tool="terminal", + expected_params={}, + param_validators={ + "command": lambda v: isinstance(v, str) and any( + cmd in v for cmd in ["ls", "dir", "find"] + ) + }, + description="Generate appropriate shell command", + difficulty=2, + ), + ToolCallTestCase( + id="terminal-pipe", + category=ToolCallCategory.TERMINAL_CMD, + prompt="Count how many Python files are in /tmp recursively", + tools=[TOOL_SCHEMAS[1]], + expected_tool="terminal", + expected_params={}, + param_validators={ + "command": lambda v: isinstance(v, str) and ( + "find" in v or "ls" in v or "python" in v or ".py" in v + ) + }, + description="Needs piped or recursive command", + difficulty=3, + ), + + # ── Level 3: Web search ──────────────────────────────── + ToolCallTestCase( + id="web-search-simple", + category=ToolCallCategory.WEB_SEARCH, + prompt="Search for the current price of Bitcoin", + tools=[TOOL_SCHEMAS[2]], + expected_tool="web_search", + expected_params={"query": "Bitcoin price"}, + param_validators={ + "query": lambda v: isinstance(v, str) and len(v) > 3 and "bitcoin" in v.lower() + }, + description="Extract search query from natural language", + difficulty=2, + ), + + # ── Level 4: Multi-tool selection ────────────────────── + ToolCallTestCase( + id="multi-tool-select-read", + category=ToolCallCategory.MULTI_TOOL_SELECT, + prompt="Read the file at /etc/hostname", + tools=TOOL_SCHEMAS[:3], # read_file, terminal, web_search + expected_tool="read_file", + expected_params={"path": "/etc/hostname"}, + description="Choose correct tool from 3 options", + difficulty=3, + ), + ToolCallTestCase( + id="multi-tool-select-terminal", + category=ToolCallCategory.MULTI_TOOL_SELECT, + prompt="Check how much disk space is available", + tools=TOOL_SCHEMAS[:3], + expected_tool="terminal", + expected_params={}, + param_validators={ + "command": lambda v: isinstance(v, str) and any( + cmd in v for cmd in ["df", "du", "disk"] + ) + }, + description="Choose terminal over read_file for system info", + difficulty=3, + ), + ToolCallTestCase( + id="multi-tool-select-search", + category=ToolCallCategory.MULTI_TOOL_SELECT, + prompt="What is the weather in Tokyo right now?", + tools=TOOL_SCHEMAS[:3], + expected_tool="web_search", + expected_params={}, + param_validators={ + "query": lambda v: isinstance(v, str) and "weather" in v.lower() and "tokyo" in v.lower() + }, + description="Choose web_search for real-time info", + difficulty=3, + ), + + # ── Level 5: Nested/complex params ───────────────────── + ToolCallTestCase( + id="write-file-with-content", + category=ToolCallCategory.NESTED_PARAMS, + prompt="Create a file at /tmp/hello.txt with the content 'Hello, World!'", + tools=[TOOL_SCHEMAS[3]], + expected_tool="write_file", + expected_params={"path": "/tmp/hello.txt"}, + param_validators={ + "content": lambda v: isinstance(v, str) and "hello" in v.lower() + }, + description="Two required string params", + difficulty=3, + ), + ToolCallTestCase( + id="patch-edit", + category=ToolCallCategory.NESTED_PARAMS, + prompt="In the file /tmp/config.yaml, replace 'debug: false' with 'debug: true'", + tools=[TOOL_SCHEMAS[4]], + expected_tool="patch", + expected_params={"path": "/tmp/config.yaml"}, + param_validators={ + "old_string": lambda v: isinstance(v, str) and "debug: false" in v, + "new_string": lambda v: isinstance(v, str) and "debug: true" in v, + }, + description="Three required params, find-and-replace", + difficulty=4, + ), + + # ── Level 6: Multi-step reasoning ────────────────────── + ToolCallTestCase( + id="multi-step-read-then-write", + category=ToolCallCategory.MULTI_STEP, + prompt="Read /tmp/source.txt and write its contents to /tmp/backup.txt", + tools=[TOOL_SCHEMAS[0], TOOL_SCHEMAS[3]], # read_file + write_file + expected_tool="read_file", # First step should be reading + expected_params={"path": "/tmp/source.txt"}, + description="Requires planning: read first, then write", + difficulty=5, + ), +] + + +# ── Test runner ────────────────────────────────────────────── + +@dataclass +class TestRunResult: + """Result of running a single test case.""" + test_id: str + category: str + difficulty: int + result: str # TestResult value + expected_tool: str + actual_tool: str + expected_params: dict + actual_params: dict + param_scores: Dict[str, bool] = field(default_factory=dict) + response_text: str = "" + latency_s: float = 0.0 + tokens_per_sec: float = 0.0 + error: str = "" + raw_response: dict = field(default_factory=dict) + + +def call_openai_compatible( + messages: list, + tools: list, + url: str, + model: str, + timeout: int = 120, +) -> dict: + """Call an OpenAI-compatible chat completions endpoint.""" + payload = { + "model": model, + "messages": messages, + "tools": tools, + "tool_choice": "auto", + "max_tokens": 512, + "temperature": 0.0, + } + resp = requests.post(url, json=payload, timeout=timeout) + resp.raise_for_status() + return resp.json() + + +def call_ollama( + messages: list, + tools: list, + url: str, + model: str, + timeout: int = 120, +) -> dict: + """Call Ollama /api/chat endpoint.""" + # Convert OpenAI tool format to Ollama format + ollama_tools = [] + for t in tools: + fn = t["function"] + ollama_tools.append({ + "type": "function", + "function": { + "name": fn["name"], + "description": fn["description"], + "parameters": fn["parameters"], + }, + }) + + resp = requests.post(url, json={ + "model": model, + "messages": messages, + "tools": ollama_tools, + "stream": False, + }, timeout=timeout) + resp.raise_for_status() + data = resp.json() + + # Normalize to OpenAI format + result = {"choices": [{"message": {}}]} + msg = data.get("message", {}) + result["choices"][0]["message"]["content"] = msg.get("content", "") + if msg.get("tool_calls"): + result["choices"][0]["message"]["tool_calls"] = msg["tool_calls"] + return result + + +def validate_tool_call( + response: dict, + test: ToolCallTestCase, +) -> Tuple[TestResult, str, dict, Dict[str, bool]]: + """ + Validate a model response against a test case. + + Returns: (result, actual_tool, actual_params, param_scores) + """ + try: + choice = response["choices"][0] + msg = choice["message"] + except (KeyError, IndexError): + return TestResult.FAIL, "", {}, {} + + # Check if model called a tool + tool_calls = msg.get("tool_calls", []) + if not tool_calls: + # Model responded with text instead — check if it at least mentioned the tool + content = msg.get("content", "") + if test.expected_tool in content: + return TestResult.PARTIAL, "text_only", {"content": content}, {} + return TestResult.FAIL, "none", {}, {} + + tc = tool_calls[0] + actual_tool = tc.get("function", {}).get("name", "") + + # Parse arguments + try: + args_str = tc.get("function", {}).get("arguments", "{}") + if isinstance(args_str, str): + actual_params = json.loads(args_str) + else: + actual_params = args_str + except json.JSONDecodeError: + return TestResult.FAIL, actual_tool, {}, {"json_parse": False} + + # Check tool name + if actual_tool != test.expected_tool: + return TestResult.FAIL, actual_tool, actual_params, { + "tool_match": False + } + + # Validate expected params + param_scores = {"tool_match": True} + all_pass = True + + for key, expected_val in test.expected_params.items(): + if key in actual_params: + if actual_params[key] == expected_val: + param_scores[f"param_{key}"] = True + else: + param_scores[f"param_{key}"] = False + all_pass = False + else: + param_scores[f"param_{key}"] = False + all_pass = False + + # Run custom validators + for key, validator in test.param_validators.items(): + if key in actual_params: + try: + passed = validator(actual_params[key]) + param_scores[f"validator_{key}"] = bool(passed) + if not passed: + all_pass = False + except Exception: + param_scores[f"validator_{key}"] = False + all_pass = False + else: + param_scores[f"validator_{key}"] = False + all_pass = False + + if all_pass and len(test.expected_params) > 0: + return TestResult.PASS, actual_tool, actual_params, param_scores + elif all_pass: + # No expected params to check — validators passed + return TestResult.PASS, actual_tool, actual_params, param_scores + else: + return TestResult.PARTIAL, actual_tool, actual_params, param_scores + + +def run_test( + test: ToolCallTestCase, + url: str, + model: str, + backend: str = "openai", + timeout: int = 120, +) -> TestRunResult: + """Run a single test case against the model.""" + messages = [{"role": "user", "content": test.prompt}] + + start = time.time() + try: + if backend == "ollama": + response = call_ollama(messages, test.tools, url, model, timeout) + else: + response = call_openai_compatible(messages, test.tools, url, model, timeout) + elapsed = time.time() - start + + result, actual_tool, actual_params, param_scores = validate_tool_call(response, test) + + # Extract text response + try: + text = response["choices"][0]["message"].get("content", "") + except (KeyError, IndexError): + text = "" + + return TestRunResult( + test_id=test.id, + category=test.category.value, + difficulty=test.difficulty, + result=result.value, + expected_tool=test.expected_tool, + actual_tool=actual_tool, + expected_params=test.expected_params, + actual_params=actual_params, + param_scores=param_scores, + response_text=text[:200], + latency_s=round(elapsed, 3), + raw_response=response, + ) + + except requests.exceptions.Timeout: + return TestRunResult( + test_id=test.id, + category=test.category.value, + difficulty=test.difficulty, + result=TestResult.TIMEOUT.value, + expected_tool=test.expected_tool, + actual_tool="", + expected_params=test.expected_params, + actual_params={}, + error=f"Timeout after {timeout}s", + ) + except Exception as e: + return TestRunResult( + test_id=test.id, + category=test.category.value, + difficulty=test.difficulty, + result=TestResult.ERROR.value, + expected_tool=test.expected_tool, + actual_tool="", + expected_params=test.expected_params, + actual_params={}, + error=str(e)[:200], + ) + + +def run_dry_run() -> List[TestRunResult]: + """Validate test cases without a model.""" + results = [] + for test in TEST_CASES: + results.append(TestRunResult( + test_id=test.id, + category=test.category.value, + difficulty=test.difficulty, + result=TestResult.SKIP.value, + expected_tool=test.expected_tool, + actual_tool="(dry run)", + expected_params=test.expected_params, + actual_params={}, + )) + return results + + +def generate_report(results: List[TestRunResult], model: str) -> str: + """Generate markdown report.""" + lines = [ + f"# 1-Bit Model Tool Calling Test Results", + f"", + f"**Model:** {model}", + f"**Date:** {time.strftime('%Y-%m-%d %H:%M:%S')}", + f"**Test cases:** {len(results)}", + f"", + ] + + # Summary table + by_result = {} + for r in results: + by_result[r.result] = by_result.get(r.result, 0) + 1 + + lines.append("## Summary") + lines.append("") + lines.append("| Result | Count |") + lines.append("|--------|-------|") + for result, count in sorted(by_result.items()): + lines.append(f"| {result} | {count} |") + lines.append("") + + pass_count = by_result.get("PASS", 0) + total = len(results) + pass_rate = (pass_count / total * 100) if total > 0 else 0 + lines.append(f"**Pass rate: {pass_rate:.0f}%** ({pass_count}/{total})") + lines.append("") + + # By difficulty + lines.append("## Results by Difficulty") + lines.append("") + lines.append("| Difficulty | PASS | PARTIAL | FAIL | Other |") + lines.append("|-----------|------|---------|------|-------|") + for diff in range(1, 6): + diff_results = [r for r in results if r.difficulty == diff] + if not diff_results: + continue + p = sum(1 for r in diff_results if r.result == "PASS") + pa = sum(1 for r in diff_results if r.result == "PARTIAL") + f = sum(1 for r in diff_results if r.result in ("FAIL", "ERROR", "TIMEOUT")) + o = len(diff_results) - p - pa - f + lines.append(f"| {diff}/5 | {p} | {pa} | {f} | {o} |") + lines.append("") + + # Detailed results + lines.append("## Detailed Results") + lines.append("") + for r in results: + icon = {"PASS": "✅", "PARTIAL": "⚠️", "FAIL": "❌", "ERROR": "💥", "TIMEOUT": "⏱"}.get(r.result, "❓") + lines.append(f"### {icon} {r.test_id} (difficulty {r.difficulty}/5)") + lines.append(f"- **Category:** {r.category}") + lines.append(f"- **Expected tool:** `{r.expected_tool}`") + lines.append(f"- **Actual tool:** `{r.actual_tool}`") + if r.latency_s > 0: + lines.append(f"- **Latency:** {r.latency_s}s") + if r.param_scores: + lines.append(f"- **Param scores:** {json.dumps(r.param_scores)}") + if r.error: + lines.append(f"- **Error:** {r.error}") + lines.append("") + + # Viability verdict + lines.append("## Viability Verdict") + lines.append("") + if pass_rate >= 80: + lines.append("**VERDICT: VIABLE** — 1-bit model can handle tool calling for production use.") + elif pass_rate >= 50: + lines.append("**VERDICT: CONDITIONALLY VIABLE** — Works for simple tools, struggles with complex params. Consider for edge deployment with guardrails.") + elif pass_rate >= 20: + lines.append("**VERDICT: MARGINAL** — Can select correct tool sometimes, but parameter accuracy is too low for production. Investigate alternative quantization (2-bit, 3-bit).") + else: + lines.append("**VERDICT: NOT VIABLE** — 1-bit quantization destroys tool calling capability. Recommend minimum 3-bit quantization for tool-using models.") + lines.append("") + + return "\n".join(lines) + + +def main(): + parser = argparse.ArgumentParser(description="Test tool calling on 1-bit models") + parser.add_argument("--url", default="http://localhost:8081/v1/chat/completions", + help="Model API endpoint") + parser.add_argument("--model", default="bonsai-1b", help="Model name") + parser.add_argument("--backend", default="openai", choices=["openai", "ollama"], + help="API backend type") + parser.add_argument("--timeout", type=int, default=120, help="Request timeout in seconds") + parser.add_argument("--dry-run", action="store_true", help="Validate tests without model") + parser.add_argument("--output", default="benchmarks/bonsai-tool-calling-results.json", + help="Output file for results") + parser.add_argument("--report", default="benchmarks/bonsai-tool-calling.md", + help="Output file for markdown report") + parser.add_argument("--test-id", help="Run a single test by ID") + + args = parser.parse_args() + + print("=" * 60) + print(" 1-Bit Model Tool Calling Test Suite") + print("=" * 60) + + if args.dry_run: + print("\n[DRY RUN] Validating test cases...") + results = run_dry_run() + print(f" {len(results)} test cases validated") + for r in results: + print(f" ✓ {r.test_id} — expects {r.expected_tool} (difficulty {r.difficulty}/5)") + else: + print(f"\nModel: {args.model}") + print(f"Endpoint: {args.url}") + print(f"Backend: {args.backend}") + print() + + tests = TEST_CASES + if args.test_id: + tests = [t for t in tests if t.id == args.test_id] + if not tests: + print(f"Test '{args.test_id}' not found") + sys.exit(1) + + results = [] + for i, test in enumerate(tests): + print(f" [{i+1}/{len(tests)}] {test.id} (difficulty {test.difficulty}/5)... ", end="", flush=True) + result = run_test(test, args.url, args.model, args.backend, args.timeout) + results.append(result) + icon = {"PASS": "✅", "PARTIAL": "⚠️", "FAIL": "❌", "ERROR": "💥", "TIMEOUT": "⏱"}.get(result.result, "❓") + print(f"{icon} {result.result} ({result.latency_s}s)") + + # Save results + os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) + with open(args.output, "w") as f: + json.dump([asdict(r) for r in results], f, indent=2) + print(f"\nResults saved to {args.output}") + + # Generate report + report = generate_report(results, args.model) + with open(args.report, "w") as f: + f.write(report) + print(f"Report saved to {args.report}") + + # Print summary + pass_count = sum(1 for r in results if r.result == "PASS") + total = len(results) + print(f"\n{'='*60}") + print(f" Results: {pass_count}/{total} passed ({pass_count/total*100:.0f}%)") + + +if __name__ == "__main__": + main()