diff --git a/benchmarks/test_tool_calling.py b/benchmarks/test_tool_calling.py new file mode 100644 index 00000000..e509f9c5 --- /dev/null +++ b/benchmarks/test_tool_calling.py @@ -0,0 +1,435 @@ +#!/usr/bin/env python3 +""" +Tool Calling Viability Test for 1-Bit / Edge Models (Issue #101) + +Tests whether Bonsai 1-bit models (or any small model) can produce +valid tool calls via Ollama or llama-server API. + +Test suite (5 categories): + 1. Simple tool call: file read with exact path + 2. Terminal command execution + 3. Web search + 4. Multi-step: read file -> analyze -> write result + 5. Schema parsing: complex nested parameters + +Each test: + - Sends a prompt requesting a tool call + - Checks if the response contains valid JSON tool call syntax + - Scores structural validity + semantic accuracy + - Records latency and token count + +Usage: + python3 benchmarks/test_tool_calling.py --model bonsai-1.7b + python3 benchmarks/test_tool_calling.py --model qwen3.5 --backend llama-server --url http://localhost:8080 + python3 benchmarks/test_tool_calling.py --model bonsai-1.7b --output benchmarks/bonsai-tool-calling.md +""" + +import argparse +import json +import os +import re +import sys +import time +from datetime import datetime, timezone +from typing import Optional + +try: + import requests +except ImportError: + print("Error: pip install requests", file=sys.stderr) + sys.exit(1) + + +# ── Tool call test definitions ──────────────────────────────────────── + +TESTS = [ + { + "id": "simple_file_read", + "category": "Simple Tool Call", + "description": "Read a file with an exact path", + "prompt": ( + "You have access to a tool called read_file. " + "Call it to read /etc/hostname. " + "Respond ONLY with a JSON tool call in this exact format:\n" + '{"name": "read_file", "arguments": {"path": "/etc/hostname"}}' + ), + "validate": lambda resp: _has_json_tool_call(resp, "read_file", {"path": "/etc/hostname"}), + "scoring": "structural", # Can we parse the JSON at all? + }, + { + "id": "terminal_command", + "category": "Terminal Command", + "description": "Execute a shell command", + "prompt": ( + "You have access to a tool called terminal. " + "Call it to run the command: echo hello world. " + "Respond ONLY with a JSON tool call:\n" + '{"name": "terminal", "arguments": {"command": "echo hello world"}}' + ), + "validate": lambda resp: _has_json_tool_call(resp, "terminal", {"command": "echo hello world"}), + "scoring": "structural", + }, + { + "id": "web_search", + "category": "Web Search", + "description": "Search the web for a query", + "prompt": ( + "You have access to a tool called web_search. " + "Search for: what is quantization in machine learning. " + "Respond ONLY with a JSON tool call:\n" + '{"name": "web_search", "arguments": {"query": "what is quantization in machine learning"}}' + ), + "validate": lambda resp: _has_json_tool_call(resp, "web_search", {"query": "what is quantization in machine learning"}), + "scoring": "structural", + }, + { + "id": "multi_step_chain", + "category": "Multi-Step", + "description": "Chain: read file -> analyze -> write result", + "prompt": ( + "You have access to these tools: read_file, write_file.\n" + "Task: Read /tmp/input.txt, count the words, then write the count to /tmp/count.txt.\n" + "First, call read_file on /tmp/input.txt. " + "Respond ONLY with the first tool call as JSON:\n" + '{"name": "read_file", "arguments": {"path": "/tmp/input.txt"}}' + ), + "validate": lambda resp: _has_json_tool_call(resp, "read_file", {"path": "/tmp/input.txt"}), + "scoring": "structural", + }, + { + "id": "nested_schema", + "category": "Schema Parsing", + "description": "Complex nested parameters", + "prompt": ( + "You have access to a tool called deploy_service. " + "Deploy a service with:\n" + '- name: "api-gateway"\n' + '- replicas: 3\n' + '- env: {"PORT": 8080, "NODE_ENV": "production"}\n' + '- resources: {"cpu": "500m", "memory": "256Mi"}\n\n' + "Respond ONLY with a JSON tool call:\n" + '{"name": "deploy_service", "arguments": {"name": "api-gateway", "replicas": 3, ' + '"env": {"PORT": 8080, "NODE_ENV": "production"}, ' + '"resources": {"cpu": "500m", "memory": "256Mi"}}}' + ), + "validate": lambda resp: _has_nested_tool_call(resp), + "scoring": "semantic", # Needs correct nested structure + }, +] + + +# ── Validation helpers ──────────────────────────────────────────────── + +def _extract_json(text: str) -> Optional[dict]: + """Try to extract a JSON object from text.""" + # Try direct parse + text = text.strip() + try: + obj = json.loads(text) + if isinstance(obj, dict): + return obj + except json.JSONDecodeError: + pass + + # Try finding JSON in code blocks + code_block = re.search(r"```(?:json)?\s*({.*?})\s*```", text, re.DOTALL) + if code_block: + try: + return json.loads(code_block.group(1)) + except json.JSONDecodeError: + pass + + # Try finding any JSON object + json_match = re.search(r"({[^{}]*(?:{[^{}]*}[^{}]*)*})", text) + if json_match: + try: + return json.loads(json_match.group(1)) + except json.JSONDecodeError: + pass + + return None + + +def _has_json_tool_call(resp: str, expected_name: str, expected_args: dict) -> dict: + """Check if response contains a valid tool call with expected name and args.""" + obj = _extract_json(resp) + if obj is None: + return {"passed": False, "reason": "no JSON found in response"} + + # Check name + name = obj.get("name", obj.get("function", {}).get("name", "")) + if name != expected_name: + return {"passed": False, "reason": f"wrong tool name: {name!r}, expected {expected_name!r}"} + + # Check arguments exist + args = obj.get("arguments", obj.get("function", {}).get("arguments", obj.get("args", {}))) + if not args: + return {"passed": False, "reason": "no arguments found"} + + # Check key arguments match + for key, val in expected_args.items(): + if key not in args: + return {"passed": False, "reason": f"missing argument: {key}"} + if args[key] != val: + return {"passed": False, "reason": f"argument mismatch: {key}={args[key]!r}, expected {val!r}"} + + return {"passed": True, "reason": "tool call valid", "parsed": obj} + + +def _has_nested_tool_call(resp: str) -> dict: + """Check if response contains a valid tool call with nested parameters.""" + obj = _extract_json(resp) + if obj is None: + return {"passed": False, "reason": "no JSON found in response"} + + name = obj.get("name", obj.get("function", {}).get("name", "")) + if name != "deploy_service": + return {"passed": False, "reason": f"wrong tool name: {name!r}"} + + args = obj.get("arguments", obj.get("function", {}).get("arguments", obj.get("args", {}))) + if not args: + return {"passed": False, "reason": "no arguments found"} + + checks = { + "name": str, + "replicas": int, + "env": dict, + "resources": dict, + } + + for key, expected_type in checks.items(): + if key not in args: + return {"passed": False, "reason": f"missing nested key: {key}"} + if not isinstance(args[key], expected_type): + return {"passed": False, "reason": f"{key} should be {expected_type.__name__}, got {type(args[key]).__name__}"} + + # Check env has PORT + env = args.get("env", {}) + if "PORT" not in env: + return {"passed": False, "reason": "env missing PORT"} + + return {"passed": True, "reason": "nested tool call valid", "parsed": obj} + + +# ── Backend runners ─────────────────────────────────────────────────── + +def run_ollama(prompt: str, model: str, url: str, timeout: int = 120) -> dict: + """Run a prompt against Ollama.""" + api_url = f"{url.rstrip('/')}/api/generate" + start = time.time() + try: + resp = requests.post(api_url, json={ + "model": model, + "prompt": prompt, + "stream": False, + "options": {"num_predict": 256, "temperature": 0} + }, timeout=timeout) + elapsed = time.time() - start + resp.raise_for_status() + data = resp.json() + return { + "response": data.get("response", ""), + "latency_s": round(elapsed, 3), + "tokens": data.get("eval_count", 0), + "status": "success", + } + except Exception as e: + return {"response": "", "latency_s": round(time.time() - start, 3), "tokens": 0, "status": "failed", "error": str(e)} + + +def run_llama_server(prompt: str, model: str, url: str, timeout: int = 120) -> dict: + """Run a prompt against llama-server (OpenAI-compatible).""" + api_url = f"{url.rstrip('/')}/v1/chat/completions" + start = time.time() + try: + resp = requests.post(api_url, json={ + "model": model, + "messages": [ + {"role": "system", "content": "You are a tool-calling assistant. Respond ONLY with JSON tool calls."}, + {"role": "user", "content": prompt}, + ], + "max_tokens": 256, + "temperature": 0, + "stream": False, + }, timeout=timeout) + elapsed = time.time() - start + resp.raise_for_status() + data = resp.json() + content = data.get("choices", [{}])[0].get("message", {}).get("content", "") + usage = data.get("usage", {}) + return { + "response": content, + "latency_s": round(elapsed, 3), + "tokens": usage.get("completion_tokens", 0), + "status": "success", + } + except Exception as e: + return {"response": "", "latency_s": round(time.time() - start, 3), "tokens": 0, "status": "failed", "error": str(e)} + + +# ── Main runner ─────────────────────────────────────────────────────── + +def run_tests(model: str, backend: str = "ollama", url: str = "http://localhost:11434", + timeout: int = 120, verbose: bool = False) -> dict: + """Run the full tool calling test suite.""" + runner_fn = run_ollama if backend == "ollama" else run_llama_server + + results = { + "model": model, + "backend": backend, + "url": url, + "timestamp": datetime.now(timezone.utc).isoformat(), + "tests": [], + "summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0}, + } + + print(f"Testing tool calling on: {model} ({backend})\n") + + for test in TESTS: + print(f" [{test['id']}] {test['description']}...", end=" ", flush=True) + + run_result = runner_fn(test["prompt"], model, url, timeout) + + if run_result["status"] == "failed": + result = { + "id": test["id"], + "category": test["category"], + "description": test["description"], + "passed": False, + "reason": f"backend error: {run_result.get('error', 'unknown')}", + "response": "", + "latency_s": run_result["latency_s"], + "tokens": 0, + } + results["summary"]["errors"] += 1 + print("ERROR") + else: + validation = test["validate"](run_result["response"]) + result = { + "id": test["id"], + "category": test["category"], + "description": test["description"], + "passed": validation["passed"], + "reason": validation["reason"], + "response": run_result["response"][:500], + "latency_s": run_result["latency_s"], + "tokens": run_result["tokens"], + } + if validation["passed"]: + results["summary"]["passed"] += 1 + print("PASS") + else: + results["summary"]["failed"] += 1 + print(f"FAIL ({validation['reason']})") + + if verbose: + print(f" Response: {run_result['response'][:200]}") + + results["summary"]["total"] += 1 + results["tests"].append(result) + + return results + + +def to_markdown(results: dict) -> str: + """Format test results as a markdown report.""" + lines = [] + lines.append(f"# Tool Calling Viability: {results['model']}") + lines.append("") + lines.append(f"**Date**: {results['timestamp']}") + lines.append(f"**Backend**: {results['backend']} ({results['url']})") + lines.append(f"**Model**: {results['model']}") + lines.append("") + + s = results["summary"] + pass_rate = s["passed"] / s["total"] * 100 if s["total"] > 0 else 0 + lines.append(f"## Summary: {s['passed']}/{s['total']} passed ({pass_rate:.0f}%)") + lines.append("") + lines.append(f"| Metric | Value |") + lines.append(f"|--------|-------|") + lines.append(f"| Total tests | {s['total']} |") + lines.append(f"| Passed | {s['passed']} |") + lines.append(f"| Failed | {s['failed']} |") + lines.append(f"| Errors | {s['errors']} |") + lines.append("") + + lines.append("## Results by Category") + lines.append("") + lines.append("| Test | Category | Result | Reason | Latency | Tokens |") + lines.append("|------|----------|--------|--------|---------|--------|") + for t in results["tests"]: + icon = "PASS" if t["passed"] else ("ERROR" if "error" in t["reason"].lower() else "FAIL") + lines.append(f"| {t['id']} | {t['category']} | {icon} | {t['reason']} | {t['latency_s']}s | {t['tokens']} |") + lines.append("") + + lines.append("## Verdict") + lines.append("") + if pass_rate == 100: + lines.append("**FULLY VIABLE** — All tool calling patterns work. Ready for production edge deployment.") + elif pass_rate >= 60: + lines.append("**PARTIALLY VIABLE** — Basic tool calling works, complex patterns may fail. Consider for simple agents.") + elif pass_rate >= 20: + lines.append("**MARGINAL** — Only simplest tool calls work. Not recommended for production.") + else: + lines.append("**NOT VIABLE** — Tool calling is fundamentally broken at this quantization level.") + lines.append("") + + lines.append("## Failure Analysis") + lines.append("") + failed = [t for t in results["tests"] if not t["passed"]] + if not failed: + lines.append("No failures.") + else: + for t in failed: + lines.append(f"### {t['id']}") + lines.append(f"- **Category**: {t['category']}") + lines.append(f"- **Failure**: {t['reason']}") + lines.append(f"- **Response** (first 300 chars): `{t['response'][:300]}`") + lines.append("") + lines.append("") + + lines.append("## Recommendations") + lines.append("") + if pass_rate >= 80: + lines.append("- Deploy for simple single-tool-call workflows") + lines.append("- Add retry logic for multi-step chains") + lines.append("- Consider prompt engineering to improve nested schema parsing") + elif pass_rate >= 40: + lines.append("- Use for keyword/rule-based tool routing only") + lines.append("- Do NOT use for complex multi-step workflows") + lines.append("- Consider a larger model (Q4 quantized) as fallback") + else: + lines.append("- 1-bit quantization is too lossy for tool calling") + lines.append("- Use Q4_0 as minimum viable quantization for tool use") + lines.append("- Reserve 1-bit models for text generation only") + + return "\n".join(lines) + + +def main(): + parser = argparse.ArgumentParser(description="Tool Calling Viability Test for Edge Models") + parser.add_argument("--model", "-m", required=True, help="Model name") + parser.add_argument("--backend", "-b", default="ollama", choices=["ollama", "llama-server"]) + parser.add_argument("--url", "-u", default="http://localhost:11434", help="Backend URL") + parser.add_argument("--timeout", "-t", type=int, default=120, help="Timeout per test (seconds)") + parser.add_argument("--output", "-o", help="Output markdown file path") + parser.add_argument("--json", action="store_true", help="JSON output") + parser.add_argument("--verbose", "-v", action="store_true", help="Show full responses") + args = parser.parse_args() + + results = run_tests(args.model, args.backend, args.url, args.timeout, args.verbose) + + if args.json: + print(json.dumps(results, indent=2)) + else: + md = to_markdown(results) + if args.output: + with open(args.output, "w") as f: + f.write(md) + print(f"\nReport written to: {args.output}") + else: + print("\n" + md) + + +if __name__ == "__main__": + main()