Compare commits
3 Commits
burn/63-17
...
feat/101-b
| Author | SHA1 | Date | |
|---|---|---|---|
| 590c4c7820 | |||
| 629be9714f | |||
| 3123d1fa8e |
50
benchmarks/bonsai-tool-calling.md
Normal file
50
benchmarks/bonsai-tool-calling.md
Normal file
@@ -0,0 +1,50 @@
|
||||
# Tool Calling Viability: Bonsai 1-Bit Models
|
||||
|
||||
**Epic**: #99 (1-Bit Models + Edge)
|
||||
**Date**: TBD (run benchmarks/test_tool_calling.py to populate)
|
||||
|
||||
## Hypothesis
|
||||
|
||||
1-bit quantization destroys fine-grained reasoning. Tool calling (precise JSON output) may be impossible at Q1_0. But worth testing — the field is moving fast.
|
||||
|
||||
## Models to Test
|
||||
|
||||
| Model | Size | Quant | Source |
|
||||
|-------|------|-------|--------|
|
||||
| Bonsai-1.7B | 1.7B | Q1_0 | prism-ml/Bonsai-1.7B-gguf |
|
||||
| Bonsai-4B | 4B | Q1_0 | prism-ml/Bonsai-4B-gguf |
|
||||
| Bonsai-8B | 8B | Q1_0 | prism-ml/Bonsai-8B-gguf |
|
||||
|
||||
## Test Suite
|
||||
|
||||
| # | Test | Category | Description |
|
||||
|---|------|----------|-------------|
|
||||
| 1 | simple_file_read | Simple Tool Call | Read a file with an exact path |
|
||||
| 2 | terminal_command | Terminal Command | Execute a shell command |
|
||||
| 3 | web_search | Web Search | Search the web for a query |
|
||||
| 4 | multi_step_chain | Multi-Step | Chain: read -> analyze -> write |
|
||||
| 5 | nested_schema | Schema Parsing | Complex nested parameters |
|
||||
|
||||
## Results
|
||||
|
||||
> **Run**: `python3 benchmarks/test_tool_calling.py --model bonsai-1.7b --output benchmarks/bonsai-tool-calling.md`
|
||||
|
||||
| Test | Bonsai-1.7B | Bonsai-4B | Bonsai-8B |
|
||||
|------|-------------|-----------|-----------|
|
||||
| simple_file_read | TBD | TBD | TBD |
|
||||
| terminal_command | TBD | TBD | TBD |
|
||||
| web_search | TBD | TBD | TBD |
|
||||
| multi_step_chain | TBD | TBD | TBD |
|
||||
| nested_schema | TBD | TBD | TBD |
|
||||
|
||||
## Verdict
|
||||
|
||||
TBD — run the test suite to populate.
|
||||
|
||||
## Failure Modes (if any)
|
||||
|
||||
TBD — document specific failure patterns observed.
|
||||
|
||||
## Recommendations
|
||||
|
||||
TBD — based on results, recommend minimum viable quantization level for tool calling.
|
||||
@@ -1,308 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Perplexity Quality Gate — Unified PPL measurement for TurboQuant (#63).
|
||||
|
||||
Provides a single interface for perplexity measurement regardless of backend:
|
||||
- llama-server: Real perplexity via llama-perplexity with --logprobs
|
||||
- Ollama: Proxy metric with documented limitations
|
||||
|
||||
Usage:
|
||||
# Real PPL via llama-server (recommended)
|
||||
python3 benchmarks/quality_gate.py \
|
||||
--backend llama-server \
|
||||
--model ~/models/model.gguf \
|
||||
--corpus corpora/wiki.test.raw
|
||||
|
||||
# Proxy PPL via Ollama (documented limitation)
|
||||
python3 benchmarks/quality_gate.py \
|
||||
--backend ollama \
|
||||
--model llama3 \
|
||||
--corpus corpora/wiki.test.raw
|
||||
|
||||
# CI mode — exit 1 if quality gate fails
|
||||
python3 benchmarks/quality_gate.py --check --threshold 0.5
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
import time
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerplexityResult:
|
||||
"""Result of a perplexity measurement."""
|
||||
backend: str # "llama-server" or "ollama-proxy"
|
||||
kv_type: str # "f16", "turbo4", etc.
|
||||
perplexity: Optional[float]
|
||||
is_proxy: bool # True if this is an approximation, not real PPL
|
||||
tokens: Optional[int] = None
|
||||
elapsed_seconds: float = 0.0
|
||||
method: str = "" # How PPL was measured
|
||||
exit_code: int = 0
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QualityGateResult:
|
||||
"""Result of a quality gate comparison."""
|
||||
f16: Optional[PerplexityResult]
|
||||
turbo4: Optional[PerplexityResult]
|
||||
delta: Optional[float]
|
||||
threshold: float
|
||||
passed: bool
|
||||
is_proxy: bool # True if either measurement is proxy
|
||||
warning: str = ""
|
||||
|
||||
def summary(self) -> str:
|
||||
lines = ["Perplexity Quality Gate", "=" * 40]
|
||||
if self.f16:
|
||||
lines.append(f" F16: PPL={self.f16.perplexity} ({self.f16.backend}, proxy={self.f16.is_proxy})")
|
||||
if self.turbo4:
|
||||
lines.append(f" Turbo4: PPL={self.turbo4.perplexity} ({self.turbo4.backend}, proxy={self.turbo4.is_proxy})")
|
||||
if self.delta is not None:
|
||||
lines.append(f" Delta: {self.delta:.4f} (threshold={self.threshold})")
|
||||
status = "PASS" if self.passed else "FAIL"
|
||||
lines.append(f" Result: {status}")
|
||||
else:
|
||||
lines.append(" Result: INCOMPLETE (missing measurements)")
|
||||
if self.warning:
|
||||
lines.append(f" Warning: {self.warning}")
|
||||
if self.is_proxy:
|
||||
lines.append(" NOTE: Proxy measurement — not real perplexity via logprobs")
|
||||
return "\n".join(lines)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"f16": self.f16.to_dict() if self.f16 else None,
|
||||
"turbo4": self.turbo4.to_dict() if self.turbo4 else None,
|
||||
"delta": self.delta,
|
||||
"threshold": self.threshold,
|
||||
"passed": self.passed,
|
||||
"is_proxy": self.is_proxy,
|
||||
"warning": self.warning,
|
||||
}
|
||||
|
||||
|
||||
def measure_perplexity_llama_server(
|
||||
llama_bin: str, model: str, corpus: str, context: int,
|
||||
kv_type: str, threads: int = 4
|
||||
) -> PerplexityResult:
|
||||
"""Real perplexity via llama-perplexity binary (supports --logprobs)."""
|
||||
cmd = [
|
||||
llama_bin, "-m", model, "-f", corpus,
|
||||
"-c", str(context), "-t", str(threads),
|
||||
"--kv-type", kv_type,
|
||||
]
|
||||
start = time.time()
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600)
|
||||
elapsed = time.time() - start
|
||||
output = result.stdout + "\n" + result.stderr
|
||||
|
||||
ppl_match = re.search(r"perplexity[:\s]+(\d+\.?\d*)", output, re.IGNORECASE)
|
||||
ppl = float(ppl_match.group(1)) if ppl_match else None
|
||||
|
||||
token_match = re.search(r"(\d+) tokens", output)
|
||||
tokens = int(token_match.group(1)) if token_match else None
|
||||
|
||||
return PerplexityResult(
|
||||
backend="llama-server",
|
||||
kv_type=kv_type,
|
||||
perplexity=ppl,
|
||||
is_proxy=False,
|
||||
tokens=tokens,
|
||||
elapsed_seconds=round(elapsed, 1),
|
||||
method="llama-perplexity with --logprobs",
|
||||
exit_code=result.returncode,
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
return PerplexityResult(
|
||||
backend="llama-server", kv_type=kv_type, perplexity=None,
|
||||
is_proxy=False, elapsed_seconds=3600, method="timeout",
|
||||
exit_code=-1, error="Timeout after 3600s",
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return PerplexityResult(
|
||||
backend="llama-server", kv_type=kv_type, perplexity=None,
|
||||
is_proxy=False, method="binary not found",
|
||||
exit_code=-1, error=f"Binary not found: {llama_bin}",
|
||||
)
|
||||
|
||||
|
||||
def measure_perplexity_ollama_proxy(
|
||||
model: str, corpus: str, api_base: str = "http://localhost:11434"
|
||||
) -> PerplexityResult:
|
||||
"""
|
||||
Proxy perplexity estimation via Ollama.
|
||||
|
||||
Ollama does NOT expose token logprobs. This method approximates
|
||||
perplexity by measuring generation coherence on the corpus text.
|
||||
|
||||
This is a PROXY metric — not real perplexity. The actual PPL delta
|
||||
between FP16 and TurboQuant cannot be validated through this method.
|
||||
Use llama-server for real measurements.
|
||||
"""
|
||||
import urllib.request
|
||||
|
||||
# Read corpus sample (first 2048 chars to keep it fast)
|
||||
corpus_path = Path(corpus)
|
||||
if corpus_path.exists():
|
||||
sample = corpus_path.read_text()[:2048]
|
||||
else:
|
||||
sample = "The quick brown fox jumps over the lazy dog. " * 50
|
||||
|
||||
# Use Ollama generate API to measure token throughput
|
||||
# This is the proxy metric: higher tok/s = lower effective perplexity
|
||||
start = time.time()
|
||||
try:
|
||||
payload = json.dumps({
|
||||
"model": model,
|
||||
"prompt": sample,
|
||||
"stream": False,
|
||||
"options": {"num_predict": 256},
|
||||
}).encode()
|
||||
|
||||
req = urllib.request.Request(
|
||||
f"{api_base}/api/generate",
|
||||
data=payload,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
resp = urllib.request.urlopen(req, timeout=120)
|
||||
data = json.loads(resp.read())
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Extract eval rate as proxy
|
||||
eval_count = data.get("eval_count", 0)
|
||||
eval_duration = data.get("eval_duration", 1)
|
||||
tok_per_sec = (eval_count / (eval_duration / 1e9)) if eval_duration > 0 else 0
|
||||
|
||||
# Approximate PPL from tok/s (heuristic: faster = better quality preservation)
|
||||
# This is NOT real perplexity — it's a relative proxy
|
||||
proxy_ppl = max(1.0, 50.0 / max(tok_per_sec, 1.0))
|
||||
|
||||
return PerplexityResult(
|
||||
backend="ollama-proxy",
|
||||
kv_type="f16", # Ollama manages KV internally
|
||||
perplexity=round(proxy_ppl, 2),
|
||||
is_proxy=True,
|
||||
tokens=eval_count,
|
||||
elapsed_seconds=round(elapsed, 1),
|
||||
method=f"proxy: tok/s heuristic ({tok_per_sec:.1f} tok/s)",
|
||||
exit_code=0,
|
||||
)
|
||||
except Exception as e:
|
||||
return PerplexityResult(
|
||||
backend="ollama-proxy", kv_type="f16", perplexity=None,
|
||||
is_proxy=True, method="ollama proxy",
|
||||
exit_code=-1, error=str(e),
|
||||
)
|
||||
|
||||
|
||||
def run_quality_gate(
|
||||
backend: str = "llama-server",
|
||||
model: str = "",
|
||||
corpus: str = "corpora/wiki.test.raw",
|
||||
context: int = 2048,
|
||||
threads: int = 4,
|
||||
llama_bin: str = "llama.cpp-fork/build/bin/llama-perplexity",
|
||||
threshold: float = 0.5,
|
||||
ollama_base: str = "http://localhost:11434",
|
||||
) -> QualityGateResult:
|
||||
"""Run quality gate: measure F16 vs Turbo4 PPL and check delta."""
|
||||
|
||||
if backend == "llama-server":
|
||||
f16 = measure_perplexity_llama_server(llama_bin, model, corpus, context, "f16", threads)
|
||||
turbo4 = measure_perplexity_llama_server(llama_bin, model, corpus, context, "turbo4", threads)
|
||||
elif backend == "ollama":
|
||||
f16 = measure_perplexity_ollama_proxy(model, corpus, ollama_base)
|
||||
turbo4 = None # Can't measure turbo4 via Ollama
|
||||
else:
|
||||
return QualityGateResult(
|
||||
f16=None, turbo4=None, delta=None,
|
||||
threshold=threshold, passed=False, is_proxy=True,
|
||||
warning=f"Unknown backend: {backend}",
|
||||
)
|
||||
|
||||
# Compute delta
|
||||
delta = None
|
||||
passed = False
|
||||
is_proxy = f16.is_proxy or (turbo4.is_proxy if turbo4 else True)
|
||||
warning = ""
|
||||
|
||||
if f16.perplexity is not None and turbo4 and turbo4.perplexity is not None:
|
||||
delta = turbo4.perplexity - f16.perplexity
|
||||
passed = delta <= threshold
|
||||
elif f16.perplexity is not None and turbo4 is None:
|
||||
warning = "Only F16 measured — cannot compute delta (turbo4 not available)"
|
||||
|
||||
if is_proxy:
|
||||
warning += " PROXY measurement — not real perplexity via logprobs."
|
||||
|
||||
return QualityGateResult(
|
||||
f16=f16, turbo4=turbo4, delta=delta,
|
||||
threshold=threshold, passed=passed,
|
||||
is_proxy=is_proxy, warning=warning.strip(),
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Perplexity Quality Gate (#63)")
|
||||
parser.add_argument("--backend", choices=["llama-server", "ollama"], default="llama-server")
|
||||
parser.add_argument("--model", required=True, help="Model path (GGUF) or Ollama model name")
|
||||
parser.add_argument("--corpus", default="corpora/wiki.test.raw")
|
||||
parser.add_argument("--context", type=int, default=2048)
|
||||
parser.add_argument("--threads", type=int, default=4)
|
||||
parser.add_argument("--llama-bin", default="llama.cpp-fork/build/bin/llama-perplexity")
|
||||
parser.add_argument("--threshold", type=float, default=0.5)
|
||||
parser.add_argument("--ollama-base", default="http://localhost:11434")
|
||||
parser.add_argument("--output", default="benchmarks/perplexity_results.json")
|
||||
parser.add_argument("--check", action="store_true", help="CI mode: exit 1 if gate fails")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = run_quality_gate(
|
||||
backend=args.backend, model=args.model, corpus=args.corpus,
|
||||
context=args.context, threads=args.threads, llama_bin=args.llama_bin,
|
||||
threshold=args.threshold, ollama_base=args.ollama_base,
|
||||
)
|
||||
|
||||
print(result.summary())
|
||||
|
||||
# Save results
|
||||
output_path = Path(args.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
existing = {}
|
||||
if output_path.exists():
|
||||
try:
|
||||
existing = json.loads(output_path.read_text())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
existing.update({
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
"model": args.model,
|
||||
"corpus": args.corpus,
|
||||
"context_length": args.context,
|
||||
"threshold": args.threshold,
|
||||
"quality_gate": result.to_dict(),
|
||||
})
|
||||
output_path.write_text(json.dumps(existing, indent=2))
|
||||
|
||||
if args.check and not result.passed:
|
||||
sys.exit(1)
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -5,16 +5,8 @@ TurboQuant Benchmarking Suite — Multi-Backend (Issue #29)
|
||||
Supports Ollama and llama-server backends with KV cache type configuration.
|
||||
Measures: TTFT, tokens/sec, latency, peak memory.
|
||||
|
||||
IMPORTANT — Perplexity Limitation (Issue #63):
|
||||
Ollama does NOT expose token logprobs. This means:
|
||||
- True perplexity (PPL) cannot be measured via the Ollama backend
|
||||
- The metrics here (tok/s, latency) are throughput proxies, not quality gates
|
||||
- For real perplexity measurement, use benchmarks/run_perplexity.py
|
||||
which calls llama-perplexity directly (--logprobs support)
|
||||
- The pass criterion "PPL delta <= 0.5" cannot be validated via Ollama
|
||||
|
||||
Usage:
|
||||
# Ollama (default) — throughput benchmarks only, NOT perplexity
|
||||
# Ollama (default)
|
||||
python3 benchmarks/run_benchmarks.py --backend ollama --model llama3
|
||||
|
||||
# llama-server with turbo4 KV
|
||||
|
||||
435
benchmarks/test_tool_calling.py
Normal file
435
benchmarks/test_tool_calling.py
Normal file
@@ -0,0 +1,435 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tool Calling Viability Test for 1-Bit / Edge Models (Issue #101)
|
||||
|
||||
Tests whether Bonsai 1-bit models (or any small model) can produce
|
||||
valid tool calls via Ollama or llama-server API.
|
||||
|
||||
Test suite (5 categories):
|
||||
1. Simple tool call: file read with exact path
|
||||
2. Terminal command execution
|
||||
3. Web search
|
||||
4. Multi-step: read file -> analyze -> write result
|
||||
5. Schema parsing: complex nested parameters
|
||||
|
||||
Each test:
|
||||
- Sends a prompt requesting a tool call
|
||||
- Checks if the response contains valid JSON tool call syntax
|
||||
- Scores structural validity + semantic accuracy
|
||||
- Records latency and token count
|
||||
|
||||
Usage:
|
||||
python3 benchmarks/test_tool_calling.py --model bonsai-1.7b
|
||||
python3 benchmarks/test_tool_calling.py --model qwen3.5 --backend llama-server --url http://localhost:8080
|
||||
python3 benchmarks/test_tool_calling.py --model bonsai-1.7b --output benchmarks/bonsai-tool-calling.md
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
print("Error: pip install requests", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# ── Tool call test definitions ────────────────────────────────────────
|
||||
|
||||
TESTS = [
|
||||
{
|
||||
"id": "simple_file_read",
|
||||
"category": "Simple Tool Call",
|
||||
"description": "Read a file with an exact path",
|
||||
"prompt": (
|
||||
"You have access to a tool called read_file. "
|
||||
"Call it to read /etc/hostname. "
|
||||
"Respond ONLY with a JSON tool call in this exact format:\n"
|
||||
'{"name": "read_file", "arguments": {"path": "/etc/hostname"}}'
|
||||
),
|
||||
"validate": lambda resp: _has_json_tool_call(resp, "read_file", {"path": "/etc/hostname"}),
|
||||
"scoring": "structural", # Can we parse the JSON at all?
|
||||
},
|
||||
{
|
||||
"id": "terminal_command",
|
||||
"category": "Terminal Command",
|
||||
"description": "Execute a shell command",
|
||||
"prompt": (
|
||||
"You have access to a tool called terminal. "
|
||||
"Call it to run the command: echo hello world. "
|
||||
"Respond ONLY with a JSON tool call:\n"
|
||||
'{"name": "terminal", "arguments": {"command": "echo hello world"}}'
|
||||
),
|
||||
"validate": lambda resp: _has_json_tool_call(resp, "terminal", {"command": "echo hello world"}),
|
||||
"scoring": "structural",
|
||||
},
|
||||
{
|
||||
"id": "web_search",
|
||||
"category": "Web Search",
|
||||
"description": "Search the web for a query",
|
||||
"prompt": (
|
||||
"You have access to a tool called web_search. "
|
||||
"Search for: what is quantization in machine learning. "
|
||||
"Respond ONLY with a JSON tool call:\n"
|
||||
'{"name": "web_search", "arguments": {"query": "what is quantization in machine learning"}}'
|
||||
),
|
||||
"validate": lambda resp: _has_json_tool_call(resp, "web_search", {"query": "what is quantization in machine learning"}),
|
||||
"scoring": "structural",
|
||||
},
|
||||
{
|
||||
"id": "multi_step_chain",
|
||||
"category": "Multi-Step",
|
||||
"description": "Chain: read file -> analyze -> write result",
|
||||
"prompt": (
|
||||
"You have access to these tools: read_file, write_file.\n"
|
||||
"Task: Read /tmp/input.txt, count the words, then write the count to /tmp/count.txt.\n"
|
||||
"First, call read_file on /tmp/input.txt. "
|
||||
"Respond ONLY with the first tool call as JSON:\n"
|
||||
'{"name": "read_file", "arguments": {"path": "/tmp/input.txt"}}'
|
||||
),
|
||||
"validate": lambda resp: _has_json_tool_call(resp, "read_file", {"path": "/tmp/input.txt"}),
|
||||
"scoring": "structural",
|
||||
},
|
||||
{
|
||||
"id": "nested_schema",
|
||||
"category": "Schema Parsing",
|
||||
"description": "Complex nested parameters",
|
||||
"prompt": (
|
||||
"You have access to a tool called deploy_service. "
|
||||
"Deploy a service with:\n"
|
||||
'- name: "api-gateway"\n'
|
||||
'- replicas: 3\n'
|
||||
'- env: {"PORT": 8080, "NODE_ENV": "production"}\n'
|
||||
'- resources: {"cpu": "500m", "memory": "256Mi"}\n\n'
|
||||
"Respond ONLY with a JSON tool call:\n"
|
||||
'{"name": "deploy_service", "arguments": {"name": "api-gateway", "replicas": 3, '
|
||||
'"env": {"PORT": 8080, "NODE_ENV": "production"}, '
|
||||
'"resources": {"cpu": "500m", "memory": "256Mi"}}}'
|
||||
),
|
||||
"validate": lambda resp: _has_nested_tool_call(resp),
|
||||
"scoring": "semantic", # Needs correct nested structure
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ── Validation helpers ────────────────────────────────────────────────
|
||||
|
||||
def _extract_json(text: str) -> Optional[dict]:
|
||||
"""Try to extract a JSON object from text."""
|
||||
# Try direct parse
|
||||
text = text.strip()
|
||||
try:
|
||||
obj = json.loads(text)
|
||||
if isinstance(obj, dict):
|
||||
return obj
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try finding JSON in code blocks
|
||||
code_block = re.search(r"```(?:json)?\s*({.*?})\s*```", text, re.DOTALL)
|
||||
if code_block:
|
||||
try:
|
||||
return json.loads(code_block.group(1))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try finding any JSON object
|
||||
json_match = re.search(r"({[^{}]*(?:{[^{}]*}[^{}]*)*})", text)
|
||||
if json_match:
|
||||
try:
|
||||
return json.loads(json_match.group(1))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _has_json_tool_call(resp: str, expected_name: str, expected_args: dict) -> dict:
|
||||
"""Check if response contains a valid tool call with expected name and args."""
|
||||
obj = _extract_json(resp)
|
||||
if obj is None:
|
||||
return {"passed": False, "reason": "no JSON found in response"}
|
||||
|
||||
# Check name
|
||||
name = obj.get("name", obj.get("function", {}).get("name", ""))
|
||||
if name != expected_name:
|
||||
return {"passed": False, "reason": f"wrong tool name: {name!r}, expected {expected_name!r}"}
|
||||
|
||||
# Check arguments exist
|
||||
args = obj.get("arguments", obj.get("function", {}).get("arguments", obj.get("args", {})))
|
||||
if not args:
|
||||
return {"passed": False, "reason": "no arguments found"}
|
||||
|
||||
# Check key arguments match
|
||||
for key, val in expected_args.items():
|
||||
if key not in args:
|
||||
return {"passed": False, "reason": f"missing argument: {key}"}
|
||||
if args[key] != val:
|
||||
return {"passed": False, "reason": f"argument mismatch: {key}={args[key]!r}, expected {val!r}"}
|
||||
|
||||
return {"passed": True, "reason": "tool call valid", "parsed": obj}
|
||||
|
||||
|
||||
def _has_nested_tool_call(resp: str) -> dict:
|
||||
"""Check if response contains a valid tool call with nested parameters."""
|
||||
obj = _extract_json(resp)
|
||||
if obj is None:
|
||||
return {"passed": False, "reason": "no JSON found in response"}
|
||||
|
||||
name = obj.get("name", obj.get("function", {}).get("name", ""))
|
||||
if name != "deploy_service":
|
||||
return {"passed": False, "reason": f"wrong tool name: {name!r}"}
|
||||
|
||||
args = obj.get("arguments", obj.get("function", {}).get("arguments", obj.get("args", {})))
|
||||
if not args:
|
||||
return {"passed": False, "reason": "no arguments found"}
|
||||
|
||||
checks = {
|
||||
"name": str,
|
||||
"replicas": int,
|
||||
"env": dict,
|
||||
"resources": dict,
|
||||
}
|
||||
|
||||
for key, expected_type in checks.items():
|
||||
if key not in args:
|
||||
return {"passed": False, "reason": f"missing nested key: {key}"}
|
||||
if not isinstance(args[key], expected_type):
|
||||
return {"passed": False, "reason": f"{key} should be {expected_type.__name__}, got {type(args[key]).__name__}"}
|
||||
|
||||
# Check env has PORT
|
||||
env = args.get("env", {})
|
||||
if "PORT" not in env:
|
||||
return {"passed": False, "reason": "env missing PORT"}
|
||||
|
||||
return {"passed": True, "reason": "nested tool call valid", "parsed": obj}
|
||||
|
||||
|
||||
# ── Backend runners ───────────────────────────────────────────────────
|
||||
|
||||
def run_ollama(prompt: str, model: str, url: str, timeout: int = 120) -> dict:
|
||||
"""Run a prompt against Ollama."""
|
||||
api_url = f"{url.rstrip('/')}/api/generate"
|
||||
start = time.time()
|
||||
try:
|
||||
resp = requests.post(api_url, json={
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"num_predict": 256, "temperature": 0}
|
||||
}, timeout=timeout)
|
||||
elapsed = time.time() - start
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return {
|
||||
"response": data.get("response", ""),
|
||||
"latency_s": round(elapsed, 3),
|
||||
"tokens": data.get("eval_count", 0),
|
||||
"status": "success",
|
||||
}
|
||||
except Exception as e:
|
||||
return {"response": "", "latency_s": round(time.time() - start, 3), "tokens": 0, "status": "failed", "error": str(e)}
|
||||
|
||||
|
||||
def run_llama_server(prompt: str, model: str, url: str, timeout: int = 120) -> dict:
|
||||
"""Run a prompt against llama-server (OpenAI-compatible)."""
|
||||
api_url = f"{url.rstrip('/')}/v1/chat/completions"
|
||||
start = time.time()
|
||||
try:
|
||||
resp = requests.post(api_url, json={
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a tool-calling assistant. Respond ONLY with JSON tool calls."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
"max_tokens": 256,
|
||||
"temperature": 0,
|
||||
"stream": False,
|
||||
}, timeout=timeout)
|
||||
elapsed = time.time() - start
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
usage = data.get("usage", {})
|
||||
return {
|
||||
"response": content,
|
||||
"latency_s": round(elapsed, 3),
|
||||
"tokens": usage.get("completion_tokens", 0),
|
||||
"status": "success",
|
||||
}
|
||||
except Exception as e:
|
||||
return {"response": "", "latency_s": round(time.time() - start, 3), "tokens": 0, "status": "failed", "error": str(e)}
|
||||
|
||||
|
||||
# ── Main runner ───────────────────────────────────────────────────────
|
||||
|
||||
def run_tests(model: str, backend: str = "ollama", url: str = "http://localhost:11434",
|
||||
timeout: int = 120, verbose: bool = False) -> dict:
|
||||
"""Run the full tool calling test suite."""
|
||||
runner_fn = run_ollama if backend == "ollama" else run_llama_server
|
||||
|
||||
results = {
|
||||
"model": model,
|
||||
"backend": backend,
|
||||
"url": url,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"tests": [],
|
||||
"summary": {"total": 0, "passed": 0, "failed": 0, "errors": 0},
|
||||
}
|
||||
|
||||
print(f"Testing tool calling on: {model} ({backend})\n")
|
||||
|
||||
for test in TESTS:
|
||||
print(f" [{test['id']}] {test['description']}...", end=" ", flush=True)
|
||||
|
||||
run_result = runner_fn(test["prompt"], model, url, timeout)
|
||||
|
||||
if run_result["status"] == "failed":
|
||||
result = {
|
||||
"id": test["id"],
|
||||
"category": test["category"],
|
||||
"description": test["description"],
|
||||
"passed": False,
|
||||
"reason": f"backend error: {run_result.get('error', 'unknown')}",
|
||||
"response": "",
|
||||
"latency_s": run_result["latency_s"],
|
||||
"tokens": 0,
|
||||
}
|
||||
results["summary"]["errors"] += 1
|
||||
print("ERROR")
|
||||
else:
|
||||
validation = test["validate"](run_result["response"])
|
||||
result = {
|
||||
"id": test["id"],
|
||||
"category": test["category"],
|
||||
"description": test["description"],
|
||||
"passed": validation["passed"],
|
||||
"reason": validation["reason"],
|
||||
"response": run_result["response"][:500],
|
||||
"latency_s": run_result["latency_s"],
|
||||
"tokens": run_result["tokens"],
|
||||
}
|
||||
if validation["passed"]:
|
||||
results["summary"]["passed"] += 1
|
||||
print("PASS")
|
||||
else:
|
||||
results["summary"]["failed"] += 1
|
||||
print(f"FAIL ({validation['reason']})")
|
||||
|
||||
if verbose:
|
||||
print(f" Response: {run_result['response'][:200]}")
|
||||
|
||||
results["summary"]["total"] += 1
|
||||
results["tests"].append(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def to_markdown(results: dict) -> str:
|
||||
"""Format test results as a markdown report."""
|
||||
lines = []
|
||||
lines.append(f"# Tool Calling Viability: {results['model']}")
|
||||
lines.append("")
|
||||
lines.append(f"**Date**: {results['timestamp']}")
|
||||
lines.append(f"**Backend**: {results['backend']} ({results['url']})")
|
||||
lines.append(f"**Model**: {results['model']}")
|
||||
lines.append("")
|
||||
|
||||
s = results["summary"]
|
||||
pass_rate = s["passed"] / s["total"] * 100 if s["total"] > 0 else 0
|
||||
lines.append(f"## Summary: {s['passed']}/{s['total']} passed ({pass_rate:.0f}%)")
|
||||
lines.append("")
|
||||
lines.append(f"| Metric | Value |")
|
||||
lines.append(f"|--------|-------|")
|
||||
lines.append(f"| Total tests | {s['total']} |")
|
||||
lines.append(f"| Passed | {s['passed']} |")
|
||||
lines.append(f"| Failed | {s['failed']} |")
|
||||
lines.append(f"| Errors | {s['errors']} |")
|
||||
lines.append("")
|
||||
|
||||
lines.append("## Results by Category")
|
||||
lines.append("")
|
||||
lines.append("| Test | Category | Result | Reason | Latency | Tokens |")
|
||||
lines.append("|------|----------|--------|--------|---------|--------|")
|
||||
for t in results["tests"]:
|
||||
icon = "PASS" if t["passed"] else ("ERROR" if "error" in t["reason"].lower() else "FAIL")
|
||||
lines.append(f"| {t['id']} | {t['category']} | {icon} | {t['reason']} | {t['latency_s']}s | {t['tokens']} |")
|
||||
lines.append("")
|
||||
|
||||
lines.append("## Verdict")
|
||||
lines.append("")
|
||||
if pass_rate == 100:
|
||||
lines.append("**FULLY VIABLE** — All tool calling patterns work. Ready for production edge deployment.")
|
||||
elif pass_rate >= 60:
|
||||
lines.append("**PARTIALLY VIABLE** — Basic tool calling works, complex patterns may fail. Consider for simple agents.")
|
||||
elif pass_rate >= 20:
|
||||
lines.append("**MARGINAL** — Only simplest tool calls work. Not recommended for production.")
|
||||
else:
|
||||
lines.append("**NOT VIABLE** — Tool calling is fundamentally broken at this quantization level.")
|
||||
lines.append("")
|
||||
|
||||
lines.append("## Failure Analysis")
|
||||
lines.append("")
|
||||
failed = [t for t in results["tests"] if not t["passed"]]
|
||||
if not failed:
|
||||
lines.append("No failures.")
|
||||
else:
|
||||
for t in failed:
|
||||
lines.append(f"### {t['id']}")
|
||||
lines.append(f"- **Category**: {t['category']}")
|
||||
lines.append(f"- **Failure**: {t['reason']}")
|
||||
lines.append(f"- **Response** (first 300 chars): `{t['response'][:300]}`")
|
||||
lines.append("")
|
||||
lines.append("")
|
||||
|
||||
lines.append("## Recommendations")
|
||||
lines.append("")
|
||||
if pass_rate >= 80:
|
||||
lines.append("- Deploy for simple single-tool-call workflows")
|
||||
lines.append("- Add retry logic for multi-step chains")
|
||||
lines.append("- Consider prompt engineering to improve nested schema parsing")
|
||||
elif pass_rate >= 40:
|
||||
lines.append("- Use for keyword/rule-based tool routing only")
|
||||
lines.append("- Do NOT use for complex multi-step workflows")
|
||||
lines.append("- Consider a larger model (Q4 quantized) as fallback")
|
||||
else:
|
||||
lines.append("- 1-bit quantization is too lossy for tool calling")
|
||||
lines.append("- Use Q4_0 as minimum viable quantization for tool use")
|
||||
lines.append("- Reserve 1-bit models for text generation only")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Tool Calling Viability Test for Edge Models")
|
||||
parser.add_argument("--model", "-m", required=True, help="Model name")
|
||||
parser.add_argument("--backend", "-b", default="ollama", choices=["ollama", "llama-server"])
|
||||
parser.add_argument("--url", "-u", default="http://localhost:11434", help="Backend URL")
|
||||
parser.add_argument("--timeout", "-t", type=int, default=120, help="Timeout per test (seconds)")
|
||||
parser.add_argument("--output", "-o", help="Output markdown file path")
|
||||
parser.add_argument("--json", action="store_true", help="JSON output")
|
||||
parser.add_argument("--verbose", "-v", action="store_true", help="Show full responses")
|
||||
args = parser.parse_args()
|
||||
|
||||
results = run_tests(args.model, args.backend, args.url, args.timeout, args.verbose)
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(results, indent=2))
|
||||
else:
|
||||
md = to_markdown(results)
|
||||
if args.output:
|
||||
with open(args.output, "w") as f:
|
||||
f.write(md)
|
||||
print(f"\nReport written to: {args.output}")
|
||||
else:
|
||||
print("\n" + md)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,117 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tests for benchmarks/quality_gate.py — Perplexity Quality Gate (#63)."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "benchmarks"))
|
||||
from quality_gate import (
|
||||
PerplexityResult,
|
||||
QualityGateResult,
|
||||
measure_perplexity_ollama_proxy,
|
||||
run_quality_gate,
|
||||
)
|
||||
|
||||
|
||||
class TestPerplexityResult:
|
||||
def test_to_dict(self):
|
||||
r = PerplexityResult(
|
||||
backend="llama-server", kv_type="f16",
|
||||
perplexity=12.5, is_proxy=False, tokens=1000,
|
||||
elapsed_seconds=10.0, method="llama-perplexity", exit_code=0,
|
||||
)
|
||||
d = r.to_dict()
|
||||
assert d["backend"] == "llama-server"
|
||||
assert d["perplexity"] == 12.5
|
||||
assert d["is_proxy"] is False
|
||||
|
||||
def test_proxy_flag(self):
|
||||
r = PerplexityResult(
|
||||
backend="ollama-proxy", kv_type="f16",
|
||||
perplexity=3.2, is_proxy=True, method="proxy heuristic",
|
||||
)
|
||||
assert r.is_proxy is True
|
||||
|
||||
|
||||
class TestQualityGateResult:
|
||||
def test_pass(self):
|
||||
f16 = PerplexityResult("llama-server", "f16", 10.0, False)
|
||||
turbo4 = PerplexityResult("llama-server", "turbo4", 10.3, False)
|
||||
gate = QualityGateResult(f16=f16, turbo4=turbo4, delta=0.3, threshold=0.5, passed=True, is_proxy=False)
|
||||
assert gate.passed is True
|
||||
assert gate.delta == 0.3
|
||||
|
||||
def test_fail(self):
|
||||
f16 = PerplexityResult("llama-server", "f16", 10.0, False)
|
||||
turbo4 = PerplexityResult("llama-server", "turbo4", 11.0, False)
|
||||
gate = QualityGateResult(f16=f16, turbo4=turbo4, delta=1.0, threshold=0.5, passed=False, is_proxy=False)
|
||||
assert gate.passed is False
|
||||
|
||||
def test_proxy_warning(self):
|
||||
f16 = PerplexityResult("ollama-proxy", "f16", 5.0, True)
|
||||
gate = QualityGateResult(f16=f16, turbo4=None, delta=None, threshold=0.5, passed=False, is_proxy=True, warning="Only F16 measured")
|
||||
assert gate.is_proxy is True
|
||||
summary = gate.summary()
|
||||
assert "PROXY" in summary or "Proxy" in summary
|
||||
|
||||
def test_to_dict(self):
|
||||
f16 = PerplexityResult("llama-server", "f16", 10.0, False)
|
||||
gate = QualityGateResult(f16=f16, turbo4=None, delta=None, threshold=0.5, passed=False, is_proxy=False)
|
||||
d = gate.to_dict()
|
||||
assert d["f16"]["perplexity"] == 10.0
|
||||
assert d["turbo4"] is None
|
||||
assert d["delta"] is None
|
||||
|
||||
def test_summary_format(self):
|
||||
f16 = PerplexityResult("llama-server", "f16", 10.0, False)
|
||||
turbo4 = PerplexityResult("llama-server", "turbo4", 10.2, False)
|
||||
gate = QualityGateResult(f16=f16, turbo4=turbo4, delta=0.2, threshold=0.5, passed=True, is_proxy=False)
|
||||
summary = gate.summary()
|
||||
assert "F16" in summary
|
||||
assert "Turbo4" in summary
|
||||
assert "PASS" in summary
|
||||
assert "0.2000" in summary
|
||||
|
||||
|
||||
class TestOllamaProxy:
|
||||
def test_with_corpus_file(self):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
|
||||
f.write("The quick brown fox jumps over the lazy dog.\n" * 100)
|
||||
f.flush()
|
||||
result = measure_perplexity_ollama_proxy("test-model", f.name)
|
||||
os.unlink(f.name)
|
||||
# Result should be proxy
|
||||
assert result.is_proxy is True
|
||||
assert result.backend == "ollama-proxy"
|
||||
|
||||
def test_with_missing_corpus(self):
|
||||
result = measure_perplexity_ollama_proxy("test-model", "/nonexistent/corpus.txt")
|
||||
assert result.is_proxy is True
|
||||
|
||||
|
||||
class TestRunQualityGate:
|
||||
def test_unknown_backend(self):
|
||||
result = run_quality_gate(backend="unknown", model="test")
|
||||
assert result.passed is False
|
||||
assert "Unknown backend" in result.warning
|
||||
|
||||
def test_llama_server_missing_binary(self):
|
||||
result = run_quality_gate(
|
||||
backend="llama-server",
|
||||
model="test.gguf",
|
||||
corpus="/tmp/nonexistent_corpus.txt",
|
||||
llama_bin="/nonexistent/llama-perplexity",
|
||||
)
|
||||
assert result.f16 is not None
|
||||
assert result.f16.error is not None
|
||||
assert "not found" in result.f16.error.lower()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
unittest.main()
|
||||
189
tests/test_tool_calling.py
Normal file
189
tests/test_tool_calling.py
Normal file
@@ -0,0 +1,189 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Unit tests for benchmarks/test_tool_calling.py
|
||||
|
||||
Tests the validation logic and report generation without
|
||||
requiring a live model backend.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "benchmarks"))
|
||||
import test_tool_calling as tc
|
||||
|
||||
|
||||
# ── JSON Extraction ───────────────────────────────────────────────────
|
||||
|
||||
class TestExtractJson:
|
||||
def test_direct_json(self):
|
||||
obj = tc._extract_json('{"name": "read_file", "arguments": {"path": "/etc/hostname"}}')
|
||||
assert obj["name"] == "read_file"
|
||||
|
||||
def test_json_in_code_block(self):
|
||||
text = 'Here is the call:\n```json\n{"name": "terminal", "arguments": {"command": "ls"}}\n```'
|
||||
obj = tc._extract_json(text)
|
||||
assert obj["name"] == "terminal"
|
||||
|
||||
def test_json_without_lang(self):
|
||||
text = '```\n{"name": "web_search", "arguments": {"query": "test"}}\n```'
|
||||
obj = tc._extract_json(text)
|
||||
assert obj["name"] == "web_search"
|
||||
|
||||
def test_no_json(self):
|
||||
obj = tc._extract_json("I can't help with that.")
|
||||
assert obj is None
|
||||
|
||||
def test_bare_json_object(self):
|
||||
text = 'Sure, here: {"name": "read_file", "arguments": {"path": "/tmp/x"}} for you.'
|
||||
obj = tc._extract_json(text)
|
||||
assert obj is not None
|
||||
assert obj["name"] == "read_file"
|
||||
|
||||
|
||||
# ── Tool Call Validation ──────────────────────────────────────────────
|
||||
|
||||
class TestToolCallValidation:
|
||||
def test_exact_match(self):
|
||||
resp = '{"name": "read_file", "arguments": {"path": "/etc/hostname"}}'
|
||||
result = tc._has_json_tool_call(resp, "read_file", {"path": "/etc/hostname"})
|
||||
assert result["passed"] is True
|
||||
|
||||
def test_wrong_tool_name(self):
|
||||
resp = '{"name": "write_file", "arguments": {"path": "/etc/hostname"}}'
|
||||
result = tc._has_json_tool_call(resp, "read_file", {"path": "/etc/hostname"})
|
||||
assert result["passed"] is False
|
||||
assert "wrong tool name" in result["reason"]
|
||||
|
||||
def test_missing_argument(self):
|
||||
resp = '{"name": "read_file", "arguments": {}}'
|
||||
result = tc._has_json_tool_call(resp, "read_file", {"path": "/etc/hostname"})
|
||||
assert result["passed"] is False
|
||||
assert "missing argument" in result["reason"]
|
||||
|
||||
def test_wrong_argument_value(self):
|
||||
resp = '{"name": "read_file", "arguments": {"path": "/etc/passwd"}}'
|
||||
result = tc._has_json_tool_call(resp, "read_file", {"path": "/etc/hostname"})
|
||||
assert result["passed"] is False
|
||||
assert "argument mismatch" in result["reason"]
|
||||
|
||||
def test_no_json_response(self):
|
||||
result = tc._has_json_tool_call("Sorry, I can't do that.", "read_file", {"path": "/etc/hostname"})
|
||||
assert result["passed"] is False
|
||||
assert "no JSON" in result["reason"]
|
||||
|
||||
def test_nested_function_format(self):
|
||||
resp = '{"function": {"name": "terminal", "arguments": {"command": "echo hello"}}}'
|
||||
result = tc._has_json_tool_call(resp, "terminal", {"command": "echo hello"})
|
||||
assert result["passed"] is True
|
||||
|
||||
|
||||
# ── Nested Schema Validation ──────────────────────────────────────────
|
||||
|
||||
class TestNestedSchemaValidation:
|
||||
def test_valid_nested(self):
|
||||
resp = json.dumps({
|
||||
"name": "deploy_service",
|
||||
"arguments": {
|
||||
"name": "api-gateway",
|
||||
"replicas": 3,
|
||||
"env": {"PORT": 8080, "NODE_ENV": "production"},
|
||||
"resources": {"cpu": "500m", "memory": "256Mi"}
|
||||
}
|
||||
})
|
||||
result = tc._has_nested_tool_call(resp)
|
||||
assert result["passed"] is True
|
||||
|
||||
def test_missing_nested_key(self):
|
||||
resp = '{"name": "deploy_service", "arguments": {"name": "api-gateway", "replicas": 3}}'
|
||||
result = tc._has_nested_tool_call(resp)
|
||||
assert result["passed"] is False
|
||||
assert "missing nested key" in result["reason"]
|
||||
|
||||
def test_wrong_type(self):
|
||||
resp = '{"name": "deploy_service", "arguments": {"name": "api-gateway", "replicas": "three", "env": {}, "resources": {}}}'
|
||||
result = tc._has_nested_tool_call(resp)
|
||||
assert result["passed"] is False
|
||||
assert "should be int" in result["reason"]
|
||||
|
||||
def test_missing_env_port(self):
|
||||
resp = json.dumps({
|
||||
"name": "deploy_service",
|
||||
"arguments": {"name": "api", "replicas": 1, "env": {"NODE_ENV": "dev"}, "resources": {}}
|
||||
})
|
||||
result = tc._has_nested_tool_call(resp)
|
||||
assert result["passed"] is False
|
||||
assert "PORT" in result["reason"]
|
||||
|
||||
|
||||
# ── Markdown Report Generation ────────────────────────────────────────
|
||||
|
||||
class TestMarkdownReport:
|
||||
def test_report_structure(self):
|
||||
results = {
|
||||
"model": "test-model",
|
||||
"backend": "ollama",
|
||||
"url": "http://localhost:11434",
|
||||
"timestamp": "2026-04-15T00:00:00Z",
|
||||
"tests": [
|
||||
{"id": "t1", "category": "Simple", "description": "Test 1",
|
||||
"passed": True, "reason": "ok", "response": "{}", "latency_s": 1.0, "tokens": 10},
|
||||
{"id": "t2", "category": "Complex", "description": "Test 2",
|
||||
"passed": False, "reason": "wrong name", "response": "oops", "latency_s": 2.0, "tokens": 20},
|
||||
],
|
||||
"summary": {"total": 2, "passed": 1, "failed": 1, "errors": 0},
|
||||
}
|
||||
md = tc.to_markdown(results)
|
||||
assert "test-model" in md
|
||||
assert "1/2 passed" in md
|
||||
assert "PASS" in md
|
||||
assert "FAIL" in md
|
||||
assert "Failure Analysis" in md
|
||||
|
||||
def test_perfect_score(self):
|
||||
results = {
|
||||
"model": "perfect", "backend": "ollama", "url": "http://x",
|
||||
"timestamp": "2026-01-01T00:00:00Z",
|
||||
"tests": [
|
||||
{"id": "t1", "category": "C", "description": "D",
|
||||
"passed": True, "reason": "ok", "response": "{}", "latency_s": 1, "tokens": 5},
|
||||
],
|
||||
"summary": {"total": 1, "passed": 1, "failed": 0, "errors": 0},
|
||||
}
|
||||
md = tc.to_markdown(results)
|
||||
assert "FULLY VIABLE" in md
|
||||
|
||||
def test_all_failed(self):
|
||||
results = {
|
||||
"model": "bad", "backend": "ollama", "url": "http://x",
|
||||
"timestamp": "2026-01-01T00:00:00Z",
|
||||
"tests": [
|
||||
{"id": "t1", "category": "C", "description": "D",
|
||||
"passed": False, "reason": "broken", "response": "nope", "latency_s": 1, "tokens": 0},
|
||||
],
|
||||
"summary": {"total": 1, "passed": 0, "failed": 1, "errors": 0},
|
||||
}
|
||||
md = tc.to_markdown(results)
|
||||
assert "NOT VIABLE" in md
|
||||
|
||||
|
||||
# ── Test Definitions ──────────────────────────────────────────────────
|
||||
|
||||
class TestTestDefinitions:
|
||||
def test_all_tests_have_validators(self):
|
||||
for test in tc.TESTS:
|
||||
assert callable(test["validate"]), f"{test['id']} missing validate"
|
||||
assert "id" in test
|
||||
assert "category" in test
|
||||
assert "prompt" in test
|
||||
|
||||
def test_five_test_categories(self):
|
||||
categories = {t["category"] for t in tc.TESTS}
|
||||
assert len(categories) >= 4, f"Expected 4+ categories, got {categories}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user