Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7cef18fdcb | ||
|
|
706024e11e | ||
| d412939b4f |
@@ -8,6 +8,7 @@ from .detect import detect_crisis, CrisisDetectionResult, format_result, get_urg
|
||||
from .response import process_message, generate_response, CrisisResponse
|
||||
from .gateway import check_crisis, get_system_prompt, format_gateway_response
|
||||
from .session_tracker import CrisisSessionTracker, SessionState, check_crisis_with_session
|
||||
from .ab_testing import ABTestCrisisDetector, VariantRecord
|
||||
|
||||
__all__ = [
|
||||
"detect_crisis",
|
||||
@@ -23,4 +24,6 @@ __all__ = [
|
||||
"CrisisSessionTracker",
|
||||
"SessionState",
|
||||
"check_crisis_with_session",
|
||||
"ABTestCrisisDetector",
|
||||
"VariantRecord",
|
||||
]
|
||||
|
||||
112
crisis/ab_testing.py
Normal file
112
crisis/ab_testing.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""A/B test framework for crisis detection in the-door."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from .detect import CrisisDetectionResult
|
||||
|
||||
|
||||
def _get_variant_override() -> Optional[str]:
|
||||
"""Return env override for deterministic testing/debugging."""
|
||||
value = os.environ.get("CRISIS_AB_VARIANT", "").strip().upper()
|
||||
if value in {"A", "B"}:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VariantRecord:
|
||||
"""Single crisis detection event record with no user text or PII."""
|
||||
|
||||
variant: str
|
||||
level: str
|
||||
latency_ms: float
|
||||
indicator_count: int
|
||||
false_positive: Optional[bool] = None
|
||||
|
||||
|
||||
class ABTestCrisisDetector:
|
||||
"""Route crisis detection between two variants and collect comparison stats."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
variant_a: Callable[[str], CrisisDetectionResult],
|
||||
variant_b: Callable[[str], CrisisDetectionResult],
|
||||
split: float = 0.5,
|
||||
):
|
||||
self.variant_a = variant_a
|
||||
self.variant_b = variant_b
|
||||
self.split = max(0.0, min(1.0, float(split)))
|
||||
self.records: List[VariantRecord] = []
|
||||
|
||||
def _select_variant(self) -> str:
|
||||
override = _get_variant_override()
|
||||
if override:
|
||||
return override
|
||||
return "A" if random.random() < self.split else "B"
|
||||
|
||||
def detect(self, text: str) -> Tuple[CrisisDetectionResult, str, int]:
|
||||
variant = self._select_variant()
|
||||
detector = self.variant_a if variant == "A" else self.variant_b
|
||||
|
||||
start = time.perf_counter()
|
||||
result = detector(text)
|
||||
latency_ms = (time.perf_counter() - start) * 1000.0
|
||||
|
||||
record = VariantRecord(
|
||||
variant=variant,
|
||||
level=result.level,
|
||||
latency_ms=latency_ms,
|
||||
indicator_count=len(result.indicators),
|
||||
)
|
||||
self.records.append(record)
|
||||
return result, variant, len(self.records) - 1
|
||||
|
||||
def record_outcome(self, record_id: int, *, false_positive: bool) -> None:
|
||||
if record_id < 0 or record_id >= len(self.records):
|
||||
raise IndexError(f"Unknown record id: {record_id}")
|
||||
self.records[record_id].false_positive = bool(false_positive)
|
||||
|
||||
def get_stats(self) -> Dict[str, dict]:
|
||||
stats: Dict[str, dict] = {}
|
||||
for variant in ("A", "B"):
|
||||
records = [record for record in self.records if record.variant == variant]
|
||||
if not records:
|
||||
stats[variant] = {
|
||||
"count": 0,
|
||||
"reviewed_count": 0,
|
||||
"false_positive_rate": None,
|
||||
}
|
||||
continue
|
||||
|
||||
levels: Dict[str, int] = {}
|
||||
for record in records:
|
||||
levels[record.level] = levels.get(record.level, 0) + 1
|
||||
|
||||
reviewed = [record for record in records if record.false_positive is not None]
|
||||
false_positive_rate = None
|
||||
if reviewed:
|
||||
false_positive_rate = round(
|
||||
sum(1 for record in reviewed if record.false_positive) / len(reviewed),
|
||||
4,
|
||||
)
|
||||
|
||||
stats[variant] = {
|
||||
"count": len(records),
|
||||
"avg_latency_ms": round(sum(record.latency_ms for record in records) / len(records), 4),
|
||||
"max_latency_ms": round(max(record.latency_ms for record in records), 4),
|
||||
"min_latency_ms": round(min(record.latency_ms for record in records), 4),
|
||||
"avg_indicator_count": round(sum(record.indicator_count for record in records) / len(records), 4),
|
||||
"levels": levels,
|
||||
"reviewed_count": len(reviewed),
|
||||
"false_positive_rate": false_positive_rate,
|
||||
}
|
||||
return stats
|
||||
|
||||
def reset(self) -> None:
|
||||
self.records.clear()
|
||||
75
index.html
75
index.html
@@ -241,48 +241,6 @@ html, body {
|
||||
opacity: 0.5;
|
||||
}
|
||||
|
||||
/* ===== CHAT HEADER ===== */
|
||||
#chat-header {
|
||||
flex-shrink: 0;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 12px;
|
||||
padding: 10px 12px;
|
||||
border-bottom: 1px solid #21262d;
|
||||
background: #11161d;
|
||||
}
|
||||
|
||||
.chat-header-title {
|
||||
font-size: 0.85rem;
|
||||
color: #8b949e;
|
||||
font-weight: 600;
|
||||
letter-spacing: 0.02em;
|
||||
}
|
||||
|
||||
#chat-safety-plan-btn {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 8px 12px;
|
||||
min-height: 36px;
|
||||
border: 1px solid #30363d;
|
||||
border-radius: 999px;
|
||||
background: transparent;
|
||||
color: #c9d1d9;
|
||||
font-size: 0.8rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
#chat-safety-plan-btn:hover,
|
||||
#chat-safety-plan-btn:focus {
|
||||
border-color: #58a6ff;
|
||||
background: rgba(88, 166, 255, 0.12);
|
||||
outline: 2px solid #58a6ff;
|
||||
outline-offset: 2px;
|
||||
}
|
||||
|
||||
/* ===== CHAT AREA ===== */
|
||||
#chat-area {
|
||||
flex: 1;
|
||||
@@ -691,14 +649,6 @@ html, body {
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div id="chat-header">
|
||||
<div class="chat-header-title" aria-hidden="true">Conversation</div>
|
||||
<button id="chat-safety-plan-btn" type="button" aria-label="Open My Safety Plan from chat header">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" aria-hidden="true"><path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z"/><polyline points="14 2 14 8 20 8"/><line x1="16" y1="13" x2="8" y2="13"/><line x1="16" y1="17" x2="8" y2="17"/><polyline points="10 9 9 9 8 9"/></svg>
|
||||
My Safety Plan
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Chat messages -->
|
||||
<div id="chat-area" role="log" aria-label="Chat messages" aria-live="polite" tabindex="0">
|
||||
<!-- Messages inserted here -->
|
||||
@@ -730,7 +680,7 @@ html, body {
|
||||
|
||||
<!-- Footer -->
|
||||
<footer id="footer">
|
||||
<a href="/about" aria-label="About The Door">about</a>
|
||||
<a href="/about.html" aria-label="About The Door">about</a>
|
||||
<button id="safety-plan-btn" aria-label="Open My Safety Plan">my safety plan</button>
|
||||
<button id="clear-chat-btn" aria-label="Clear chat history">clear chat</button>
|
||||
</footer>
|
||||
@@ -864,7 +814,6 @@ Sovereignty and service always.`;
|
||||
|
||||
// Safety Plan Elements
|
||||
var safetyPlanBtn = document.getElementById('safety-plan-btn');
|
||||
var chatSafetyPlanBtn = document.getElementById('chat-safety-plan-btn');
|
||||
var crisisSafetyPlanBtn = document.getElementById('crisis-safety-plan-btn');
|
||||
var safetyPlanModal = document.getElementById('safety-plan-modal');
|
||||
var closeSafetyPlan = document.getElementById('close-safety-plan');
|
||||
@@ -1336,25 +1285,19 @@ Sovereignty and service always.`;
|
||||
_spTriggerEl = null;
|
||||
}
|
||||
|
||||
function openSafetyPlan(triggerEl) {
|
||||
loadSafetyPlan();
|
||||
safetyPlanModal.classList.add('active');
|
||||
_activateSafetyPlanFocusTrap(triggerEl || document.activeElement);
|
||||
}
|
||||
|
||||
// Wire open buttons to activate focus trap
|
||||
safetyPlanBtn.addEventListener('click', function() {
|
||||
openSafetyPlan(safetyPlanBtn);
|
||||
});
|
||||
|
||||
chatSafetyPlanBtn.addEventListener('click', function() {
|
||||
openSafetyPlan(chatSafetyPlanBtn);
|
||||
loadSafetyPlan();
|
||||
safetyPlanModal.classList.add('active');
|
||||
_activateSafetyPlanFocusTrap(safetyPlanBtn);
|
||||
});
|
||||
|
||||
// Crisis panel safety plan button (if crisis panel is visible)
|
||||
if (crisisSafetyPlanBtn) {
|
||||
crisisSafetyPlanBtn.addEventListener('click', function() {
|
||||
openSafetyPlan(crisisSafetyPlanBtn);
|
||||
loadSafetyPlan();
|
||||
safetyPlanModal.classList.add('active');
|
||||
_activateSafetyPlanFocusTrap(crisisSafetyPlanBtn);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1501,7 +1444,9 @@ Sovereignty and service always.`;
|
||||
// Check for URL params (e.g., ?safetyplan=true for PWA shortcut)
|
||||
var urlParams = new URLSearchParams(window.location.search);
|
||||
if (urlParams.get('safetyplan') === 'true') {
|
||||
openSafetyPlan(chatSafetyPlanBtn || safetyPlanBtn);
|
||||
loadSafetyPlan();
|
||||
safetyPlanModal.classList.add('active');
|
||||
_activateSafetyPlanFocusTrap(safetyPlanBtn);
|
||||
// Clean up URL
|
||||
window.history.replaceState({}, document.title, window.location.pathname);
|
||||
}
|
||||
|
||||
138
tests/test_ab_testing.py
Normal file
138
tests/test_ab_testing.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Tests for crisis.ab_testing — A/B test framework for crisis detection (#101)."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crisis.ab_testing import ABTestCrisisDetector
|
||||
from crisis.detect import CrisisDetectionResult, detect_crisis
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_variant_override():
|
||||
old = os.environ.pop("CRISIS_AB_VARIANT", None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if old is not None:
|
||||
os.environ["CRISIS_AB_VARIANT"] = old
|
||||
else:
|
||||
os.environ.pop("CRISIS_AB_VARIANT", None)
|
||||
|
||||
|
||||
def _make_variant(level: str, indicators=None):
|
||||
indicators = indicators or [f"mock_{level.lower()}"]
|
||||
|
||||
def fn(text: str) -> CrisisDetectionResult:
|
||||
return CrisisDetectionResult(level=level, indicators=list(indicators))
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def test_detect_returns_result_variant_and_logged_record():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
with patch.object(detector, "_select_variant", return_value="A"):
|
||||
result, variant, record_id = detector.detect("test message")
|
||||
|
||||
assert isinstance(result, CrisisDetectionResult)
|
||||
assert variant == "A"
|
||||
assert record_id == 0
|
||||
assert len(detector.records) == 1
|
||||
assert detector.records[0].variant == "A"
|
||||
assert detector.records[0].level == "LOW"
|
||||
|
||||
|
||||
def test_env_override_forces_variant_b():
|
||||
os.environ["CRISIS_AB_VARIANT"] = "b"
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
result, variant, _ = detector.detect("test")
|
||||
|
||||
assert variant == "B"
|
||||
assert result.level == "HIGH"
|
||||
|
||||
|
||||
def test_get_stats_reports_latency_counts_and_level_breakdown():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("CRITICAL"),
|
||||
)
|
||||
|
||||
with patch.object(detector, "_select_variant", side_effect=["A", "A", "B"]):
|
||||
detector.detect("first")
|
||||
detector.detect("second")
|
||||
detector.detect("third")
|
||||
|
||||
stats = detector.get_stats()
|
||||
assert stats["A"]["count"] == 2
|
||||
assert stats["B"]["count"] == 1
|
||||
assert stats["A"]["levels"]["LOW"] == 2
|
||||
assert stats["B"]["levels"]["CRITICAL"] == 1
|
||||
assert "avg_latency_ms" in stats["A"]
|
||||
assert "avg_indicator_count" in stats["B"]
|
||||
|
||||
|
||||
def test_false_positive_rate_is_computed_from_reviewed_outcomes():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
with patch.object(detector, "_select_variant", side_effect=["A", "A", "B"]):
|
||||
_, _, a0 = detector.detect("first")
|
||||
_, _, a1 = detector.detect("second")
|
||||
_, _, b0 = detector.detect("third")
|
||||
|
||||
detector.record_outcome(a0, false_positive=True)
|
||||
detector.record_outcome(a1, false_positive=False)
|
||||
detector.record_outcome(b0, false_positive=False)
|
||||
|
||||
stats = detector.get_stats()
|
||||
assert stats["A"]["reviewed_count"] == 2
|
||||
assert stats["A"]["false_positive_rate"] == 0.5
|
||||
assert stats["B"]["false_positive_rate"] == 0.0
|
||||
|
||||
|
||||
def test_record_outcome_rejects_unknown_record():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
with pytest.raises(IndexError):
|
||||
detector.record_outcome(99, false_positive=True)
|
||||
|
||||
|
||||
def test_reset_clears_records_and_stats():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
detector.detect("test")
|
||||
detector.reset()
|
||||
|
||||
assert detector.records == []
|
||||
stats = detector.get_stats()
|
||||
assert stats["A"]["count"] == 0
|
||||
assert stats["B"]["count"] == 0
|
||||
|
||||
|
||||
def test_with_real_detector_integration():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=detect_crisis,
|
||||
variant_b=detect_crisis,
|
||||
)
|
||||
|
||||
result, variant, record_id = detector.detect("I want to kill myself")
|
||||
|
||||
assert result.level == "CRITICAL"
|
||||
assert variant in ("A", "B")
|
||||
assert record_id == 0
|
||||
@@ -1,20 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
INDEX = Path("index.html")
|
||||
|
||||
|
||||
def test_chat_header_has_persistent_safety_plan_button():
|
||||
html = INDEX.read_text()
|
||||
assert 'id="chat-header"' in html
|
||||
assert 'id="chat-safety-plan-btn"' in html
|
||||
assert 'aria-label="Open My Safety Plan from chat header"' in html
|
||||
assert 'My Safety Plan' in html
|
||||
|
||||
|
||||
def test_chat_header_button_opens_existing_safety_plan_modal():
|
||||
html = INDEX.read_text()
|
||||
assert "var chatSafetyPlanBtn = document.getElementById('chat-safety-plan-btn');" in html
|
||||
assert "chatSafetyPlanBtn.addEventListener('click'" in html
|
||||
assert "function openSafetyPlan(triggerEl)" in html
|
||||
assert "safetyPlanModal.classList.add('active');" in html
|
||||
assert "openSafetyPlan(chatSafetyPlanBtn);" in html
|
||||
Reference in New Issue
Block a user