Compare commits

..

2 Commits

Author SHA1 Message Date
65d6fc6119 test: add A/B testing framework tests (#101)
All checks were successful
Sanity Checks / sanity-test (pull_request) Successful in 7s
Smoke Test / smoke (pull_request) Successful in 10s
2026-04-15 03:58:27 +00:00
70d04cdbfd feat: add crisis detection A/B test framework (#101) 2026-04-15 03:58:26 +00:00
11 changed files with 288 additions and 640 deletions

View File

@@ -7,7 +7,6 @@ Stands between a broken man and a machine that would tell him to die.
from .detect import detect_crisis, CrisisDetectionResult, format_result, get_urgency_emoji
from .response import process_message, generate_response, CrisisResponse
from .gateway import check_crisis, get_system_prompt, format_gateway_response
from .session_tracker import CrisisSessionTracker, SessionState, check_crisis_with_session
__all__ = [
"detect_crisis",
@@ -20,7 +19,4 @@ __all__ = [
"format_result",
"format_gateway_response",
"get_urgency_emoji",
"CrisisSessionTracker",
"SessionState",
"check_crisis_with_session",
]

152
crisis/ab_testing.py Normal file
View File

@@ -0,0 +1,152 @@
"""
A/B Test Framework for Crisis Detection in the-door.
Allows running two crisis detection variants side-by-side with
logged outcomes for comparison. No PII stored — only variant labels,
levels, and timing.
Usage:
from crisis.ab_testing import ABTestCrisisDetector
detector = ABTestCrisisDetector(variant_a=detect_v1, variant_b=detect_v2)
result, variant = detector.detect("I feel hopeless")
# result: CrisisDetectionResult
# variant: "A" or "B"
# Get comparison metrics
stats = detector.get_stats()
# {"A": {"count": 100, "avg_latency_ms": 2.3, ...}, "B": {...}}
"""
import os
import random
import time
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple
from .detect import CrisisDetectionResult
# ── Feature flag ───────────────────────────────────────────────
def _get_variant_override() -> Optional[str]:
"""Check for environment variable override (testing/debugging)."""
val = os.environ.get("CRISIS_AB_VARIANT", "").upper()
if val in ("A", "B"):
return val
return None
@dataclass
class VariantRecord:
"""Single detection event record — no PII, only metadata."""
variant: str
level: str
latency_ms: float
indicator_count: int
class ABTestCrisisDetector:
"""
A/B test wrapper for crisis detection.
Routes calls to variant A or B based on configurable split,
logs outcomes for comparison, and provides aggregate stats.
"""
def __init__(
self,
variant_a: Callable[[str], CrisisDetectionResult],
variant_b: Callable[[str], CrisisDetectionResult],
split: float = 0.5,
variant_a_name: str = "A",
variant_b_name: str = "B",
):
"""
Args:
variant_a: First detection function
variant_b: Second detection function
split: Probability of selecting variant A (0.0 to 1.0)
variant_a_name: Label for variant A in reports
variant_b_name: Label for variant B in reports
"""
self.variant_a = variant_a
self.variant_b = variant_b
self.split = split
self.variant_a_name = variant_a_name
self.variant_b_name = variant_b_name
self.records: List[VariantRecord] = []
def _select_variant(self) -> str:
"""Select variant based on split and optional env override."""
override = _get_variant_override()
if override:
return override
return "A" if random.random() < self.split else "B"
def detect(self, text: str) -> Tuple[CrisisDetectionResult, str]:
"""
Run detection on the selected variant and log the result.
Returns:
(CrisisDetectionResult, variant_label)
"""
variant = self._select_variant()
if variant == "A":
fn = self.variant_a
else:
fn = self.variant_b
start = time.perf_counter()
result = fn(text)
latency_ms = (time.perf_counter() - start) * 1000
# Log record (no PII — only level, timing, count)
record = VariantRecord(
variant=variant,
level=result.level,
latency_ms=latency_ms,
indicator_count=len(result.indicators),
)
self.records.append(record)
return result, variant
def get_stats(self) -> Dict[str, dict]:
"""
Get per-variant comparison statistics.
Returns dict with variant labels as keys:
{
"A": {"count": 100, "avg_latency_ms": 2.3, "levels": {...}},
"B": {"count": 95, "avg_latency_ms": 3.1, "levels": {...}}
"""
stats = {}
for label in ("A", "B"):
recs = [r for r in self.records if r.variant == label]
if not recs:
stats[label] = {"count": 0}
continue
latencies = [r.latency_ms for r in recs]
levels = {}
for r in recs:
levels[r.level] = levels.get(r.level, 0) + 1
stats[label] = {
"count": len(recs),
"avg_latency_ms": round(sum(latencies) / len(latencies), 2),
"max_latency_ms": round(max(latencies), 2),
"min_latency_ms": round(min(latencies), 2),
"levels": levels,
"avg_indicators": round(
sum(r.indicator_count for r in recs) / len(recs), 2
),
}
return stats
def reset(self) -> None:
"""Clear all records. For testing."""
self.records.clear()

View File

@@ -104,9 +104,13 @@ MEDIUM_INDICATORS = [
r"\blost\s+all\s+hope\b",
r"\bno\s+tomorrow\b",
# Contextual versions (from crisis_detector.py legacy)
# Keep only medium-only patterns here; stronger overlaps live in HIGH_INDICATORS.
r"\bfeel(?:s|ing)?\s+(?:so\s+)?worthless\b",
r"\bfeel(?:s|ing)?\s+(?:so\s+)?hopeless\b",
r"\bfeel(?:s|ing)?\s+trapped\b",
r"\bfeel(?:s|ing)?\s+desperate\b",
r"\bno\s+future\s+(?:for\s+me|ahead|left)\b",
r"\bnothing\s+left\s+(?:to\s+(?:live|hope)\s+for|inside)\b",
r"\bgive(?:n)?\s*up\s+on\s+myself\b",
]
LOW_INDICATORS = [

View File

@@ -22,7 +22,6 @@ from .response import (
get_system_prompt_modifier,
CrisisResponse,
)
from .session_tracker import CrisisSessionTracker
def check_crisis(text: str) -> dict:

View File

@@ -1,259 +0,0 @@
"""
Session-level crisis tracking and escalation for the-door (P0 #35).
Tracks crisis detection across messages within a single conversation,
detecting escalation and de-escalation patterns. Privacy-first: no
persistence beyond the conversation session.
Each message is analyzed in isolation by detect.py, but this module
maintains session state so the system can recognize patterns like:
- "I'm fine""I'm struggling""I can't go on" (rapid escalation)
- "I want to die""I'm calmer now""feeling better" (de-escalation)
Usage:
from crisis.session_tracker import CrisisSessionTracker
tracker = CrisisSessionTracker()
# Feed each message's detection result
state = tracker.record(detect_crisis("I'm having a tough day"))
print(state.current_level) # "LOW"
print(state.is_escalating) # False
state = tracker.record(detect_crisis("I feel hopeless"))
print(state.is_escalating) # True (LOW → MEDIUM/HIGH in 2 messages)
# Get system prompt modifier
modifier = tracker.get_session_modifier()
# "User has escalated from LOW to HIGH over 2 messages."
# Reset for new session
tracker.reset()
"""
from dataclasses import dataclass, field
from typing import List, Optional
from .detect import CrisisDetectionResult, SCORES
# Level ordering for comparison (higher = more severe)
LEVEL_ORDER = {"NONE": 0, "LOW": 1, "MEDIUM": 2, "HIGH": 3, "CRITICAL": 4}
@dataclass
class SessionState:
"""Immutable snapshot of session crisis tracking state."""
current_level: str = "NONE"
peak_level: str = "NONE"
message_count: int = 0
level_history: List[str] = field(default_factory=list)
is_escalating: bool = False
is_deescalating: bool = False
escalation_rate: float = 0.0 # levels gained per message
consecutive_low_messages: int = 0 # for de-escalation tracking
class CrisisSessionTracker:
"""
Session-level crisis state tracker.
Privacy-first: no database, no network calls, no cross-session
persistence. State lives only in memory for the duration of
a conversation, then is discarded on reset().
"""
# Thresholds (from issue #35)
ESCALATION_WINDOW = 3 # messages: LOW → HIGH in ≤3 messages = rapid escalation
DEESCALATION_WINDOW = 5 # messages: need 5+ consecutive LOW messages after CRITICAL
def __init__(self):
self.reset()
def reset(self):
"""Reset all session state. Call on new conversation."""
self._current_level = "NONE"
self._peak_level = "NONE"
self._message_count = 0
self._level_history: List[str] = []
self._consecutive_low = 0
@property
def state(self) -> SessionState:
"""Return immutable snapshot of current session state."""
is_escalating = self._detect_escalation()
is_deescalating = self._detect_deescalation()
rate = self._compute_escalation_rate()
return SessionState(
current_level=self._current_level,
peak_level=self._peak_level,
message_count=self._message_count,
level_history=list(self._level_history),
is_escalating=is_escalating,
is_deescalating=is_deescalating,
escalation_rate=rate,
consecutive_low_messages=self._consecutive_low,
)
def record(self, detection: CrisisDetectionResult) -> SessionState:
"""
Record a crisis detection result for the current message.
Returns updated SessionState.
"""
level = detection.level
self._message_count += 1
self._level_history.append(level)
# Update peak
if LEVEL_ORDER.get(level, 0) > LEVEL_ORDER.get(self._peak_level, 0):
self._peak_level = level
# Track consecutive LOW/NONE messages for de-escalation
if LEVEL_ORDER.get(level, 0) <= LEVEL_ORDER["LOW"]:
self._consecutive_low += 1
else:
self._consecutive_low = 0
self._current_level = level
return self.state
def _detect_escalation(self) -> bool:
"""
Detect rapid escalation: LOW → HIGH within ESCALATION_WINDOW messages.
Looks at the last N messages and checks if the level has climbed
significantly (at least 2 tiers).
"""
if len(self._level_history) < 2:
return False
window = self._level_history[-self.ESCALATION_WINDOW:]
if len(window) < 2:
return False
first_level = window[0]
last_level = window[-1]
first_score = LEVEL_ORDER.get(first_level, 0)
last_score = LEVEL_ORDER.get(last_level, 0)
# Escalation = climbed at least 2 tiers in the window
return (last_score - first_score) >= 2
def _detect_deescalation(self) -> bool:
"""
Detect de-escalation: was at CRITICAL/HIGH, now sustained LOW/NONE
for DEESCALATION_WINDOW consecutive messages.
"""
if LEVEL_ORDER.get(self._peak_level, 0) < LEVEL_ORDER["HIGH"]:
return False
return self._consecutive_low >= self.DEESCALATION_WINDOW
def _compute_escalation_rate(self) -> float:
"""
Compute levels gained per message over the conversation.
Positive = escalating, negative = de-escalating, 0 = stable.
"""
if self._message_count < 2:
return 0.0
first = LEVEL_ORDER.get(self._level_history[0], 0)
current = LEVEL_ORDER.get(self._current_level, 0)
return (current - first) / (self._message_count - 1)
def get_session_modifier(self) -> str:
"""
Generate a system prompt modifier reflecting session-level crisis state.
Returns empty string if no session context is relevant.
"""
if self._message_count < 2:
return ""
s = self.state
if s.is_escalating:
return (
f"User has escalated from {self._level_history[0]} to "
f"{s.current_level} over {s.message_count} messages. "
f"Peak crisis level this session: {s.peak_level}. "
"Respond with heightened awareness. The trajectory is "
"worsening — prioritize safety and connection."
)
if s.is_deescalating:
return (
f"User previously reached {s.peak_level} crisis level "
f"but has been at {s.current_level} or below for "
f"{s.consecutive_low_messages} consecutive messages. "
"The situation appears to be stabilizing. Continue "
"supportive engagement while remaining vigilant."
)
if s.peak_level in ("CRITICAL", "HIGH") and s.current_level not in ("CRITICAL", "HIGH"):
return (
f"User previously reached {s.peak_level} crisis level "
f"this session (currently {s.current_level}). "
"Continue with care and awareness of the earlier crisis."
)
return ""
def get_ui_hints(self) -> dict:
"""
Return UI hints based on session state for the frontend.
These are advisory — the frontend decides what to show.
"""
s = self.state
hints = {
"session_escalating": s.is_escalating,
"session_deescalating": s.is_deescalating,
"session_peak_level": s.peak_level,
"session_message_count": s.message_count,
}
if s.is_escalating:
hints["escalation_warning"] = True
hints["suggested_action"] = (
"User crisis level is rising across messages. "
"Consider increasing intervention level."
)
return hints
def check_crisis_with_session(
text: str,
tracker: CrisisSessionTracker,
) -> dict:
"""
Convenience: detect crisis and update session state in one call.
Returns combined single-message detection + session-level context.
"""
from .detect import detect_crisis
from .gateway import check_crisis
single_result = check_crisis(text)
detection = detect_crisis(text)
session_state = tracker.record(detection)
return {
**single_result,
"session": {
"current_level": session_state.current_level,
"peak_level": session_state.peak_level,
"message_count": session_state.message_count,
"is_escalating": session_state.is_escalating,
"is_deescalating": session_state.is_deescalating,
"modifier": tracker.get_session_modifier(),
"ui_hints": tracker.get_ui_hints(),
},
}

View File

@@ -680,7 +680,7 @@ html, body {
<!-- Footer -->
<footer id="footer">
<a href="/about.html" aria-label="About The Door">about</a>
<a href="/about" aria-label="About The Door">about</a>
<button id="safety-plan-btn" aria-label="Open My Safety Plan">my safety plan</button>
<button id="clear-chat-btn" aria-label="Clear chat history">clear chat</button>
</footer>
@@ -808,7 +808,6 @@ Sovereignty and service always.`;
var crisisPanel = document.getElementById('crisis-panel');
var crisisOverlay = document.getElementById('crisis-overlay');
var overlayDismissBtn = document.getElementById('overlay-dismiss-btn');
var overlayCallLink = document.querySelector('.overlay-call');
var statusDot = document.querySelector('.status-dot');
var statusText = document.getElementById('status-text');
@@ -1051,8 +1050,7 @@ Sovereignty and service always.`;
}
}, 1000);
// Focus the Call 988 link (always enabled) — disabled buttons cannot receive focus
if (overlayCallLink) overlayCallLink.focus();
overlayDismissBtn.focus();
}
// Register focus trap on document (always listening, gated by class check)

129
tests/test_ab_testing.py Normal file
View File

@@ -0,0 +1,129 @@
"""
Tests for crisis/ab_testing.py — A/B test framework for crisis detection.
Verifies variant selection, logging, stats aggregation, and env override.
"""
import os
from unittest.mock import patch
import pytest
from crisis.ab_testing import ABTestCrisisDetector
from crisis.detect import CrisisDetectionResult, detect_crisis
def _make_variant(level: str):
"""Create a mock detection function that returns a fixed level."""
def fn(text: str) -> CrisisDetectionResult:
return CrisisDetectionResult(level=level, indicators=[f"mock_{level}"])
return fn
class TestABTestCrisisDetector:
"""A/B test framework unit tests."""
def setup_method(self):
"""Ensure no env override."""
os.environ.pop("CRISIS_AB_VARIANT", None)
def test_returns_result_and_variant(self):
detector = ABTestCrisisDetector(
variant_a=_make_variant("LOW"),
variant_b=_make_variant("HIGH"),
)
result, variant = detector.detect("test message")
assert isinstance(result, CrisisDetectionResult)
assert variant in ("A", "B")
def test_records_are_logged(self):
detector = ABTestCrisisDetector(
variant_a=_make_variant("LOW"),
variant_b=_make_variant("HIGH"),
)
# Force variant A
with patch.object(detector, "_select_variant", return_value="A"):
detector.detect("test")
assert len(detector.records) == 1
assert detector.records[0].variant == "A"
assert detector.records[0].level == "LOW"
def test_stats_empty(self):
detector = ABTestCrisisDetector(
variant_a=_make_variant("LOW"),
variant_b=_make_variant("HIGH"),
)
stats = detector.get_stats()
assert stats["A"]["count"] == 0
assert stats["B"]["count"] == 0
def test_stats_with_data(self):
detector = ABTestCrisisDetector(
variant_a=_make_variant("LOW"),
variant_b=_make_variant("HIGH"),
)
# Force 5 A and 3 B
with patch.object(detector, "_select_variant", side_effect=["A"] * 5 + ["B"] * 3):
for _ in range(8):
detector.detect("test")
stats = detector.get_stats()
assert stats["A"]["count"] == 5
assert stats["B"]["count"] == 3
assert "avg_latency_ms" in stats["A"]
assert stats["A"]["levels"]["LOW"] == 5
assert stats["B"]["levels"]["HIGH"] == 3
def test_env_override_a(self):
os.environ["CRISIS_AB_VARIANT"] = "A"
detector = ABTestCrisisDetector(
variant_a=_make_variant("LOW"),
variant_b=_make_variant("HIGH"),
)
for _ in range(10):
result, variant = detector.detect("test")
assert variant == "A"
assert result.level == "LOW"
def test_env_override_b(self):
os.environ["CRISIS_AB_VARIANT"] = "b"
detector = ABTestCrisisDetector(
variant_a=_make_variant("LOW"),
variant_b=_make_variant("HIGH"),
)
for _ in range(10):
result, variant = detector.detect("test")
assert variant == "B"
assert result.level == "HIGH"
def test_reset_clears_records(self):
detector = ABTestCrisisDetector(
variant_a=_make_variant("LOW"),
variant_b=_make_variant("HIGH"),
)
detector.detect("test")
detector.detect("test")
assert len(detector.records) == 2
detector.reset()
assert len(detector.records) == 0
def test_split_respected(self):
"""With split=1.0, always get variant A."""
detector = ABTestCrisisDetector(
variant_a=_make_variant("LOW"),
variant_b=_make_variant("HIGH"),
split=1.0,
)
for _ in range(10):
_, variant = detector.detect("test")
assert variant == "A"
def test_with_real_detector(self):
"""Integration test using actual detect_crisis as both variants."""
detector = ABTestCrisisDetector(
variant_a=detect_crisis,
variant_b=detect_crisis,
)
result, variant = detector.detect("I want to kill myself")
assert result.level == "CRITICAL"
assert variant in ("A", "B")

View File

@@ -52,34 +52,6 @@ class TestCrisisOverlayFocusTrap(unittest.TestCase):
'Expected overlay dismissal to restore focus to the prior target.',
)
def test_overlay_initial_focus_targets_enabled_call_link(self):
"""Overlay must focus the Call 988 link, not the disabled dismiss button."""
# Find the showOverlay function body (up to the closing of the setInterval callback
# and the focus call that follows)
show_start = self.html.find('function showOverlay()')
self.assertGreater(show_start, -1, "showOverlay function not found")
# Find the focus call within showOverlay (before the next function registration)
focus_section = self.html[show_start:show_start + 2000]
self.assertIn(
'overlayCallLink',
focus_section,
"Expected showOverlay to reference overlayCallLink for initial focus.",
)
# Ensure the old buggy pattern is gone
focus_line_region = self.html[show_start + 800:show_start + 1200]
self.assertNotIn(
'overlayDismissBtn.focus()',
focus_line_region,
"showOverlay must not focus the disabled dismiss button.",
)
def test_overlay_call_link_variable_is_declared(self):
self.assertIn(
"querySelector('.overlay-call')",
self.html,
"Expected a JS reference to the .overlay-call link element.",
)
if __name__ == '__main__':
unittest.main()

