Compare commits
9 Commits
door/issue
...
fix/101
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7cef18fdcb | ||
|
|
706024e11e | ||
| d412939b4f | |||
| 07c582aa08 | |||
| 5f95dc1e39 | |||
| b1f3cac36d | |||
| 07b3f67845 | |||
| c22bbbaf65 | |||
| 543cb1d40f |
@@ -8,6 +8,7 @@ from .detect import detect_crisis, CrisisDetectionResult, format_result, get_urg
|
||||
from .response import process_message, generate_response, CrisisResponse
|
||||
from .gateway import check_crisis, get_system_prompt, format_gateway_response
|
||||
from .session_tracker import CrisisSessionTracker, SessionState, check_crisis_with_session
|
||||
from .ab_testing import ABTestCrisisDetector, VariantRecord
|
||||
|
||||
__all__ = [
|
||||
"detect_crisis",
|
||||
@@ -23,4 +24,6 @@ __all__ = [
|
||||
"CrisisSessionTracker",
|
||||
"SessionState",
|
||||
"check_crisis_with_session",
|
||||
"ABTestCrisisDetector",
|
||||
"VariantRecord",
|
||||
]
|
||||
|
||||
112
crisis/ab_testing.py
Normal file
112
crisis/ab_testing.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""A/B test framework for crisis detection in the-door."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from .detect import CrisisDetectionResult
|
||||
|
||||
|
||||
def _get_variant_override() -> Optional[str]:
|
||||
"""Return env override for deterministic testing/debugging."""
|
||||
value = os.environ.get("CRISIS_AB_VARIANT", "").strip().upper()
|
||||
if value in {"A", "B"}:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VariantRecord:
|
||||
"""Single crisis detection event record with no user text or PII."""
|
||||
|
||||
variant: str
|
||||
level: str
|
||||
latency_ms: float
|
||||
indicator_count: int
|
||||
false_positive: Optional[bool] = None
|
||||
|
||||
|
||||
class ABTestCrisisDetector:
|
||||
"""Route crisis detection between two variants and collect comparison stats."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
variant_a: Callable[[str], CrisisDetectionResult],
|
||||
variant_b: Callable[[str], CrisisDetectionResult],
|
||||
split: float = 0.5,
|
||||
):
|
||||
self.variant_a = variant_a
|
||||
self.variant_b = variant_b
|
||||
self.split = max(0.0, min(1.0, float(split)))
|
||||
self.records: List[VariantRecord] = []
|
||||
|
||||
def _select_variant(self) -> str:
|
||||
override = _get_variant_override()
|
||||
if override:
|
||||
return override
|
||||
return "A" if random.random() < self.split else "B"
|
||||
|
||||
def detect(self, text: str) -> Tuple[CrisisDetectionResult, str, int]:
|
||||
variant = self._select_variant()
|
||||
detector = self.variant_a if variant == "A" else self.variant_b
|
||||
|
||||
start = time.perf_counter()
|
||||
result = detector(text)
|
||||
latency_ms = (time.perf_counter() - start) * 1000.0
|
||||
|
||||
record = VariantRecord(
|
||||
variant=variant,
|
||||
level=result.level,
|
||||
latency_ms=latency_ms,
|
||||
indicator_count=len(result.indicators),
|
||||
)
|
||||
self.records.append(record)
|
||||
return result, variant, len(self.records) - 1
|
||||
|
||||
def record_outcome(self, record_id: int, *, false_positive: bool) -> None:
|
||||
if record_id < 0 or record_id >= len(self.records):
|
||||
raise IndexError(f"Unknown record id: {record_id}")
|
||||
self.records[record_id].false_positive = bool(false_positive)
|
||||
|
||||
def get_stats(self) -> Dict[str, dict]:
|
||||
stats: Dict[str, dict] = {}
|
||||
for variant in ("A", "B"):
|
||||
records = [record for record in self.records if record.variant == variant]
|
||||
if not records:
|
||||
stats[variant] = {
|
||||
"count": 0,
|
||||
"reviewed_count": 0,
|
||||
"false_positive_rate": None,
|
||||
}
|
||||
continue
|
||||
|
||||
levels: Dict[str, int] = {}
|
||||
for record in records:
|
||||
levels[record.level] = levels.get(record.level, 0) + 1
|
||||
|
||||
reviewed = [record for record in records if record.false_positive is not None]
|
||||
false_positive_rate = None
|
||||
if reviewed:
|
||||
false_positive_rate = round(
|
||||
sum(1 for record in reviewed if record.false_positive) / len(reviewed),
|
||||
4,
|
||||
)
|
||||
|
||||
stats[variant] = {
|
||||
"count": len(records),
|
||||
"avg_latency_ms": round(sum(record.latency_ms for record in records) / len(records), 4),
|
||||
"max_latency_ms": round(max(record.latency_ms for record in records), 4),
|
||||
"min_latency_ms": round(min(record.latency_ms for record in records), 4),
|
||||
"avg_indicator_count": round(sum(record.indicator_count for record in records) / len(records), 4),
|
||||
"levels": levels,
|
||||
"reviewed_count": len(reviewed),
|
||||
"false_positive_rate": false_positive_rate,
|
||||
}
|
||||
return stats
|
||||
|
||||
def reset(self) -> None:
|
||||
self.records.clear()
|
||||
@@ -680,7 +680,7 @@ html, body {
|
||||
|
||||
<!-- Footer -->
|
||||
<footer id="footer">
|
||||
<a href="/about" aria-label="About The Door">about</a>
|
||||
<a href="/about.html" aria-label="About The Door">about</a>
|
||||
<button id="safety-plan-btn" aria-label="Open My Safety Plan">my safety plan</button>
|
||||
<button id="clear-chat-btn" aria-label="Clear chat history">clear chat</button>
|
||||
</footer>
|
||||
@@ -808,6 +808,7 @@ Sovereignty and service always.`;
|
||||
var crisisPanel = document.getElementById('crisis-panel');
|
||||
var crisisOverlay = document.getElementById('crisis-overlay');
|
||||
var overlayDismissBtn = document.getElementById('overlay-dismiss-btn');
|
||||
var overlayCallLink = document.querySelector('.overlay-call');
|
||||
var statusDot = document.querySelector('.status-dot');
|
||||
var statusText = document.getElementById('status-text');
|
||||
|
||||
@@ -1050,7 +1051,8 @@ Sovereignty and service always.`;
|
||||
}
|
||||
}, 1000);
|
||||
|
||||
overlayDismissBtn.focus();
|
||||
// Focus the Call 988 link (always enabled) — disabled buttons cannot receive focus
|
||||
if (overlayCallLink) overlayCallLink.focus();
|
||||
}
|
||||
|
||||
// Register focus trap on document (always listening, gated by class check)
|
||||
|
||||
138
tests/test_ab_testing.py
Normal file
138
tests/test_ab_testing.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Tests for crisis.ab_testing — A/B test framework for crisis detection (#101)."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crisis.ab_testing import ABTestCrisisDetector
|
||||
from crisis.detect import CrisisDetectionResult, detect_crisis
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_variant_override():
|
||||
old = os.environ.pop("CRISIS_AB_VARIANT", None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if old is not None:
|
||||
os.environ["CRISIS_AB_VARIANT"] = old
|
||||
else:
|
||||
os.environ.pop("CRISIS_AB_VARIANT", None)
|
||||
|
||||
|
||||
def _make_variant(level: str, indicators=None):
|
||||
indicators = indicators or [f"mock_{level.lower()}"]
|
||||
|
||||
def fn(text: str) -> CrisisDetectionResult:
|
||||
return CrisisDetectionResult(level=level, indicators=list(indicators))
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def test_detect_returns_result_variant_and_logged_record():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
with patch.object(detector, "_select_variant", return_value="A"):
|
||||
result, variant, record_id = detector.detect("test message")
|
||||
|
||||
assert isinstance(result, CrisisDetectionResult)
|
||||
assert variant == "A"
|
||||
assert record_id == 0
|
||||
assert len(detector.records) == 1
|
||||
assert detector.records[0].variant == "A"
|
||||
assert detector.records[0].level == "LOW"
|
||||
|
||||
|
||||
def test_env_override_forces_variant_b():
|
||||
os.environ["CRISIS_AB_VARIANT"] = "b"
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
result, variant, _ = detector.detect("test")
|
||||
|
||||
assert variant == "B"
|
||||
assert result.level == "HIGH"
|
||||
|
||||
|
||||
def test_get_stats_reports_latency_counts_and_level_breakdown():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("CRITICAL"),
|
||||
)
|
||||
|
||||
with patch.object(detector, "_select_variant", side_effect=["A", "A", "B"]):
|
||||
detector.detect("first")
|
||||
detector.detect("second")
|
||||
detector.detect("third")
|
||||
|
||||
stats = detector.get_stats()
|
||||
assert stats["A"]["count"] == 2
|
||||
assert stats["B"]["count"] == 1
|
||||
assert stats["A"]["levels"]["LOW"] == 2
|
||||
assert stats["B"]["levels"]["CRITICAL"] == 1
|
||||
assert "avg_latency_ms" in stats["A"]
|
||||
assert "avg_indicator_count" in stats["B"]
|
||||
|
||||
|
||||
def test_false_positive_rate_is_computed_from_reviewed_outcomes():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
with patch.object(detector, "_select_variant", side_effect=["A", "A", "B"]):
|
||||
_, _, a0 = detector.detect("first")
|
||||
_, _, a1 = detector.detect("second")
|
||||
_, _, b0 = detector.detect("third")
|
||||
|
||||
detector.record_outcome(a0, false_positive=True)
|
||||
detector.record_outcome(a1, false_positive=False)
|
||||
detector.record_outcome(b0, false_positive=False)
|
||||
|
||||
stats = detector.get_stats()
|
||||
assert stats["A"]["reviewed_count"] == 2
|
||||
assert stats["A"]["false_positive_rate"] == 0.5
|
||||
assert stats["B"]["false_positive_rate"] == 0.0
|
||||
|
||||
|
||||
def test_record_outcome_rejects_unknown_record():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
with pytest.raises(IndexError):
|
||||
detector.record_outcome(99, false_positive=True)
|
||||
|
||||
|
||||
def test_reset_clears_records_and_stats():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
detector.detect("test")
|
||||
detector.reset()
|
||||
|
||||
assert detector.records == []
|
||||
stats = detector.get_stats()
|
||||
assert stats["A"]["count"] == 0
|
||||
assert stats["B"]["count"] == 0
|
||||
|
||||
|
||||
def test_with_real_detector_integration():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=detect_crisis,
|
||||
variant_b=detect_crisis,
|
||||
)
|
||||
|
||||
result, variant, record_id = detector.detect("I want to kill myself")
|
||||
|
||||
assert result.level == "CRITICAL"
|
||||
assert variant in ("A", "B")
|
||||
assert record_id == 0
|
||||
@@ -52,6 +52,34 @@ class TestCrisisOverlayFocusTrap(unittest.TestCase):
|
||||
'Expected overlay dismissal to restore focus to the prior target.',
|
||||
)
|
||||
|
||||
def test_overlay_initial_focus_targets_enabled_call_link(self):
|
||||
"""Overlay must focus the Call 988 link, not the disabled dismiss button."""
|
||||
# Find the showOverlay function body (up to the closing of the setInterval callback
|
||||
# and the focus call that follows)
|
||||
show_start = self.html.find('function showOverlay()')
|
||||
self.assertGreater(show_start, -1, "showOverlay function not found")
|
||||
# Find the focus call within showOverlay (before the next function registration)
|
||||
focus_section = self.html[show_start:show_start + 2000]
|
||||
self.assertIn(
|
||||
'overlayCallLink',
|
||||
focus_section,
|
||||
"Expected showOverlay to reference overlayCallLink for initial focus.",
|
||||
)
|
||||
# Ensure the old buggy pattern is gone
|
||||
focus_line_region = self.html[show_start + 800:show_start + 1200]
|
||||
self.assertNotIn(
|
||||
'overlayDismissBtn.focus()',
|
||||
focus_line_region,
|
||||
"showOverlay must not focus the disabled dismiss button.",
|
||||
)
|
||||
|
||||
def test_overlay_call_link_variable_is_declared(self):
|
||||
self.assertIn(
|
||||
"querySelector('.overlay-call')",
|
||||
self.html,
|
||||
"Expected a JS reference to the .overlay-call link element.",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -50,6 +50,22 @@ class TestCrisisOfflinePage(unittest.TestCase):
|
||||
for phrase in required_phrases:
|
||||
self.assertIn(phrase, self.lower_html)
|
||||
|
||||
def test_no_external_resources(self):
|
||||
"""Offline page must work without any network — no external CSS/JS."""
|
||||
import re
|
||||
html = self.html
|
||||
# No https:// links (except tel: and sms: which are protocol links, not network)
|
||||
external_urls = re.findall(r'href=["\']https://|src=["\']https://', html)
|
||||
self.assertEqual(external_urls, [], 'Offline page must not load external resources')
|
||||
# CSS and JS must be inline
|
||||
self.assertIn('<style>', html, 'CSS must be inline')
|
||||
self.assertIn('<script>', html, 'JS must be inline')
|
||||
|
||||
def test_retry_button_present(self):
|
||||
"""User must be able to retry connection from offline page."""
|
||||
self.assertIn('retry-connection', self.html)
|
||||
self.assertIn('Retry connection', self.html)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user