139 lines
4.0 KiB
Python
139 lines
4.0 KiB
Python
"""Tests for crisis.ab_testing — A/B test framework for crisis detection (#101)."""
|
|
|
|
import os
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from crisis.ab_testing import ABTestCrisisDetector
|
|
from crisis.detect import CrisisDetectionResult, detect_crisis
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def clear_variant_override():
|
|
old = os.environ.pop("CRISIS_AB_VARIANT", None)
|
|
try:
|
|
yield
|
|
finally:
|
|
if old is not None:
|
|
os.environ["CRISIS_AB_VARIANT"] = old
|
|
else:
|
|
os.environ.pop("CRISIS_AB_VARIANT", None)
|
|
|
|
|
|
def _make_variant(level: str, indicators=None):
|
|
indicators = indicators or [f"mock_{level.lower()}"]
|
|
|
|
def fn(text: str) -> CrisisDetectionResult:
|
|
return CrisisDetectionResult(level=level, indicators=list(indicators))
|
|
|
|
return fn
|
|
|
|
|
|
def test_detect_returns_result_variant_and_logged_record():
|
|
detector = ABTestCrisisDetector(
|
|
variant_a=_make_variant("LOW"),
|
|
variant_b=_make_variant("HIGH"),
|
|
)
|
|
|
|
with patch.object(detector, "_select_variant", return_value="A"):
|
|
result, variant, record_id = detector.detect("test message")
|
|
|
|
assert isinstance(result, CrisisDetectionResult)
|
|
assert variant == "A"
|
|
assert record_id == 0
|
|
assert len(detector.records) == 1
|
|
assert detector.records[0].variant == "A"
|
|
assert detector.records[0].level == "LOW"
|
|
|
|
|
|
def test_env_override_forces_variant_b():
|
|
os.environ["CRISIS_AB_VARIANT"] = "b"
|
|
detector = ABTestCrisisDetector(
|
|
variant_a=_make_variant("LOW"),
|
|
variant_b=_make_variant("HIGH"),
|
|
)
|
|
|
|
result, variant, _ = detector.detect("test")
|
|
|
|
assert variant == "B"
|
|
assert result.level == "HIGH"
|
|
|
|
|
|
def test_get_stats_reports_latency_counts_and_level_breakdown():
|
|
detector = ABTestCrisisDetector(
|
|
variant_a=_make_variant("LOW"),
|
|
variant_b=_make_variant("CRITICAL"),
|
|
)
|
|
|
|
with patch.object(detector, "_select_variant", side_effect=["A", "A", "B"]):
|
|
detector.detect("first")
|
|
detector.detect("second")
|
|
detector.detect("third")
|
|
|
|
stats = detector.get_stats()
|
|
assert stats["A"]["count"] == 2
|
|
assert stats["B"]["count"] == 1
|
|
assert stats["A"]["levels"]["LOW"] == 2
|
|
assert stats["B"]["levels"]["CRITICAL"] == 1
|
|
assert "avg_latency_ms" in stats["A"]
|
|
assert "avg_indicator_count" in stats["B"]
|
|
|
|
|
|
def test_false_positive_rate_is_computed_from_reviewed_outcomes():
|
|
detector = ABTestCrisisDetector(
|
|
variant_a=_make_variant("LOW"),
|
|
variant_b=_make_variant("HIGH"),
|
|
)
|
|
|
|
with patch.object(detector, "_select_variant", side_effect=["A", "A", "B"]):
|
|
_, _, a0 = detector.detect("first")
|
|
_, _, a1 = detector.detect("second")
|
|
_, _, b0 = detector.detect("third")
|
|
|
|
detector.record_outcome(a0, false_positive=True)
|
|
detector.record_outcome(a1, false_positive=False)
|
|
detector.record_outcome(b0, false_positive=False)
|
|
|
|
stats = detector.get_stats()
|
|
assert stats["A"]["reviewed_count"] == 2
|
|
assert stats["A"]["false_positive_rate"] == 0.5
|
|
assert stats["B"]["false_positive_rate"] == 0.0
|
|
|
|
|
|
def test_record_outcome_rejects_unknown_record():
|
|
detector = ABTestCrisisDetector(
|
|
variant_a=_make_variant("LOW"),
|
|
variant_b=_make_variant("HIGH"),
|
|
)
|
|
|
|
with pytest.raises(IndexError):
|
|
detector.record_outcome(99, false_positive=True)
|
|
|
|
|
|
def test_reset_clears_records_and_stats():
|
|
detector = ABTestCrisisDetector(
|
|
variant_a=_make_variant("LOW"),
|
|
variant_b=_make_variant("HIGH"),
|
|
)
|
|
detector.detect("test")
|
|
detector.reset()
|
|
|
|
assert detector.records == []
|
|
stats = detector.get_stats()
|
|
assert stats["A"]["count"] == 0
|
|
assert stats["B"]["count"] == 0
|
|
|
|
|
|
def test_with_real_detector_integration():
|
|
detector = ABTestCrisisDetector(
|
|
variant_a=detect_crisis,
|
|
variant_b=detect_crisis,
|
|
)
|
|
|
|
result, variant, record_id = detector.detect("I want to kill myself")
|
|
|
|
assert result.level == "CRITICAL"
|
|
assert variant in ("A", "B")
|
|
assert record_id == 0
|