"""Unit tests for infrastructure.self_correction.""" import pytest # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest.fixture(autouse=True) def _isolated_db(tmp_path, monkeypatch): """Point the self-correction module at a fresh temp database per test.""" import infrastructure.self_correction as sc_mod # Reset the cached path so each test gets a clean DB sc_mod._DB_PATH = tmp_path / "self_correction.db" yield sc_mod._DB_PATH = None # --------------------------------------------------------------------------- # log_self_correction # --------------------------------------------------------------------------- class TestLogSelfCorrection: def test_returns_event_id(self): from infrastructure.self_correction import log_self_correction eid = log_self_correction( source="test", original_intent="Do X", detected_error="ValueError: bad input", correction_strategy="Try Y instead", final_outcome="Y succeeded", ) assert isinstance(eid, str) assert len(eid) == 36 # UUID format def test_derives_error_type_from_error_string(self): from infrastructure.self_correction import get_corrections, log_self_correction log_self_correction( source="test", original_intent="Connect", detected_error="ConnectionRefusedError: port 80", correction_strategy="Use port 8080", final_outcome="ok", ) rows = get_corrections(limit=1) assert rows[0]["error_type"] == "ConnectionRefusedError" def test_explicit_error_type_preserved(self): from infrastructure.self_correction import get_corrections, log_self_correction log_self_correction( source="test", original_intent="Run task", detected_error="Some weird error", correction_strategy="Fix it", final_outcome="done", error_type="CustomError", ) rows = get_corrections(limit=1) assert rows[0]["error_type"] == "CustomError" def test_task_id_stored(self): from infrastructure.self_correction import get_corrections, log_self_correction log_self_correction( source="test", original_intent="intent", detected_error="err", correction_strategy="strat", final_outcome="outcome", task_id="task-abc-123", ) rows = get_corrections(limit=1) assert rows[0]["task_id"] == "task-abc-123" def test_outcome_status_stored(self): from infrastructure.self_correction import get_corrections, log_self_correction log_self_correction( source="test", original_intent="i", detected_error="e", correction_strategy="s", final_outcome="o", outcome_status="failed", ) rows = get_corrections(limit=1) assert rows[0]["outcome_status"] == "failed" def test_long_strings_truncated(self): from infrastructure.self_correction import get_corrections, log_self_correction long = "x" * 3000 log_self_correction( source="test", original_intent=long, detected_error=long, correction_strategy=long, final_outcome=long, ) rows = get_corrections(limit=1) assert len(rows[0]["original_intent"]) <= 2000 # --------------------------------------------------------------------------- # get_corrections # --------------------------------------------------------------------------- class TestGetCorrections: def test_empty_db_returns_empty_list(self): from infrastructure.self_correction import get_corrections assert get_corrections() == [] def test_returns_newest_first(self): from infrastructure.self_correction import get_corrections, log_self_correction for i in range(3): log_self_correction( source="test", original_intent=f"intent {i}", detected_error="err", correction_strategy="fix", final_outcome="done", error_type=f"Type{i}", ) rows = get_corrections(limit=10) assert len(rows) == 3 # Newest first — Type2 should appear before Type0 types = [r["error_type"] for r in rows] assert types.index("Type2") < types.index("Type0") def test_limit_respected(self): from infrastructure.self_correction import get_corrections, log_self_correction for _ in range(5): log_self_correction( source="test", original_intent="i", detected_error="e", correction_strategy="s", final_outcome="o", ) rows = get_corrections(limit=3) assert len(rows) == 3 # --------------------------------------------------------------------------- # get_patterns # --------------------------------------------------------------------------- class TestGetPatterns: def test_empty_db_returns_empty_list(self): from infrastructure.self_correction import get_patterns assert get_patterns() == [] def test_counts_by_error_type(self): from infrastructure.self_correction import get_patterns, log_self_correction for _ in range(3): log_self_correction( source="test", original_intent="i", detected_error="e", correction_strategy="s", final_outcome="o", error_type="TimeoutError", ) log_self_correction( source="test", original_intent="i", detected_error="e", correction_strategy="s", final_outcome="o", error_type="ValueError", ) patterns = get_patterns(top_n=10) by_type = {p["error_type"]: p for p in patterns} assert by_type["TimeoutError"]["count"] == 3 assert by_type["ValueError"]["count"] == 1 def test_success_vs_failed_counts(self): from infrastructure.self_correction import get_patterns, log_self_correction log_self_correction( source="test", original_intent="i", detected_error="e", correction_strategy="s", final_outcome="o", error_type="Foo", outcome_status="success", ) log_self_correction( source="test", original_intent="i", detected_error="e", correction_strategy="s", final_outcome="o", error_type="Foo", outcome_status="failed", ) patterns = get_patterns(top_n=5) foo = next(p for p in patterns if p["error_type"] == "Foo") assert foo["success_count"] == 1 assert foo["failed_count"] == 1 def test_ordered_by_count_desc(self): from infrastructure.self_correction import get_patterns, log_self_correction for _ in range(2): log_self_correction( source="t", original_intent="i", detected_error="e", correction_strategy="s", final_outcome="o", error_type="Rare", ) for _ in range(5): log_self_correction( source="t", original_intent="i", detected_error="e", correction_strategy="s", final_outcome="o", error_type="Common", ) patterns = get_patterns(top_n=5) assert patterns[0]["error_type"] == "Common" # --------------------------------------------------------------------------- # get_stats # --------------------------------------------------------------------------- class TestGetStats: def test_empty_db_returns_zeroes(self): from infrastructure.self_correction import get_stats stats = get_stats() assert stats["total"] == 0 assert stats["success_rate"] == 0 def test_counts_outcomes(self): from infrastructure.self_correction import get_stats, log_self_correction log_self_correction( source="t", original_intent="i", detected_error="e", correction_strategy="s", final_outcome="o", outcome_status="success", ) log_self_correction( source="t", original_intent="i", detected_error="e", correction_strategy="s", final_outcome="o", outcome_status="failed", ) stats = get_stats() assert stats["total"] == 2 assert stats["success_count"] == 1 assert stats["failed_count"] == 1 assert stats["success_rate"] == 50 def test_success_rate_100_when_all_succeed(self): from infrastructure.self_correction import get_stats, log_self_correction for _ in range(4): log_self_correction( source="t", original_intent="i", detected_error="e", correction_strategy="s", final_outcome="o", outcome_status="success", ) stats = get_stats() assert stats["success_rate"] == 100