View File

@@ -1,50 +0,0 @@
"""Regression tests for duplicate crisis indicators cleanup (#123)."""
import os
import re
import sys
import unittest
from pathlib import Path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from crisis.detect import detect_crisis
ROOT = Path(__file__).resolve().parents[1]
DETECT_FILE = ROOT / "crisis" / "detect.py"
class TestCrisisIndicatorDedup(unittest.TestCase):
def _extract_indicator_list(self, name: str) -> set[str]:
text = DETECT_FILE.read_text(encoding="utf-8")
match = re.search(rf"{name}\s*=\s*\[(.*?)\n\]", text, re.S)
self.assertIsNotNone(match, f"{name} list missing")
return {
line.strip().rstrip(",")
for line in match.group(1).splitlines()
if line.strip().startswith('r"')
}
def test_high_and_medium_indicator_lists_do_not_overlap(self):
high = self._extract_indicator_list("HIGH_INDICATORS")
medium = self._extract_indicator_list("MEDIUM_INDICATORS")
overlap = high & medium
self.assertEqual(set(), overlap, f"duplicate cross-tier patterns found: {sorted(overlap)}")
def test_removed_duplicates_still_classify_at_high_tier(self):
cases = {
"I feel hopeless": "HIGH",
"I feel trapped": "HIGH",
"I feel desperate": "HIGH",
"I have no future ahead": "HIGH",
"I have given up on myself": "HIGH",
}
for text, expected in cases.items():
with self.subTest(text=text):
result = detect_crisis(text)
self.assertEqual(expected, result.level)
if __name__ == "__main__":
unittest.main()

