#!/usr/bin/env python3 """ Tests for benchmark comparison module (Issue #29). Covers: ConfigEntry, ConfigResult, aggregation, comparison table, demo mode, and config loading. """ import json import os import sys import tempfile import unittest from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "benchmarks")) from compare_configs import ( ConfigEntry, ConfigResult, DEFAULT_CONFIGS, aggregate, build_comparison_table, load_prompts, pick_winner, run_demo, ) class TestConfigEntry(unittest.TestCase): def test_default_values(self): c = ConfigEntry(name="test", backend="ollama", model="gemma4", url="http://x") self.assertEqual(c.kv_type, "f16") self.assertFalse(c.layer_adaptive) def test_to_dict(self): c = ConfigEntry(name="test", backend="llama-server", model="g", url="http://x", kv_type="turbo4", layer_adaptive=True) d = c.to_dict() self.assertEqual(d["kv_type"], "turbo4") self.assertTrue(d["layer_adaptive"]) class TestDefaultConfigs(unittest.TestCase): def test_four_configs(self): self.assertEqual(len(DEFAULT_CONFIGS), 4) def test_names(self): names = [c.name for c in DEFAULT_CONFIGS] self.assertIn("ollama-gemma4", names) self.assertIn("llama-f16", names) self.assertIn("llama-turbo4", names) self.assertIn("llama-turbo4-adaptive", names) def test_turbo4_adaptive_has_flag(self): cfg = next(c for c in DEFAULT_CONFIGS if c.name == "llama-turbo4-adaptive") self.assertTrue(cfg.layer_adaptive) self.assertEqual(cfg.kv_type, "turbo4") class TestAggregate(unittest.TestCase): def _make_results(self, n_success: int, n_fail: int) -> list[dict]: results = [] for i in range(n_success): results.append({ "status": "success", "ttft_s": 0.5 + i * 0.1, "tokens_per_sec": 20.0 + i * 0.5, "latency_s": 1.0 + i * 0.05, }) for _ in range(n_fail): results.append({"status": "failed", "latency_s": 0.5}) return results def test_basic_aggregate(self): results = self._make_results(5, 1) cfg = ConfigEntry(name="test", backend="ollama", model="m", url="http://x") agg = aggregate(results, cfg, peak_mb=100.0) self.assertEqual(agg.success, 5) self.assertEqual(agg.failed, 1) self.assertEqual(agg.total_prompts, 6) self.assertAlmostEqual(agg.peak_memory_mb, 100.0) self.assertGreater(agg.avg_tok_per_sec, 0) def test_no_success(self): results = [{"status": "failed", "latency_s": 0.1}] cfg = ConfigEntry(name="test", backend="ollama", model="m", url="http://x") agg = aggregate(results, cfg, peak_mb=0.0) self.assertEqual(agg.avg_tok_per_sec, 0.0) self.assertIsNone(agg.avg_ttft_s) class TestPickWinner(unittest.TestCase): def test_highest_tps_wins(self): configs = [ ConfigResult(config_name="slow", backend="o", model="m", kv_type="f", total_prompts=5, success=5, failed=0, avg_ttft_s=1.0, avg_tok_per_sec=10.0, avg_latency_s=2.0, peak_memory_mb=100), ConfigResult(config_name="fast", backend="o", model="m", kv_type="f", total_prompts=5, success=5, failed=0, avg_ttft_s=0.5, avg_tok_per_sec=25.0, avg_latency_s=1.5, peak_memory_mb=100), ] w = pick_winner(configs) self.assertEqual(w.config_name, "fast") self.assertTrue(w.winner) def test_no_success_returns_first(self): configs = [ ConfigResult(config_name="dead", backend="o", model="m", kv_type="f", total_prompts=5, success=0, failed=5, avg_ttft_s=None, avg_tok_per_sec=0.0, avg_latency_s=0.0, peak_memory_mb=0), ] w = pick_winner(configs) self.assertEqual(w.config_name, "dead") class TestComparisonTable(unittest.TestCase): def test_table_has_headers(self): configs = [ ConfigResult(config_name="test-cfg", backend="o", model="m", kv_type="f", total_prompts=5, success=5, failed=0, avg_ttft_s=0.5, avg_tok_per_sec=20.0, avg_latency_s=1.5, peak_memory_mb=100), ] w = pick_winner(configs) table = build_comparison_table(configs) self.assertIn("Config", table) self.assertIn("tok/s", table) self.assertIn("WINNER", table) class TestDemoMode(unittest.TestCase): def test_demo_produces_report(self): with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f: out_path = Path(f.name) try: report = run_demo(str(out_path)) self.assertEqual(report["mode"], "demo") self.assertEqual(report["prompts_count"], 10) self.assertEqual(len(report["configs"]), 4) self.assertTrue(out_path.exists()) saved = json.loads(out_path.read_text()) self.assertIn("winner", saved) finally: out_path.unlink(missing_ok=True) def test_demo_without_output(self): report = run_demo() self.assertIn("winner", report) self.assertGreater(report["winner_tok_per_sec"], 0) class TestLoadPrompts(unittest.TestCase): def test_load_test_prompts(self): prompts_file = Path(__file__).resolve().parent.parent / "benchmarks" / "test_prompts.json" if prompts_file.exists(): prompts = load_prompts(str(prompts_file)) self.assertGreater(len(prompts), 0) for p in prompts: self.assertIn("prompt", p) if __name__ == "__main__": unittest.main()