forked from Rockachopa/Timmy-time-dashboard
266 lines
9.1 KiB
Python
266 lines
9.1 KiB
Python
"""Unit tests for infrastructure.self_correction."""
|
|
|
|
|
|
import pytest
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _isolated_db(tmp_path, monkeypatch):
|
|
"""Point the self-correction module at a fresh temp database per test."""
|
|
import infrastructure.self_correction as sc_mod
|
|
|
|
# Reset the cached path so each test gets a clean DB
|
|
sc_mod._DB_PATH = tmp_path / "self_correction.db"
|
|
yield
|
|
sc_mod._DB_PATH = None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# log_self_correction
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestLogSelfCorrection:
|
|
def test_returns_event_id(self):
|
|
from infrastructure.self_correction import log_self_correction
|
|
|
|
eid = log_self_correction(
|
|
source="test",
|
|
original_intent="Do X",
|
|
detected_error="ValueError: bad input",
|
|
correction_strategy="Try Y instead",
|
|
final_outcome="Y succeeded",
|
|
)
|
|
assert isinstance(eid, str)
|
|
assert len(eid) == 36 # UUID format
|
|
|
|
def test_derives_error_type_from_error_string(self):
|
|
from infrastructure.self_correction import get_corrections, log_self_correction
|
|
|
|
log_self_correction(
|
|
source="test",
|
|
original_intent="Connect",
|
|
detected_error="ConnectionRefusedError: port 80",
|
|
correction_strategy="Use port 8080",
|
|
final_outcome="ok",
|
|
)
|
|
rows = get_corrections(limit=1)
|
|
assert rows[0]["error_type"] == "ConnectionRefusedError"
|
|
|
|
def test_explicit_error_type_preserved(self):
|
|
from infrastructure.self_correction import get_corrections, log_self_correction
|
|
|
|
log_self_correction(
|
|
source="test",
|
|
original_intent="Run task",
|
|
detected_error="Some weird error",
|
|
correction_strategy="Fix it",
|
|
final_outcome="done",
|
|
error_type="CustomError",
|
|
)
|
|
rows = get_corrections(limit=1)
|
|
assert rows[0]["error_type"] == "CustomError"
|
|
|
|
def test_task_id_stored(self):
|
|
from infrastructure.self_correction import get_corrections, log_self_correction
|
|
|
|
log_self_correction(
|
|
source="test",
|
|
original_intent="intent",
|
|
detected_error="err",
|
|
correction_strategy="strat",
|
|
final_outcome="outcome",
|
|
task_id="task-abc-123",
|
|
)
|
|
rows = get_corrections(limit=1)
|
|
assert rows[0]["task_id"] == "task-abc-123"
|
|
|
|
def test_outcome_status_stored(self):
|
|
from infrastructure.self_correction import get_corrections, log_self_correction
|
|
|
|
log_self_correction(
|
|
source="test",
|
|
original_intent="i",
|
|
detected_error="e",
|
|
correction_strategy="s",
|
|
final_outcome="o",
|
|
outcome_status="failed",
|
|
)
|
|
rows = get_corrections(limit=1)
|
|
assert rows[0]["outcome_status"] == "failed"
|
|
|
|
def test_long_strings_truncated(self):
|
|
from infrastructure.self_correction import get_corrections, log_self_correction
|
|
|
|
long = "x" * 3000
|
|
log_self_correction(
|
|
source="test",
|
|
original_intent=long,
|
|
detected_error=long,
|
|
correction_strategy=long,
|
|
final_outcome=long,
|
|
)
|
|
rows = get_corrections(limit=1)
|
|
assert len(rows[0]["original_intent"]) <= 2000
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# get_corrections
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetCorrections:
|
|
def test_empty_db_returns_empty_list(self):
|
|
from infrastructure.self_correction import get_corrections
|
|
|
|
assert get_corrections() == []
|
|
|
|
def test_returns_newest_first(self):
|
|
from infrastructure.self_correction import get_corrections, log_self_correction
|
|
|
|
for i in range(3):
|
|
log_self_correction(
|
|
source="test",
|
|
original_intent=f"intent {i}",
|
|
detected_error="err",
|
|
correction_strategy="fix",
|
|
final_outcome="done",
|
|
error_type=f"Type{i}",
|
|
)
|
|
rows = get_corrections(limit=10)
|
|
assert len(rows) == 3
|
|
# Newest first — Type2 should appear before Type0
|
|
types = [r["error_type"] for r in rows]
|
|
assert types.index("Type2") < types.index("Type0")
|
|
|
|
def test_limit_respected(self):
|
|
from infrastructure.self_correction import get_corrections, log_self_correction
|
|
|
|
for _ in range(5):
|
|
log_self_correction(
|
|
source="test",
|
|
original_intent="i",
|
|
detected_error="e",
|
|
correction_strategy="s",
|
|
final_outcome="o",
|
|
)
|
|
rows = get_corrections(limit=3)
|
|
assert len(rows) == 3
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# get_patterns
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetPatterns:
|
|
def test_empty_db_returns_empty_list(self):
|
|
from infrastructure.self_correction import get_patterns
|
|
|
|
assert get_patterns() == []
|
|
|
|
def test_counts_by_error_type(self):
|
|
from infrastructure.self_correction import get_patterns, log_self_correction
|
|
|
|
for _ in range(3):
|
|
log_self_correction(
|
|
source="test",
|
|
original_intent="i",
|
|
detected_error="e",
|
|
correction_strategy="s",
|
|
final_outcome="o",
|
|
error_type="TimeoutError",
|
|
)
|
|
log_self_correction(
|
|
source="test",
|
|
original_intent="i",
|
|
detected_error="e",
|
|
correction_strategy="s",
|
|
final_outcome="o",
|
|
error_type="ValueError",
|
|
)
|
|
patterns = get_patterns(top_n=10)
|
|
by_type = {p["error_type"]: p for p in patterns}
|
|
assert by_type["TimeoutError"]["count"] == 3
|
|
assert by_type["ValueError"]["count"] == 1
|
|
|
|
def test_success_vs_failed_counts(self):
|
|
from infrastructure.self_correction import get_patterns, log_self_correction
|
|
|
|
log_self_correction(
|
|
source="test", original_intent="i", detected_error="e",
|
|
correction_strategy="s", final_outcome="o",
|
|
error_type="Foo", outcome_status="success",
|
|
)
|
|
log_self_correction(
|
|
source="test", original_intent="i", detected_error="e",
|
|
correction_strategy="s", final_outcome="o",
|
|
error_type="Foo", outcome_status="failed",
|
|
)
|
|
patterns = get_patterns(top_n=5)
|
|
foo = next(p for p in patterns if p["error_type"] == "Foo")
|
|
assert foo["success_count"] == 1
|
|
assert foo["failed_count"] == 1
|
|
|
|
def test_ordered_by_count_desc(self):
|
|
from infrastructure.self_correction import get_patterns, log_self_correction
|
|
|
|
for _ in range(2):
|
|
log_self_correction(
|
|
source="t", original_intent="i", detected_error="e",
|
|
correction_strategy="s", final_outcome="o", error_type="Rare",
|
|
)
|
|
for _ in range(5):
|
|
log_self_correction(
|
|
source="t", original_intent="i", detected_error="e",
|
|
correction_strategy="s", final_outcome="o", error_type="Common",
|
|
)
|
|
patterns = get_patterns(top_n=5)
|
|
assert patterns[0]["error_type"] == "Common"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# get_stats
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetStats:
|
|
def test_empty_db_returns_zeroes(self):
|
|
from infrastructure.self_correction import get_stats
|
|
|
|
stats = get_stats()
|
|
assert stats["total"] == 0
|
|
assert stats["success_rate"] == 0
|
|
|
|
def test_counts_outcomes(self):
|
|
from infrastructure.self_correction import get_stats, log_self_correction
|
|
|
|
log_self_correction(
|
|
source="t", original_intent="i", detected_error="e",
|
|
correction_strategy="s", final_outcome="o", outcome_status="success",
|
|
)
|
|
log_self_correction(
|
|
source="t", original_intent="i", detected_error="e",
|
|
correction_strategy="s", final_outcome="o", outcome_status="failed",
|
|
)
|
|
stats = get_stats()
|
|
assert stats["total"] == 2
|
|
assert stats["success_count"] == 1
|
|
assert stats["failed_count"] == 1
|
|
assert stats["success_rate"] == 50
|
|
|
|
def test_success_rate_100_when_all_succeed(self):
|
|
from infrastructure.self_correction import get_stats, log_self_correction
|
|
|
|
for _ in range(4):
|
|
log_self_correction(
|
|
source="t", original_intent="i", detected_error="e",
|
|
correction_strategy="s", final_outcome="o", outcome_status="success",
|
|
)
|
|
stats = get_stats()
|
|
assert stats["success_rate"] == 100
|