#!/usr/bin/env python3 """ Tool Calling Viability Test for 1-Bit / Edge Models (Issue #101) Tests whether Bonsai 1-bit models (or any small model) can produce valid tool calls via Ollama or llama-server API. Test suite (5 categories): 1. Simple tool call: file read with exact path 2. Terminal command execution 3. Web search 4. Multi-step: read file -> analyze -> write result 5. Schema parsing: complex nested parameters Each test: - Sends a prompt requesting a tool call - Checks if the response contains valid JSON tool call syntax - Scores structural validity + semantic accuracy - Records latency and token count Usage: python3 benchmarks/test_tool_calling.py --model bonsai-1.7b python3 benchmarks/test_tool_calling.py --model qwen3.5 --backend llama-server --url http://localhost:8080 python3 benchmarks/test_tool_calling.py --model bonsai-1.7b --output benchmarks/bonsai-tool-calling.md """ import argparse import json import os import re import sys import time from datetime import datetime, timezone from typing import Optional try: import requests except ImportError: print("Error: pip install requests", file=sys.stderr) sys.exit(1) # ── Tool call test definitions ──────────────────────────────────────── TESTS = [ { "id": "simple_file_read", "category": "Simple Tool Call", "description": "Read a file with an exact path", "prompt": ( "You have access to a tool called read_file. " "Call it to read /etc/hostname. " "Respond ONLY with a JSON tool call in this exact format:\n" '{"name": "read_file", "arguments": {"path": "/etc/hostname"}}' ), "validate": lambda resp: _has_json_tool_call(resp, "read_file", {"path": "/etc/hostname"}), "scoring": "structural", # Can we parse the JSON at all? }, { "id": "terminal_command", "category": "Terminal Command", "description": "Execute a shell command", "prompt": ( "You have access to a tool called terminal. " "Call it to run the command: echo hello world. " "Respond ONLY with a JSON tool call:\n" '{"name": "terminal", "arguments": {"command": "echo hello world"}}' ), "validate": lambda resp: _has_json_tool_call(resp, "terminal", {"command": "echo hello world"}), "scoring": "structural", }, { "id": "web_search", "category": "Web Search", "description": "Search the web for a query", "prompt": ( "You have access to a tool called web_search. " "Search for: what is quantization in machine learning. " "Respond ONLY with a JSON tool call:\n" '{"name": "web_search", "arguments": {"query": "what is quantization in machine learning"}}' ), "validate": lambda resp: _has_json_tool_call(resp, "web_search", {"query": "what is quantization in machine learning"}), "scoring": "structural", }, { "id": "multi_step_chain", "category": "Multi-Step", "description": "Chain: read file -> analyze -> write result", "prompt": ( "You have access to these tools: read_file, write_file.\n" "Task: Read /tmp/input.txt, count the words, then write the count to /tmp/count.txt.\n" "First, call read_file on /tmp/input.txt. " "Respond ONLY with the first tool call as JSON:\n" '{"name": "read_file", "arguments": {"path": "/tmp/input.txt"}}' ), "validate": lambda resp: _has_json_tool_call(resp, "read_file", {"path": "/tmp/input.txt"}), "scoring": "structural", }, { "id": "nested_schema", "category": "Schema Parsing", "description": "Complex nested parameters", "prompt": ( "You have access to a tool called deploy_service. " "Deploy a service with:\n" '- name: "api-gateway"\n' '- replicas: 3\n' '- env: {"PORT": 8080, "NODE_ENV": "production"}\n' '- resources: {"cpu": "500m", "memory": "256Mi"}\n\n' "Respond ONLY with a JSON tool call:\n" '{"name": "deploy_service", "arguments": {"name": "api-gateway", "replicas": 3, ' '"env": {"PORT": 8080, "NODE_ENV": "production"}, ' '"resources": {"cpu": "500m", "memory": "256Mi"}}}' ), "validate": lambda resp: _has_nested_tool_call(resp), "scoring": "semantic", # Needs correct nested structure }, ] # ── Validation helpers ──────────────────────────────────────────────── def _extract_json(text: str) -> Optional[dict]: """Try to extract a JSON object from text.""" # Try direct parse text = text.strip() try: obj = json.loads(text) if isinstance(obj, dict): return obj except json.JSONDecodeError: pass # Try finding JSON in code blocks code_block = re.search(r"```(?:json)?\s*({.*?})\s*```", text, re.DOTALL) if code_block: try: return json.loads(code_block.group(1)) except json.JSONDecodeError: pass # Try finding any JSON object json_match = re.search(r"({[^{}]*(?:{[^{}]*}[^{}]*)*})", text) if json_match: try: return json.loads(json_match.group(1)) except json.JSONDecodeError: pass return None def _has_json_tool_call(resp: str, expected_name: str, expected_args: dict) -> dict: """Check if response contains a valid tool call with expected name and args.""" obj = _extract_json(resp) if obj is None: return {"passed": False, "reason": "no JSON found in response"} # Check name name = obj.get("name", obj.get("function", {}).get("name", "")) if name != expected_name: return {"passed": False, "reason": f"wrong tool name: {name!r}, expected {expected_name!r}"} # Check arguments exist args = obj.get("arguments", obj.get("function", {}).get("arguments", obj.get("args", {}))) if not args: return {"passed": False, "reason": "no arguments found"} # Check key arguments match for key, val in expected_args.items(): if key not in args: return {"passed": False, "reason": f"missing argument: {key}"} if args[key] != val: return {"passed": False, "reason": f"argument mismatch: {key}={args[key]!r}, expected {val!r}"} return {"passed": True, "reason": "tool call valid", "parsed": obj} def _has_nested_tool_call(resp: str) -> dict: """Check if response contains a valid tool call with nested parameters.""" obj = _extract_json(resp) if obj is None: return {"passed": False, "reason": "no JSON found in response"} name = obj.get("name", obj.get("function", {}).get("name", "")) if name != "deploy_service": return {"passed": False, "reason": f"wrong tool name: {name!r}"} args = obj.get("arguments", obj.get("function", {}).get("arguments", obj.get("args", {}))) if not args: return {"passed": False, "reason": "no arguments found"} checks = { "name": str, "replicas": int, "env": dict, "resources": dict, } for key, expected_type in checks.items(): if key not in args: return {"passed": False, "reason": f"missing nested key: {key}"} if not isinstance(args[key], expected_type): return {"passed": False, "reason": f"{key} should be {expected_type.__name__}, got {type(args[key]).__name__}"} # Check env has PORT env = args.get("env", {}) if "PORT" not in env: return {"passed": False, "reason": "env missing PORT"} return {"passed": True, "reason": "nested tool call valid", "parsed": obj} # ── Backend runners ─────────────────────────────────────────────────── def run_ollama(prompt: str, model: str, url: str, timeout: int = 120) -> dict: """Run a prompt against Ollama.""" api_url = f"{url.rstrip('/')}/api/generate" start = time.time() try: resp = requests.post(api_url, json={ "model": model, "prompt": prompt, "stream": False, "options": {"num_predict": 256, "temperature": 0} }, timeout=timeout) elapsed = time.time() - start resp.raise_for_status() data = resp.json() return { "response": data.get("response", ""), "latency_s": round(elapsed, 3), "tokens": data.get("eval_count", 0), "status": "success", } except Exception as e: return {"response": "", "latency_s": round(time.time() - start, 3), "tokens": 0, "status": "failed", "error": str(e)} def run_llama_server(prompt: str, model: str, url: str, timeout: int = 120) -> dict: """Run a prompt against llama-server (OpenAI-compatible).""" api_url = f"{url.rstrip('/')}/v1/chat/completions" start = time.time() try: resp = requests.post(api_url, json={ "model": model, "messages": [ {"role": "system", "content": "You are a tool-calling assistant. Respond ONLY with JSON tool calls."}, {"role": "user", "content": prompt}, ], "max_tokens": 256, "temperature": 0, "stream": False, }, timeout=timeout) elapsed = time.time() - start resp.raise_for_status() data = resp.json() content = data.get("choices", [{}])[0].get("message", {}).get("content", "") usage = data.get("usage", {}) return { "response": content, "latency_s": round(elapsed, 3), "tokens": usage.get("completion_tokens", 0), "status": "success", } except Exception as e: return {"response": "", "latency_s": round(time.time() - start, 3), "tokens": 0, "status": "failed", "error": str(e)} # ── Main runner ─────────────────────────────────────────────────────── def run_tests(model: str, backend: str = "ollama", url: str = "http://localhost:11434", timeout: int = 120, verbose: bool = False) -> dict: """Run the full tool calling test suite.""" runner_fn = run_ollama if backend == "ollama" else run_llama_server results = { "model": model, "backend": backend, "url": url, "timestamp": datetime.now(timezone.utc).isoformat(), "tests": [], "summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0}, } print(f"Testing tool calling on: {model} ({backend})\n") for test in TESTS: print(f" [{test['id']}] {test['description']}...", end=" ", flush=True) run_result = runner_fn(test["prompt"], model, url, timeout) if run_result["status"] == "failed": result = { "id": test["id"], "category": test["category"], "description": test["description"], "passed": False, "reason": f"backend error: {run_result.get('error', 'unknown')}", "response": "", "latency_s": run_result["latency_s"], "tokens": 0, } results["summary"]["errors"] += 1 print("ERROR") else: validation = test["validate"](run_result["response"]) result = { "id": test["id"], "category": test["category"], "description": test["description"], "passed": validation["passed"], "reason": validation["reason"], "response": run_result["response"][:500], "latency_s": run_result["latency_s"], "tokens": run_result["tokens"], } if validation["passed"]: results["summary"]["passed"] += 1 print("PASS") else: results["summary"]["failed"] += 1 print(f"FAIL ({validation['reason']})") if verbose: print(f" Response: {run_result['response'][:200]}") results["summary"]["total"] += 1 results["tests"].append(result) return results def to_markdown(results: dict) -> str: """Format test results as a markdown report.""" lines = [] lines.append(f"# Tool Calling Viability: {results['model']}") lines.append("") lines.append(f"**Date**: {results['timestamp']}") lines.append(f"**Backend**: {results['backend']} ({results['url']})") lines.append(f"**Model**: {results['model']}") lines.append("") s = results["summary"] pass_rate = s["passed"] / s["total"] * 100 if s["total"] > 0 else 0 lines.append(f"## Summary: {s['passed']}/{s['total']} passed ({pass_rate:.0f}%)") lines.append("") lines.append(f"| Metric | Value |") lines.append(f"|--------|-------|") lines.append(f"| Total tests | {s['total']} |") lines.append(f"| Passed | {s['passed']} |") lines.append(f"| Failed | {s['failed']} |") lines.append(f"| Errors | {s['errors']} |") lines.append("") lines.append("## Results by Category") lines.append("") lines.append("| Test | Category | Result | Reason | Latency | Tokens |") lines.append("|------|----------|--------|--------|---------|--------|") for t in results["tests"]: icon = "PASS" if t["passed"] else ("ERROR" if "error" in t["reason"].lower() else "FAIL") lines.append(f"| {t['id']} | {t['category']} | {icon} | {t['reason']} | {t['latency_s']}s | {t['tokens']} |") lines.append("") lines.append("## Verdict") lines.append("") if pass_rate == 100: lines.append("**FULLY VIABLE** — All tool calling patterns work. Ready for production edge deployment.") elif pass_rate >= 60: lines.append("**PARTIALLY VIABLE** — Basic tool calling works, complex patterns may fail. Consider for simple agents.") elif pass_rate >= 20: lines.append("**MARGINAL** — Only simplest tool calls work. Not recommended for production.") else: lines.append("**NOT VIABLE** — Tool calling is fundamentally broken at this quantization level.") lines.append("") lines.append("## Failure Analysis") lines.append("") failed = [t for t in results["tests"] if not t["passed"]] if not failed: lines.append("No failures.") else: for t in failed: lines.append(f"### {t['id']}") lines.append(f"- **Category**: {t['category']}") lines.append(f"- **Failure**: {t['reason']}") lines.append(f"- **Response** (first 300 chars): `{t['response'][:300]}`") lines.append("") lines.append("") lines.append("## Recommendations") lines.append("") if pass_rate >= 80: lines.append("- Deploy for simple single-tool-call workflows") lines.append("- Add retry logic for multi-step chains") lines.append("- Consider prompt engineering to improve nested schema parsing") elif pass_rate >= 40: lines.append("- Use for keyword/rule-based tool routing only") lines.append("- Do NOT use for complex multi-step workflows") lines.append("- Consider a larger model (Q4 quantized) as fallback") else: lines.append("- 1-bit quantization is too lossy for tool calling") lines.append("- Use Q4_0 as minimum viable quantization for tool use") lines.append("- Reserve 1-bit models for text generation only") return "\n".join(lines) def main(): parser = argparse.ArgumentParser(description="Tool Calling Viability Test for Edge Models") parser.add_argument("--model", "-m", required=True, help="Model name") parser.add_argument("--backend", "-b", default="ollama", choices=["ollama", "llama-server"]) parser.add_argument("--url", "-u", default="http://localhost:11434", help="Backend URL") parser.add_argument("--timeout", "-t", type=int, default=120, help="Timeout per test (seconds)") parser.add_argument("--output", "-o", help="Output markdown file path") parser.add_argument("--json", action="store_true", help="JSON output") parser.add_argument("--verbose", "-v", action="store_true", help="Show full responses") args = parser.parse_args() results = run_tests(args.model, args.backend, args.url, args.timeout, args.verbose) if args.json: print(json.dumps(results, indent=2)) else: md = to_markdown(results) if args.output: with open(args.output, "w") as f: f.write(md) print(f"\nReport written to: {args.output}") else: print("\n" + md) if __name__ == "__main__": main()