Compare commits
8 Commits
fix/74-git
...
burn/101-1
| Author | SHA1 | Date | |
|---|---|---|---|
| 3c815664e4 | |||
| 0d92de9b3f | |||
| 3caeaf13eb | |||
| 3cd8750cbb | |||
| ef765bbd30 | |||
|
|
5f0d00f127 | ||
|
|
8affe79489 | ||
|
|
319f57780d |
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
build/
|
||||||
|
*.pyc
|
||||||
|
__pycache__/
|
||||||
36
CMakeLists.txt
Normal file
36
CMakeLists.txt
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.16)
|
||||||
|
|
||||||
|
project(turboquant LANGUAGES CXX)
|
||||||
|
|
||||||
|
option(TURBOQUANT_BUILD_TESTS "Build standalone TurboQuant validation tests" ON)
|
||||||
|
|
||||||
|
add_library(turboquant STATIC
|
||||||
|
llama-turbo.cpp
|
||||||
|
)
|
||||||
|
|
||||||
|
target_include_directories(turboquant PUBLIC
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
)
|
||||||
|
|
||||||
|
target_compile_features(turboquant PUBLIC cxx_std_17)
|
||||||
|
|
||||||
|
if(MSVC)
|
||||||
|
target_compile_options(turboquant PRIVATE /W4)
|
||||||
|
else()
|
||||||
|
target_compile_options(turboquant PRIVATE -Wall -Wextra -Wpedantic)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(TURBOQUANT_BUILD_TESTS)
|
||||||
|
include(CTest)
|
||||||
|
|
||||||
|
add_executable(turboquant_roundtrip_test
|
||||||
|
tests/roundtrip_test.cpp
|
||||||
|
)
|
||||||
|
target_link_libraries(turboquant_roundtrip_test PRIVATE turboquant)
|
||||||
|
target_compile_features(turboquant_roundtrip_test PRIVATE cxx_std_17)
|
||||||
|
|
||||||
|
add_test(
|
||||||
|
NAME turboquant_roundtrip
|
||||||
|
COMMAND turboquant_roundtrip_test
|
||||||
|
)
|
||||||
|
endif()
|
||||||
@@ -13,7 +13,7 @@ Unlock 64K-128K context on qwen3.5:27b within 32GB unified memory.
|
|||||||
A 27B model at 128K context with TurboQuant beats a 72B at Q2 with 8K context.
|
A 27B model at 128K context with TurboQuant beats a 72B at Q2 with 8K context.
|
||||||
|
|
||||||
## Status
|
## Status
|
||||||
See [issues](http://143.198.27.163:3000/Timmy_Foundation/turboquant/issues) for current progress.
|
See [issues](https://forge.alexanderwhitestone.com/Timmy_Foundation/turboquant/issues) for current progress.
|
||||||
|
|
||||||
## Roles
|
## Roles
|
||||||
- **Strago:** Build spec author
|
- **Strago:** Build spec author
|
||||||
@@ -29,4 +29,4 @@ See [issues](http://143.198.27.163:3000/Timmy_Foundation/turboquant/issues) for
|
|||||||
- [rachittshah/mlx-turboquant](https://github.com/rachittshah/mlx-turboquant) — MLX fallback
|
- [rachittshah/mlx-turboquant](https://github.com/rachittshah/mlx-turboquant) — MLX fallback
|
||||||
|
|
||||||
## Docs
|
## Docs
|
||||||
- [BUILD-SPEC.md](BUILD-SPEC.md) — Full build specification (Strago, v2.2)
|
- [Project Status](docs/PROJECT_STATUS.md) — Full project status and build specification
|
||||||
|
|||||||
49
benchmarks/bonsai-tool-calling.md
Normal file
49
benchmarks/bonsai-tool-calling.md
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# Tool Calling Test Results — 1-Bit Models
|
||||||
|
|
||||||
|
**Status:** Pending execution
|
||||||
|
**Issue:** #101
|
||||||
|
**Model:** bonsai-1bit (to be tested)
|
||||||
|
**Backend:** Ollama
|
||||||
|
|
||||||
|
## Test Suite
|
||||||
|
|
||||||
|
10 test cases covering:
|
||||||
|
|
||||||
|
| # | Test | Type | Difficulty | Description |
|
||||||
|
|---|------|------|------------|-------------|
|
||||||
|
| 1 | simple_file_read | file_read | easy | Read README.md with exact path |
|
||||||
|
| 2 | absolute_path_read | file_read | easy | Read /etc/hostname with absolute path |
|
||||||
|
| 3 | simple_terminal | terminal | easy | Run `echo hello world` |
|
||||||
|
| 4 | terminal_ls | terminal | medium | List files in directory |
|
||||||
|
| 5 | web_search | web_search | easy | Search for a query |
|
||||||
|
| 6 | read_then_analyze | multi_step | medium | Read file then analyze content |
|
||||||
|
| 7 | nested_params | schema_parsing | hard | Complex nested parameters |
|
||||||
|
| 8 | optional_params | schema_parsing | medium | Tool with optional parameters |
|
||||||
|
| 9 | sequential_calls | multi_step | hard | Multiple tool calls in sequence |
|
||||||
|
| 10 | no_tool_needed | file_read | easy | No tool needed for simple question |
|
||||||
|
|
||||||
|
## Hypothesis
|
||||||
|
|
||||||
|
1-bit quantization destroys fine-grained reasoning. Tool calling (precise JSON output) may be impossible. But worth testing — the field is moving fast.
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
||||||
|
*To be filled after running:*
|
||||||
|
```bash
|
||||||
|
python3 benchmarks/test_tool_calling_1bit.py --model bonsai-1bit --report benchmarks/bonsai-tool-calling.md --results benchmarks/tool_calling_results.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## Failure Modes (Expected)
|
||||||
|
|
||||||
|
If tests fail, likely causes:
|
||||||
|
1. **JSON formatting:** Model cannot produce valid JSON tool calls
|
||||||
|
2. **Parameter extraction:** Model confuses or drops parameters
|
||||||
|
3. **Schema adherence:** Model ignores tool schema constraints
|
||||||
|
4. **Consistency:** Model produces different formats across runs
|
||||||
|
|
||||||
|
## Alternative Edge Models
|
||||||
|
|
||||||
|
If 1-bit is not viable:
|
||||||
|
- **Qwen3.5 3B Q4** — Good tool calling, reasonable size
|
||||||
|
- **Phi-3 Mini** — Strong reasoning, supports function calling
|
||||||
|
- **Llama 3.2 3B** — Good balance of size and capability
|
||||||
255
benchmarks/test_tool_calling_1bit.py
Normal file
255
benchmarks/test_tool_calling_1bit.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Tool Calling Test Suite for 1-Bit Models (Issue #101)
|
||||||
|
|
||||||
|
Tests whether Bonsai 1-bit models can handle tool calling at all.
|
||||||
|
Evaluates: file read, terminal execution, web search, multi-step workflows, schema parsing.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 benchmarks/test_tool_calling_1bit.py --model bonsai-1bit --backend ollama
|
||||||
|
python3 benchmarks/test_tool_calling_1bit.py --results benchmarks/tool_calling_results.json
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field, asdict
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallType(Enum):
|
||||||
|
FILE_READ = "file_read"
|
||||||
|
TERMINAL_EXEC = "terminal_exec"
|
||||||
|
WEB_SEARCH = "web_search"
|
||||||
|
MULTI_STEP = "multi_step"
|
||||||
|
SCHEMA_PARSING = "schema_parsing"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCallTest:
|
||||||
|
name: str
|
||||||
|
tool_type: ToolCallType
|
||||||
|
prompt: str
|
||||||
|
expected_tool: Optional[str]
|
||||||
|
expected_params: Dict[str, Any]
|
||||||
|
validation_fn: Optional[str] = None
|
||||||
|
difficulty: str = "easy"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TestResult:
|
||||||
|
test_name: str
|
||||||
|
tool_type: str
|
||||||
|
passed: bool
|
||||||
|
latency_ms: float
|
||||||
|
response_text: str
|
||||||
|
parsed_tool: Optional[str] = None
|
||||||
|
parsed_params: Optional[Dict[str, Any]] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
quality_score: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkResult:
|
||||||
|
model: str
|
||||||
|
backend: str
|
||||||
|
timestamp: str
|
||||||
|
results: List[TestResult] = field(default_factory=list)
|
||||||
|
summary: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
TOOL_CALL_TESTS = [
|
||||||
|
ToolCallTest("simple_file_read", ToolCallType.FILE_READ,
|
||||||
|
"Use read_file to read README.md. Tools: read_file(path: str)",
|
||||||
|
"read_file", {"path": "README.md"}, difficulty="easy"),
|
||||||
|
ToolCallTest("absolute_path_read", ToolCallType.FILE_READ,
|
||||||
|
"Use read_file to read /etc/hostname. Tools: read_file(path: str)",
|
||||||
|
"read_file", {"path": "/etc/hostname"}, difficulty="easy"),
|
||||||
|
ToolCallTest("simple_terminal", ToolCallType.TERMINAL_EXEC,
|
||||||
|
"Use terminal to run: echo hello world. Tools: terminal(command: str)",
|
||||||
|
"terminal", {"command": "echo hello world"}, difficulty="easy"),
|
||||||
|
ToolCallTest("terminal_ls", ToolCallType.TERMINAL_EXEC,
|
||||||
|
"Use terminal to list files. Tools: terminal(command: str)",
|
||||||
|
"terminal", {}, validation_fn="validate_ls", difficulty="medium"),
|
||||||
|
ToolCallTest("web_search", ToolCallType.WEB_SEARCH,
|
||||||
|
"Use web_search for Python. Tools: web_search(query: str)",
|
||||||
|
"web_search", {"query": "Python"}, difficulty="easy"),
|
||||||
|
ToolCallTest("read_then_analyze", ToolCallType.MULTI_STEP,
|
||||||
|
"First read README.md then analyze. Tools: read_file(path: str)",
|
||||||
|
"read_file", {"path": "README.md"}, difficulty="medium"),
|
||||||
|
ToolCallTest("nested_params", ToolCallType.SCHEMA_PARSING,
|
||||||
|
"Use complex_tool(name=test, config={verbose:true}, tags=[a,b]). Tools: complex_tool(name: str, config: dict, tags: list)",
|
||||||
|
"complex_tool", {"name": "test"}, difficulty="hard"),
|
||||||
|
ToolCallTest("optional_params", ToolCallType.SCHEMA_PARSING,
|
||||||
|
"Use search for ML with limit 5. Tools: search(query: str, limit: int=10)",
|
||||||
|
"search", {"query": "ML", "limit": 5}, difficulty="medium"),
|
||||||
|
ToolCallTest("sequential_calls", ToolCallType.MULTI_STEP,
|
||||||
|
"First run pwd, then read README.md. Tools: terminal(command: str), read_file(path: str)",
|
||||||
|
"terminal", {"command": "pwd"}, difficulty="hard"),
|
||||||
|
ToolCallTest("no_tool_needed", ToolCallType.FILE_READ,
|
||||||
|
"What is 2+2? Tools: read_file(path: str)",
|
||||||
|
None, {}, difficulty="easy"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def validate_ls(params):
|
||||||
|
cmd = params.get("command", "").strip()
|
||||||
|
return cmd in ["ls", "ls -l", "ls -la", "ls -1", "dir"] or cmd.startswith("ls ")
|
||||||
|
|
||||||
|
|
||||||
|
VALIDATORS = {"validate_ls": validate_ls}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_tool_call(response: str) -> Tuple[Optional[str], Optional[Dict]]:
|
||||||
|
# JSON format
|
||||||
|
patterns = [
|
||||||
|
r'"tool"\s*:\s*"([^"]+)"\s*,\s*"params"\s*:\s*({[^}]+})',
|
||||||
|
r'"name"\s*:\s*"([^"]+)"\s*,\s*"arguments"\s*:\s*({[^}]+})',
|
||||||
|
]
|
||||||
|
for pattern in patterns:
|
||||||
|
match = re.search(pattern, response, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
return match.group(1), json.loads(match.group(2))
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Function call format
|
||||||
|
match = re.search(r'(\w+)\(([^)]+)\)', response)
|
||||||
|
if match:
|
||||||
|
tool_name = match.group(1)
|
||||||
|
params = {}
|
||||||
|
for m in re.finditer(r'(\w+)\s*=\s*"?([^",)]+)"?', match.group(2)):
|
||||||
|
params[m.group(1)] = m.group(2).strip().strip('"\'')
|
||||||
|
return tool_name, params
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def call_model(prompt: str, model: str, backend: str, url: str, timeout: int = 60) -> Tuple[str, float]:
|
||||||
|
start = time.time()
|
||||||
|
try:
|
||||||
|
if backend == "ollama":
|
||||||
|
resp = requests.post(f"{url}/api/generate", json={
|
||||||
|
"model": model, "prompt": prompt, "stream": False,
|
||||||
|
"options": {"num_predict": 256, "temperature": 0.1}
|
||||||
|
}, timeout=timeout)
|
||||||
|
resp.raise_for_status()
|
||||||
|
text = resp.json().get("response", "")
|
||||||
|
else:
|
||||||
|
text = f"ERROR: Unknown backend {backend}"
|
||||||
|
except Exception as e:
|
||||||
|
text = f"ERROR: {e}"
|
||||||
|
return text, (time.time() - start) * 1000
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(test: ToolCallTest, model: str, backend: str, url: str) -> TestResult:
|
||||||
|
response, latency = call_model(test.prompt, model, backend, url)
|
||||||
|
|
||||||
|
if response.startswith("ERROR:"):
|
||||||
|
return TestResult(test.name, test.tool_type.value, False, latency, response, error=response)
|
||||||
|
|
||||||
|
parsed_tool, parsed_params = parse_tool_call(response)
|
||||||
|
passed = False
|
||||||
|
quality = 0.0
|
||||||
|
|
||||||
|
if test.expected_tool is None:
|
||||||
|
passed = parsed_tool is None
|
||||||
|
quality = 1.0 if passed else 0.0
|
||||||
|
elif parsed_tool:
|
||||||
|
tool_match = parsed_tool.lower() == test.expected_tool.lower()
|
||||||
|
if test.validation_fn and test.validation_fn in VALIDATORS:
|
||||||
|
params_match = VALIDATORS[test.validation_fn](parsed_params or {})
|
||||||
|
else:
|
||||||
|
params_match = all(
|
||||||
|
k in (parsed_params or {}) and
|
||||||
|
(str(v).lower() in str(parsed_params.get(k, "")).lower() if isinstance(v, str) else parsed_params.get(k) == v)
|
||||||
|
for k, v in test.expected_params.items()
|
||||||
|
) if test.expected_params else True
|
||||||
|
passed = tool_match and params_match
|
||||||
|
quality = (0.5 if tool_match else 0) + (0.5 if params_match else 0)
|
||||||
|
|
||||||
|
return TestResult(test.name, test.tool_type.value, passed, latency, response[:500],
|
||||||
|
parsed_tool, parsed_params, quality_score=quality)
|
||||||
|
|
||||||
|
|
||||||
|
def run_all_tests(model: str, backend: str, url: str) -> BenchmarkResult:
|
||||||
|
results = BenchmarkResult(model, backend, datetime.now(timezone.utc).isoformat())
|
||||||
|
print(f"Testing {model} ({backend})")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
for test in TOOL_CALL_TESTS:
|
||||||
|
result = run_test(test, model, backend, url)
|
||||||
|
results.results.append(result)
|
||||||
|
status = "PASS" if result.passed else "FAIL"
|
||||||
|
print(f" {status} {test.name} ({result.latency_ms:.0f}ms, q={result.quality_score:.0%})")
|
||||||
|
|
||||||
|
total = len(results.results)
|
||||||
|
passed = sum(1 for r in results.results if r.passed)
|
||||||
|
results.summary = {
|
||||||
|
"total": total, "passed": passed, "failed": total - passed,
|
||||||
|
"pass_rate": passed / total if total else 0,
|
||||||
|
"avg_latency_ms": sum(r.latency_ms for r in results.results) / total if total else 0,
|
||||||
|
"avg_quality": sum(r.quality_score for r in results.results) / total if total else 0,
|
||||||
|
}
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def generate_report(results: BenchmarkResult) -> str:
|
||||||
|
s = results.summary
|
||||||
|
lines = [
|
||||||
|
"# Tool Calling Test Results - 1-Bit Models", "",
|
||||||
|
f"**Model:** {results.model} ",
|
||||||
|
f"**Backend:** {results.backend} ",
|
||||||
|
f"**Timestamp:** {results.timestamp}", "",
|
||||||
|
"## Summary", "",
|
||||||
|
f"- Pass Rate: {s['passed']}/{s['total']} ({s['pass_rate']:.0%})",
|
||||||
|
f"- Avg Latency: {s['avg_latency_ms']:.0f}ms",
|
||||||
|
f"- Avg Quality: {s['avg_quality']:.0%}", "",
|
||||||
|
"## Detailed Results", "",
|
||||||
|
]
|
||||||
|
for r, t in zip(results.results, TOOL_CALL_TESTS):
|
||||||
|
lines.append(f"- {'PASS' if r.passed else 'FAIL'} {r.test_name} ({t.difficulty}, {r.latency_ms:.0f}ms)")
|
||||||
|
lines.extend(["", "## Conclusion", ""])
|
||||||
|
if s['pass_rate'] >= 0.8:
|
||||||
|
lines.append(f"**VIABLE** - {s['pass_rate']:.0%} pass rate.")
|
||||||
|
elif s['pass_rate'] >= 0.5:
|
||||||
|
lines.append(f"**MARGINAL** - {s['pass_rate']:.0%} pass rate.")
|
||||||
|
else:
|
||||||
|
lines.append(f"**NOT VIABLE** - {s['pass_rate']:.0%} pass rate.")
|
||||||
|
lines.extend(["", "### Alternatives", "- Qwen3.5 3B Q4", "- Phi-3 Mini", "- Llama 3.2 3B"])
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model", default="bonsai-1bit")
|
||||||
|
parser.add_argument("--backend", default="ollama")
|
||||||
|
parser.add_argument("--url", default="http://localhost:11434")
|
||||||
|
parser.add_argument("--results", help="Save results JSON")
|
||||||
|
parser.add_argument("--report", help="Save report markdown")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
results = run_all_tests(args.model, args.backend, args.url)
|
||||||
|
print(f"\nSUMMARY: {results.summary['passed']}/{results.summary['total']} passed")
|
||||||
|
|
||||||
|
if args.results:
|
||||||
|
os.makedirs(os.path.dirname(args.results) or ".", exist_ok=True)
|
||||||
|
with open(args.results, "w") as f:
|
||||||
|
json.dump(asdict(results), f, indent=2)
|
||||||
|
|
||||||
|
if args.report:
|
||||||
|
os.makedirs(os.path.dirname(args.report) or ".", exist_ok=True)
|
||||||
|
with open(args.report, "w") as f:
|
||||||
|
f.write(generate_report(results))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -135,7 +135,5 @@ llama-server -m model.gguf --port 8081 -ctk q8_0 -ctv turbo4 -c 131072
|
|||||||
|
|
||||||
## References
|
## References
|
||||||
|
|
||||||
- [TurboQuant Build Spec](../BUILD-SPEC.md)
|
- [Project Status](../docs/PROJECT_STATUS.md)
|
||||||
- [Phase 1 Report](../PHASE1-REPORT.md)
|
|
||||||
- [Full Knowledge Transfer](../FULL-REPORT.md)
|
|
||||||
- [llama.cpp TurboQuant Fork](https://github.com/TheTom/llama-cpp-turboquant)
|
- [llama.cpp TurboQuant Fork](https://github.com/TheTom/llama-cpp-turboquant)
|
||||||
|
|||||||
104
tests/roundtrip_test.cpp
Normal file
104
tests/roundtrip_test.cpp
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
#include "llama-turbo.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <iostream>
|
||||||
|
#include <random>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr int kDim = 128;
|
||||||
|
constexpr float kCosineThreshold = 0.99f;
|
||||||
|
constexpr float kZeroTolerance = 1.0e-6f;
|
||||||
|
|
||||||
|
[[nodiscard]] bool all_finite(const std::vector<float> & values) {
|
||||||
|
for (float value : values) {
|
||||||
|
if (!std::isfinite(value)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] float max_abs(const std::vector<float> & values) {
|
||||||
|
float best = 0.0f;
|
||||||
|
for (float value : values) {
|
||||||
|
best = std::max(best, std::fabs(value));
|
||||||
|
}
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] float cosine_similarity(const std::vector<float> & lhs, const std::vector<float> & rhs) {
|
||||||
|
float dot = 0.0f;
|
||||||
|
float lhs_norm = 0.0f;
|
||||||
|
float rhs_norm = 0.0f;
|
||||||
|
for (int i = 0; i < kDim; ++i) {
|
||||||
|
dot += lhs[i] * rhs[i];
|
||||||
|
lhs_norm += lhs[i] * lhs[i];
|
||||||
|
rhs_norm += rhs[i] * rhs[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
const float denom = std::sqrt(lhs_norm) * std::sqrt(rhs_norm);
|
||||||
|
return denom == 0.0f ? 1.0f : dot / denom;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::vector<float> roundtrip(const std::vector<float> & input, float & norm_out) {
|
||||||
|
std::vector<uint8_t> packed(kDim / 2, 0);
|
||||||
|
norm_out = -1.0f;
|
||||||
|
polar_quant_encode_turbo4(input.data(), packed.data(), &norm_out, kDim);
|
||||||
|
|
||||||
|
std::vector<float> decoded(kDim, 0.0f);
|
||||||
|
polar_quant_decode_turbo4(packed.data(), decoded.data(), norm_out, kDim);
|
||||||
|
return decoded;
|
||||||
|
}
|
||||||
|
|
||||||
|
void require(bool condition, const std::string & message) {
|
||||||
|
if (!condition) {
|
||||||
|
throw std::runtime_error(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_zero_vector_roundtrip() {
|
||||||
|
std::vector<float> zeros(kDim, 0.0f);
|
||||||
|
float norm = -1.0f;
|
||||||
|
const auto decoded = roundtrip(zeros, norm);
|
||||||
|
|
||||||
|
require(norm == 0.0f, "zero vector should encode with zero norm");
|
||||||
|
require(all_finite(decoded), "zero vector decode produced non-finite values");
|
||||||
|
require(max_abs(decoded) <= kZeroTolerance, "zero vector decode should remain near zero");
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_gaussian_roundtrip_quality() {
|
||||||
|
std::mt19937 rng(12345);
|
||||||
|
std::normal_distribution<float> dist(0.0f, 1.0f);
|
||||||
|
|
||||||
|
std::vector<float> input(kDim, 0.0f);
|
||||||
|
for (float & value : input) {
|
||||||
|
value = dist(rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
float norm = -1.0f;
|
||||||
|
const auto decoded = roundtrip(input, norm);
|
||||||
|
|
||||||
|
require(norm > 0.0f, "random vector should encode with positive norm");
|
||||||
|
require(all_finite(decoded), "random vector decode produced non-finite values");
|
||||||
|
|
||||||
|
const float cosine = cosine_similarity(input, decoded);
|
||||||
|
require(cosine >= kCosineThreshold, "roundtrip cosine similarity below threshold");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
try {
|
||||||
|
test_zero_vector_roundtrip();
|
||||||
|
test_gaussian_roundtrip_quality();
|
||||||
|
std::cout << "PASS: turboquant standalone roundtrip tests\n";
|
||||||
|
return 0;
|
||||||
|
} catch (const std::exception & exc) {
|
||||||
|
std::cerr << "FAIL: " << exc.what() << '\n';
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
125
tests/test_tool_calling_suite.py
Normal file
125
tests/test_tool_calling_suite.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Tests for Tool Calling Test Suite.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Add benchmarks to path
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "benchmarks"))
|
||||||
|
|
||||||
|
from test_tool_calling_1bit import (
|
||||||
|
ToolCallType,
|
||||||
|
ToolCallTest,
|
||||||
|
TestResult,
|
||||||
|
BenchmarkResult,
|
||||||
|
parse_tool_call,
|
||||||
|
validate_ls,
|
||||||
|
TOOL_CALL_TESTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolCallType:
|
||||||
|
def test_values(self):
|
||||||
|
assert ToolCallType.FILE_READ.value == "file_read"
|
||||||
|
assert ToolCallType.TERMINAL_EXEC.value == "terminal_exec"
|
||||||
|
assert ToolCallType.WEB_SEARCH.value == "web_search"
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseToolCall:
|
||||||
|
def test_json_format(self):
|
||||||
|
response = '{"tool": "read_file", "params": {"path": "test.txt"}}'
|
||||||
|
tool, params = parse_tool_call(response)
|
||||||
|
assert tool == "read_file"
|
||||||
|
assert params == {"path": "test.txt"}
|
||||||
|
|
||||||
|
def test_json_alt_format(self):
|
||||||
|
response = '{"name": "terminal", "arguments": {"command": "ls"}}'
|
||||||
|
tool, params = parse_tool_call(response)
|
||||||
|
assert tool == "terminal"
|
||||||
|
assert params == {"command": "ls"}
|
||||||
|
|
||||||
|
def test_function_format(self):
|
||||||
|
response = 'read_file(path="test.txt")'
|
||||||
|
tool, params = parse_tool_call(response)
|
||||||
|
assert tool == "read_file"
|
||||||
|
assert params.get("path") == "test.txt"
|
||||||
|
|
||||||
|
def test_no_tool(self):
|
||||||
|
response = "The answer is 4."
|
||||||
|
tool, params = parse_tool_call(response)
|
||||||
|
assert tool is None
|
||||||
|
assert params is None
|
||||||
|
|
||||||
|
def test_embedded_json(self):
|
||||||
|
response = "I will call the tool: {\"tool\": \"search\", \"params\": {\"query\": \"test\"}}"
|
||||||
|
tool, params = parse_tool_call(response)
|
||||||
|
assert tool == "search"
|
||||||
|
assert params == {"query": "test"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateLs:
|
||||||
|
def test_simple_ls(self):
|
||||||
|
assert validate_ls({"command": "ls"}) is True
|
||||||
|
|
||||||
|
def test_ls_flags(self):
|
||||||
|
assert validate_ls({"command": "ls -la"}) is True
|
||||||
|
assert validate_ls({"command": "ls -l"}) is True
|
||||||
|
|
||||||
|
def test_not_ls(self):
|
||||||
|
assert validate_ls({"command": "pwd"}) is False
|
||||||
|
assert validate_ls({"command": "cat file.txt"}) is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolCallTests:
|
||||||
|
def test_all_tests_valid(self):
|
||||||
|
assert len(TOOL_CALL_TESTS) == 10
|
||||||
|
for test in TOOL_CALL_TESTS:
|
||||||
|
assert test.name
|
||||||
|
assert isinstance(test.tool_type, ToolCallType)
|
||||||
|
assert test.prompt
|
||||||
|
|
||||||
|
def test_difficulty_distribution(self):
|
||||||
|
difficulties = [t.difficulty for t in TOOL_CALL_TESTS]
|
||||||
|
assert "easy" in difficulties
|
||||||
|
assert "medium" in difficulties
|
||||||
|
assert "hard" in difficulties
|
||||||
|
|
||||||
|
def test_type_distribution(self):
|
||||||
|
types = [t.tool_type for t in TOOL_CALL_TESTS]
|
||||||
|
assert ToolCallType.FILE_READ in types
|
||||||
|
assert ToolCallType.TERMINAL_EXEC in types
|
||||||
|
assert ToolCallType.WEB_SEARCH in types
|
||||||
|
|
||||||
|
|
||||||
|
class TestTestResult:
|
||||||
|
def test_creation(self):
|
||||||
|
result = TestResult("test1", "file_read", True, 100.0, "response")
|
||||||
|
assert result.test_name == "test1"
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.latency_ms == 100.0
|
||||||
|
|
||||||
|
def test_to_dict(self):
|
||||||
|
result = TestResult("test1", "file_read", True, 100.0, "response")
|
||||||
|
d = result.__dict__
|
||||||
|
assert "test_name" in d
|
||||||
|
assert "passed" in d
|
||||||
|
|
||||||
|
|
||||||
|
class TestBenchmarkResult:
|
||||||
|
def test_creation(self):
|
||||||
|
result = BenchmarkResult("model1", "ollama", "2026-01-01T00:00:00Z")
|
||||||
|
assert result.model == "model1"
|
||||||
|
assert result.results == []
|
||||||
|
|
||||||
|
def test_summary(self):
|
||||||
|
result = BenchmarkResult("model1", "ollama", "2026-01-01T00:00:00Z")
|
||||||
|
result.results = [
|
||||||
|
TestResult("t1", "file_read", True, 100, "r1"),
|
||||||
|
TestResult("t2", "terminal", False, 200, "r2"),
|
||||||
|
]
|
||||||
|
# Summary would be computed by run_all_tests
|
||||||
|
assert len(result.results) == 2
|
||||||
Reference in New Issue
Block a user