diff --git a/tests/test_token_tracker.py b/tests/test_token_tracker.py new file mode 100644 index 00000000..dd58340f --- /dev/null +++ b/tests/test_token_tracker.py @@ -0,0 +1,159 @@ +""" +Tests for scripts/token_tracker.py — Token Budget Tracker. +""" + +import json +import os +import sqlite3 +import tempfile +import unittest +from pathlib import Path + +import sys +sys.path.insert(0, str(Path(__file__).parent.parent / "scripts")) +from token_tracker import ( + get_db, + record_usage, + get_usage_since, + get_hourly_usage, + get_worker_usage, + format_tokens, + progress_bar, + estimate_eta, + check_alerts, + load_budgets, + save_budgets, +) + + +class TestTokenTracker(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.db_path = Path(self.tmpdir) / "test.db" + self.conn = get_db(self.db_path) + + def tearDown(self): + self.conn.close() + + def test_record_usage(self): + record_usage(self.conn, "pipeline1", "worker1", 1000) + cursor = self.conn.execute("SELECT pipeline, worker, tokens FROM token_usage") + row = cursor.fetchone() + self.assertEqual(row, ("pipeline1", "worker1", 1000)) + + def test_get_usage_since(self): + record_usage(self.conn, "p1", "w1", 500) + record_usage(self.conn, "p1", "w2", 300) + record_usage(self.conn, "p2", "w1", 200) + + usage = get_usage_since(self.conn, "2020-01-01T00:00:00") + self.assertEqual(usage["p1"], 800) + self.assertEqual(usage["p2"], 200) + + def test_get_worker_usage(self): + record_usage(self.conn, "p1", "w1", 500) + record_usage(self.conn, "p1", "w2", 300) + record_usage(self.conn, "p1", "w1", 100) + + workers = get_worker_usage(self.conn, "p1", "2020-01-01T00:00:00") + self.assertEqual(workers["w1"], 600) + self.assertEqual(workers["w2"], 300) + + +class TestFormatTokens(unittest.TestCase): + def test_billions(self): + self.assertEqual(format_tokens(1_500_000_000), "1.5B") + + def test_millions(self): + self.assertEqual(format_tokens(45_200_000), "45.2M") + + def test_thousands(self): + self.assertEqual(format_tokens(1_500), "1.5K") + + def test_small(self): + self.assertEqual(format_tokens(42), "42") + + def test_zero(self): + self.assertEqual(format_tokens(0), "0") + + +class TestProgressBar(unittest.TestCase): + def test_empty(self): + self.assertEqual(progress_bar(0, 100), "░" * 10) + + def test_half(self): + bar = progress_bar(50, 100) + self.assertEqual(bar, "█████░░░░░") + + def test_full(self): + self.assertEqual(progress_bar(100, 100), "█" * 10) + + def test_overfull(self): + self.assertEqual(progress_bar(150, 100), "█" * 10) + + def test_zero_target(self): + self.assertEqual(progress_bar(0, 0), "░" * 10) + + +class TestEstimateEta(unittest.TestCase): + def test_done(self): + self.assertEqual(estimate_eta(100, 100, 1), "DONE") + + def test_hours(self): + eta = estimate_eta(50, 100, 1) + self.assertEqual(eta, "1.0h") + + def test_minutes(self): + eta = estimate_eta(90, 100, 1) + self.assertIn("m", eta) # Should be in minutes format + + def test_no_data(self): + self.assertEqual(estimate_eta(0, 100, 1), "N/A") + + +class TestCheckAlerts(unittest.TestCase): + def test_no_alerts(self): + usage = {"p1": 100} + budgets = {"p1": 1000} + alerts = check_alerts(usage, budgets) + self.assertEqual(alerts, []) + + def test_50_percent(self): + usage = {"p1": 500} + budgets = {"p1": 1000} + alerts = check_alerts(usage, budgets) + self.assertTrue(any("50" in a for a in alerts)) + + def test_80_percent(self): + usage = {"p1": 800} + budgets = {"p1": 1000} + alerts = check_alerts(usage, budgets) + self.assertTrue(any("80" in a for a in alerts)) + + def test_100_percent(self): + usage = {"p1": 1000} + budgets = {"p1": 1000} + alerts = check_alerts(usage, budgets) + self.assertTrue(any("100" in a for a in alerts)) + + def test_over_budget(self): + usage = {"p1": 1500} + budgets = {"p1": 1000} + alerts = check_alerts(usage, budgets) + self.assertTrue(len(alerts) >= 3) # 50%, 80%, 100% all triggered + self.assertTrue(any("🔴" in a for a in alerts)) + + +class TestBudgets(unittest.TestCase): + def test_save_load(self): + tmpfile = tempfile.mktemp(suffix=".json") + budgets = {"p1": 100, "p2": 200} + save_budgets(budgets) + # Reset and reload + from token_tracker import BUDGETS_FILE + loaded = load_budgets() + self.assertIn("p1", loaded) + + +if __name__ == "__main__": + unittest.main()