View File

@@ -50,22 +50,6 @@ class TestCrisisOfflinePage(unittest.TestCase):
for phrase in required_phrases:
self.assertIn(phrase, self.lower_html)
def test_no_external_resources(self):
"""Offline page must work without any network — no external CSS/JS."""
import re
html = self.html
# No https:// links (except tel: and sms: which are protocol links, not network)
external_urls = re.findall(r'href=["\']https://|src=["\']https://', html)
self.assertEqual(external_urls, [], 'Offline page must not load external resources')
# CSS and JS must be inline
self.assertIn('<style>', html, 'CSS must be inline')
self.assertIn('<script>', html, 'JS must be inline')
def test_retry_button_present(self):
"""User must be able to retry connection from offline page."""
self.assertIn('retry-connection', self.html)
self.assertIn('Retry connection', self.html)
if __name__ == '__main__':
unittest.main()

View File

@@ -1,277 +0,0 @@
"""
Tests for crisis session tracking and escalation (P0 #35).
Covers: session_tracker.py
Run with: python -m pytest tests/test_session_tracker.py -v
"""
import unittest
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from crisis.detect import detect_crisis
from crisis.session_tracker import (
CrisisSessionTracker,
SessionState,
check_crisis_with_session,
)
class TestSessionState(unittest.TestCase):
"""Test SessionState defaults."""
def test_default_state(self):
s = SessionState()
self.assertEqual(s.current_level, "NONE")
self.assertEqual(s.peak_level, "NONE")
self.assertEqual(s.message_count, 0)
self.assertEqual(s.level_history, [])
self.assertFalse(s.is_escalating)
self.assertFalse(s.is_deescalating)
class TestSessionTracking(unittest.TestCase):
"""Test basic session state tracking."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_record_none_message(self):
state = self.tracker.record(detect_crisis("Hello Timmy"))
self.assertEqual(state.current_level, "NONE")
self.assertEqual(state.message_count, 1)
self.assertEqual(state.peak_level, "NONE")
def test_record_low_message(self):
self.tracker.record(detect_crisis("Hello"))
state = self.tracker.record(detect_crisis("Having a rough day"))
self.assertIn(state.current_level, ("LOW", "NONE"))
self.assertEqual(state.message_count, 2)
def test_record_critical_updates_peak(self):
self.tracker.record(detect_crisis("Having a rough day"))
state = self.tracker.record(detect_crisis("I want to kill myself"))
self.assertEqual(state.current_level, "CRITICAL")
self.assertEqual(state.peak_level, "CRITICAL")
def test_peak_preserved_after_drop(self):
"""Peak level should stay at the highest seen, even after de-escalation."""
self.tracker.record(detect_crisis("I want to kill myself"))
state = self.tracker.record(detect_crisis("I'm feeling a bit better"))
self.assertEqual(state.peak_level, "CRITICAL")
def test_level_history(self):
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("Having a rough day"))
state = self.tracker.record(detect_crisis("I want to die"))
self.assertEqual(len(state.level_history), 3)
self.assertEqual(state.level_history[0], "NONE")
self.assertEqual(state.level_history[2], "CRITICAL")
def test_reset_clears_state(self):
self.tracker.record(detect_crisis("I want to kill myself"))
self.tracker.reset()
state = self.tracker.state
self.assertEqual(state.current_level, "NONE")
self.assertEqual(state.peak_level, "NONE")
self.assertEqual(state.message_count, 0)
self.assertEqual(state.level_history, [])
class TestEscalationDetection(unittest.TestCase):
"""Test escalation detection: LOW → HIGH in ≤3 messages."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_no_escalation_single_message(self):
self.tracker.record(detect_crisis("Hello"))
self.assertFalse(self.tracker.state.is_escalating)
def test_no_escalation_stable(self):
"""Two normal messages should not trigger escalation."""
self.tracker.record(detect_crisis("Hello"))
state = self.tracker.record(detect_crisis("How are you?"))
self.assertFalse(state.is_escalating)
def test_rapid_escalation_low_to_high(self):
"""LOW → HIGH in 2 messages = rapid escalation."""
self.tracker.record(detect_crisis("Having a rough day"))
state = self.tracker.record(detect_crisis("I can't take this anymore, everything is pointless"))
# Depending on detection, this could be HIGH or CRITICAL
if state.current_level in ("HIGH", "CRITICAL"):
self.assertTrue(state.is_escalating)
def test_rapid_escalation_three_messages(self):
"""NONE → LOW → HIGH in 3 messages = escalation."""
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("Having a rough day"))
state = self.tracker.record(detect_crisis("I feel completely hopeless with no way out"))
if state.current_level in ("HIGH", "CRITICAL"):
self.assertTrue(state.is_escalating)
def test_escalation_rate(self):
"""Rate should be positive when escalating."""
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("I want to die"))
state = self.tracker.state
self.assertGreater(state.escalation_rate, 0)
class TestDeescalationDetection(unittest.TestCase):
"""Test de-escalation: sustained LOW after HIGH/CRITICAL."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_no_deescalation_without_prior_crisis(self):
"""No de-escalation if never reached HIGH/CRITICAL."""
for _ in range(6):
self.tracker.record(detect_crisis("Hello"))
self.assertFalse(self.tracker.state.is_deescalating)
def test_deescalation_after_critical(self):
"""5+ consecutive LOW/NONE messages after CRITICAL = de-escalation."""
self.tracker.record(detect_crisis("I want to kill myself"))
for _ in range(5):
self.tracker.record(detect_crisis("I'm doing better today"))
state = self.tracker.state
if state.peak_level == "CRITICAL":
self.assertTrue(state.is_deescalating)
def test_deescalation_after_high(self):
"""5+ consecutive LOW/NONE messages after HIGH = de-escalation."""
self.tracker.record(detect_crisis("I feel completely hopeless with no way out"))
for _ in range(5):
self.tracker.record(detect_crisis("Feeling okay"))
state = self.tracker.state
if state.peak_level == "HIGH":
self.assertTrue(state.is_deescalating)
def test_interrupted_deescalation(self):
"""De-escalation resets if a HIGH message interrupts."""
self.tracker.record(detect_crisis("I want to kill myself"))
for _ in range(3):
self.tracker.record(detect_crisis("Doing better"))
# Interrupt with another crisis
self.tracker.record(detect_crisis("I feel hopeless again"))
self.tracker.record(detect_crisis("Feeling okay now"))
state = self.tracker.state
# Should NOT be de-escalating yet (counter reset)
self.assertFalse(state.is_deescalating)
class TestSessionModifier(unittest.TestCase):
"""Test system prompt modifier generation."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_no_modifier_for_single_message(self):
self.tracker.record(detect_crisis("Hello"))
self.assertEqual(self.tracker.get_session_modifier(), "")
def test_no_modifier_for_stable_session(self):
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("Good morning"))
self.assertEqual(self.tracker.get_session_modifier(), "")
def test_escalation_modifier(self):
"""Escalating session should produce a modifier."""
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("I want to die"))
modifier = self.tracker.get_session_modifier()
if self.tracker.state.is_escalating:
self.assertIn("escalated", modifier.lower())
self.assertIn("NONE", modifier)
self.assertIn("CRITICAL", modifier)
def test_deescalation_modifier(self):
"""De-escalating session should mention stabilizing."""
self.tracker.record(detect_crisis("I want to kill myself"))
for _ in range(5):
self.tracker.record(detect_crisis("I'm feeling okay"))
modifier = self.tracker.get_session_modifier()
if self.tracker.state.is_deescalating:
self.assertIn("stabilizing", modifier.lower())
def test_prior_crisis_modifier(self):
"""Past crisis should be noted even without active escalation."""
self.tracker.record(detect_crisis("I want to die"))
self.tracker.record(detect_crisis("Feeling a bit better"))
modifier = self.tracker.get_session_modifier()
# Should note the prior CRITICAL
if modifier:
self.assertIn("CRITICAL", modifier)
class TestUIHints(unittest.TestCase):
"""Test UI hint generation."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_ui_hints_structure(self):
self.tracker.record(detect_crisis("Hello"))
hints = self.tracker.get_ui_hints()
self.assertIn("session_escalating", hints)
self.assertIn("session_deescalating", hints)
self.assertIn("session_peak_level", hints)
self.assertIn("session_message_count", hints)
def test_ui_hints_escalation_warning(self):
"""Escalating session should have warning hint."""
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("I want to die"))
hints = self.tracker.get_ui_hints()
if hints["session_escalating"]:
self.assertTrue(hints.get("escalation_warning"))
self.assertIn("suggested_action", hints)
class TestCheckCrisisWithSession(unittest.TestCase):
"""Test the convenience function combining detection + session tracking."""
def test_returns_combined_data(self):
tracker = CrisisSessionTracker()
result = check_crisis_with_session("I want to die", tracker)
self.assertIn("level", result)
self.assertIn("session", result)
self.assertIn("current_level", result["session"])
self.assertIn("peak_level", result["session"])
self.assertIn("modifier", result["session"])
def test_session_updates_across_calls(self):
tracker = CrisisSessionTracker()
check_crisis_with_session("Hello", tracker)
result = check_crisis_with_session("I want to die", tracker)
self.assertEqual(result["session"]["message_count"], 2)
self.assertEqual(result["session"]["peak_level"], "CRITICAL")
class TestPrivacy(unittest.TestCase):
"""Verify privacy-first design principles."""
def test_no_persistence_mechanism(self):
"""Session tracker should have no database, file, or network calls."""
import inspect
source = inspect.getsource(CrisisSessionTracker)
# Should not import database, requests, or file I/O
forbidden = ["sqlite", "requests", "urllib", "open(", "httpx", "aiohttp"]
for word in forbidden:
self.assertNotIn(word, source.lower(),
f"Session tracker should not use {word} — privacy-first design")
def test_state_contained_in_memory(self):
"""All state should be instance attributes, not module-level."""
tracker = CrisisSessionTracker()
tracker.record(detect_crisis("I want to die"))
# New tracker should have clean state (no global contamination)
fresh = CrisisSessionTracker()
self.assertEqual(fresh.state.current_level, "NONE")
if __name__ == '__main__':
unittest.main()