Compare commits
3 Commits
feat/136-c
...
fix/101
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7cef18fdcb | ||
|
|
706024e11e | ||
| d412939b4f |
@@ -8,6 +8,7 @@ from .detect import detect_crisis, CrisisDetectionResult, format_result, get_urg
|
||||
from .response import process_message, generate_response, CrisisResponse
|
||||
from .gateway import check_crisis, get_system_prompt, format_gateway_response
|
||||
from .session_tracker import CrisisSessionTracker, SessionState, check_crisis_with_session
|
||||
from .ab_testing import ABTestCrisisDetector, VariantRecord
|
||||
|
||||
__all__ = [
|
||||
"detect_crisis",
|
||||
@@ -23,4 +24,6 @@ __all__ = [
|
||||
"CrisisSessionTracker",
|
||||
"SessionState",
|
||||
"check_crisis_with_session",
|
||||
"ABTestCrisisDetector",
|
||||
"VariantRecord",
|
||||
]
|
||||
|
||||
112
crisis/ab_testing.py
Normal file
112
crisis/ab_testing.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""A/B test framework for crisis detection in the-door."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from .detect import CrisisDetectionResult
|
||||
|
||||
|
||||
def _get_variant_override() -> Optional[str]:
|
||||
"""Return env override for deterministic testing/debugging."""
|
||||
value = os.environ.get("CRISIS_AB_VARIANT", "").strip().upper()
|
||||
if value in {"A", "B"}:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VariantRecord:
|
||||
"""Single crisis detection event record with no user text or PII."""
|
||||
|
||||
variant: str
|
||||
level: str
|
||||
latency_ms: float
|
||||
indicator_count: int
|
||||
false_positive: Optional[bool] = None
|
||||
|
||||
|
||||
class ABTestCrisisDetector:
|
||||
"""Route crisis detection between two variants and collect comparison stats."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
variant_a: Callable[[str], CrisisDetectionResult],
|
||||
variant_b: Callable[[str], CrisisDetectionResult],
|
||||
split: float = 0.5,
|
||||
):
|
||||
self.variant_a = variant_a
|
||||
self.variant_b = variant_b
|
||||
self.split = max(0.0, min(1.0, float(split)))
|
||||
self.records: List[VariantRecord] = []
|
||||
|
||||
def _select_variant(self) -> str:
|
||||
override = _get_variant_override()
|
||||
if override:
|
||||
return override
|
||||
return "A" if random.random() < self.split else "B"
|
||||
|
||||
def detect(self, text: str) -> Tuple[CrisisDetectionResult, str, int]:
|
||||
variant = self._select_variant()
|
||||
detector = self.variant_a if variant == "A" else self.variant_b
|
||||
|
||||
start = time.perf_counter()
|
||||
result = detector(text)
|
||||
latency_ms = (time.perf_counter() - start) * 1000.0
|
||||
|
||||
record = VariantRecord(
|
||||
variant=variant,
|
||||
level=result.level,
|
||||
latency_ms=latency_ms,
|
||||
indicator_count=len(result.indicators),
|
||||
)
|
||||
self.records.append(record)
|
||||
return result, variant, len(self.records) - 1
|
||||
|
||||
def record_outcome(self, record_id: int, *, false_positive: bool) -> None:
|
||||
if record_id < 0 or record_id >= len(self.records):
|
||||
raise IndexError(f"Unknown record id: {record_id}")
|
||||
self.records[record_id].false_positive = bool(false_positive)
|
||||
|
||||
def get_stats(self) -> Dict[str, dict]:
|
||||
stats: Dict[str, dict] = {}
|
||||
for variant in ("A", "B"):
|
||||
records = [record for record in self.records if record.variant == variant]
|
||||
if not records:
|
||||
stats[variant] = {
|
||||
"count": 0,
|
||||
"reviewed_count": 0,
|
||||
"false_positive_rate": None,
|
||||
}
|
||||
continue
|
||||
|
||||
levels: Dict[str, int] = {}
|
||||
for record in records:
|
||||
levels[record.level] = levels.get(record.level, 0) + 1
|
||||
|
||||
reviewed = [record for record in records if record.false_positive is not None]
|
||||
false_positive_rate = None
|
||||
if reviewed:
|
||||
false_positive_rate = round(
|
||||
sum(1 for record in reviewed if record.false_positive) / len(reviewed),
|
||||
4,
|
||||
)
|
||||
|
||||
stats[variant] = {
|
||||
"count": len(records),
|
||||
"avg_latency_ms": round(sum(record.latency_ms for record in records) / len(records), 4),
|
||||
"max_latency_ms": round(max(record.latency_ms for record in records), 4),
|
||||
"min_latency_ms": round(min(record.latency_ms for record in records), 4),
|
||||
"avg_indicator_count": round(sum(record.indicator_count for record in records) / len(records), 4),
|
||||
"levels": levels,
|
||||
"reviewed_count": len(reviewed),
|
||||
"false_positive_rate": false_positive_rate,
|
||||
}
|
||||
return stats
|
||||
|
||||
def reset(self) -> None:
|
||||
self.records.clear()
|
||||
@@ -1,133 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Crisis Metrics CLI — View crisis detection health from the command line.
|
||||
|
||||
Usage:
|
||||
python3 -m crisis.metrics --summary # weekly report
|
||||
python3 -m crisis.metrics --json # raw JSON export
|
||||
python3 -m crisis.metrics --last 24h # last 24 hours
|
||||
|
||||
Ref: #136
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
METRICS_DIR = os.environ.get("CRISIS_METRICS_DIR", str(Path.home() / ".the-door" / "metrics"))
|
||||
|
||||
|
||||
def load_metrics(hours: int = 168) -> List[dict]:
|
||||
"""Load metrics entries from the last N hours."""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(hours=hours)
|
||||
entries = []
|
||||
metrics_path = Path(METRICS_DIR)
|
||||
|
||||
if not metrics_path.exists():
|
||||
return entries
|
||||
|
||||
for f in sorted(metrics_path.glob("*.json")):
|
||||
try:
|
||||
with open(f) as fh:
|
||||
data = json.load(fh)
|
||||
if isinstance(data, list):
|
||||
entries.extend(data)
|
||||
elif isinstance(data, dict):
|
||||
entries.append(data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Filter by timestamp
|
||||
filtered = []
|
||||
for e in entries:
|
||||
ts = e.get("timestamp", "")
|
||||
if ts:
|
||||
try:
|
||||
t = datetime.fromisoformat(ts.replace("Z", "+00:00"))
|
||||
if t >= cutoff:
|
||||
filtered.append(e)
|
||||
except Exception:
|
||||
filtered.append(e)
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def summarize(entries: List[dict]) -> dict:
|
||||
"""Summarize metrics entries."""
|
||||
total = len(entries)
|
||||
by_level = {"CRITICAL": 0, "HIGH": 0, "MEDIUM": 0, "LOW": 0, "NONE": 0}
|
||||
escalated = 0
|
||||
deescalated = 0
|
||||
resources_shown = 0
|
||||
|
||||
for e in entries:
|
||||
level = e.get("level", "NONE")
|
||||
by_level[level] = by_level.get(level, 0) + 1
|
||||
if e.get("escalated"):
|
||||
escalated += 1
|
||||
if e.get("deescalation_confirmed"):
|
||||
deescalated += 1
|
||||
if e.get("resources_shown"):
|
||||
resources_shown += 1
|
||||
|
||||
return {
|
||||
"period_hours": 168,
|
||||
"total_interactions": total,
|
||||
"by_level": by_level,
|
||||
"escalated_sessions": escalated,
|
||||
"deescalated_sessions": deescalated,
|
||||
"resources_shown": resources_shown,
|
||||
"crisis_rate": round((by_level["CRITICAL"] + by_level["HIGH"]) / max(total, 1) * 100, 1),
|
||||
}
|
||||
|
||||
|
||||
def print_summary(summary: dict):
|
||||
print(f"\n{'='*50}")
|
||||
print(f" CRISIS METRICS SUMMARY")
|
||||
print(f" {datetime.now().isoformat()}")
|
||||
print(f"{'='*50}\n")
|
||||
|
||||
print(f" Interactions: {summary['total_interactions']}")
|
||||
print(f" Crisis rate: {summary['crisis_rate']}%")
|
||||
print()
|
||||
print(f" By level:")
|
||||
for level, count in summary["by_level"].items():
|
||||
bar = "█" * min(count, 40)
|
||||
print(f" {level:10} {count:5} {bar}")
|
||||
print()
|
||||
print(f" Escalated: {summary['escalated_sessions']}")
|
||||
print(f" De-escalated: {summary['deescalated_sessions']}")
|
||||
print(f" 988 shown: {summary['resources_shown']}")
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Crisis Metrics CLI")
|
||||
parser.add_argument("--summary", action="store_true", help="Weekly summary")
|
||||
parser.add_argument("--json", action="store_true", help="JSON export")
|
||||
parser.add_argument("--last", default="168h", help="Time window (e.g., 24h, 7d)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse time window
|
||||
last = args.last
|
||||
if last.endswith("h"):
|
||||
hours = int(last[:-1])
|
||||
elif last.endswith("d"):
|
||||
hours = int(last[:-1]) * 24
|
||||
else:
|
||||
hours = 168
|
||||
|
||||
entries = load_metrics(hours)
|
||||
summary = summarize(entries)
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(summary, indent=2))
|
||||
else:
|
||||
print_summary(summary)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -680,7 +680,7 @@ html, body {
|
||||
|
||||
<!-- Footer -->
|
||||
<footer id="footer">
|
||||
<a href="/about" aria-label="About The Door">about</a>
|
||||
<a href="/about.html" aria-label="About The Door">about</a>
|
||||
<button id="safety-plan-btn" aria-label="Open My Safety Plan">my safety plan</button>
|
||||
<button id="clear-chat-btn" aria-label="Clear chat history">clear chat</button>
|
||||
</footer>
|
||||
|
||||
138
tests/test_ab_testing.py
Normal file
138
tests/test_ab_testing.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Tests for crisis.ab_testing — A/B test framework for crisis detection (#101)."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crisis.ab_testing import ABTestCrisisDetector
|
||||
from crisis.detect import CrisisDetectionResult, detect_crisis
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_variant_override():
|
||||
old = os.environ.pop("CRISIS_AB_VARIANT", None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if old is not None:
|
||||
os.environ["CRISIS_AB_VARIANT"] = old
|
||||
else:
|
||||
os.environ.pop("CRISIS_AB_VARIANT", None)
|
||||
|
||||
|
||||
def _make_variant(level: str, indicators=None):
|
||||
indicators = indicators or [f"mock_{level.lower()}"]
|
||||
|
||||
def fn(text: str) -> CrisisDetectionResult:
|
||||
return CrisisDetectionResult(level=level, indicators=list(indicators))
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def test_detect_returns_result_variant_and_logged_record():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
with patch.object(detector, "_select_variant", return_value="A"):
|
||||
result, variant, record_id = detector.detect("test message")
|
||||
|
||||
assert isinstance(result, CrisisDetectionResult)
|
||||
assert variant == "A"
|
||||
assert record_id == 0
|
||||
assert len(detector.records) == 1
|
||||
assert detector.records[0].variant == "A"
|
||||
assert detector.records[0].level == "LOW"
|
||||
|
||||
|
||||
def test_env_override_forces_variant_b():
|
||||
os.environ["CRISIS_AB_VARIANT"] = "b"
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
result, variant, _ = detector.detect("test")
|
||||
|
||||
assert variant == "B"
|
||||
assert result.level == "HIGH"
|
||||
|
||||
|
||||
def test_get_stats_reports_latency_counts_and_level_breakdown():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("CRITICAL"),
|
||||
)
|
||||
|
||||
with patch.object(detector, "_select_variant", side_effect=["A", "A", "B"]):
|
||||
detector.detect("first")
|
||||
detector.detect("second")
|
||||
detector.detect("third")
|
||||
|
||||
stats = detector.get_stats()
|
||||
assert stats["A"]["count"] == 2
|
||||
assert stats["B"]["count"] == 1
|
||||
assert stats["A"]["levels"]["LOW"] == 2
|
||||
assert stats["B"]["levels"]["CRITICAL"] == 1
|
||||
assert "avg_latency_ms" in stats["A"]
|
||||
assert "avg_indicator_count" in stats["B"]
|
||||
|
||||
|
||||
def test_false_positive_rate_is_computed_from_reviewed_outcomes():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
with patch.object(detector, "_select_variant", side_effect=["A", "A", "B"]):
|
||||
_, _, a0 = detector.detect("first")
|
||||
_, _, a1 = detector.detect("second")
|
||||
_, _, b0 = detector.detect("third")
|
||||
|
||||
detector.record_outcome(a0, false_positive=True)
|
||||
detector.record_outcome(a1, false_positive=False)
|
||||
detector.record_outcome(b0, false_positive=False)
|
||||
|
||||
stats = detector.get_stats()
|
||||
assert stats["A"]["reviewed_count"] == 2
|
||||
assert stats["A"]["false_positive_rate"] == 0.5
|
||||
assert stats["B"]["false_positive_rate"] == 0.0
|
||||
|
||||
|
||||
def test_record_outcome_rejects_unknown_record():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
with pytest.raises(IndexError):
|
||||
detector.record_outcome(99, false_positive=True)
|
||||
|
||||
|
||||
def test_reset_clears_records_and_stats():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
detector.detect("test")
|
||||
detector.reset()
|
||||
|
||||
assert detector.records == []
|
||||
stats = detector.get_stats()
|
||||
assert stats["A"]["count"] == 0
|
||||
assert stats["B"]["count"] == 0
|
||||
|
||||
|
||||
def test_with_real_detector_integration():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=detect_crisis,
|
||||
variant_b=detect_crisis,
|
||||
)
|
||||
|
||||
result, variant, record_id = detector.detect("I want to kill myself")
|
||||
|
||||
assert result.level == "CRITICAL"
|
||||
assert variant in ("A", "B")
|
||||
assert record_id == 0
|
||||
Reference in New Issue
Block a user