Compare commits
3 Commits
step35/67-
...
burn/101-1
| Author | SHA1 | Date | |
|---|---|---|---|
| 3c815664e4 | |||
| 0d92de9b3f | |||
| 3caeaf13eb |
49
benchmarks/bonsai-tool-calling.md
Normal file
49
benchmarks/bonsai-tool-calling.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# Tool Calling Test Results — 1-Bit Models
|
||||
|
||||
**Status:** Pending execution
|
||||
**Issue:** #101
|
||||
**Model:** bonsai-1bit (to be tested)
|
||||
**Backend:** Ollama
|
||||
|
||||
## Test Suite
|
||||
|
||||
10 test cases covering:
|
||||
|
||||
| # | Test | Type | Difficulty | Description |
|
||||
|---|------|------|------------|-------------|
|
||||
| 1 | simple_file_read | file_read | easy | Read README.md with exact path |
|
||||
| 2 | absolute_path_read | file_read | easy | Read /etc/hostname with absolute path |
|
||||
| 3 | simple_terminal | terminal | easy | Run `echo hello world` |
|
||||
| 4 | terminal_ls | terminal | medium | List files in directory |
|
||||
| 5 | web_search | web_search | easy | Search for a query |
|
||||
| 6 | read_then_analyze | multi_step | medium | Read file then analyze content |
|
||||
| 7 | nested_params | schema_parsing | hard | Complex nested parameters |
|
||||
| 8 | optional_params | schema_parsing | medium | Tool with optional parameters |
|
||||
| 9 | sequential_calls | multi_step | hard | Multiple tool calls in sequence |
|
||||
| 10 | no_tool_needed | file_read | easy | No tool needed for simple question |
|
||||
|
||||
## Hypothesis
|
||||
|
||||
1-bit quantization destroys fine-grained reasoning. Tool calling (precise JSON output) may be impossible. But worth testing — the field is moving fast.
|
||||
|
||||
## Results
|
||||
|
||||
*To be filled after running:*
|
||||
```bash
|
||||
python3 benchmarks/test_tool_calling_1bit.py --model bonsai-1bit --report benchmarks/bonsai-tool-calling.md --results benchmarks/tool_calling_results.json
|
||||
```
|
||||
|
||||
## Failure Modes (Expected)
|
||||
|
||||
If tests fail, likely causes:
|
||||
1. **JSON formatting:** Model cannot produce valid JSON tool calls
|
||||
2. **Parameter extraction:** Model confuses or drops parameters
|
||||
3. **Schema adherence:** Model ignores tool schema constraints
|
||||
4. **Consistency:** Model produces different formats across runs
|
||||
|
||||
## Alternative Edge Models
|
||||
|
||||
If 1-bit is not viable:
|
||||
- **Qwen3.5 3B Q4** — Good tool calling, reasonable size
|
||||
- **Phi-3 Mini** — Strong reasoning, supports function calling
|
||||
- **Llama 3.2 3B** — Good balance of size and capability
|
||||
255
benchmarks/test_tool_calling_1bit.py
Normal file
255
benchmarks/test_tool_calling_1bit.py
Normal file
@@ -0,0 +1,255 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tool Calling Test Suite for 1-Bit Models (Issue #101)
|
||||
|
||||
Tests whether Bonsai 1-bit models can handle tool calling at all.
|
||||
Evaluates: file read, terminal execution, web search, multi-step workflows, schema parsing.
|
||||
|
||||
Usage:
|
||||
python3 benchmarks/test_tool_calling_1bit.py --model bonsai-1bit --backend ollama
|
||||
python3 benchmarks/test_tool_calling_1bit.py --results benchmarks/tool_calling_results.json
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class ToolCallType(Enum):
|
||||
FILE_READ = "file_read"
|
||||
TERMINAL_EXEC = "terminal_exec"
|
||||
WEB_SEARCH = "web_search"
|
||||
MULTI_STEP = "multi_step"
|
||||
SCHEMA_PARSING = "schema_parsing"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallTest:
|
||||
name: str
|
||||
tool_type: ToolCallType
|
||||
prompt: str
|
||||
expected_tool: Optional[str]
|
||||
expected_params: Dict[str, Any]
|
||||
validation_fn: Optional[str] = None
|
||||
difficulty: str = "easy"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
test_name: str
|
||||
tool_type: str
|
||||
passed: bool
|
||||
latency_ms: float
|
||||
response_text: str
|
||||
parsed_tool: Optional[str] = None
|
||||
parsed_params: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
quality_score: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
model: str
|
||||
backend: str
|
||||
timestamp: str
|
||||
results: List[TestResult] = field(default_factory=list)
|
||||
summary: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
TOOL_CALL_TESTS = [
|
||||
ToolCallTest("simple_file_read", ToolCallType.FILE_READ,
|
||||
"Use read_file to read README.md. Tools: read_file(path: str)",
|
||||
"read_file", {"path": "README.md"}, difficulty="easy"),
|
||||
ToolCallTest("absolute_path_read", ToolCallType.FILE_READ,
|
||||
"Use read_file to read /etc/hostname. Tools: read_file(path: str)",
|
||||
"read_file", {"path": "/etc/hostname"}, difficulty="easy"),
|
||||
ToolCallTest("simple_terminal", ToolCallType.TERMINAL_EXEC,
|
||||
"Use terminal to run: echo hello world. Tools: terminal(command: str)",
|
||||
"terminal", {"command": "echo hello world"}, difficulty="easy"),
|
||||
ToolCallTest("terminal_ls", ToolCallType.TERMINAL_EXEC,
|
||||
"Use terminal to list files. Tools: terminal(command: str)",
|
||||
"terminal", {}, validation_fn="validate_ls", difficulty="medium"),
|
||||
ToolCallTest("web_search", ToolCallType.WEB_SEARCH,
|
||||
"Use web_search for Python. Tools: web_search(query: str)",
|
||||
"web_search", {"query": "Python"}, difficulty="easy"),
|
||||
ToolCallTest("read_then_analyze", ToolCallType.MULTI_STEP,
|
||||
"First read README.md then analyze. Tools: read_file(path: str)",
|
||||
"read_file", {"path": "README.md"}, difficulty="medium"),
|
||||
ToolCallTest("nested_params", ToolCallType.SCHEMA_PARSING,
|
||||
"Use complex_tool(name=test, config={verbose:true}, tags=[a,b]). Tools: complex_tool(name: str, config: dict, tags: list)",
|
||||
"complex_tool", {"name": "test"}, difficulty="hard"),
|
||||
ToolCallTest("optional_params", ToolCallType.SCHEMA_PARSING,
|
||||
"Use search for ML with limit 5. Tools: search(query: str, limit: int=10)",
|
||||
"search", {"query": "ML", "limit": 5}, difficulty="medium"),
|
||||
ToolCallTest("sequential_calls", ToolCallType.MULTI_STEP,
|
||||
"First run pwd, then read README.md. Tools: terminal(command: str), read_file(path: str)",
|
||||
"terminal", {"command": "pwd"}, difficulty="hard"),
|
||||
ToolCallTest("no_tool_needed", ToolCallType.FILE_READ,
|
||||
"What is 2+2? Tools: read_file(path: str)",
|
||||
None, {}, difficulty="easy"),
|
||||
]
|
||||
|
||||
|
||||
def validate_ls(params):
|
||||
cmd = params.get("command", "").strip()
|
||||
return cmd in ["ls", "ls -l", "ls -la", "ls -1", "dir"] or cmd.startswith("ls ")
|
||||
|
||||
|
||||
VALIDATORS = {"validate_ls": validate_ls}
|
||||
|
||||
|
||||
def parse_tool_call(response: str) -> Tuple[Optional[str], Optional[Dict]]:
|
||||
# JSON format
|
||||
patterns = [
|
||||
r'"tool"\s*:\s*"([^"]+)"\s*,\s*"params"\s*:\s*({[^}]+})',
|
||||
r'"name"\s*:\s*"([^"]+)"\s*,\s*"arguments"\s*:\s*({[^}]+})',
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, response, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
return match.group(1), json.loads(match.group(2))
|
||||
except:
|
||||
continue
|
||||
|
||||
# Function call format
|
||||
match = re.search(r'(\w+)\(([^)]+)\)', response)
|
||||
if match:
|
||||
tool_name = match.group(1)
|
||||
params = {}
|
||||
for m in re.finditer(r'(\w+)\s*=\s*"?([^",)]+)"?', match.group(2)):
|
||||
params[m.group(1)] = m.group(2).strip().strip('"\'')
|
||||
return tool_name, params
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def call_model(prompt: str, model: str, backend: str, url: str, timeout: int = 60) -> Tuple[str, float]:
|
||||
start = time.time()
|
||||
try:
|
||||
if backend == "ollama":
|
||||
resp = requests.post(f"{url}/api/generate", json={
|
||||
"model": model, "prompt": prompt, "stream": False,
|
||||
"options": {"num_predict": 256, "temperature": 0.1}
|
||||
}, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
text = resp.json().get("response", "")
|
||||
else:
|
||||
text = f"ERROR: Unknown backend {backend}"
|
||||
except Exception as e:
|
||||
text = f"ERROR: {e}"
|
||||
return text, (time.time() - start) * 1000
|
||||
|
||||
|
||||
def run_test(test: ToolCallTest, model: str, backend: str, url: str) -> TestResult:
|
||||
response, latency = call_model(test.prompt, model, backend, url)
|
||||
|
||||
if response.startswith("ERROR:"):
|
||||
return TestResult(test.name, test.tool_type.value, False, latency, response, error=response)
|
||||
|
||||
parsed_tool, parsed_params = parse_tool_call(response)
|
||||
passed = False
|
||||
quality = 0.0
|
||||
|
||||
if test.expected_tool is None:
|
||||
passed = parsed_tool is None
|
||||
quality = 1.0 if passed else 0.0
|
||||
elif parsed_tool:
|
||||
tool_match = parsed_tool.lower() == test.expected_tool.lower()
|
||||
if test.validation_fn and test.validation_fn in VALIDATORS:
|
||||
params_match = VALIDATORS[test.validation_fn](parsed_params or {})
|
||||
else:
|
||||
params_match = all(
|
||||
k in (parsed_params or {}) and
|
||||
(str(v).lower() in str(parsed_params.get(k, "")).lower() if isinstance(v, str) else parsed_params.get(k) == v)
|
||||
for k, v in test.expected_params.items()
|
||||
) if test.expected_params else True
|
||||
passed = tool_match and params_match
|
||||
quality = (0.5 if tool_match else 0) + (0.5 if params_match else 0)
|
||||
|
||||
return TestResult(test.name, test.tool_type.value, passed, latency, response[:500],
|
||||
parsed_tool, parsed_params, quality_score=quality)
|
||||
|
||||
|
||||
def run_all_tests(model: str, backend: str, url: str) -> BenchmarkResult:
|
||||
results = BenchmarkResult(model, backend, datetime.now(timezone.utc).isoformat())
|
||||
print(f"Testing {model} ({backend})")
|
||||
print("=" * 50)
|
||||
|
||||
for test in TOOL_CALL_TESTS:
|
||||
result = run_test(test, model, backend, url)
|
||||
results.results.append(result)
|
||||
status = "PASS" if result.passed else "FAIL"
|
||||
print(f" {status} {test.name} ({result.latency_ms:.0f}ms, q={result.quality_score:.0%})")
|
||||
|
||||
total = len(results.results)
|
||||
passed = sum(1 for r in results.results if r.passed)
|
||||
results.summary = {
|
||||
"total": total, "passed": passed, "failed": total - passed,
|
||||
"pass_rate": passed / total if total else 0,
|
||||
"avg_latency_ms": sum(r.latency_ms for r in results.results) / total if total else 0,
|
||||
"avg_quality": sum(r.quality_score for r in results.results) / total if total else 0,
|
||||
}
|
||||
return results
|
||||
|
||||
|
||||
def generate_report(results: BenchmarkResult) -> str:
|
||||
s = results.summary
|
||||
lines = [
|
||||
"# Tool Calling Test Results - 1-Bit Models", "",
|
||||
f"**Model:** {results.model} ",
|
||||
f"**Backend:** {results.backend} ",
|
||||
f"**Timestamp:** {results.timestamp}", "",
|
||||
"## Summary", "",
|
||||
f"- Pass Rate: {s['passed']}/{s['total']} ({s['pass_rate']:.0%})",
|
||||
f"- Avg Latency: {s['avg_latency_ms']:.0f}ms",
|
||||
f"- Avg Quality: {s['avg_quality']:.0%}", "",
|
||||
"## Detailed Results", "",
|
||||
]
|
||||
for r, t in zip(results.results, TOOL_CALL_TESTS):
|
||||
lines.append(f"- {'PASS' if r.passed else 'FAIL'} {r.test_name} ({t.difficulty}, {r.latency_ms:.0f}ms)")
|
||||
lines.extend(["", "## Conclusion", ""])
|
||||
if s['pass_rate'] >= 0.8:
|
||||
lines.append(f"**VIABLE** - {s['pass_rate']:.0%} pass rate.")
|
||||
elif s['pass_rate'] >= 0.5:
|
||||
lines.append(f"**MARGINAL** - {s['pass_rate']:.0%} pass rate.")
|
||||
else:
|
||||
lines.append(f"**NOT VIABLE** - {s['pass_rate']:.0%} pass rate.")
|
||||
lines.extend(["", "### Alternatives", "- Qwen3.5 3B Q4", "- Phi-3 Mini", "- Llama 3.2 3B"])
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", default="bonsai-1bit")
|
||||
parser.add_argument("--backend", default="ollama")
|
||||
parser.add_argument("--url", default="http://localhost:11434")
|
||||
parser.add_argument("--results", help="Save results JSON")
|
||||
parser.add_argument("--report", help="Save report markdown")
|
||||
args = parser.parse_args()
|
||||
|
||||
results = run_all_tests(args.model, args.backend, args.url)
|
||||
print(f"\nSUMMARY: {results.summary['passed']}/{results.summary['total']} passed")
|
||||
|
||||
if args.results:
|
||||
os.makedirs(os.path.dirname(args.results) or ".", exist_ok=True)
|
||||
with open(args.results, "w") as f:
|
||||
json.dump(asdict(results), f, indent=2)
|
||||
|
||||
if args.report:
|
||||
os.makedirs(os.path.dirname(args.report) or ".", exist_ok=True)
|
||||
with open(args.report, "w") as f:
|
||||
f.write(generate_report(results))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
125
tests/test_tool_calling_suite.py
Normal file
125
tests/test_tool_calling_suite.py
Normal file
@@ -0,0 +1,125 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for Tool Calling Test Suite.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
|
||||
# Add benchmarks to path
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "benchmarks"))
|
||||
|
||||
from test_tool_calling_1bit import (
|
||||
ToolCallType,
|
||||
ToolCallTest,
|
||||
TestResult,
|
||||
BenchmarkResult,
|
||||
parse_tool_call,
|
||||
validate_ls,
|
||||
TOOL_CALL_TESTS,
|
||||
)
|
||||
|
||||
|
||||
class TestToolCallType:
|
||||
def test_values(self):
|
||||
assert ToolCallType.FILE_READ.value == "file_read"
|
||||
assert ToolCallType.TERMINAL_EXEC.value == "terminal_exec"
|
||||
assert ToolCallType.WEB_SEARCH.value == "web_search"
|
||||
|
||||
|
||||
class TestParseToolCall:
|
||||
def test_json_format(self):
|
||||
response = '{"tool": "read_file", "params": {"path": "test.txt"}}'
|
||||
tool, params = parse_tool_call(response)
|
||||
assert tool == "read_file"
|
||||
assert params == {"path": "test.txt"}
|
||||
|
||||
def test_json_alt_format(self):
|
||||
response = '{"name": "terminal", "arguments": {"command": "ls"}}'
|
||||
tool, params = parse_tool_call(response)
|
||||
assert tool == "terminal"
|
||||
assert params == {"command": "ls"}
|
||||
|
||||
def test_function_format(self):
|
||||
response = 'read_file(path="test.txt")'
|
||||
tool, params = parse_tool_call(response)
|
||||
assert tool == "read_file"
|
||||
assert params.get("path") == "test.txt"
|
||||
|
||||
def test_no_tool(self):
|
||||
response = "The answer is 4."
|
||||
tool, params = parse_tool_call(response)
|
||||
assert tool is None
|
||||
assert params is None
|
||||
|
||||
def test_embedded_json(self):
|
||||
response = "I will call the tool: {\"tool\": \"search\", \"params\": {\"query\": \"test\"}}"
|
||||
tool, params = parse_tool_call(response)
|
||||
assert tool == "search"
|
||||
assert params == {"query": "test"}
|
||||
|
||||
|
||||
class TestValidateLs:
|
||||
def test_simple_ls(self):
|
||||
assert validate_ls({"command": "ls"}) is True
|
||||
|
||||
def test_ls_flags(self):
|
||||
assert validate_ls({"command": "ls -la"}) is True
|
||||
assert validate_ls({"command": "ls -l"}) is True
|
||||
|
||||
def test_not_ls(self):
|
||||
assert validate_ls({"command": "pwd"}) is False
|
||||
assert validate_ls({"command": "cat file.txt"}) is False
|
||||
|
||||
|
||||
class TestToolCallTests:
|
||||
def test_all_tests_valid(self):
|
||||
assert len(TOOL_CALL_TESTS) == 10
|
||||
for test in TOOL_CALL_TESTS:
|
||||
assert test.name
|
||||
assert isinstance(test.tool_type, ToolCallType)
|
||||
assert test.prompt
|
||||
|
||||
def test_difficulty_distribution(self):
|
||||
difficulties = [t.difficulty for t in TOOL_CALL_TESTS]
|
||||
assert "easy" in difficulties
|
||||
assert "medium" in difficulties
|
||||
assert "hard" in difficulties
|
||||
|
||||
def test_type_distribution(self):
|
||||
types = [t.tool_type for t in TOOL_CALL_TESTS]
|
||||
assert ToolCallType.FILE_READ in types
|
||||
assert ToolCallType.TERMINAL_EXEC in types
|
||||
assert ToolCallType.WEB_SEARCH in types
|
||||
|
||||
|
||||
class TestTestResult:
|
||||
def test_creation(self):
|
||||
result = TestResult("test1", "file_read", True, 100.0, "response")
|
||||
assert result.test_name == "test1"
|
||||
assert result.passed is True
|
||||
assert result.latency_ms == 100.0
|
||||
|
||||
def test_to_dict(self):
|
||||
result = TestResult("test1", "file_read", True, 100.0, "response")
|
||||
d = result.__dict__
|
||||
assert "test_name" in d
|
||||
assert "passed" in d
|
||||
|
||||
|
||||
class TestBenchmarkResult:
|
||||
def test_creation(self):
|
||||
result = BenchmarkResult("model1", "ollama", "2026-01-01T00:00:00Z")
|
||||
assert result.model == "model1"
|
||||
assert result.results == []
|
||||
|
||||
def test_summary(self):
|
||||
result = BenchmarkResult("model1", "ollama", "2026-01-01T00:00:00Z")
|
||||
result.results = [
|
||||
TestResult("t1", "file_read", True, 100, "r1"),
|
||||
TestResult("t2", "terminal", False, 200, "r2"),
|
||||
]
|
||||
# Summary would be computed by run_all_tests
|
||||
assert len(result.results) == 2
|
||||
Reference in New Issue
Block a user