diff --git a/tests/test_config_drift.py b/tests/test_config_drift.py new file mode 100644 index 00000000..265dd78b --- /dev/null +++ b/tests/test_config_drift.py @@ -0,0 +1,149 @@ +""" +Tests for scripts/config_drift.py — Config drift detection. +""" + +import json +import tempfile +import unittest +from pathlib import Path + +import sys +sys.path.insert(0, str(Path(__file__).parent.parent / "scripts")) +from config_drift import ( + get_nested_value, + compare_configs, + parse_yaml_basic, + generate_report, +) + + +class TestGetNestedValue(unittest.TestCase): + def test_top_level(self): + config = {"provider": "openrouter"} + self.assertEqual(get_nested_value(config, "provider"), "openrouter") + + def test_nested(self): + config = {"cron": {"enabled": True, "workers": 4}} + self.assertEqual(get_nested_value(config, "cron.enabled"), True) + self.assertEqual(get_nested_value(config, "cron.workers"), 4) + + def test_missing_key(self): + config = {"provider": "openrouter"} + self.assertIsNone(get_nested_value(config, "missing")) + + def test_missing_nested(self): + config = {"cron": {}} + self.assertIsNone(get_nested_value(config, "cron.enabled")) + + def test_deep_nesting(self): + config = {"a": {"b": {"c": "value"}}} + self.assertEqual(get_nested_value(config, "a.b.c"), "value") + + +class TestCompareConfigs(unittest.TestCase): + def test_no_diff(self): + canonical = {"provider": "openrouter", "model": "mimo"} + remote = {"provider": "openrouter", "model": "mimo"} + diffs = compare_configs(canonical, remote, ["provider", "model"]) + self.assertEqual(diffs, []) + + def test_single_diff(self): + canonical = {"provider": "openrouter"} + remote = {"provider": "anthropic"} + diffs = compare_configs(canonical, remote, ["provider"]) + self.assertEqual(len(diffs), 1) + self.assertEqual(diffs[0][0], "provider") + self.assertEqual(diffs[0][1], "openrouter") + self.assertEqual(diffs[0][2], "anthropic") + + def test_multiple_diffs(self): + canonical = {"provider": "openrouter", "model": "mimo"} + remote = {"provider": "anthropic", "model": "claude"} + diffs = compare_configs(canonical, remote, ["provider", "model"]) + self.assertEqual(len(diffs), 2) + + def test_nested_diff(self): + canonical = {"cron": {"enabled": True}} + remote = {"cron": {"enabled": False}} + diffs = compare_configs(canonical, remote, ["cron.enabled"]) + self.assertEqual(len(diffs), 1) + self.assertEqual(diffs[0][0], "cron.enabled") + + def test_missing_in_remote(self): + canonical = {"provider": "openrouter"} + remote = {} + diffs = compare_configs(canonical, remote, ["provider"]) + self.assertEqual(len(diffs), 1) + + def test_extra_in_remote(self): + canonical = {} + remote = {"provider": "openrouter"} + diffs = compare_configs(canonical, remote, ["provider"]) + self.assertEqual(len(diffs), 1) + + +class TestParseYamlBasic(unittest.TestCase): + def test_simple(self): + content = "provider: openrouter\nmodel: mimo-v2-pro\n" + result = parse_yaml_basic(content) + self.assertEqual(result["provider"], "openrouter") + self.assertEqual(result["model"], "mimo-v2-pro") + + def test_boolean(self): + content = "enabled: true\ndisabled: false\n" + result = parse_yaml_basic(content) + self.assertEqual(result["enabled"], True) + self.assertEqual(result["disabled"], False) + + def test_integer(self): + content = "workers: 4\nport: 8080\n" + result = parse_yaml_basic(content) + self.assertEqual(result["workers"], 4) + self.assertEqual(result["port"], 8080) + + def test_comments_skipped(self): + content = "# This is a comment\nprovider: openrouter\n" + result = parse_yaml_basic(content) + self.assertNotIn("#", result) + self.assertEqual(result["provider"], "openrouter") + + def test_quoted_values(self): + content = 'name: "hello world"\nother: \'single quotes\'\n' + result = parse_yaml_basic(content) + self.assertEqual(result["name"], "hello world") + self.assertEqual(result["other"], "single quotes") + + +class TestGenerateReport(unittest.TestCase): + def test_all_ok(self): + results = { + "node1": {"status": "ok", "diffs": []}, + "node2": {"status": "ok", "diffs": []}, + } + report = generate_report(results, ["provider"]) + self.assertIn("OK", report) + self.assertIn("2 ok", report) + + def test_drift_reported(self): + results = { + "node1": { + "status": "drift", + "diffs": [("provider", "openrouter", "anthropic")] + }, + "node2": {"status": "ok", "diffs": []}, + } + report = generate_report(results, ["provider"]) + self.assertIn("DRIFT DETECTED", report) + self.assertIn("openrouter", report) + self.assertIn("anthropic", report) + + def test_unreachable_reported(self): + results = { + "node1": {"status": "unreachable", "diffs": []}, + } + report = generate_report(results, ["provider"]) + self.assertIn("UNREACHABLE", report) + + +if __name__ == "__main__": + unittest.main()