135 lines
4.1 KiB
Python
135 lines
4.1 KiB
Python
#!/usr/bin/env python3
|
|
"""Tests for benchmarks/bonsai_benchmark.py — 8 tests."""
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
|
|
sys.path.insert(0, os.path.dirname(__file__) or ".")
|
|
import importlib.util
|
|
spec = importlib.util.spec_from_file_location(
|
|
"bb", os.path.join(os.path.dirname(__file__) or ".", "bonsai_benchmark.py"))
|
|
mod = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(mod)
|
|
|
|
check_gsm8k_answer = mod.check_gsm8k_answer
|
|
find_models = mod.find_models
|
|
generate_report = mod.generate_report
|
|
|
|
|
|
def test_gsm8k_answer_correct():
|
|
"""Correct answer should be detected."""
|
|
assert check_gsm8k_answer("The answer is 18.", ["18", "$18", "eighteen"])
|
|
print("PASS: test_gsm8k_answer_correct")
|
|
|
|
|
|
def test_gsm8k_answer_case_insensitive():
|
|
"""Answer check should be case insensitive."""
|
|
assert check_gsm8k_answer("The answer is EIGHTEEN.", ["18", "eighteen"])
|
|
print("PASS: test_gsm8k_answer_case_insensitive")
|
|
|
|
|
|
def test_gsm8k_answer_wrong():
|
|
"""Wrong answer should return False."""
|
|
assert not check_gsm8k_answer("The answer is 42.", ["18", "$18", "eighteen"])
|
|
print("PASS: test_gsm8k_answer_wrong")
|
|
|
|
|
|
def test_gsm8k_answer_partial():
|
|
"""Partial match should work."""
|
|
assert check_gsm8k_answer("She makes $18 per day.", ["18", "$18"])
|
|
print("PASS: test_gsm8k_answer_partial")
|
|
|
|
|
|
def test_find_models_empty():
|
|
"""Empty directory should return empty lists."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
models = find_models(tmpdir)
|
|
assert models["Q1_0"] == []
|
|
assert models["Q4_0"] == []
|
|
print("PASS: test_find_models_empty")
|
|
|
|
|
|
def test_find_models_with_files():
|
|
"""Should find models by quantization type."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
# Create test files
|
|
q1_file = os.path.join(tmpdir, "Bonsai-8B-Q1_0.gguf")
|
|
q4_file = os.path.join(tmpdir, "Bonsai-8B-Q4_0.gguf")
|
|
other_file = os.path.join(tmpdir, "other.txt")
|
|
|
|
for f in [q1_file, q4_file, other_file]:
|
|
with open(f, "w") as fh:
|
|
fh.write("")
|
|
|
|
models = find_models(tmpdir)
|
|
assert len(models["Q1_0"]) == 1
|
|
assert len(models["Q4_0"]) == 1
|
|
assert q1_file in models["Q1_0"]
|
|
assert q4_file in models["Q4_0"]
|
|
print("PASS: test_find_models_with_files")
|
|
|
|
|
|
def test_find_models_nested():
|
|
"""Should find models in subdirectories."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
subdir = os.path.join(tmpdir, "models")
|
|
os.makedirs(subdir)
|
|
|
|
q1_file = os.path.join(subdir, "Bonsai-1.7B-Q1_0.gguf")
|
|
with open(q1_file, "w") as f:
|
|
f.write("")
|
|
|
|
models = find_models(tmpdir)
|
|
assert len(models["Q1_0"]) == 1
|
|
assert q1_file in models["Q1_0"]
|
|
print("PASS: test_find_models_nested")
|
|
|
|
|
|
def test_generate_report():
|
|
"""Report generation should produce markdown."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
results = [{
|
|
"model_name": "Bonsai-8B",
|
|
"quant_type": "Q1_0",
|
|
"model_path": "/test/Bonsai-8B-Q1_0.gguf",
|
|
"file_size_mb": 1024.5,
|
|
"memory_used_mb": 2048.0,
|
|
"gsm8k_score": 0.6,
|
|
"gsm8k_correct": 3,
|
|
"gsm8k_total": 5,
|
|
"tool_calling_rate": 0.8,
|
|
"tool_calls_correct": 8,
|
|
"tool_calls_total": 10,
|
|
"avg_tokens_per_sec": 15.2,
|
|
"gsm8k_results": [],
|
|
"tool_results": []
|
|
}]
|
|
|
|
output_file = os.path.join(tmpdir, "report.md")
|
|
report = generate_report(results, output_file)
|
|
|
|
assert os.path.exists(output_file)
|
|
assert "Bonsai-8B" in report
|
|
assert "Q1_0" in report
|
|
assert "GSM8K" in report
|
|
assert "60.0%" in report
|
|
print("PASS: test_generate_report")
|
|
|
|
|
|
def run_all():
|
|
test_gsm8k_answer_correct()
|
|
test_gsm8k_answer_case_insensitive()
|
|
test_gsm8k_answer_wrong()
|
|
test_gsm8k_answer_partial()
|
|
test_find_models_empty()
|
|
test_find_models_with_files()
|
|
test_find_models_nested()
|
|
test_generate_report()
|
|
print("\nAll 8 tests passed!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_all()
|