From dd06e4c5e01a97a07b3ad0d0443506ba8e93d95b Mon Sep 17 00:00:00 2001 From: Alexander Whitestone Date: Thu, 16 Apr 2026 02:17:55 +0000 Subject: [PATCH] bench: Add test_bonsai_benchmark.py (#100) --- benchmarks/test_bonsai_benchmark.py | 134 ++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 benchmarks/test_bonsai_benchmark.py diff --git a/benchmarks/test_bonsai_benchmark.py b/benchmarks/test_bonsai_benchmark.py new file mode 100644 index 00000000..6dcbb00d --- /dev/null +++ b/benchmarks/test_bonsai_benchmark.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +"""Tests for benchmarks/bonsai_benchmark.py — 8 tests.""" + +import json +import os +import sys +import tempfile + +sys.path.insert(0, os.path.dirname(__file__) or ".") +import importlib.util +spec = importlib.util.spec_from_file_location( + "bb", os.path.join(os.path.dirname(__file__) or ".", "bonsai_benchmark.py")) +mod = importlib.util.module_from_spec(spec) +spec.loader.exec_module(mod) + +check_gsm8k_answer = mod.check_gsm8k_answer +find_models = mod.find_models +generate_report = mod.generate_report + + +def test_gsm8k_answer_correct(): + """Correct answer should be detected.""" + assert check_gsm8k_answer("The answer is 18.", ["18", "$18", "eighteen"]) + print("PASS: test_gsm8k_answer_correct") + + +def test_gsm8k_answer_case_insensitive(): + """Answer check should be case insensitive.""" + assert check_gsm8k_answer("The answer is EIGHTEEN.", ["18", "eighteen"]) + print("PASS: test_gsm8k_answer_case_insensitive") + + +def test_gsm8k_answer_wrong(): + """Wrong answer should return False.""" + assert not check_gsm8k_answer("The answer is 42.", ["18", "$18", "eighteen"]) + print("PASS: test_gsm8k_answer_wrong") + + +def test_gsm8k_answer_partial(): + """Partial match should work.""" + assert check_gsm8k_answer("She makes $18 per day.", ["18", "$18"]) + print("PASS: test_gsm8k_answer_partial") + + +def test_find_models_empty(): + """Empty directory should return empty lists.""" + with tempfile.TemporaryDirectory() as tmpdir: + models = find_models(tmpdir) + assert models["Q1_0"] == [] + assert models["Q4_0"] == [] + print("PASS: test_find_models_empty") + + +def test_find_models_with_files(): + """Should find models by quantization type.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create test files + q1_file = os.path.join(tmpdir, "Bonsai-8B-Q1_0.gguf") + q4_file = os.path.join(tmpdir, "Bonsai-8B-Q4_0.gguf") + other_file = os.path.join(tmpdir, "other.txt") + + for f in [q1_file, q4_file, other_file]: + with open(f, "w") as fh: + fh.write("") + + models = find_models(tmpdir) + assert len(models["Q1_0"]) == 1 + assert len(models["Q4_0"]) == 1 + assert q1_file in models["Q1_0"] + assert q4_file in models["Q4_0"] + print("PASS: test_find_models_with_files") + + +def test_find_models_nested(): + """Should find models in subdirectories.""" + with tempfile.TemporaryDirectory() as tmpdir: + subdir = os.path.join(tmpdir, "models") + os.makedirs(subdir) + + q1_file = os.path.join(subdir, "Bonsai-1.7B-Q1_0.gguf") + with open(q1_file, "w") as f: + f.write("") + + models = find_models(tmpdir) + assert len(models["Q1_0"]) == 1 + assert q1_file in models["Q1_0"] + print("PASS: test_find_models_nested") + + +def test_generate_report(): + """Report generation should produce markdown.""" + with tempfile.TemporaryDirectory() as tmpdir: + results = [{ + "model_name": "Bonsai-8B", + "quant_type": "Q1_0", + "model_path": "/test/Bonsai-8B-Q1_0.gguf", + "file_size_mb": 1024.5, + "memory_used_mb": 2048.0, + "gsm8k_score": 0.6, + "gsm8k_correct": 3, + "gsm8k_total": 5, + "tool_calling_rate": 0.8, + "tool_calls_correct": 8, + "tool_calls_total": 10, + "avg_tokens_per_sec": 15.2, + "gsm8k_results": [], + "tool_results": [] + }] + + output_file = os.path.join(tmpdir, "report.md") + report = generate_report(results, output_file) + + assert os.path.exists(output_file) + assert "Bonsai-8B" in report + assert "Q1_0" in report + assert "GSM8K" in report + assert "60.0%" in report + print("PASS: test_generate_report") + + +def run_all(): + test_gsm8k_answer_correct() + test_gsm8k_answer_case_insensitive() + test_gsm8k_answer_wrong() + test_gsm8k_answer_partial() + test_find_models_empty() + test_find_models_with_files() + test_find_models_nested() + test_generate_report() + print("\nAll 8 tests passed!") + + +if __name__ == "__main__": + run_all()