Compare commits
8 Commits
burn/54-17
...
feat/101-b
| Author | SHA1 | Date | |
|---|---|---|---|
| 590c4c7820 | |||
| 629be9714f | |||
| 3123d1fa8e | |||
| 3cd8750cbb | |||
| ef765bbd30 | |||
|
|
5f0d00f127 | ||
|
|
8affe79489 | ||
|
|
319f57780d |
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
build/
|
||||
*.pyc
|
||||
__pycache__/
|
||||
36
CMakeLists.txt
Normal file
36
CMakeLists.txt
Normal file
@@ -0,0 +1,36 @@
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
|
||||
project(turboquant LANGUAGES CXX)
|
||||
|
||||
option(TURBOQUANT_BUILD_TESTS "Build standalone TurboQuant validation tests" ON)
|
||||
|
||||
add_library(turboquant STATIC
|
||||
llama-turbo.cpp
|
||||
)
|
||||
|
||||
target_include_directories(turboquant PUBLIC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
)
|
||||
|
||||
target_compile_features(turboquant PUBLIC cxx_std_17)
|
||||
|
||||
if(MSVC)
|
||||
target_compile_options(turboquant PRIVATE /W4)
|
||||
else()
|
||||
target_compile_options(turboquant PRIVATE -Wall -Wextra -Wpedantic)
|
||||
endif()
|
||||
|
||||
if(TURBOQUANT_BUILD_TESTS)
|
||||
include(CTest)
|
||||
|
||||
add_executable(turboquant_roundtrip_test
|
||||
tests/roundtrip_test.cpp
|
||||
)
|
||||
target_link_libraries(turboquant_roundtrip_test PRIVATE turboquant)
|
||||
target_compile_features(turboquant_roundtrip_test PRIVATE cxx_std_17)
|
||||
|
||||
add_test(
|
||||
NAME turboquant_roundtrip
|
||||
COMMAND turboquant_roundtrip_test
|
||||
)
|
||||
endif()
|
||||
@@ -13,7 +13,7 @@ Unlock 64K-128K context on qwen3.5:27b within 32GB unified memory.
|
||||
A 27B model at 128K context with TurboQuant beats a 72B at Q2 with 8K context.
|
||||
|
||||
## Status
|
||||
See [issues](http://143.198.27.163:3000/Timmy_Foundation/turboquant/issues) for current progress.
|
||||
See [issues](https://forge.alexanderwhitestone.com/Timmy_Foundation/turboquant/issues) for current progress.
|
||||
|
||||
## Roles
|
||||
- **Strago:** Build spec author
|
||||
@@ -29,4 +29,4 @@ See [issues](http://143.198.27.163:3000/Timmy_Foundation/turboquant/issues) for
|
||||
- [rachittshah/mlx-turboquant](https://github.com/rachittshah/mlx-turboquant) — MLX fallback
|
||||
|
||||
## Docs
|
||||
- [BUILD-SPEC.md](BUILD-SPEC.md) — Full build specification (Strago, v2.2)
|
||||
- [Project Status](docs/PROJECT_STATUS.md) — Full project status and build specification
|
||||
|
||||
50
benchmarks/bonsai-tool-calling.md
Normal file
50
benchmarks/bonsai-tool-calling.md
Normal file
@@ -0,0 +1,50 @@
|
||||
# Tool Calling Viability: Bonsai 1-Bit Models
|
||||
|
||||
**Epic**: #99 (1-Bit Models + Edge)
|
||||
**Date**: TBD (run benchmarks/test_tool_calling.py to populate)
|
||||
|
||||
## Hypothesis
|
||||
|
||||
1-bit quantization destroys fine-grained reasoning. Tool calling (precise JSON output) may be impossible at Q1_0. But worth testing — the field is moving fast.
|
||||
|
||||
## Models to Test
|
||||
|
||||
| Model | Size | Quant | Source |
|
||||
|-------|------|-------|--------|
|
||||
| Bonsai-1.7B | 1.7B | Q1_0 | prism-ml/Bonsai-1.7B-gguf |
|
||||
| Bonsai-4B | 4B | Q1_0 | prism-ml/Bonsai-4B-gguf |
|
||||
| Bonsai-8B | 8B | Q1_0 | prism-ml/Bonsai-8B-gguf |
|
||||
|
||||
## Test Suite
|
||||
|
||||
| # | Test | Category | Description |
|
||||
|---|------|----------|-------------|
|
||||
| 1 | simple_file_read | Simple Tool Call | Read a file with an exact path |
|
||||
| 2 | terminal_command | Terminal Command | Execute a shell command |
|
||||
| 3 | web_search | Web Search | Search the web for a query |
|
||||
| 4 | multi_step_chain | Multi-Step | Chain: read -> analyze -> write |
|
||||
| 5 | nested_schema | Schema Parsing | Complex nested parameters |
|
||||
|
||||
## Results
|
||||
|
||||
> **Run**: `python3 benchmarks/test_tool_calling.py --model bonsai-1.7b --output benchmarks/bonsai-tool-calling.md`
|
||||
|
||||
| Test | Bonsai-1.7B | Bonsai-4B | Bonsai-8B |
|
||||
|------|-------------|-----------|-----------|
|
||||
| simple_file_read | TBD | TBD | TBD |
|
||||
| terminal_command | TBD | TBD | TBD |
|
||||
| web_search | TBD | TBD | TBD |
|
||||
| multi_step_chain | TBD | TBD | TBD |
|
||||
| nested_schema | TBD | TBD | TBD |
|
||||
|
||||
## Verdict
|
||||
|
||||
TBD — run the test suite to populate.
|
||||
|
||||
## Failure Modes (if any)
|
||||
|
||||
TBD — document specific failure patterns observed.
|
||||
|
||||
## Recommendations
|
||||
|
||||
TBD — based on results, recommend minimum viable quantization level for tool calling.
|
||||
435
benchmarks/test_tool_calling.py
Normal file
435
benchmarks/test_tool_calling.py
Normal file
@@ -0,0 +1,435 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tool Calling Viability Test for 1-Bit / Edge Models (Issue #101)
|
||||
|
||||
Tests whether Bonsai 1-bit models (or any small model) can produce
|
||||
valid tool calls via Ollama or llama-server API.
|
||||
|
||||
Test suite (5 categories):
|
||||
1. Simple tool call: file read with exact path
|
||||
2. Terminal command execution
|
||||
3. Web search
|
||||
4. Multi-step: read file -> analyze -> write result
|
||||
5. Schema parsing: complex nested parameters
|
||||
|
||||
Each test:
|
||||
- Sends a prompt requesting a tool call
|
||||
- Checks if the response contains valid JSON tool call syntax
|
||||
- Scores structural validity + semantic accuracy
|
||||
- Records latency and token count
|
||||
|
||||
Usage:
|
||||
python3 benchmarks/test_tool_calling.py --model bonsai-1.7b
|
||||
python3 benchmarks/test_tool_calling.py --model qwen3.5 --backend llama-server --url http://localhost:8080
|
||||
python3 benchmarks/test_tool_calling.py --model bonsai-1.7b --output benchmarks/bonsai-tool-calling.md
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
print("Error: pip install requests", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# ── Tool call test definitions ────────────────────────────────────────
|
||||
|
||||
TESTS = [
|
||||
{
|
||||
"id": "simple_file_read",
|
||||
"category": "Simple Tool Call",
|
||||
"description": "Read a file with an exact path",
|
||||
"prompt": (
|
||||
"You have access to a tool called read_file. "
|
||||
"Call it to read /etc/hostname. "
|
||||
"Respond ONLY with a JSON tool call in this exact format:\n"
|
||||
'{"name": "read_file", "arguments": {"path": "/etc/hostname"}}'
|
||||
),
|
||||
"validate": lambda resp: _has_json_tool_call(resp, "read_file", {"path": "/etc/hostname"}),
|
||||
"scoring": "structural", # Can we parse the JSON at all?
|
||||
},
|
||||
{
|
||||
"id": "terminal_command",
|
||||
"category": "Terminal Command",
|
||||
"description": "Execute a shell command",
|
||||
"prompt": (
|
||||
"You have access to a tool called terminal. "
|
||||
"Call it to run the command: echo hello world. "
|
||||
"Respond ONLY with a JSON tool call:\n"
|
||||
'{"name": "terminal", "arguments": {"command": "echo hello world"}}'
|
||||
),
|
||||
"validate": lambda resp: _has_json_tool_call(resp, "terminal", {"command": "echo hello world"}),
|
||||
"scoring": "structural",
|
||||
},
|
||||
{
|
||||
"id": "web_search",
|
||||
"category": "Web Search",
|
||||
"description": "Search the web for a query",
|
||||
"prompt": (
|
||||
"You have access to a tool called web_search. "
|
||||
"Search for: what is quantization in machine learning. "
|
||||
"Respond ONLY with a JSON tool call:\n"
|
||||
'{"name": "web_search", "arguments": {"query": "what is quantization in machine learning"}}'
|
||||
),
|
||||
"validate": lambda resp: _has_json_tool_call(resp, "web_search", {"query": "what is quantization in machine learning"}),
|
||||
"scoring": "structural",
|
||||
},
|
||||
{
|
||||
"id": "multi_step_chain",
|
||||
"category": "Multi-Step",
|
||||
"description": "Chain: read file -> analyze -> write result",
|
||||
"prompt": (
|
||||
"You have access to these tools: read_file, write_file.\n"
|
||||
"Task: Read /tmp/input.txt, count the words, then write the count to /tmp/count.txt.\n"
|
||||
"First, call read_file on /tmp/input.txt. "
|
||||
"Respond ONLY with the first tool call as JSON:\n"
|
||||
'{"name": "read_file", "arguments": {"path": "/tmp/input.txt"}}'
|
||||
),
|
||||
"validate": lambda resp: _has_json_tool_call(resp, "read_file", {"path": "/tmp/input.txt"}),
|
||||
"scoring": "structural",
|
||||
},
|
||||
{
|
||||
"id": "nested_schema",
|
||||
"category": "Schema Parsing",
|
||||
"description": "Complex nested parameters",
|
||||
"prompt": (
|
||||
"You have access to a tool called deploy_service. "
|
||||
"Deploy a service with:\n"
|
||||
'- name: "api-gateway"\n'
|
||||
'- replicas: 3\n'
|
||||
'- env: {"PORT": 8080, "NODE_ENV": "production"}\n'
|
||||
'- resources: {"cpu": "500m", "memory": "256Mi"}\n\n'
|
||||
"Respond ONLY with a JSON tool call:\n"
|
||||
'{"name": "deploy_service", "arguments": {"name": "api-gateway", "replicas": 3, '
|
||||
'"env": {"PORT": 8080, "NODE_ENV": "production"}, '
|
||||
'"resources": {"cpu": "500m", "memory": "256Mi"}}}'
|
||||
),
|
||||
"validate": lambda resp: _has_nested_tool_call(resp),
|
||||
"scoring": "semantic", # Needs correct nested structure
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ── Validation helpers ────────────────────────────────────────────────
|
||||
|
||||
def _extract_json(text: str) -> Optional[dict]:
|
||||
"""Try to extract a JSON object from text."""
|
||||
# Try direct parse
|
||||
text = text.strip()
|
||||
try:
|
||||
obj = json.loads(text)
|
||||
if isinstance(obj, dict):
|
||||
return obj
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try finding JSON in code blocks
|
||||
code_block = re.search(r"```(?:json)?\s*({.*?})\s*```", text, re.DOTALL)
|
||||
if code_block:
|
||||
try:
|
||||
return json.loads(code_block.group(1))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try finding any JSON object
|
||||
json_match = re.search(r"({[^{}]*(?:{[^{}]*}[^{}]*)*})", text)
|
||||
if json_match:
|
||||
try:
|
||||
return json.loads(json_match.group(1))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _has_json_tool_call(resp: str, expected_name: str, expected_args: dict) -> dict:
|
||||
"""Check if response contains a valid tool call with expected name and args."""
|
||||
obj = _extract_json(resp)
|
||||
if obj is None:
|
||||
return {"passed": False, "reason": "no JSON found in response"}
|
||||
|
||||
# Check name
|
||||
name = obj.get("name", obj.get("function", {}).get("name", ""))
|
||||
if name != expected_name:
|
||||
return {"passed": False, "reason": f"wrong tool name: {name!r}, expected {expected_name!r}"}
|
||||
|
||||
# Check arguments exist
|
||||
args = obj.get("arguments", obj.get("function", {}).get("arguments", obj.get("args", {})))
|
||||
if not args:
|
||||
return {"passed": False, "reason": "no arguments found"}
|
||||
|
||||
# Check key arguments match
|
||||
for key, val in expected_args.items():
|
||||
if key not in args:
|
||||
return {"passed": False, "reason": f"missing argument: {key}"}
|
||||
if args[key] != val:
|
||||
return {"passed": False, "reason": f"argument mismatch: {key}={args[key]!r}, expected {val!r}"}
|
||||
|
||||
return {"passed": True, "reason": "tool call valid", "parsed": obj}
|
||||
|
||||
|
||||
def _has_nested_tool_call(resp: str) -> dict:
|
||||
"""Check if response contains a valid tool call with nested parameters."""
|
||||
obj = _extract_json(resp)
|
||||
if obj is None:
|
||||
return {"passed": False, "reason": "no JSON found in response"}
|
||||
|
||||
name = obj.get("name", obj.get("function", {}).get("name", ""))
|
||||
if name != "deploy_service":
|
||||
return {"passed": False, "reason": f"wrong tool name: {name!r}"}
|
||||
|
||||
args = obj.get("arguments", obj.get("function", {}).get("arguments", obj.get("args", {})))
|
||||
if not args:
|
||||
return {"passed": False, "reason": "no arguments found"}
|
||||
|
||||
checks = {
|
||||
"name": str,
|
||||
"replicas": int,
|
||||
"env": dict,
|
||||
"resources": dict,
|
||||
}
|
||||
|
||||
for key, expected_type in checks.items():
|
||||
if key not in args:
|
||||
return {"passed": False, "reason": f"missing nested key: {key}"}
|
||||
if not isinstance(args[key], expected_type):
|
||||
return {"passed": False, "reason": f"{key} should be {expected_type.__name__}, got {type(args[key]).__name__}"}
|
||||
|
||||
# Check env has PORT
|
||||
env = args.get("env", {})
|
||||
if "PORT" not in env:
|
||||
return {"passed": False, "reason": "env missing PORT"}
|
||||
|
||||
return {"passed": True, "reason": "nested tool call valid", "parsed": obj}
|
||||
|
||||
|
||||
# ── Backend runners ───────────────────────────────────────────────────
|
||||
|
||||
def run_ollama(prompt: str, model: str, url: str, timeout: int = 120) -> dict:
|
||||
"""Run a prompt against Ollama."""
|
||||
api_url = f"{url.rstrip('/')}/api/generate"
|
||||
start = time.time()
|
||||
try:
|
||||
resp = requests.post(api_url, json={
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"num_predict": 256, "temperature": 0}
|
||||
}, timeout=timeout)
|
||||
elapsed = time.time() - start
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return {
|
||||
"response": data.get("response", ""),
|
||||
"latency_s": round(elapsed, 3),
|
||||
"tokens": data.get("eval_count", 0),
|
||||
"status": "success",
|
||||
}
|
||||
except Exception as e:
|
||||
return {"response": "", "latency_s": round(time.time() - start, 3), "tokens": 0, "status": "failed", "error": str(e)}
|
||||
|
||||
|
||||
def run_llama_server(prompt: str, model: str, url: str, timeout: int = 120) -> dict:
|
||||
"""Run a prompt against llama-server (OpenAI-compatible)."""
|
||||
api_url = f"{url.rstrip('/')}/v1/chat/completions"
|
||||
start = time.time()
|
||||
try:
|
||||
resp = requests.post(api_url, json={
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a tool-calling assistant. Respond ONLY with JSON tool calls."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
"max_tokens": 256,
|
||||
"temperature": 0,
|
||||
"stream": False,
|
||||
}, timeout=timeout)
|
||||
elapsed = time.time() - start
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
usage = data.get("usage", {})
|
||||
return {
|
||||
"response": content,
|
||||
"latency_s": round(elapsed, 3),
|
||||
"tokens": usage.get("completion_tokens", 0),
|
||||
"status": "success",
|
||||
}
|
||||
except Exception as e:
|
||||
return {"response": "", "latency_s": round(time.time() - start, 3), "tokens": 0, "status": "failed", "error": str(e)}
|
||||
|
||||
|
||||
# ── Main runner ───────────────────────────────────────────────────────
|
||||
|
||||
def run_tests(model: str, backend: str = "ollama", url: str = "http://localhost:11434",
|
||||
timeout: int = 120, verbose: bool = False) -> dict:
|
||||
"""Run the full tool calling test suite."""
|
||||
runner_fn = run_ollama if backend == "ollama" else run_llama_server
|
||||
|
||||
results = {
|
||||
"model": model,
|
||||
"backend": backend,
|
||||
"url": url,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"tests": [],
|
||||
"summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0},
|
||||
}
|
||||
|
||||
print(f"Testing tool calling on: {model} ({backend})\n")
|
||||
|
||||
for test in TESTS:
|
||||
print(f" [{test['id']}] {test['description']}...", end=" ", flush=True)
|
||||
|
||||
run_result = runner_fn(test["prompt"], model, url, timeout)
|
||||
|
||||
if run_result["status"] == "failed":
|
||||
result = {
|
||||
"id": test["id"],
|
||||
"category": test["category"],
|
||||
"description": test["description"],
|
||||
"passed": False,
|
||||
"reason": f"backend error: {run_result.get('error', 'unknown')}",
|
||||
"response": "",
|
||||
"latency_s": run_result["latency_s"],
|
||||
"tokens": 0,
|
||||
}
|
||||
results["summary"]["errors"] += 1
|
||||
print("ERROR")
|
||||
else:
|
||||
validation = test["validate"](run_result["response"])
|
||||
result = {
|
||||
"id": test["id"],
|
||||
"category": test["category"],
|
||||
"description": test["description"],
|
||||
"passed": validation["passed"],
|
||||
"reason": validation["reason"],
|
||||
"response": run_result["response"][:500],
|
||||
"latency_s": run_result["latency_s"],
|
||||
"tokens": run_result["tokens"],
|
||||
}
|
||||
if validation["passed"]:
|
||||
results["summary"]["passed"] += 1
|
||||
print("PASS")
|
||||
else:
|
||||
results["summary"]["failed"] += 1
|
||||
print(f"FAIL ({validation['reason']})")
|
||||
|
||||
if verbose:
|
||||
print(f" Response: {run_result['response'][:200]}")
|
||||
|
||||
results["summary"]["total"] += 1
|
||||
results["tests"].append(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def to_markdown(results: dict) -> str:
|
||||
"""Format test results as a markdown report."""
|
||||
lines = []
|
||||
lines.append(f"# Tool Calling Viability: {results['model']}")
|
||||
lines.append("")
|
||||
lines.append(f"**Date**: {results['timestamp']}")
|
||||
lines.append(f"**Backend**: {results['backend']} ({results['url']})")
|
||||
lines.append(f"**Model**: {results['model']}")
|
||||
lines.append("")
|
||||
|
||||
s = results["summary"]
|
||||
pass_rate = s["passed"] / s["total"] * 100 if s["total"] > 0 else 0
|
||||
lines.append(f"## Summary: {s['passed']}/{s['total']} passed ({pass_rate:.0f}%)")
|
||||
lines.append("")
|
||||
lines.append(f"| Metric | Value |")
|
||||
lines.append(f"|--------|-------|")
|
||||
lines.append(f"| Total tests | {s['total']} |")
|
||||
lines.append(f"| Passed | {s['passed']} |")
|
||||
lines.append(f"| Failed | {s['failed']} |")
|
||||
lines.append(f"| Errors | {s['errors']} |")
|
||||
lines.append("")
|
||||
|
||||
lines.append("## Results by Category")
|
||||
lines.append("")
|
||||
lines.append("| Test | Category | Result | Reason | Latency | Tokens |")
|
||||
lines.append("|------|----------|--------|--------|---------|--------|")
|
||||
for t in results["tests"]:
|
||||
icon = "PASS" if t["passed"] else ("ERROR" if "error" in t["reason"].lower() else "FAIL")
|
||||
lines.append(f"| {t['id']} | {t['category']} | {icon} | {t['reason']} | {t['latency_s']}s | {t['tokens']} |")
|
||||
lines.append("")
|
||||
|
||||
lines.append("## Verdict")
|
||||
lines.append("")
|
||||
if pass_rate == 100:
|
||||
lines.append("**FULLY VIABLE** — All tool calling patterns work. Ready for production edge deployment.")
|
||||
elif pass_rate >= 60:
|
||||
lines.append("**PARTIALLY VIABLE** — Basic tool calling works, complex patterns may fail. Consider for simple agents.")
|
||||
elif pass_rate >= 20:
|
||||
lines.append("**MARGINAL** — Only simplest tool calls work. Not recommended for production.")
|
||||
else:
|
||||
lines.append("**NOT VIABLE** — Tool calling is fundamentally broken at this quantization level.")
|
||||
lines.append("")
|
||||
|
||||
lines.append("## Failure Analysis")
|
||||
lines.append("")
|
||||
failed = [t for t in results["tests"] if not t["passed"]]
|
||||
if not failed:
|
||||
lines.append("No failures.")
|
||||
else:
|
||||
for t in failed:
|
||||
lines.append(f"### {t['id']}")
|
||||
lines.append(f"- **Category**: {t['category']}")
|
||||
lines.append(f"- **Failure**: {t['reason']}")
|
||||
lines.append(f"- **Response** (first 300 chars): `{t['response'][:300]}`")
|
||||
lines.append("")
|
||||
lines.append("")
|
||||
|
||||
lines.append("## Recommendations")
|
||||
lines.append("")
|
||||
if pass_rate >= 80:
|
||||
lines.append("- Deploy for simple single-tool-call workflows")
|
||||
lines.append("- Add retry logic for multi-step chains")
|
||||
lines.append("- Consider prompt engineering to improve nested schema parsing")
|
||||
elif pass_rate >= 40:
|
||||
lines.append("- Use for keyword/rule-based tool routing only")
|
||||
lines.append("- Do NOT use for complex multi-step workflows")
|
||||
lines.append("- Consider a larger model (Q4 quantized) as fallback")
|
||||
else:
|
||||
lines.append("- 1-bit quantization is too lossy for tool calling")
|
||||
lines.append("- Use Q4_0 as minimum viable quantization for tool use")
|
||||
lines.append("- Reserve 1-bit models for text generation only")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Tool Calling Viability Test for Edge Models")
|
||||
parser.add_argument("--model", "-m", required=True, help="Model name")
|
||||
parser.add_argument("--backend", "-b", default="ollama", choices=["ollama", "llama-server"])
|
||||
parser.add_argument("--url", "-u", default="http://localhost:11434", help="Backend URL")
|
||||
parser.add_argument("--timeout", "-t", type=int, default=120, help="Timeout per test (seconds)")
|
||||
parser.add_argument("--output", "-o", help="Output markdown file path")
|
||||
parser.add_argument("--json", action="store_true", help="JSON output")
|
||||
parser.add_argument("--verbose", "-v", action="store_true", help="Show full responses")
|
||||
args = parser.parse_args()
|
||||
|
||||
results = run_tests(args.model, args.backend, args.url, args.timeout, args.verbose)
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(results, indent=2))
|
||||
else:
|
||||
md = to_markdown(results)
|
||||
if args.output:
|
||||
with open(args.output, "w") as f:
|
||||
f.write(md)
|
||||
print(f"\nReport written to: {args.output}")
|
||||
else:
|
||||
print("\n" + md)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -135,7 +135,5 @@ llama-server -m model.gguf --port 8081 -ctk q8_0 -ctv turbo4 -c 131072
|
||||
|
||||
## References
|
||||
|
||||
- [TurboQuant Build Spec](../BUILD-SPEC.md)
|
||||
- [Phase 1 Report](../PHASE1-REPORT.md)
|
||||
- [Full Knowledge Transfer](../FULL-REPORT.md)
|
||||
- [Project Status](../docs/PROJECT_STATUS.md)
|
||||
- [llama.cpp TurboQuant Fork](https://github.com/TheTom/llama-cpp-turboquant)
|
||||
|
||||
104
tests/roundtrip_test.cpp
Normal file
104
tests/roundtrip_test.cpp
Normal file
@@ -0,0 +1,104 @@
|
||||
#include "llama-turbo.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr int kDim = 128;
|
||||
constexpr float kCosineThreshold = 0.99f;
|
||||
constexpr float kZeroTolerance = 1.0e-6f;
|
||||
|
||||
[[nodiscard]] bool all_finite(const std::vector<float> & values) {
|
||||
for (float value : values) {
|
||||
if (!std::isfinite(value)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
[[nodiscard]] float max_abs(const std::vector<float> & values) {
|
||||
float best = 0.0f;
|
||||
for (float value : values) {
|
||||
best = std::max(best, std::fabs(value));
|
||||
}
|
||||
return best;
|
||||
}
|
||||
|
||||
[[nodiscard]] float cosine_similarity(const std::vector<float> & lhs, const std::vector<float> & rhs) {
|
||||
float dot = 0.0f;
|
||||
float lhs_norm = 0.0f;
|
||||
float rhs_norm = 0.0f;
|
||||
for (int i = 0; i < kDim; ++i) {
|
||||
dot += lhs[i] * rhs[i];
|
||||
lhs_norm += lhs[i] * lhs[i];
|
||||
rhs_norm += rhs[i] * rhs[i];
|
||||
}
|
||||
|
||||
const float denom = std::sqrt(lhs_norm) * std::sqrt(rhs_norm);
|
||||
return denom == 0.0f ? 1.0f : dot / denom;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<float> roundtrip(const std::vector<float> & input, float & norm_out) {
|
||||
std::vector<uint8_t> packed(kDim / 2, 0);
|
||||
norm_out = -1.0f;
|
||||
polar_quant_encode_turbo4(input.data(), packed.data(), &norm_out, kDim);
|
||||
|
||||
std::vector<float> decoded(kDim, 0.0f);
|
||||
polar_quant_decode_turbo4(packed.data(), decoded.data(), norm_out, kDim);
|
||||
return decoded;
|
||||
}
|
||||
|
||||
void require(bool condition, const std::string & message) {
|
||||
if (!condition) {
|
||||
throw std::runtime_error(message);
|
||||
}
|
||||
}
|
||||
|
||||
void test_zero_vector_roundtrip() {
|
||||
std::vector<float> zeros(kDim, 0.0f);
|
||||
float norm = -1.0f;
|
||||
const auto decoded = roundtrip(zeros, norm);
|
||||
|
||||
require(norm == 0.0f, "zero vector should encode with zero norm");
|
||||
require(all_finite(decoded), "zero vector decode produced non-finite values");
|
||||
require(max_abs(decoded) <= kZeroTolerance, "zero vector decode should remain near zero");
|
||||
}
|
||||
|
||||
void test_gaussian_roundtrip_quality() {
|
||||
std::mt19937 rng(12345);
|
||||
std::normal_distribution<float> dist(0.0f, 1.0f);
|
||||
|
||||
std::vector<float> input(kDim, 0.0f);
|
||||
for (float & value : input) {
|
||||
value = dist(rng);
|
||||
}
|
||||
|
||||
float norm = -1.0f;
|
||||
const auto decoded = roundtrip(input, norm);
|
||||
|
||||
require(norm > 0.0f, "random vector should encode with positive norm");
|
||||
require(all_finite(decoded), "random vector decode produced non-finite values");
|
||||
|
||||
const float cosine = cosine_similarity(input, decoded);
|
||||
require(cosine >= kCosineThreshold, "roundtrip cosine similarity below threshold");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main() {
|
||||
try {
|
||||
test_zero_vector_roundtrip();
|
||||
test_gaussian_roundtrip_quality();
|
||||
std::cout << "PASS: turboquant standalone roundtrip tests\n";
|
||||
return 0;
|
||||
} catch (const std::exception & exc) {
|
||||
std::cerr << "FAIL: " << exc.what() << '\n';
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
@@ -1,263 +0,0 @@
|
||||
/*
|
||||
* Unit tests for PolarQuant Turbo4
|
||||
*
|
||||
* Compile: gcc -o test_polar_quant test_polar_quant.c llama-turbo.cpp -lm
|
||||
* Run: ./test_polar_quant
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
#include <assert.h>
|
||||
#include "llama-turbo.h"
|
||||
|
||||
#define TEST_ASSERT(cond, msg) do { if (!(cond)) { fprintf(stderr, "FAIL: %s (line %d)\n", msg, __LINE__); failures++; } else { passes++; } } while(0)
|
||||
|
||||
static int passes = 0;
|
||||
static int failures = 0;
|
||||
|
||||
// Test encode/decode roundtrip
|
||||
void test_roundtrip() {
|
||||
printf("Testing encode/decode roundtrip...\n");
|
||||
|
||||
const int d = 128;
|
||||
float src[128];
|
||||
float dst[128];
|
||||
uint8_t packed[64];
|
||||
float norm;
|
||||
|
||||
// Generate test data
|
||||
for (int i = 0; i < d; i++) {
|
||||
src[i] = sinf(i * 0.1f);
|
||||
}
|
||||
|
||||
// Encode
|
||||
polar_quant_encode_turbo4(src, packed, &norm, d);
|
||||
|
||||
// Decode
|
||||
polar_quant_decode_turbo4(packed, dst, norm, d);
|
||||
|
||||
// Check reconstruction error
|
||||
float orig_norm = 0;
|
||||
float diff_norm = 0;
|
||||
for (int i = 0; i < d; i++) {
|
||||
orig_norm += src[i] * src[i];
|
||||
float diff = src[i] - dst[i];
|
||||
diff_norm += diff * diff;
|
||||
}
|
||||
orig_norm = sqrtf(orig_norm);
|
||||
diff_norm = sqrtf(diff_norm);
|
||||
|
||||
float rel_error = diff_norm / (orig_norm + 1e-9f);
|
||||
TEST_ASSERT(rel_error < 0.5f, "Roundtrip relative error too high");
|
||||
|
||||
// Check packed size
|
||||
TEST_ASSERT(norm > 0, "Norm should be positive");
|
||||
}
|
||||
|
||||
// Test zero vector
|
||||
void test_zero_vector() {
|
||||
printf("Testing zero vector...\n");
|
||||
|
||||
const int d = 128;
|
||||
float src[128] = {0};
|
||||
float dst[128];
|
||||
uint8_t packed[64];
|
||||
float norm;
|
||||
|
||||
polar_quant_encode_turbo4(src, packed, &norm, d);
|
||||
polar_quant_decode_turbo4(packed, dst, norm, d);
|
||||
|
||||
// Zero vector: norm should be 0 or very small
|
||||
TEST_ASSERT(norm < 0.1f, "Zero vector norm should be small");
|
||||
}
|
||||
|
||||
// Test inner product preservation
|
||||
void test_inner_product() {
|
||||
printf("Testing inner product preservation...\n");
|
||||
|
||||
const int d = 128;
|
||||
float q[128], k[128], k_recon[128];
|
||||
uint8_t k_packed[64];
|
||||
float k_norm;
|
||||
|
||||
// Generate test vectors
|
||||
for (int i = 0; i < d; i++) {
|
||||
q[i] = cosf(i * 0.1f);
|
||||
k[i] = sinf(i * 0.15f);
|
||||
}
|
||||
|
||||
// Original inner product
|
||||
float orig_ip = 0;
|
||||
for (int i = 0; i < d; i++) {
|
||||
orig_ip += q[i] * k[i];
|
||||
}
|
||||
|
||||
// Compress k
|
||||
polar_quant_encode_turbo4(k, k_packed, &k_norm, d);
|
||||
polar_quant_decode_turbo4(k_packed, k_recon, k_norm, d);
|
||||
|
||||
// Compressed inner product
|
||||
float comp_ip = 0;
|
||||
for (int i = 0; i < d; i++) {
|
||||
comp_ip += q[i] * k_recon[i];
|
||||
}
|
||||
|
||||
float rel_error = fabsf(orig_ip - comp_ip) / (fabsf(orig_ip) + 1e-9f);
|
||||
TEST_ASSERT(rel_error < 0.5f, "Inner product preservation");
|
||||
}
|
||||
|
||||
// Test WHT orthogonality
|
||||
void test_wht_orthogonality() {
|
||||
printf("Testing WHT orthogonality...\n");
|
||||
|
||||
const int d = 64;
|
||||
float src[64], result[64];
|
||||
|
||||
for (int i = 0; i < d; i++) {
|
||||
src[i] = (float)i;
|
||||
result[i] = src[i];
|
||||
}
|
||||
|
||||
// Compute norm before
|
||||
float norm_before = 0;
|
||||
for (int i = 0; i < d; i++) {
|
||||
norm_before += src[i] * src[i];
|
||||
}
|
||||
norm_before = sqrtf(norm_before);
|
||||
|
||||
// Apply encode (which includes WHT)
|
||||
uint8_t packed[32];
|
||||
float enc_norm;
|
||||
polar_quant_encode_turbo4(result, packed, &enc_norm, d);
|
||||
|
||||
// Decode (which includes inverse WHT)
|
||||
float decoded[64];
|
||||
polar_quant_decode_turbo4(packed, decoded, enc_norm, d);
|
||||
|
||||
// Compute norm after
|
||||
float norm_after = 0;
|
||||
for (int i = 0; i < d; i++) {
|
||||
norm_after += decoded[i] * decoded[i];
|
||||
}
|
||||
norm_after = sqrtf(norm_after);
|
||||
|
||||
// Norms should be similar (within quantization error)
|
||||
float ratio = norm_after / (norm_before + 1e-9f);
|
||||
TEST_ASSERT(ratio > 0.5f && ratio < 2.0f, "Norm preservation through WHT");
|
||||
}
|
||||
|
||||
// Test bit packing
|
||||
void test_bit_packing() {
|
||||
printf("Testing bit packing...\n");
|
||||
|
||||
const int d = 128;
|
||||
uint8_t packed[64] = {0};
|
||||
|
||||
// Pack alternating 0 and 15 (max value)
|
||||
for (int i = 0; i < d; i++) {
|
||||
int idx = (i % 2 == 0) ? 0 : 15;
|
||||
if (i % 2 == 0) {
|
||||
packed[i / 2] = idx;
|
||||
} else {
|
||||
packed[i / 2] |= idx << 4;
|
||||
}
|
||||
}
|
||||
|
||||
// Unpack and verify
|
||||
for (int i = 0; i < d; i++) {
|
||||
int expected = (i % 2 == 0) ? 0 : 15;
|
||||
int actual;
|
||||
if (i % 2 == 0) {
|
||||
actual = packed[i / 2] & 0x0F;
|
||||
} else {
|
||||
actual = packed[i / 2] >> 4;
|
||||
}
|
||||
|
||||
char msg[64];
|
||||
sprintf(msg, "Bit packing at index %d", i);
|
||||
TEST_ASSERT(actual == expected, msg);
|
||||
}
|
||||
}
|
||||
|
||||
// Test various dimensions
|
||||
void test_dimensions() {
|
||||
printf("Testing various dimensions...\n");
|
||||
|
||||
int dims[] = {16, 32, 64, 128, 256};
|
||||
int num_dims = sizeof(dims) / sizeof(dims[0]);
|
||||
|
||||
for (int d_idx = 0; d_idx < num_dims; d_idx++) {
|
||||
int d = dims[d_idx];
|
||||
float* src = malloc(d * sizeof(float));
|
||||
float* dst = malloc(d * sizeof(float));
|
||||
uint8_t* packed = malloc(d / 2);
|
||||
float norm;
|
||||
|
||||
// Generate test data
|
||||
for (int i = 0; i < d; i++) {
|
||||
src[i] = sinf(i * 0.1f);
|
||||
}
|
||||
|
||||
// Encode/decode
|
||||
polar_quant_encode_turbo4(src, packed, &norm, d);
|
||||
polar_quant_decode_turbo4(packed, dst, norm, d);
|
||||
|
||||
// Check basic sanity
|
||||
float orig_energy = 0, recon_energy = 0;
|
||||
for (int i = 0; i < d; i++) {
|
||||
orig_energy += src[i] * src[i];
|
||||
recon_energy += dst[i] * dst[i];
|
||||
}
|
||||
|
||||
float ratio = recon_energy / (orig_energy + 1e-9f);
|
||||
|
||||
char msg[64];
|
||||
sprintf(msg, "Dimension %d energy ratio", d);
|
||||
TEST_ASSERT(ratio > 0.5f && ratio < 2.0f, msg);
|
||||
|
||||
free(src);
|
||||
free(dst);
|
||||
free(packed);
|
||||
}
|
||||
}
|
||||
|
||||
// Test memory bounds
|
||||
void test_memory_bounds() {
|
||||
printf("Testing memory bounds...\n");
|
||||
|
||||
// Test with max 4-bit value everywhere
|
||||
const int d = 256;
|
||||
float src[256];
|
||||
|
||||
for (int i = 0; i < d; i++) {
|
||||
src[i] = 0.35f; // Near max centroid
|
||||
}
|
||||
|
||||
uint8_t packed[128];
|
||||
float norm;
|
||||
|
||||
// Should not crash
|
||||
polar_quant_encode_turbo4(src, packed, &norm, d);
|
||||
|
||||
TEST_ASSERT(1, "Memory bounds check passed");
|
||||
}
|
||||
|
||||
int main() {
|
||||
printf("=== PolarQuant Turbo4 Unit Tests ===\n\n");
|
||||
|
||||
test_roundtrip();
|
||||
test_zero_vector();
|
||||
test_inner_product();
|
||||
test_wht_orthogonality();
|
||||
test_bit_packing();
|
||||
test_dimensions();
|
||||
test_memory_bounds();
|
||||
|
||||
printf("\n=== Results ===\n");
|
||||
printf("Passed: %d\n", passes);
|
||||
printf("Failed: %d\n", failures);
|
||||
|
||||
return failures > 0 ? 1 : 0;
|
||||
}
|
||||
@@ -1,410 +0,0 @@
|
||||
"""
|
||||
Unit tests for PolarQuant Turbo4 encode/decode.
|
||||
|
||||
Tests the algorithm logic using Python reference implementations
|
||||
that mirror the C++/Metal code.
|
||||
"""
|
||||
|
||||
import math
|
||||
import pytest
|
||||
import struct
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
# Lloyd-Max Centroids for N(0, 1/d) where d=128
|
||||
# 4-bit (16 levels) - copied from llama-turbo.cpp
|
||||
TURBO4_CENTROIDS = [
|
||||
-0.2154, -0.1523, -0.1121, -0.0812,
|
||||
-0.0554, -0.0321, -0.0105, 0.0105,
|
||||
0.0321, 0.0554, 0.0812, 0.1121,
|
||||
0.1523, 0.2154, 0.2800, 0.3500
|
||||
]
|
||||
|
||||
|
||||
def fwht(a: List[float]) -> List[float]:
|
||||
"""Fast Walsh-Hadamard Transform (Python reference)."""
|
||||
n = len(a)
|
||||
result = a.copy()
|
||||
|
||||
h = 1
|
||||
while h < n:
|
||||
for i in range(0, n, h * 2):
|
||||
for j in range(i, i + h):
|
||||
x = result[j]
|
||||
y = result[j + h]
|
||||
result[j] = x + y
|
||||
result[j + h] = x - y
|
||||
h <<= 1
|
||||
|
||||
# Normalize
|
||||
scale = 1.0 / math.sqrt(n)
|
||||
for i in range(n):
|
||||
result[i] *= scale
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def polar_quant_encode(src: List[float]) -> Tuple[bytes, float]:
|
||||
"""
|
||||
PolarQuant Turbo4 Encode (Python reference).
|
||||
|
||||
Returns:
|
||||
Tuple of (packed_bytes, norm)
|
||||
"""
|
||||
d = len(src)
|
||||
assert d % 2 == 0, "Dimension must be even"
|
||||
|
||||
# Apply WHT
|
||||
rotated = fwht(src)
|
||||
|
||||
# Calculate L2 norm
|
||||
norm = math.sqrt(sum(x * x for x in rotated))
|
||||
|
||||
# Quantize components
|
||||
inv_norm = 1.0 / (norm + 1e-9)
|
||||
indices = []
|
||||
|
||||
for val in rotated:
|
||||
val_normalized = val * inv_norm
|
||||
|
||||
# Find nearest centroid
|
||||
best_idx = 0
|
||||
min_dist = abs(val_normalized - TURBO4_CENTROIDS[0])
|
||||
for j in range(1, 16):
|
||||
dist = abs(val_normalized - TURBO4_CENTROIDS[j])
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
best_idx = j
|
||||
|
||||
indices.append(best_idx)
|
||||
|
||||
# Pack 4-bit indices into bytes
|
||||
packed = bytearray(d // 2)
|
||||
for i in range(d):
|
||||
if i % 2 == 0:
|
||||
packed[i // 2] = indices[i]
|
||||
else:
|
||||
packed[i // 2] |= indices[i] << 4
|
||||
|
||||
return bytes(packed), norm
|
||||
|
||||
|
||||
def polar_quant_decode(src: bytes, norm: float, d: int) -> List[float]:
|
||||
"""
|
||||
PolarQuant Turbo4 Decode (Python reference).
|
||||
|
||||
Returns:
|
||||
Reconstructed float array
|
||||
"""
|
||||
# Unpack 4-bit indices
|
||||
values = []
|
||||
for i in range(d):
|
||||
if i % 2 == 0:
|
||||
idx = src[i // 2] & 0x0F
|
||||
else:
|
||||
idx = src[i // 2] >> 4
|
||||
values.append(TURBO4_CENTROIDS[idx] * norm)
|
||||
|
||||
# Apply inverse WHT (same as forward for orthogonal)
|
||||
return fwht(values)
|
||||
|
||||
|
||||
class TestEncodeDecodeRoundtrip:
|
||||
"""Test that decode(encode(x)) ≈ x."""
|
||||
|
||||
def test_zero_vector(self):
|
||||
"""Zero vector should encode/decode to zero."""
|
||||
d = 128
|
||||
src = [0.0] * d
|
||||
packed, norm = polar_quant_encode(src)
|
||||
reconstructed = polar_quant_decode(packed, norm, d)
|
||||
|
||||
# Zero has no information, reconstruction will be near-zero
|
||||
for i in range(d):
|
||||
assert abs(reconstructed[i]) < 0.1, f"Index {i}: {reconstructed[i]}"
|
||||
|
||||
def test_unit_vector(self):
|
||||
"""Unit vector should roundtrip reasonably."""
|
||||
d = 128
|
||||
src = [0.0] * d
|
||||
src[0] = 1.0 # Unit vector
|
||||
|
||||
packed, norm = polar_quant_encode(src)
|
||||
reconstructed = polar_quant_decode(packed, norm, d)
|
||||
|
||||
# Check shape is preserved (first element dominant)
|
||||
max_val = max(reconstructed)
|
||||
max_idx = reconstructed.index(max_val)
|
||||
assert max_idx == 0, f"Peak at index {max_idx}, expected 0"
|
||||
|
||||
def test_random_vectors(self):
|
||||
"""Random vectors should roundtrip with bounded error."""
|
||||
import random
|
||||
random.seed(42)
|
||||
|
||||
d = 128
|
||||
errors = []
|
||||
|
||||
for trial in range(10):
|
||||
src = [random.gauss(0, 0.1) for _ in range(d)]
|
||||
packed, norm = polar_quant_encode(src)
|
||||
reconstructed = polar_quant_decode(packed, norm, d)
|
||||
|
||||
# Compute relative error
|
||||
orig_norm = math.sqrt(sum(x * x for x in src))
|
||||
diff_norm = math.sqrt(sum((a - b) ** 2 for a, b in zip(src, reconstructed)))
|
||||
rel_error = diff_norm / (orig_norm + 1e-9)
|
||||
errors.append(rel_error)
|
||||
|
||||
avg_error = sum(errors) / len(errors)
|
||||
assert avg_error < 0.5, f"Average relative error {avg_error} too high"
|
||||
|
||||
def test_various_dimensions(self):
|
||||
"""Test with different power-of-2 dimensions."""
|
||||
for d in [16, 32, 64, 128, 256]:
|
||||
src = [math.sin(i * 0.1) for i in range(d)]
|
||||
packed, norm = polar_quant_encode(src)
|
||||
reconstructed = polar_quant_decode(packed, norm, d)
|
||||
|
||||
# Basic sanity: reconstructed should have similar magnitude
|
||||
# 4-bit quantization loses significant energy, especially at small dims
|
||||
orig_energy = sum(x * x for x in src)
|
||||
recon_energy = sum(x * x for x in reconstructed)
|
||||
ratio = recon_energy / (orig_energy + 1e-9)
|
||||
assert 0.1 < ratio < 10.0, f"d={d}: energy ratio {ratio}"
|
||||
|
||||
|
||||
class TestInnerProductPreservation:
|
||||
"""Test that Q·K ≈ Q·dequant(quant(K))."""
|
||||
|
||||
def test_inner_product_preserved(self):
|
||||
"""Inner products should be approximately preserved."""
|
||||
import random
|
||||
random.seed(123)
|
||||
|
||||
d = 128
|
||||
|
||||
# Generate two random vectors
|
||||
q = [random.gauss(0, 0.1) for _ in range(d)]
|
||||
k = [random.gauss(0, 0.1) for _ in range(d)]
|
||||
|
||||
# Original inner product
|
||||
orig_ip = sum(a * b for a, b in zip(q, k))
|
||||
|
||||
# Compress k
|
||||
k_packed, k_norm = polar_quant_encode(k)
|
||||
k_reconstructed = polar_quant_decode(k_packed, k_norm, d)
|
||||
|
||||
# Compressed inner product
|
||||
comp_ip = sum(a * b for a, b in zip(q, k_reconstructed))
|
||||
|
||||
# Check relative error
|
||||
rel_error = abs(orig_ip - comp_ip) / (abs(orig_ip) + 1e-9)
|
||||
# 4-bit quantization has significant error, allow up to 100% error
|
||||
assert rel_error < 1.0, f"Inner product error {rel_error} too high"
|
||||
|
||||
def test_self_inner_product(self):
|
||||
"""Self inner product should be close to original."""
|
||||
d = 128
|
||||
x = [math.cos(i * 0.2) for i in range(d)]
|
||||
|
||||
orig_self_ip = sum(a * a for a in x)
|
||||
|
||||
packed, norm = polar_quant_encode(x)
|
||||
reconstructed = polar_quant_decode(packed, norm, d)
|
||||
|
||||
comp_self_ip = sum(a * a for a in reconstructed)
|
||||
|
||||
# Self inner product is energy, should be roughly preserved
|
||||
# 4-bit quantization has significant error
|
||||
ratio = comp_self_ip / (orig_self_ip + 1e-9)
|
||||
assert 0.3 < ratio < 3.0, f"Self inner product ratio {ratio}"
|
||||
|
||||
|
||||
class TestWHTOrthogonality:
|
||||
"""Test that WHT is orthogonal (WHT^T · WHT = I)."""
|
||||
|
||||
def test_wht_orthogonality(self):
|
||||
"""WHT should be orthogonal transformation."""
|
||||
d = 128
|
||||
|
||||
# Create identity-like test: apply WHT, then apply again
|
||||
# For orthogonal matrix, A^T A = I, so applying twice should scale
|
||||
src = [float(i) for i in range(d)]
|
||||
|
||||
# First WHT
|
||||
result1 = fwht(src)
|
||||
|
||||
# Second WHT (should be proportional to original for orthogonal)
|
||||
result2 = fwht(result1)
|
||||
|
||||
# result2 should be proportional to src
|
||||
# For Walsh-Hadamard, WHT(WHT(x)) = x * (1/sqrt(d))^2 * d = x
|
||||
# Actually: WHT is self-inverse up to scaling
|
||||
for i in range(d):
|
||||
ratio = result2[i] / (src[i] + 1e-9) if src[i] != 0 else result2[i]
|
||||
# Should be close to 1.0 (or 0 if src[i] is 0)
|
||||
if abs(src[i]) > 0.01:
|
||||
assert abs(ratio - 1.0) < 0.1, f"Index {i}: ratio {ratio}"
|
||||
|
||||
def test_wht_preserves_norm(self):
|
||||
"""WHT should preserve L2 norm."""
|
||||
d = 128
|
||||
src = [math.sin(i) for i in range(d)]
|
||||
|
||||
orig_norm = math.sqrt(sum(x * x for x in src))
|
||||
|
||||
result = fwht(src)
|
||||
result_norm = math.sqrt(sum(x * x for x in result))
|
||||
|
||||
ratio = result_norm / orig_norm
|
||||
assert abs(ratio - 1.0) < 0.01, f"Norm ratio {ratio}, expected 1.0"
|
||||
|
||||
def test_wht_linearity(self):
|
||||
"""WHT should be linear: WHT(a+b) = WHT(a) + WHT(b)."""
|
||||
d = 64
|
||||
a = [float(i) * 0.1 for i in range(d)]
|
||||
b = [float(i) * 0.2 for i in range(d)]
|
||||
|
||||
# WHT(a + b)
|
||||
a_plus_b = [x + y for x, y in zip(a, b)]
|
||||
wht_sum = fwht(a_plus_b)
|
||||
|
||||
# WHT(a) + WHT(b)
|
||||
wht_a = fwht(a)
|
||||
wht_b = fwht(b)
|
||||
sum_wht = [x + y for x, y in zip(wht_a, wht_b)]
|
||||
|
||||
# Should be equal
|
||||
for i in range(d):
|
||||
assert abs(wht_sum[i] - sum_wht[i]) < 1e-6, f"Linearity failed at {i}"
|
||||
|
||||
|
||||
class TestCodebookCorrectness:
|
||||
"""Test that centroids match Lloyd-Max for N(0, 1/128)."""
|
||||
|
||||
def test_centroids_extremes(self):
|
||||
"""Extreme centroids should cover tails of distribution."""
|
||||
min_c = min(TURBO4_CENTROIDS)
|
||||
max_c = max(TURBO4_CENTROIDS)
|
||||
# Should have reasonable range
|
||||
assert min_c < -0.2, f"Min centroid {min_c} should be < -0.2"
|
||||
assert max_c > 0.2, f"Max centroid {max_c} should be > 0.2"
|
||||
|
||||
def test_centroids_ordered(self):
|
||||
"""Centroids should be strictly increasing."""
|
||||
for i in range(len(TURBO4_CENTROIDS) - 1):
|
||||
assert TURBO4_CENTROIDS[i] < TURBO4_CENTROIDS[i + 1], f"Centroids not ordered at index {i}"
|
||||
|
||||
def test_centroids_cover_range(self):
|
||||
"""Centroids should cover reasonable range for N(0, 1/128)."""
|
||||
# For N(0, 1/128), std = 1/sqrt(128) ≈ 0.088
|
||||
# Centroids should cover roughly [-3*std, 3*std]
|
||||
min_c = min(TURBO4_CENTROIDS)
|
||||
max_c = max(TURBO4_CENTROIDS)
|
||||
|
||||
std = 1.0 / math.sqrt(128) # ≈ 0.088
|
||||
|
||||
assert min_c < -2 * std, f"Min centroid {min_c} should be < {-2*std}"
|
||||
assert max_c > 2 * std, f"Max centroid {max_c} should be > {2*std}"
|
||||
|
||||
def test_centroids_count(self):
|
||||
"""Should have exactly 16 centroids for 4-bit quantization."""
|
||||
assert len(TURBO4_CENTROIDS) == 16, f"Expected 16 centroids, got {len(TURBO4_CENTROIDS)}"
|
||||
|
||||
|
||||
class TestBitPacking:
|
||||
"""Test bit packing/unpacking correctness."""
|
||||
|
||||
def test_packing_roundtrip(self):
|
||||
"""Packing and unpacking should be lossless for 4-bit values."""
|
||||
d = 128
|
||||
|
||||
# Create test indices (0-15)
|
||||
indices = [i % 16 for i in range(d)]
|
||||
|
||||
# Pack
|
||||
packed = bytearray(d // 2)
|
||||
for i in range(d):
|
||||
if i % 2 == 0:
|
||||
packed[i // 2] = indices[i]
|
||||
else:
|
||||
packed[i // 2] |= indices[i] << 4
|
||||
|
||||
# Unpack
|
||||
unpacked = []
|
||||
for i in range(d):
|
||||
if i % 2 == 0:
|
||||
idx = packed[i // 2] & 0x0F
|
||||
else:
|
||||
idx = packed[i // 2] >> 4
|
||||
unpacked.append(idx)
|
||||
|
||||
assert unpacked == indices, "Packing/unpacking mismatch"
|
||||
|
||||
def test_packing_bounds(self):
|
||||
"""Packed values should fit in 4 bits (0-15)."""
|
||||
d = 128
|
||||
indices = [15] * d # Max value
|
||||
|
||||
packed = bytearray(d // 2)
|
||||
for i in range(d):
|
||||
if i % 2 == 0:
|
||||
packed[i // 2] = indices[i]
|
||||
else:
|
||||
packed[i // 2] |= indices[i] << 4
|
||||
|
||||
# Each byte should have both nibbles = 15
|
||||
for byte in packed:
|
||||
assert byte == 0xFF, f"Expected 0xFF, got {hex(byte)}"
|
||||
|
||||
def test_no_overflow(self):
|
||||
"""Packing should not overflow with valid 4-bit values."""
|
||||
d = 256 # Larger dimension
|
||||
|
||||
# All max values
|
||||
indices = [15] * d
|
||||
|
||||
packed = bytearray(d // 2)
|
||||
for i in range(d):
|
||||
if i % 2 == 0:
|
||||
packed[i // 2] = indices[i]
|
||||
else:
|
||||
packed[i // 2] |= indices[i] << 4
|
||||
|
||||
# Should not crash or produce invalid values
|
||||
assert len(packed) == d // 2
|
||||
|
||||
|
||||
class TestMemoryBounds:
|
||||
"""Test memory safety with various dimensions."""
|
||||
|
||||
def test_minimum_dimension(self):
|
||||
"""Should work with minimum dimension (2)."""
|
||||
d = 2
|
||||
src = [1.0, 0.5]
|
||||
packed, norm = polar_quant_encode(src)
|
||||
assert len(packed) == d // 2
|
||||
reconstructed = polar_quant_decode(packed, norm, d)
|
||||
assert len(reconstructed) == d
|
||||
|
||||
def test_large_dimension(self):
|
||||
"""Should work with large dimensions."""
|
||||
d = 1024
|
||||
src = [math.sin(i * 0.01) for i in range(d)]
|
||||
packed, norm = polar_quant_encode(src)
|
||||
assert len(packed) == d // 2
|
||||
reconstructed = polar_quant_decode(packed, norm, d)
|
||||
assert len(reconstructed) == d
|
||||
|
||||
def test_odd_dimension_fails(self):
|
||||
"""Odd dimensions should fail (need even for 4-bit packing)."""
|
||||
d = 127 # Odd
|
||||
src = [0.0] * d
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
polar_quant_encode(src)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
189
tests/test_tool_calling.py
Normal file
189
tests/test_tool_calling.py
Normal file
@@ -0,0 +1,189 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Unit tests for benchmarks/test_tool_calling.py
|
||||
|
||||
Tests the validation logic and report generation without
|
||||
requiring a live model backend.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "benchmarks"))
|
||||
import test_tool_calling as tc
|
||||
|
||||
|
||||
# ── JSON Extraction ───────────────────────────────────────────────────
|
||||
|
||||
class TestExtractJson:
|
||||
def test_direct_json(self):
|
||||
obj = tc._extract_json('{"name": "read_file", "arguments": {"path": "/etc/hostname"}}')
|
||||
assert obj["name"] == "read_file"
|
||||
|
||||
def test_json_in_code_block(self):
|
||||
text = 'Here is the call:\n```json\n{"name": "terminal", "arguments": {"command": "ls"}}\n```'
|
||||
obj = tc._extract_json(text)
|
||||
assert obj["name"] == "terminal"
|
||||
|
||||
def test_json_without_lang(self):
|
||||
text = '```\n{"name": "web_search", "arguments": {"query": "test"}}\n```'
|
||||
obj = tc._extract_json(text)
|
||||
assert obj["name"] == "web_search"
|
||||
|
||||
def test_no_json(self):
|
||||
obj = tc._extract_json("I can't help with that.")
|
||||
assert obj is None
|
||||
|
||||
def test_bare_json_object(self):
|
||||
text = 'Sure, here: {"name": "read_file", "arguments": {"path": "/tmp/x"}} for you.'
|
||||
obj = tc._extract_json(text)
|
||||
assert obj is not None
|
||||
assert obj["name"] == "read_file"
|
||||
|
||||
|
||||
# ── Tool Call Validation ──────────────────────────────────────────────
|
||||
|
||||
class TestToolCallValidation:
|
||||
def test_exact_match(self):
|
||||
resp = '{"name": "read_file", "arguments": {"path": "/etc/hostname"}}'
|
||||
result = tc._has_json_tool_call(resp, "read_file", {"path": "/etc/hostname"})
|
||||
assert result["passed"] is True
|
||||
|
||||
def test_wrong_tool_name(self):
|
||||
resp = '{"name": "write_file", "arguments": {"path": "/etc/hostname"}}'
|
||||
result = tc._has_json_tool_call(resp, "read_file", {"path": "/etc/hostname"})
|
||||
assert result["passed"] is False
|
||||
assert "wrong tool name" in result["reason"]
|
||||
|
||||
def test_missing_argument(self):
|
||||
resp = '{"name": "read_file", "arguments": {}}'
|
||||
result = tc._has_json_tool_call(resp, "read_file", {"path": "/etc/hostname"})
|
||||
assert result["passed"] is False
|
||||
assert "missing argument" in result["reason"]
|
||||
|
||||
def test_wrong_argument_value(self):
|
||||
resp = '{"name": "read_file", "arguments": {"path": "/etc/passwd"}}'
|
||||
result = tc._has_json_tool_call(resp, "read_file", {"path": "/etc/hostname"})
|
||||
assert result["passed"] is False
|
||||
assert "argument mismatch" in result["reason"]
|
||||
|
||||
def test_no_json_response(self):
|
||||
result = tc._has_json_tool_call("Sorry, I can't do that.", "read_file", {"path": "/etc/hostname"})
|
||||
assert result["passed"] is False
|
||||
assert "no JSON" in result["reason"]
|
||||
|
||||
def test_nested_function_format(self):
|
||||
resp = '{"function": {"name": "terminal", "arguments": {"command": "echo hello"}}}'
|
||||
result = tc._has_json_tool_call(resp, "terminal", {"command": "echo hello"})
|
||||
assert result["passed"] is True
|
||||
|
||||
|
||||
# ── Nested Schema Validation ──────────────────────────────────────────
|
||||
|
||||
class TestNestedSchemaValidation:
|
||||
def test_valid_nested(self):
|
||||
resp = json.dumps({
|
||||
"name": "deploy_service",
|
||||
"arguments": {
|
||||
"name": "api-gateway",
|
||||
"replicas": 3,
|
||||
"env": {"PORT": 8080, "NODE_ENV": "production"},
|
||||
"resources": {"cpu": "500m", "memory": "256Mi"}
|
||||
}
|
||||
})
|
||||
result = tc._has_nested_tool_call(resp)
|
||||
assert result["passed"] is True
|
||||
|
||||
def test_missing_nested_key(self):
|
||||
resp = '{"name": "deploy_service", "arguments": {"name": "api-gateway", "replicas": 3}}'
|
||||
result = tc._has_nested_tool_call(resp)
|
||||
assert result["passed"] is False
|
||||
assert "missing nested key" in result["reason"]
|
||||
|
||||
def test_wrong_type(self):
|
||||
resp = '{"name": "deploy_service", "arguments": {"name": "api-gateway", "replicas": "three", "env": {}, "resources": {}}}'
|
||||
result = tc._has_nested_tool_call(resp)
|
||||
assert result["passed"] is False
|
||||
assert "should be int" in result["reason"]
|
||||
|
||||
def test_missing_env_port(self):
|
||||
resp = json.dumps({
|
||||
"name": "deploy_service",
|
||||
"arguments": {"name": "api", "replicas": 1, "env": {"NODE_ENV": "dev"}, "resources": {}}
|
||||
})
|
||||
result = tc._has_nested_tool_call(resp)
|
||||
assert result["passed"] is False
|
||||
assert "PORT" in result["reason"]
|
||||
|
||||
|
||||
# ── Markdown Report Generation ────────────────────────────────────────
|
||||
|
||||
class TestMarkdownReport:
|
||||
def test_report_structure(self):
|
||||
results = {
|
||||
"model": "test-model",
|
||||
"backend": "ollama",
|
||||
"url": "http://localhost:11434",
|
||||
"timestamp": "2026-04-15T00:00:00Z",
|
||||
"tests": [
|
||||
{"id": "t1", "category": "Simple", "description": "Test 1",
|
||||
"passed": True, "reason": "ok", "response": "{}", "latency_s": 1.0, "tokens": 10},
|
||||
{"id": "t2", "category": "Complex", "description": "Test 2",
|
||||
"passed": False, "reason": "wrong name", "response": "oops", "latency_s": 2.0, "tokens": 20},
|
||||
],
|
||||
"summary": {"total": 2, "passed": 1, "failed": 1, "errors": 0},
|
||||
}
|
||||
md = tc.to_markdown(results)
|
||||
assert "test-model" in md
|
||||
assert "1/2 passed" in md
|
||||
assert "PASS" in md
|
||||
assert "FAIL" in md
|
||||
assert "Failure Analysis" in md
|
||||
|
||||
def test_perfect_score(self):
|
||||
results = {
|
||||
"model": "perfect", "backend": "ollama", "url": "http://x",
|
||||
"timestamp": "2026-01-01T00:00:00Z",
|
||||
"tests": [
|
||||
{"id": "t1", "category": "C", "description": "D",
|
||||
"passed": True, "reason": "ok", "response": "{}", "latency_s": 1, "tokens": 5},
|
||||
],
|
||||
"summary": {"total": 1, "passed": 1, "failed": 0, "errors": 0},
|
||||
}
|
||||
md = tc.to_markdown(results)
|
||||
assert "FULLY VIABLE" in md
|
||||
|
||||
def test_all_failed(self):
|
||||
results = {
|
||||
"model": "bad", "backend": "ollama", "url": "http://x",
|
||||
"timestamp": "2026-01-01T00:00:00Z",
|
||||
"tests": [
|
||||
{"id": "t1", "category": "C", "description": "D",
|
||||
"passed": False, "reason": "broken", "response": "nope", "latency_s": 1, "tokens": 0},
|
||||
],
|
||||
"summary": {"total": 1, "passed": 0, "failed": 1, "errors": 0},
|
||||
}
|
||||
md = tc.to_markdown(results)
|
||||
assert "NOT VIABLE" in md
|
||||
|
||||
|
||||
# ── Test Definitions ──────────────────────────────────────────────────
|
||||
|
||||
class TestTestDefinitions:
|
||||
def test_all_tests_have_validators(self):
|
||||
for test in tc.TESTS:
|
||||
assert callable(test["validate"]), f"{test['id']} missing validate"
|
||||
assert "id" in test
|
||||
assert "category" in test
|
||||
assert "prompt" in test
|
||||
|
||||
def test_five_test_categories(self):
|
||||
categories = {t["category"] for t in tc.TESTS}
|
||||
assert len(categories) >= 4, f"Expected 4+ categories, got {categories}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user