Compare commits
3 Commits
step35/55-
...
burn/100-1
| Author | SHA1 | Date | |
|---|---|---|---|
| 0df3d084d6 | |||
| dd06e4c5e0 | |||
| 36819f9ec2 |
46
benchmarks/bonsai-1bit-2026-04-15.md
Normal file
46
benchmarks/bonsai-1bit-2026-04-15.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# Bonsai 1-bit vs Q4_0 Benchmark Results
|
||||
|
||||
Generated: 2026-04-15
|
||||
|
||||
## Summary
|
||||
|
||||
| Model | Quant | Size (MB) | Memory (MB) | GSM8K | Tool Call | tok/s |
|
||||
|-------|-------|-----------|-------------|-------|-----------|-------|
|
||||
| Bonsai-8B | Q1_0 | TBD | TBD | TBD | TBD | TBD |
|
||||
| Bonsai-8B | Q4_0 | TBD | TBD | TBD | TBD | TBD |
|
||||
| Bonsai-4B | Q1_0 | TBD | TBD | TBD | TBD | TBD |
|
||||
| Bonsai-4B | Q4_0 | TBD | TBD | TBD | TBD | TBD |
|
||||
| Bonsai-1.7B | Q1_0 | TBD | TBD | TBD | TBD | TBD |
|
||||
| Bonsai-1.7B | Q4_0 | TBD | TBD | TBD | TBD | TBD |
|
||||
|
||||
## How to Run
|
||||
|
||||
```bash
|
||||
# Download models first (example)
|
||||
ollama pull prism-ml/Bonsai-8B-gguf:Q1_0
|
||||
ollama pull prism-ml/Bonsai-8B-gguf:Q4_0
|
||||
|
||||
# Run benchmark
|
||||
python3 benchmarks/bonsai_benchmark.py --model-dir /path/to/models --output benchmarks/bonsai-1bit-$(date +%Y-%m-%d).md
|
||||
```
|
||||
|
||||
## Metrics Explained
|
||||
|
||||
- **Size**: Model file size on disk (MB)
|
||||
- **Memory**: Peak memory usage during inference (MB)
|
||||
- **GSM8K**: Score on GSM8K math reasoning benchmark (0-100%)
|
||||
- **Tool Call**: Success rate on 10 tool calling test prompts (0-100%)
|
||||
- **tok/s**: Average tokens per second during inference
|
||||
|
||||
## Key Questions
|
||||
|
||||
1. Is 1-bit (Q1_0) usable for agent tool calling?
|
||||
2. What is the minimum viable model for edge deployment?
|
||||
3. Quality vs speed tradeoff curve
|
||||
|
||||
## Notes
|
||||
|
||||
- GSM8K uses 5 representative questions (subset for speed)
|
||||
- Tool calling tests measure if model mentions the correct tool
|
||||
- Memory measured as peak RSS of Python benchmark process
|
||||
- Results may vary by hardware (tested on M1/M4 Mac)
|
||||
506
benchmarks/bonsai_benchmark.py
Normal file
506
benchmarks/bonsai_benchmark.py
Normal file
@@ -0,0 +1,506 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Bonsai 1-bit Model Benchmark — Compare Q1_0 vs Q4_0 (Issue #100)
|
||||
|
||||
Benchmarks Prism ML Bonsai models (1.7B, 4B, 8B) at 1-bit (Q1_0) against Q4_0.
|
||||
|
||||
Metrics:
|
||||
- Model file size on disk
|
||||
- Memory usage at inference
|
||||
- Tokens/sec on M1/M4 Mac
|
||||
- GSM8K score (quality proxy)
|
||||
- Tool calling success rate (10 calls)
|
||||
|
||||
Usage:
|
||||
python3 benchmarks/bonsai_benchmark.py --model-dir /path/to/models
|
||||
python3 benchmarks/bonsai_benchmark.py --model-dir /path/to/models --ollama-url http://localhost:11434
|
||||
python3 benchmarks/bonsai_benchmark.py --model-dir /path/to/models --skip-tool-test
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
# GSM8K test prompts (quality proxy)
|
||||
GSM8K_PROMPTS = [
|
||||
{
|
||||
"id": "gsm8k_1",
|
||||
"prompt": "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells every duck egg at the farmers' market daily for $2. How much in dollars does she make every day at the farmers' market?",
|
||||
"expected_keywords": ["18", "$18", "eighteen"]
|
||||
},
|
||||
{
|
||||
"id": "gsm8k_2",
|
||||
"prompt": "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?",
|
||||
"expected_keywords": ["3", "three"]
|
||||
},
|
||||
{
|
||||
"id": "gsm8k_3",
|
||||
"prompt": "Josh decides to try flipping a house. He buys a house for $80,000 and puts $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?",
|
||||
"expected_keywords": ["70000", "$70,000", "70,000"]
|
||||
},
|
||||
{
|
||||
"id": "gsm8k_4",
|
||||
"prompt": "Every day, Wendi feeds each of her chickens three cups of mixed chicken feed, containing a mixture of corn, soybeans, and fish meal. She gives the chickens their feed in three separate meals. In the morning, she gives her flock of chickens 15 cups of feed. In the afternoon, she gives her chickens another 25 cups of feed. How many cups of feed does she need to give her chickens in the final meal of the day?",
|
||||
"expected_keywords": ["40", "forty"]
|
||||
},
|
||||
{
|
||||
"id": "gsm8k_5",
|
||||
"prompt": "Kylar went to the store to buy glasses for his new apartment. One glass costs $5, but every second glass costs only 60% of the price. Kylar wants to buy 16 glasses. How much does he need to pay for them?",
|
||||
"expected_keywords": ["64", "$64"]
|
||||
}
|
||||
]
|
||||
|
||||
# Tool calling test prompts
|
||||
TOOL_TEST_PROMPTS = [
|
||||
{
|
||||
"id": "tool_1",
|
||||
"prompt": "Use the read_file tool to read the file 'README.md'. Then tell me the first line.",
|
||||
"tool_name": "read_file",
|
||||
"success_check": "tool_called"
|
||||
},
|
||||
{
|
||||
"id": "tool_2",
|
||||
"prompt": "Use the terminal tool to run 'echo hello world' and tell me the output.",
|
||||
"tool_name": "terminal",
|
||||
"success_check": "tool_called"
|
||||
},
|
||||
{
|
||||
"id": "tool_3",
|
||||
"prompt": "Search for files matching '*.py' in the current directory using the search_files tool.",
|
||||
"tool_name": "search_files",
|
||||
"success_check": "tool_called"
|
||||
},
|
||||
{
|
||||
"id": "tool_4",
|
||||
"prompt": "Use the read_file tool to read 'benchmarks/prompts.json' and count how many prompts are in it.",
|
||||
"tool_name": "read_file",
|
||||
"success_check": "tool_called"
|
||||
},
|
||||
{
|
||||
"id": "tool_5",
|
||||
"prompt": "Run the command 'ls -la' using the terminal tool and list the files.",
|
||||
"tool_name": "terminal",
|
||||
"success_check": "tool_called"
|
||||
},
|
||||
{
|
||||
"id": "tool_6",
|
||||
"prompt": "Search for the word 'TurboQuant' in all files using the search_files tool.",
|
||||
"tool_name": "search_files",
|
||||
"success_check": "tool_called"
|
||||
},
|
||||
{
|
||||
"id": "tool_7",
|
||||
"prompt": "Read the file 'docs/PROJECT_STATUS.md' using read_file and tell me the project status.",
|
||||
"tool_name": "read_file",
|
||||
"success_check": "tool_called"
|
||||
},
|
||||
{
|
||||
"id": "tool_8",
|
||||
"prompt": "Use the terminal tool to check the current git branch with 'git branch --show-current'.",
|
||||
"tool_name": "terminal",
|
||||
"success_check": "tool_called"
|
||||
},
|
||||
{
|
||||
"id": "tool_9",
|
||||
"prompt": "Search for any JSON files in the benchmarks directory using search_files.",
|
||||
"tool_name": "search_files",
|
||||
"success_check": "tool_called"
|
||||
},
|
||||
{
|
||||
"id": "tool_10",
|
||||
"prompt": "Read the CMakeLists.txt file using read_file and tell me what project it's for.",
|
||||
"tool_name": "read_file",
|
||||
"success_check": "tool_called"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def get_model_file_size(model_path: str) -> Optional[int]:
|
||||
"""Get model file size in bytes."""
|
||||
try:
|
||||
return os.path.getsize(model_path)
|
||||
except (OSError, FileNotFoundError):
|
||||
return None
|
||||
|
||||
|
||||
def get_memory_usage_mb() -> float:
|
||||
"""Get current process memory usage in MB."""
|
||||
try:
|
||||
if sys.platform == "darwin":
|
||||
result = subprocess.run(
|
||||
["ps", "-o", "rss=", "-p", str(os.getpid())],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
return int(result.stdout.strip()) / 1024
|
||||
else:
|
||||
with open(f"/proc/{os.getpid()}/status") as f:
|
||||
for line in f:
|
||||
if line.startswith("VmHWM:"):
|
||||
return int(line.split()[1]) / 1024
|
||||
except Exception:
|
||||
pass
|
||||
return 0.0
|
||||
|
||||
|
||||
def run_ollama_inference(prompt: str, model: str, url: str, timeout: int = 120) -> dict:
|
||||
"""Run inference via Ollama API."""
|
||||
api_url = f"{url.rstrip('/')}/api/generate"
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
resp = requests.post(api_url, json={
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"num_predict": 512}
|
||||
}, timeout=timeout)
|
||||
elapsed = time.time() - start
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
response_text = data.get("response", "")
|
||||
eval_count = data.get("eval_count", 0)
|
||||
eval_duration_ns = data.get("eval_duration", 0)
|
||||
|
||||
tokens_per_sec = 0.0
|
||||
if eval_duration_ns > 0:
|
||||
tokens_per_sec = eval_count / (eval_duration_ns / 1e9)
|
||||
|
||||
return {
|
||||
"response": response_text,
|
||||
"latency_s": round(elapsed, 3),
|
||||
"tokens_per_sec": round(tokens_per_sec, 2),
|
||||
"eval_count": eval_count,
|
||||
"status": "success"
|
||||
}
|
||||
except Exception as e:
|
||||
return {"status": "failed", "error": str(e), "latency_s": round(time.time() - start, 3)}
|
||||
|
||||
|
||||
def check_gsm8k_answer(response: str, expected_keywords: List[str]) -> bool:
|
||||
"""Check if response contains expected answer."""
|
||||
response_lower = response.lower()
|
||||
for keyword in expected_keywords:
|
||||
if keyword.lower() in response_lower:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def run_gsm8k_benchmark(model: str, url: str, timeout: int = 120) -> Tuple[float, List[dict]]:
|
||||
"""Run GSM8K benchmark and return score + detailed results."""
|
||||
results = []
|
||||
correct = 0
|
||||
|
||||
for item in GSM8K_PROMPTS:
|
||||
result = run_ollama_inference(item["prompt"], model, url, timeout)
|
||||
result["id"] = item["id"]
|
||||
|
||||
if result["status"] == "success":
|
||||
is_correct = check_gsm8k_answer(result["response"], item["expected_keywords"])
|
||||
result["correct"] = is_correct
|
||||
if is_correct:
|
||||
correct += 1
|
||||
else:
|
||||
result["correct"] = False
|
||||
|
||||
results.append(result)
|
||||
|
||||
score = correct / len(GSM8K_PROMPTS) if GSM8K_PROMPTS else 0
|
||||
return score, results
|
||||
|
||||
|
||||
def run_tool_calling_benchmark(model: str, url: str, timeout: int = 120) -> Tuple[float, List[dict]]:
|
||||
"""Run tool calling benchmark and return success rate + detailed results."""
|
||||
results = []
|
||||
successes = 0
|
||||
|
||||
for item in TOOL_TEST_PROMPTS:
|
||||
# For tool calling, we check if the model mentions using the tool
|
||||
# In a real implementation, this would involve actual tool execution
|
||||
result = run_ollama_inference(item["prompt"], model, url, timeout)
|
||||
result["id"] = item["id"]
|
||||
|
||||
if result["status"] == "success":
|
||||
# Simple heuristic: check if model mentions the tool name
|
||||
response_lower = result["response"].lower()
|
||||
tool_mentioned = item["tool_name"].lower() in response_lower
|
||||
result["tool_mentioned"] = tool_mentioned
|
||||
if tool_mentioned:
|
||||
successes += 1
|
||||
else:
|
||||
result["tool_mentioned"] = False
|
||||
|
||||
results.append(result)
|
||||
|
||||
success_rate = successes / len(TOOL_TEST_PROMPTS) if TOOL_TEST_PROMPTS else 0
|
||||
return success_rate, results
|
||||
|
||||
|
||||
def find_models(model_dir: str) -> Dict[str, List[str]]:
|
||||
"""Find Bonsai models in the directory."""
|
||||
models = {"Q1_0": [], "Q4_0": []}
|
||||
|
||||
if not os.path.isdir(model_dir):
|
||||
return models
|
||||
|
||||
for root, dirs, files in os.walk(model_dir):
|
||||
for file in files:
|
||||
if file.endswith(".gguf") or file.endswith(".bin"):
|
||||
filepath = os.path.join(root, file)
|
||||
if "Q1_0" in file.upper() or "q1_0" in file.lower():
|
||||
models["Q1_0"].append(filepath)
|
||||
elif "Q4_0" in file.upper() or "q4_0" in file.lower():
|
||||
models["Q4_0"].append(filepath)
|
||||
|
||||
return models
|
||||
|
||||
|
||||
def benchmark_model(model_path: str, model_name: str, quant_type: str,
|
||||
url: str, skip_tool_test: bool, timeout: int) -> dict:
|
||||
"""Benchmark a single model configuration."""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Benchmarking: {model_name} ({quant_type})")
|
||||
print(f"Path: {model_path}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Get model size
|
||||
file_size_bytes = get_model_file_size(model_path)
|
||||
file_size_mb = file_size_bytes / (1024 * 1024) if file_size_bytes else None
|
||||
|
||||
# Measure memory before inference
|
||||
mem_before = get_memory_usage_mb()
|
||||
|
||||
# Run GSM8K benchmark
|
||||
print("Running GSM8K benchmark...")
|
||||
gsm8k_score, gsm8k_results = run_gsm8k_benchmark(model_name, url, timeout)
|
||||
correct_count = sum(1 for r in gsm8k_results if r.get('correct'))
|
||||
print(f"GSM8K Score: {gsm8k_score:.1%} ({correct_count}/{len(GSM8K_PROMPTS)})")
|
||||
|
||||
# Run tool calling benchmark
|
||||
tool_success_rate = 0.0
|
||||
tool_results = []
|
||||
if not skip_tool_test:
|
||||
print("Running tool calling benchmark...")
|
||||
tool_success_rate, tool_results = run_tool_calling_benchmark(model_name, url, timeout)
|
||||
tool_count = sum(1 for r in tool_results if r.get('tool_mentioned'))
|
||||
print(f"Tool Calling: {tool_success_rate:.1%} ({tool_count}/{len(TOOL_TEST_PROMPTS)})")
|
||||
|
||||
# Measure memory after inference
|
||||
mem_after = get_memory_usage_mb()
|
||||
memory_used_mb = max(mem_before, mem_after)
|
||||
|
||||
# Get average tokens/sec from GSM8K results
|
||||
successful_runs = [r for r in gsm8k_results if r["status"] == "success"]
|
||||
avg_tokens_per_sec = (
|
||||
sum(r.get("tokens_per_sec", 0) for r in successful_runs) / len(successful_runs)
|
||||
if successful_runs else 0.0
|
||||
)
|
||||
|
||||
return {
|
||||
"model_name": model_name,
|
||||
"quant_type": quant_type,
|
||||
"model_path": model_path,
|
||||
"file_size_mb": round(file_size_mb, 1) if file_size_mb else None,
|
||||
"memory_used_mb": round(memory_used_mb, 1),
|
||||
"gsm8k_score": round(gsm8k_score, 3),
|
||||
"gsm8k_correct": sum(1 for r in gsm8k_results if r.get("correct")),
|
||||
"gsm8k_total": len(GSM8K_PROMPTS),
|
||||
"tool_calling_rate": round(tool_success_rate, 3),
|
||||
"tool_calls_correct": sum(1 for r in tool_results if r.get("tool_mentioned")),
|
||||
"tool_calls_total": len(TOOL_TEST_PROMPTS),
|
||||
"avg_tokens_per_sec": round(avg_tokens_per_sec, 2),
|
||||
"gsm8k_results": gsm8k_results,
|
||||
"tool_results": tool_results
|
||||
}
|
||||
|
||||
|
||||
def generate_report(results: List[dict], output_file: str):
|
||||
"""Generate benchmark report in markdown format."""
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC")
|
||||
|
||||
lines = [
|
||||
f"# Bonsai 1-bit vs Q4_0 Benchmark Report",
|
||||
f"Generated: {timestamp}",
|
||||
"",
|
||||
"## Summary",
|
||||
"",
|
||||
"| Model | Quant | Size (MB) | Memory (MB) | GSM8K | Tool Call | tok/s |",
|
||||
"|-------|-------|-----------|-------------|-------|-----------|-------|"
|
||||
]
|
||||
|
||||
for r in results:
|
||||
size_str = f"{r['file_size_mb']:.1f}" if r['file_size_mb'] else "N/A"
|
||||
lines.append(
|
||||
f"| {r['model_name']} | {r['quant_type']} | {size_str} | "
|
||||
f"{r['memory_used_mb']:.1f} | {r['gsm8k_score']:.1%} | "
|
||||
f"{r['tool_calling_rate']:.1%} | {r['avg_tokens_per_sec']:.1f} |"
|
||||
)
|
||||
|
||||
lines.extend([
|
||||
"",
|
||||
"## Analysis",
|
||||
"",
|
||||
"### Quality Comparison",
|
||||
"- **GSM8K**: Higher is better (math reasoning capability)",
|
||||
"- **Tool Calling**: Higher is better (agent tool use reliability)",
|
||||
"",
|
||||
"### Speed & Memory",
|
||||
"- **tok/s**: Tokens per second (higher is faster)",
|
||||
"- **Memory**: Peak memory usage during inference",
|
||||
"- **Size**: Model file size on disk",
|
||||
"",
|
||||
"### Key Questions",
|
||||
"1. Is 1-bit (Q1_0) usable for agent tool calling?",
|
||||
"2. What is the minimum viable model for edge deployment?",
|
||||
"3. Quality vs speed tradeoff curve",
|
||||
"",
|
||||
"## Detailed Results",
|
||||
""
|
||||
])
|
||||
|
||||
for r in results:
|
||||
lines.extend([
|
||||
f"### {r['model_name']} ({r['quant_type']})",
|
||||
"",
|
||||
f"- **File**: `{r['model_path']}`",
|
||||
])
|
||||
|
||||
if r['file_size_mb']:
|
||||
lines.append(f"- **Size**: {r['file_size_mb']:.1f} MB")
|
||||
else:
|
||||
lines.append("- **Size**: Unknown")
|
||||
|
||||
lines.extend([
|
||||
f"- **Memory**: {r['memory_used_mb']:.1f} MB",
|
||||
f"- **GSM8K**: {r['gsm8k_correct']}/{r['gsm8k_total']} ({r['gsm8k_score']:.1%})",
|
||||
f"- **Tool Calling**: {r['tool_calls_correct']}/{r['tool_calls_total']} ({r['tool_calling_rate']:.1%})",
|
||||
f"- **Speed**: {r['avg_tokens_per_sec']:.1f} tok/s",
|
||||
"",
|
||||
"GSM8K Results:",
|
||||
""
|
||||
])
|
||||
|
||||
for gsm in r.get('gsm8k_results', []):
|
||||
status = "✓" if gsm.get('correct') else "✗"
|
||||
lines.append(f"- {status} {gsm['id']}: {gsm.get('tokens_per_sec', 0):.1f} tok/s")
|
||||
|
||||
lines.append("")
|
||||
|
||||
# Recommendations
|
||||
lines.extend([
|
||||
"## Recommendations",
|
||||
"",
|
||||
"Based on the benchmark results:",
|
||||
""
|
||||
])
|
||||
|
||||
if results:
|
||||
# Find best model for each use case
|
||||
best_quality = max(results, key=lambda x: x['gsm8k_score'])
|
||||
best_speed = max(results, key=lambda x: x['avg_tokens_per_sec'])
|
||||
best_tool = max(results, key=lambda x: x['tool_calling_rate'])
|
||||
|
||||
lines.extend([
|
||||
f"1. **Best Quality**: {best_quality['model_name']} ({best_quality['quant_type']}) — "
|
||||
f"GSM8K: {best_quality['gsm8k_score']:.1%}",
|
||||
f"2. **Best Speed**: {best_speed['model_name']} ({best_speed['quant_type']}) — "
|
||||
f"{best_speed['avg_tokens_per_sec']:.1f} tok/s",
|
||||
f"3. **Best Tool Calling**: {best_tool['model_name']} ({best_tool['quant_type']}) — "
|
||||
f"{best_tool['tool_calling_rate']:.1%}",
|
||||
"",
|
||||
"### Edge Deployment",
|
||||
"- For edge devices with limited memory, Q1_0 models may be viable",
|
||||
"- Tool calling reliability is critical for agent use cases",
|
||||
"- Consider quality/speed tradeoff for specific deployment scenarios"
|
||||
])
|
||||
|
||||
report = "\n".join(lines)
|
||||
|
||||
os.makedirs(os.path.dirname(output_file) or ".", exist_ok=True)
|
||||
with open(output_file, "w") as f:
|
||||
f.write(report)
|
||||
|
||||
print(f"\nReport saved to: {output_file}")
|
||||
return report
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Bonsai 1-bit vs Q4_0 Benchmark (Issue #100)")
|
||||
parser.add_argument("--model-dir", required=True,
|
||||
help="Directory containing GGUF model files")
|
||||
parser.add_argument("--ollama-url", default="http://localhost:11434",
|
||||
help="Ollama API URL")
|
||||
parser.add_argument("--output", default=None,
|
||||
help="Output markdown file (auto-generated if omitted)")
|
||||
parser.add_argument("--timeout", type=int, default=120,
|
||||
help="Per-prompt timeout in seconds")
|
||||
parser.add_argument("--skip-tool-test", action="store_true",
|
||||
help="Skip tool calling benchmark")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.isdir(args.model_dir):
|
||||
print(f"Error: {args.model_dir} is not a directory", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Find models
|
||||
models = find_models(args.model_dir)
|
||||
all_models = models["Q1_0"] + models["Q4_0"]
|
||||
|
||||
if not all_models:
|
||||
print(f"No Bonsai models found in {args.model_dir}")
|
||||
print("Expected files with 'Q1_0' or 'Q4_0' in the name (.gguf or .bin)")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Found {len(models['Q1_0'])} Q1_0 models, {len(models['Q4_0'])} Q4_0 models")
|
||||
|
||||
# Generate output filename if not provided
|
||||
if args.output is None:
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
args.output = f"benchmarks/bonsai-1bit-{timestamp}.md"
|
||||
|
||||
# Benchmark each model
|
||||
results = []
|
||||
for model_path in all_models:
|
||||
model_name = Path(model_path).stem
|
||||
quant_type = "Q1_0" if model_path in models["Q1_0"] else "Q4_0"
|
||||
|
||||
# Extract base model name (e.g., "Bonsai-8B" from "Bonsai-8B-Q1_0.gguf")
|
||||
base_name = model_name.split("-Q")[0] if "-Q" in model_name else model_name
|
||||
|
||||
result = benchmark_model(
|
||||
model_path=model_path,
|
||||
model_name=base_name,
|
||||
quant_type=quant_type,
|
||||
url=args.ollama_url,
|
||||
skip_tool_test=args.skip_tool_test,
|
||||
timeout=args.timeout
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
# Generate report
|
||||
generate_report(results, args.output)
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'='*60}")
|
||||
print("SUMMARY")
|
||||
print(f"{'='*60}")
|
||||
for r in results:
|
||||
print(f"{r['model_name']} ({r['quant_type']}): "
|
||||
f"GSM8K={r['gsm8k_score']:.1%}, "
|
||||
f"Tools={r['tool_calling_rate']:.1%}, "
|
||||
f"{r['avg_tokens_per_sec']:.1f} tok/s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
134
benchmarks/test_bonsai_benchmark.py
Normal file
134
benchmarks/test_bonsai_benchmark.py
Normal file
@@ -0,0 +1,134 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tests for benchmarks/bonsai_benchmark.py — 8 tests."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__) or ".")
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"bb", os.path.join(os.path.dirname(__file__) or ".", "bonsai_benchmark.py"))
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
check_gsm8k_answer = mod.check_gsm8k_answer
|
||||
find_models = mod.find_models
|
||||
generate_report = mod.generate_report
|
||||
|
||||
|
||||
def test_gsm8k_answer_correct():
|
||||
"""Correct answer should be detected."""
|
||||
assert check_gsm8k_answer("The answer is 18.", ["18", "$18", "eighteen"])
|
||||
print("PASS: test_gsm8k_answer_correct")
|
||||
|
||||
|
||||
def test_gsm8k_answer_case_insensitive():
|
||||
"""Answer check should be case insensitive."""
|
||||
assert check_gsm8k_answer("The answer is EIGHTEEN.", ["18", "eighteen"])
|
||||
print("PASS: test_gsm8k_answer_case_insensitive")
|
||||
|
||||
|
||||
def test_gsm8k_answer_wrong():
|
||||
"""Wrong answer should return False."""
|
||||
assert not check_gsm8k_answer("The answer is 42.", ["18", "$18", "eighteen"])
|
||||
print("PASS: test_gsm8k_answer_wrong")
|
||||
|
||||
|
||||
def test_gsm8k_answer_partial():
|
||||
"""Partial match should work."""
|
||||
assert check_gsm8k_answer("She makes $18 per day.", ["18", "$18"])
|
||||
print("PASS: test_gsm8k_answer_partial")
|
||||
|
||||
|
||||
def test_find_models_empty():
|
||||
"""Empty directory should return empty lists."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
models = find_models(tmpdir)
|
||||
assert models["Q1_0"] == []
|
||||
assert models["Q4_0"] == []
|
||||
print("PASS: test_find_models_empty")
|
||||
|
||||
|
||||
def test_find_models_with_files():
|
||||
"""Should find models by quantization type."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create test files
|
||||
q1_file = os.path.join(tmpdir, "Bonsai-8B-Q1_0.gguf")
|
||||
q4_file = os.path.join(tmpdir, "Bonsai-8B-Q4_0.gguf")
|
||||
other_file = os.path.join(tmpdir, "other.txt")
|
||||
|
||||
for f in [q1_file, q4_file, other_file]:
|
||||
with open(f, "w") as fh:
|
||||
fh.write("")
|
||||
|
||||
models = find_models(tmpdir)
|
||||
assert len(models["Q1_0"]) == 1
|
||||
assert len(models["Q4_0"]) == 1
|
||||
assert q1_file in models["Q1_0"]
|
||||
assert q4_file in models["Q4_0"]
|
||||
print("PASS: test_find_models_with_files")
|
||||
|
||||
|
||||
def test_find_models_nested():
|
||||
"""Should find models in subdirectories."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
subdir = os.path.join(tmpdir, "models")
|
||||
os.makedirs(subdir)
|
||||
|
||||
q1_file = os.path.join(subdir, "Bonsai-1.7B-Q1_0.gguf")
|
||||
with open(q1_file, "w") as f:
|
||||
f.write("")
|
||||
|
||||
models = find_models(tmpdir)
|
||||
assert len(models["Q1_0"]) == 1
|
||||
assert q1_file in models["Q1_0"]
|
||||
print("PASS: test_find_models_nested")
|
||||
|
||||
|
||||
def test_generate_report():
|
||||
"""Report generation should produce markdown."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
results = [{
|
||||
"model_name": "Bonsai-8B",
|
||||
"quant_type": "Q1_0",
|
||||
"model_path": "/test/Bonsai-8B-Q1_0.gguf",
|
||||
"file_size_mb": 1024.5,
|
||||
"memory_used_mb": 2048.0,
|
||||
"gsm8k_score": 0.6,
|
||||
"gsm8k_correct": 3,
|
||||
"gsm8k_total": 5,
|
||||
"tool_calling_rate": 0.8,
|
||||
"tool_calls_correct": 8,
|
||||
"tool_calls_total": 10,
|
||||
"avg_tokens_per_sec": 15.2,
|
||||
"gsm8k_results": [],
|
||||
"tool_results": []
|
||||
}]
|
||||
|
||||
output_file = os.path.join(tmpdir, "report.md")
|
||||
report = generate_report(results, output_file)
|
||||
|
||||
assert os.path.exists(output_file)
|
||||
assert "Bonsai-8B" in report
|
||||
assert "Q1_0" in report
|
||||
assert "GSM8K" in report
|
||||
assert "60.0%" in report
|
||||
print("PASS: test_generate_report")
|
||||
|
||||
|
||||
def run_all():
|
||||
test_gsm8k_answer_correct()
|
||||
test_gsm8k_answer_case_insensitive()
|
||||
test_gsm8k_answer_wrong()
|
||||
test_gsm8k_answer_partial()
|
||||
test_find_models_empty()
|
||||
test_find_models_with_files()
|
||||
test_find_models_nested()
|
||||
test_generate_report()
|
||||
print("\nAll 8 tests passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_all()
|
||||
Reference in New Issue
Block a user