Co-authored-by: Claude (Opus 4.6) <claude@hermes.local> Co-committed-by: Claude (Opus 4.6) <claude@hermes.local>
This commit was merged in pull request #1269.
This commit is contained in:
269
tests/unit/test_self_correction.py
Normal file
269
tests/unit/test_self_correction.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""Unit tests for infrastructure.self_correction."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolated_db(tmp_path, monkeypatch):
|
||||
"""Point the self-correction module at a fresh temp database per test."""
|
||||
import infrastructure.self_correction as sc_mod
|
||||
|
||||
# Reset the cached path so each test gets a clean DB
|
||||
sc_mod._DB_PATH = tmp_path / "self_correction.db"
|
||||
yield
|
||||
sc_mod._DB_PATH = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# log_self_correction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLogSelfCorrection:
|
||||
def test_returns_event_id(self):
|
||||
from infrastructure.self_correction import log_self_correction
|
||||
|
||||
eid = log_self_correction(
|
||||
source="test",
|
||||
original_intent="Do X",
|
||||
detected_error="ValueError: bad input",
|
||||
correction_strategy="Try Y instead",
|
||||
final_outcome="Y succeeded",
|
||||
)
|
||||
assert isinstance(eid, str)
|
||||
assert len(eid) == 36 # UUID format
|
||||
|
||||
def test_derives_error_type_from_error_string(self):
|
||||
from infrastructure.self_correction import get_corrections, log_self_correction
|
||||
|
||||
log_self_correction(
|
||||
source="test",
|
||||
original_intent="Connect",
|
||||
detected_error="ConnectionRefusedError: port 80",
|
||||
correction_strategy="Use port 8080",
|
||||
final_outcome="ok",
|
||||
)
|
||||
rows = get_corrections(limit=1)
|
||||
assert rows[0]["error_type"] == "ConnectionRefusedError"
|
||||
|
||||
def test_explicit_error_type_preserved(self):
|
||||
from infrastructure.self_correction import get_corrections, log_self_correction
|
||||
|
||||
log_self_correction(
|
||||
source="test",
|
||||
original_intent="Run task",
|
||||
detected_error="Some weird error",
|
||||
correction_strategy="Fix it",
|
||||
final_outcome="done",
|
||||
error_type="CustomError",
|
||||
)
|
||||
rows = get_corrections(limit=1)
|
||||
assert rows[0]["error_type"] == "CustomError"
|
||||
|
||||
def test_task_id_stored(self):
|
||||
from infrastructure.self_correction import get_corrections, log_self_correction
|
||||
|
||||
log_self_correction(
|
||||
source="test",
|
||||
original_intent="intent",
|
||||
detected_error="err",
|
||||
correction_strategy="strat",
|
||||
final_outcome="outcome",
|
||||
task_id="task-abc-123",
|
||||
)
|
||||
rows = get_corrections(limit=1)
|
||||
assert rows[0]["task_id"] == "task-abc-123"
|
||||
|
||||
def test_outcome_status_stored(self):
|
||||
from infrastructure.self_correction import get_corrections, log_self_correction
|
||||
|
||||
log_self_correction(
|
||||
source="test",
|
||||
original_intent="i",
|
||||
detected_error="e",
|
||||
correction_strategy="s",
|
||||
final_outcome="o",
|
||||
outcome_status="failed",
|
||||
)
|
||||
rows = get_corrections(limit=1)
|
||||
assert rows[0]["outcome_status"] == "failed"
|
||||
|
||||
def test_long_strings_truncated(self):
|
||||
from infrastructure.self_correction import get_corrections, log_self_correction
|
||||
|
||||
long = "x" * 3000
|
||||
log_self_correction(
|
||||
source="test",
|
||||
original_intent=long,
|
||||
detected_error=long,
|
||||
correction_strategy=long,
|
||||
final_outcome=long,
|
||||
)
|
||||
rows = get_corrections(limit=1)
|
||||
assert len(rows[0]["original_intent"]) <= 2000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_corrections
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetCorrections:
|
||||
def test_empty_db_returns_empty_list(self):
|
||||
from infrastructure.self_correction import get_corrections
|
||||
|
||||
assert get_corrections() == []
|
||||
|
||||
def test_returns_newest_first(self):
|
||||
from infrastructure.self_correction import get_corrections, log_self_correction
|
||||
|
||||
for i in range(3):
|
||||
log_self_correction(
|
||||
source="test",
|
||||
original_intent=f"intent {i}",
|
||||
detected_error="err",
|
||||
correction_strategy="fix",
|
||||
final_outcome="done",
|
||||
error_type=f"Type{i}",
|
||||
)
|
||||
rows = get_corrections(limit=10)
|
||||
assert len(rows) == 3
|
||||
# Newest first — Type2 should appear before Type0
|
||||
types = [r["error_type"] for r in rows]
|
||||
assert types.index("Type2") < types.index("Type0")
|
||||
|
||||
def test_limit_respected(self):
|
||||
from infrastructure.self_correction import get_corrections, log_self_correction
|
||||
|
||||
for _ in range(5):
|
||||
log_self_correction(
|
||||
source="test",
|
||||
original_intent="i",
|
||||
detected_error="e",
|
||||
correction_strategy="s",
|
||||
final_outcome="o",
|
||||
)
|
||||
rows = get_corrections(limit=3)
|
||||
assert len(rows) == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_patterns
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetPatterns:
|
||||
def test_empty_db_returns_empty_list(self):
|
||||
from infrastructure.self_correction import get_patterns
|
||||
|
||||
assert get_patterns() == []
|
||||
|
||||
def test_counts_by_error_type(self):
|
||||
from infrastructure.self_correction import get_patterns, log_self_correction
|
||||
|
||||
for _ in range(3):
|
||||
log_self_correction(
|
||||
source="test",
|
||||
original_intent="i",
|
||||
detected_error="e",
|
||||
correction_strategy="s",
|
||||
final_outcome="o",
|
||||
error_type="TimeoutError",
|
||||
)
|
||||
log_self_correction(
|
||||
source="test",
|
||||
original_intent="i",
|
||||
detected_error="e",
|
||||
correction_strategy="s",
|
||||
final_outcome="o",
|
||||
error_type="ValueError",
|
||||
)
|
||||
patterns = get_patterns(top_n=10)
|
||||
by_type = {p["error_type"]: p for p in patterns}
|
||||
assert by_type["TimeoutError"]["count"] == 3
|
||||
assert by_type["ValueError"]["count"] == 1
|
||||
|
||||
def test_success_vs_failed_counts(self):
|
||||
from infrastructure.self_correction import get_patterns, log_self_correction
|
||||
|
||||
log_self_correction(
|
||||
source="test", original_intent="i", detected_error="e",
|
||||
correction_strategy="s", final_outcome="o",
|
||||
error_type="Foo", outcome_status="success",
|
||||
)
|
||||
log_self_correction(
|
||||
source="test", original_intent="i", detected_error="e",
|
||||
correction_strategy="s", final_outcome="o",
|
||||
error_type="Foo", outcome_status="failed",
|
||||
)
|
||||
patterns = get_patterns(top_n=5)
|
||||
foo = next(p for p in patterns if p["error_type"] == "Foo")
|
||||
assert foo["success_count"] == 1
|
||||
assert foo["failed_count"] == 1
|
||||
|
||||
def test_ordered_by_count_desc(self):
|
||||
from infrastructure.self_correction import get_patterns, log_self_correction
|
||||
|
||||
for _ in range(2):
|
||||
log_self_correction(
|
||||
source="t", original_intent="i", detected_error="e",
|
||||
correction_strategy="s", final_outcome="o", error_type="Rare",
|
||||
)
|
||||
for _ in range(5):
|
||||
log_self_correction(
|
||||
source="t", original_intent="i", detected_error="e",
|
||||
correction_strategy="s", final_outcome="o", error_type="Common",
|
||||
)
|
||||
patterns = get_patterns(top_n=5)
|
||||
assert patterns[0]["error_type"] == "Common"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_stats
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetStats:
|
||||
def test_empty_db_returns_zeroes(self):
|
||||
from infrastructure.self_correction import get_stats
|
||||
|
||||
stats = get_stats()
|
||||
assert stats["total"] == 0
|
||||
assert stats["success_rate"] == 0
|
||||
|
||||
def test_counts_outcomes(self):
|
||||
from infrastructure.self_correction import get_stats, log_self_correction
|
||||
|
||||
log_self_correction(
|
||||
source="t", original_intent="i", detected_error="e",
|
||||
correction_strategy="s", final_outcome="o", outcome_status="success",
|
||||
)
|
||||
log_self_correction(
|
||||
source="t", original_intent="i", detected_error="e",
|
||||
correction_strategy="s", final_outcome="o", outcome_status="failed",
|
||||
)
|
||||
stats = get_stats()
|
||||
assert stats["total"] == 2
|
||||
assert stats["success_count"] == 1
|
||||
assert stats["failed_count"] == 1
|
||||
assert stats["success_rate"] == 50
|
||||
|
||||
def test_success_rate_100_when_all_succeed(self):
|
||||
from infrastructure.self_correction import get_stats, log_self_correction
|
||||
|
||||
for _ in range(4):
|
||||
log_self_correction(
|
||||
source="t", original_intent="i", detected_error="e",
|
||||
correction_strategy="s", final_outcome="o", outcome_status="success",
|
||||
)
|
||||
stats = get_stats()
|
||||
assert stats["success_rate"] == 100
|
||||
Reference in New Issue
Block a user