test: tool calling on 1-bit models — test suite + harness (closes #101)

This commit is contained in:
2026-04-16 01:58:46 +00:00
parent 3cd8750cbb
commit 442c4dbcc7

View File

@@ -0,0 +1,709 @@
#!/usr/bin/env python3
"""
1-Bit Model Tool Calling Test Suite (Issue #101).
Tests whether quantized/1-bit models can handle structured tool calling.
Designed to be run against any OpenAI-compatible endpoint (llama-server, Ollama).
The core question: does 1-bit quantization destroy the precise JSON output
required for tool calling? This suite measures it empirically.
Usage:
# Against local llama-server
python3 benchmarks/test_bonsai_tool_calling.py \
--url http://localhost:8081/v1/chat/completions \
--model bonsai-1b
# Against Ollama
python3 benchmarks/test_bonsai_tool_calling.py \
--url http://localhost:11434/api/chat \
--model bonsai:latest \
--backend ollama
# Dry run (validate test cases without model)
python3 benchmarks/test_bonsai_tool_calling.py --dry-run
"""
import argparse
import json
import os
import re
import sys
import time
from dataclasses import dataclass, field, asdict
from enum import Enum
from typing import List, Dict, Optional, Tuple
import requests
class ToolCallCategory(Enum):
"""Categories of tool call complexity."""
SIMPLE_READ = "simple_read"
TERMINAL_CMD = "terminal_cmd"
WEB_SEARCH = "web_search"
MULTI_STEP = "multi_step"
NESTED_PARAMS = "nested_params"
ARRAY_PARAMS = "array_params"
OPTIONAL_PARAMS = "optional_params"
MULTI_TOOL_SELECT = "multi_tool_select"
class TestResult(Enum):
PASS = "PASS"
FAIL = "FAIL"
PARTIAL = "PARTIAL"
TIMEOUT = "TIMEOUT"
ERROR = "ERROR"
SKIP = "SKIP"
# ── Tool schemas (hermes-compatible) ─────────────────────────
TOOL_SCHEMAS = [
{
"type": "function",
"function": {
"name": "read_file",
"description": "Read a text file with line numbers.",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "File path to read"},
"offset": {"type": "integer", "description": "Start line (1-indexed)", "default": 1},
"limit": {"type": "integer", "description": "Max lines to read", "default": 500},
},
"required": ["path"],
},
},
},
{
"type": "function",
"function": {
"name": "terminal",
"description": "Execute a shell command.",
"parameters": {
"type": "object",
"properties": {
"command": {"type": "string", "description": "Shell command to execute"},
"timeout": {"type": "integer", "description": "Timeout in seconds", "default": 30},
"workdir": {"type": "string", "description": "Working directory"},
},
"required": ["command"],
},
},
},
{
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web for information.",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"},
"max_results": {"type": "integer", "description": "Max results to return", "default": 5},
},
"required": ["query"],
},
},
},
{
"type": "function",
"function": {
"name": "write_file",
"description": "Write content to a file, creating directories as needed.",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "File path to write"},
"content": {"type": "string", "description": "Content to write"},
},
"required": ["path", "content"],
},
},
},
{
"type": "function",
"function": {
"name": "patch",
"description": "Apply a targeted find-and-replace edit to a file.",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "File path to edit"},
"old_string": {"type": "string", "description": "Text to find"},
"new_string": {"type": "string", "description": "Replacement text"},
"replace_all": {"type": "boolean", "description": "Replace all occurrences", "default": False},
},
"required": ["path", "old_string", "new_string"],
},
},
},
]
# ── Test case definitions ────────────────────────────────────
@dataclass
class ToolCallTestCase:
"""A single tool calling test case."""
id: str
category: ToolCallCategory
prompt: str
tools: List[dict]
expected_tool: str
expected_params: Dict[str, any]
param_validators: Dict[str, callable] = field(default_factory=dict)
description: str = ""
difficulty: int = 1 # 1-5, higher = harder
TEST_CASES = [
# ── Level 1: Simple reads ──────────────────────────────
ToolCallTestCase(
id="simple-read-1",
category=ToolCallCategory.SIMPLE_READ,
prompt="Read the file at /tmp/test.txt",
tools=[TOOL_SCHEMAS[0]],
expected_tool="read_file",
expected_params={"path": "/tmp/test.txt"},
description="Exact path, single required param",
difficulty=1,
),
ToolCallTestCase(
id="simple-read-with-limit",
category=ToolCallCategory.SIMPLE_READ,
prompt="Read the first 10 lines of /var/log/system.log",
tools=[TOOL_SCHEMAS[0]],
expected_tool="read_file",
expected_params={"path": "/var/log/system.log"},
param_validators={"limit": lambda v: isinstance(v, int) and v <= 20},
description="Required + optional param",
difficulty=2,
),
# ── Level 2: Terminal commands ─────────────────────────
ToolCallTestCase(
id="terminal-simple",
category=ToolCallCategory.TERMINAL_CMD,
prompt="List all files in the current directory",
tools=[TOOL_SCHEMAS[1]],
expected_tool="terminal",
expected_params={},
param_validators={
"command": lambda v: isinstance(v, str) and any(
cmd in v for cmd in ["ls", "dir", "find"]
)
},
description="Generate appropriate shell command",
difficulty=2,
),
ToolCallTestCase(
id="terminal-pipe",
category=ToolCallCategory.TERMINAL_CMD,
prompt="Count how many Python files are in /tmp recursively",
tools=[TOOL_SCHEMAS[1]],
expected_tool="terminal",
expected_params={},
param_validators={
"command": lambda v: isinstance(v, str) and (
"find" in v or "ls" in v or "python" in v or ".py" in v
)
},
description="Needs piped or recursive command",
difficulty=3,
),
# ── Level 3: Web search ────────────────────────────────
ToolCallTestCase(
id="web-search-simple",
category=ToolCallCategory.WEB_SEARCH,
prompt="Search for the current price of Bitcoin",
tools=[TOOL_SCHEMAS[2]],
expected_tool="web_search",
expected_params={"query": "Bitcoin price"},
param_validators={
"query": lambda v: isinstance(v, str) and len(v) > 3 and "bitcoin" in v.lower()
},
description="Extract search query from natural language",
difficulty=2,
),
# ── Level 4: Multi-tool selection ──────────────────────
ToolCallTestCase(
id="multi-tool-select-read",
category=ToolCallCategory.MULTI_TOOL_SELECT,
prompt="Read the file at /etc/hostname",
tools=TOOL_SCHEMAS[:3], # read_file, terminal, web_search
expected_tool="read_file",
expected_params={"path": "/etc/hostname"},
description="Choose correct tool from 3 options",
difficulty=3,
),
ToolCallTestCase(
id="multi-tool-select-terminal",
category=ToolCallCategory.MULTI_TOOL_SELECT,
prompt="Check how much disk space is available",
tools=TOOL_SCHEMAS[:3],
expected_tool="terminal",
expected_params={},
param_validators={
"command": lambda v: isinstance(v, str) and any(
cmd in v for cmd in ["df", "du", "disk"]
)
},
description="Choose terminal over read_file for system info",
difficulty=3,
),
ToolCallTestCase(
id="multi-tool-select-search",
category=ToolCallCategory.MULTI_TOOL_SELECT,
prompt="What is the weather in Tokyo right now?",
tools=TOOL_SCHEMAS[:3],
expected_tool="web_search",
expected_params={},
param_validators={
"query": lambda v: isinstance(v, str) and "weather" in v.lower() and "tokyo" in v.lower()
},
description="Choose web_search for real-time info",
difficulty=3,
),
# ── Level 5: Nested/complex params ─────────────────────
ToolCallTestCase(
id="write-file-with-content",
category=ToolCallCategory.NESTED_PARAMS,
prompt="Create a file at /tmp/hello.txt with the content 'Hello, World!'",
tools=[TOOL_SCHEMAS[3]],
expected_tool="write_file",
expected_params={"path": "/tmp/hello.txt"},
param_validators={
"content": lambda v: isinstance(v, str) and "hello" in v.lower()
},
description="Two required string params",
difficulty=3,
),
ToolCallTestCase(
id="patch-edit",
category=ToolCallCategory.NESTED_PARAMS,
prompt="In the file /tmp/config.yaml, replace 'debug: false' with 'debug: true'",
tools=[TOOL_SCHEMAS[4]],
expected_tool="patch",
expected_params={"path": "/tmp/config.yaml"},
param_validators={
"old_string": lambda v: isinstance(v, str) and "debug: false" in v,
"new_string": lambda v: isinstance(v, str) and "debug: true" in v,
},
description="Three required params, find-and-replace",
difficulty=4,
),
# ── Level 6: Multi-step reasoning ──────────────────────
ToolCallTestCase(
id="multi-step-read-then-write",
category=ToolCallCategory.MULTI_STEP,
prompt="Read /tmp/source.txt and write its contents to /tmp/backup.txt",
tools=[TOOL_SCHEMAS[0], TOOL_SCHEMAS[3]], # read_file + write_file
expected_tool="read_file", # First step should be reading
expected_params={"path": "/tmp/source.txt"},
description="Requires planning: read first, then write",
difficulty=5,
),
]
# ── Test runner ──────────────────────────────────────────────
@dataclass
class TestRunResult:
"""Result of running a single test case."""
test_id: str
category: str
difficulty: int
result: str # TestResult value
expected_tool: str
actual_tool: str
expected_params: dict
actual_params: dict
param_scores: Dict[str, bool] = field(default_factory=dict)
response_text: str = ""
latency_s: float = 0.0
tokens_per_sec: float = 0.0
error: str = ""
raw_response: dict = field(default_factory=dict)
def call_openai_compatible(
messages: list,
tools: list,
url: str,
model: str,
timeout: int = 120,
) -> dict:
"""Call an OpenAI-compatible chat completions endpoint."""
payload = {
"model": model,
"messages": messages,
"tools": tools,
"tool_choice": "auto",
"max_tokens": 512,
"temperature": 0.0,
}
resp = requests.post(url, json=payload, timeout=timeout)
resp.raise_for_status()
return resp.json()
def call_ollama(
messages: list,
tools: list,
url: str,
model: str,
timeout: int = 120,
) -> dict:
"""Call Ollama /api/chat endpoint."""
# Convert OpenAI tool format to Ollama format
ollama_tools = []
for t in tools:
fn = t["function"]
ollama_tools.append({
"type": "function",
"function": {
"name": fn["name"],
"description": fn["description"],
"parameters": fn["parameters"],
},
})
resp = requests.post(url, json={
"model": model,
"messages": messages,
"tools": ollama_tools,
"stream": False,
}, timeout=timeout)
resp.raise_for_status()
data = resp.json()
# Normalize to OpenAI format
result = {"choices": [{"message": {}}]}
msg = data.get("message", {})
result["choices"][0]["message"]["content"] = msg.get("content", "")
if msg.get("tool_calls"):
result["choices"][0]["message"]["tool_calls"] = msg["tool_calls"]
return result
def validate_tool_call(
response: dict,
test: ToolCallTestCase,
) -> Tuple[TestResult, str, dict, Dict[str, bool]]:
"""
Validate a model response against a test case.
Returns: (result, actual_tool, actual_params, param_scores)
"""
try:
choice = response["choices"][0]
msg = choice["message"]
except (KeyError, IndexError):
return TestResult.FAIL, "", {}, {}
# Check if model called a tool
tool_calls = msg.get("tool_calls", [])
if not tool_calls:
# Model responded with text instead — check if it at least mentioned the tool
content = msg.get("content", "")
if test.expected_tool in content:
return TestResult.PARTIAL, "text_only", {"content": content}, {}
return TestResult.FAIL, "none", {}, {}
tc = tool_calls[0]
actual_tool = tc.get("function", {}).get("name", "")
# Parse arguments
try:
args_str = tc.get("function", {}).get("arguments", "{}")
if isinstance(args_str, str):
actual_params = json.loads(args_str)
else:
actual_params = args_str
except json.JSONDecodeError:
return TestResult.FAIL, actual_tool, {}, {"json_parse": False}
# Check tool name
if actual_tool != test.expected_tool:
return TestResult.FAIL, actual_tool, actual_params, {
"tool_match": False
}
# Validate expected params
param_scores = {"tool_match": True}
all_pass = True
for key, expected_val in test.expected_params.items():
if key in actual_params:
if actual_params[key] == expected_val:
param_scores[f"param_{key}"] = True
else:
param_scores[f"param_{key}"] = False
all_pass = False
else:
param_scores[f"param_{key}"] = False
all_pass = False
# Run custom validators
for key, validator in test.param_validators.items():
if key in actual_params:
try:
passed = validator(actual_params[key])
param_scores[f"validator_{key}"] = bool(passed)
if not passed:
all_pass = False
except Exception:
param_scores[f"validator_{key}"] = False
all_pass = False
else:
param_scores[f"validator_{key}"] = False
all_pass = False
if all_pass and len(test.expected_params) > 0:
return TestResult.PASS, actual_tool, actual_params, param_scores
elif all_pass:
# No expected params to check — validators passed
return TestResult.PASS, actual_tool, actual_params, param_scores
else:
return TestResult.PARTIAL, actual_tool, actual_params, param_scores
def run_test(
test: ToolCallTestCase,
url: str,
model: str,
backend: str = "openai",
timeout: int = 120,
) -> TestRunResult:
"""Run a single test case against the model."""
messages = [{"role": "user", "content": test.prompt}]
start = time.time()
try:
if backend == "ollama":
response = call_ollama(messages, test.tools, url, model, timeout)
else:
response = call_openai_compatible(messages, test.tools, url, model, timeout)
elapsed = time.time() - start
result, actual_tool, actual_params, param_scores = validate_tool_call(response, test)
# Extract text response
try:
text = response["choices"][0]["message"].get("content", "")
except (KeyError, IndexError):
text = ""
return TestRunResult(
test_id=test.id,
category=test.category.value,
difficulty=test.difficulty,
result=result.value,
expected_tool=test.expected_tool,
actual_tool=actual_tool,
expected_params=test.expected_params,
actual_params=actual_params,
param_scores=param_scores,
response_text=text[:200],
latency_s=round(elapsed, 3),
raw_response=response,
)
except requests.exceptions.Timeout:
return TestRunResult(
test_id=test.id,
category=test.category.value,
difficulty=test.difficulty,
result=TestResult.TIMEOUT.value,
expected_tool=test.expected_tool,
actual_tool="",
expected_params=test.expected_params,
actual_params={},
error=f"Timeout after {timeout}s",
)
except Exception as e:
return TestRunResult(
test_id=test.id,
category=test.category.value,
difficulty=test.difficulty,
result=TestResult.ERROR.value,
expected_tool=test.expected_tool,
actual_tool="",
expected_params=test.expected_params,
actual_params={},
error=str(e)[:200],
)
def run_dry_run() -> List[TestRunResult]:
"""Validate test cases without a model."""
results = []
for test in TEST_CASES:
results.append(TestRunResult(
test_id=test.id,
category=test.category.value,
difficulty=test.difficulty,
result=TestResult.SKIP.value,
expected_tool=test.expected_tool,
actual_tool="(dry run)",
expected_params=test.expected_params,
actual_params={},
))
return results
def generate_report(results: List[TestRunResult], model: str) -> str:
"""Generate markdown report."""
lines = [
f"# 1-Bit Model Tool Calling Test Results",
f"",
f"**Model:** {model}",
f"**Date:** {time.strftime('%Y-%m-%d %H:%M:%S')}",
f"**Test cases:** {len(results)}",
f"",
]
# Summary table
by_result = {}
for r in results:
by_result[r.result] = by_result.get(r.result, 0) + 1
lines.append("## Summary")
lines.append("")
lines.append("| Result | Count |")
lines.append("|--------|-------|")
for result, count in sorted(by_result.items()):
lines.append(f"| {result} | {count} |")
lines.append("")
pass_count = by_result.get("PASS", 0)
total = len(results)
pass_rate = (pass_count / total * 100) if total > 0 else 0
lines.append(f"**Pass rate: {pass_rate:.0f}%** ({pass_count}/{total})")
lines.append("")
# By difficulty
lines.append("## Results by Difficulty")
lines.append("")
lines.append("| Difficulty | PASS | PARTIAL | FAIL | Other |")
lines.append("|-----------|------|---------|------|-------|")
for diff in range(1, 6):
diff_results = [r for r in results if r.difficulty == diff]
if not diff_results:
continue
p = sum(1 for r in diff_results if r.result == "PASS")
pa = sum(1 for r in diff_results if r.result == "PARTIAL")
f = sum(1 for r in diff_results if r.result in ("FAIL", "ERROR", "TIMEOUT"))
o = len(diff_results) - p - pa - f
lines.append(f"| {diff}/5 | {p} | {pa} | {f} | {o} |")
lines.append("")
# Detailed results
lines.append("## Detailed Results")
lines.append("")
for r in results:
icon = {"PASS": "", "PARTIAL": "⚠️", "FAIL": "", "ERROR": "💥", "TIMEOUT": ""}.get(r.result, "")
lines.append(f"### {icon} {r.test_id} (difficulty {r.difficulty}/5)")
lines.append(f"- **Category:** {r.category}")
lines.append(f"- **Expected tool:** `{r.expected_tool}`")
lines.append(f"- **Actual tool:** `{r.actual_tool}`")
if r.latency_s > 0:
lines.append(f"- **Latency:** {r.latency_s}s")
if r.param_scores:
lines.append(f"- **Param scores:** {json.dumps(r.param_scores)}")
if r.error:
lines.append(f"- **Error:** {r.error}")
lines.append("")
# Viability verdict
lines.append("## Viability Verdict")
lines.append("")
if pass_rate >= 80:
lines.append("**VERDICT: VIABLE** — 1-bit model can handle tool calling for production use.")
elif pass_rate >= 50:
lines.append("**VERDICT: CONDITIONALLY VIABLE** — Works for simple tools, struggles with complex params. Consider for edge deployment with guardrails.")
elif pass_rate >= 20:
lines.append("**VERDICT: MARGINAL** — Can select correct tool sometimes, but parameter accuracy is too low for production. Investigate alternative quantization (2-bit, 3-bit).")
else:
lines.append("**VERDICT: NOT VIABLE** — 1-bit quantization destroys tool calling capability. Recommend minimum 3-bit quantization for tool-using models.")
lines.append("")
return "\n".join(lines)
def main():
parser = argparse.ArgumentParser(description="Test tool calling on 1-bit models")
parser.add_argument("--url", default="http://localhost:8081/v1/chat/completions",
help="Model API endpoint")
parser.add_argument("--model", default="bonsai-1b", help="Model name")
parser.add_argument("--backend", default="openai", choices=["openai", "ollama"],
help="API backend type")
parser.add_argument("--timeout", type=int, default=120, help="Request timeout in seconds")
parser.add_argument("--dry-run", action="store_true", help="Validate tests without model")
parser.add_argument("--output", default="benchmarks/bonsai-tool-calling-results.json",
help="Output file for results")
parser.add_argument("--report", default="benchmarks/bonsai-tool-calling.md",
help="Output file for markdown report")
parser.add_argument("--test-id", help="Run a single test by ID")
args = parser.parse_args()
print("=" * 60)
print(" 1-Bit Model Tool Calling Test Suite")
print("=" * 60)
if args.dry_run:
print("\n[DRY RUN] Validating test cases...")
results = run_dry_run()
print(f" {len(results)} test cases validated")
for r in results:
print(f"{r.test_id} — expects {r.expected_tool} (difficulty {r.difficulty}/5)")
else:
print(f"\nModel: {args.model}")
print(f"Endpoint: {args.url}")
print(f"Backend: {args.backend}")
print()
tests = TEST_CASES
if args.test_id:
tests = [t for t in tests if t.id == args.test_id]
if not tests:
print(f"Test '{args.test_id}' not found")
sys.exit(1)
results = []
for i, test in enumerate(tests):
print(f" [{i+1}/{len(tests)}] {test.id} (difficulty {test.difficulty}/5)... ", end="", flush=True)
result = run_test(test, args.url, args.model, args.backend, args.timeout)
results.append(result)
icon = {"PASS": "", "PARTIAL": "⚠️", "FAIL": "", "ERROR": "💥", "TIMEOUT": ""}.get(result.result, "")
print(f"{icon} {result.result} ({result.latency_s}s)")
# Save results
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
with open(args.output, "w") as f:
json.dump([asdict(r) for r in results], f, indent=2)
print(f"\nResults saved to {args.output}")
# Generate report
report = generate_report(results, args.model)
with open(args.report, "w") as f:
f.write(report)
print(f"Report saved to {args.report}")
# Print summary
pass_count = sum(1 for r in results if r.result == "PASS")
total = len(results)
print(f"\n{'='*60}")
print(f" Results: {pass_count}/{total} passed ({pass_count/total*100:.0f}%)")
if __name__ == "__main__":
main()