Compare commits

..

12 Commits

Author SHA1 Message Date
Alexander Whitestone
4d084654d8 feat: add crisis analytics metrics (#37)
All checks were successful
Sanity Checks / sanity-test (pull_request) Successful in 5s
Smoke Test / smoke (pull_request) Successful in 6s
2026-04-17 02:24:01 -04:00
d412939b4f fix: footer /about link to point to static about.html
Fixes #59

The footer links to /about but the repo ships about.html. On a plain static server this results in a 404. Changed to /about.html so the link resolves correctly.
2026-04-17 05:37:40 +00:00
07c582aa08 Merge pull request 'fix: crisis overlay initial focus to enabled Call 988 link (#69)' (#126) from burn/69-1776264183 into main
Merge PR #126: fix: crisis overlay initial focus to enabled Call 988 link (#69)
2026-04-17 01:46:56 +00:00
5f95dc1e39 Merge pull request '[P3] Service worker: cache crisis resources for offline (#41)' (#122) from burn/41-1776264184 into main
Merge PR #122: [P3] Service worker: cache crisis resources for offline (#41)
2026-04-17 01:46:55 +00:00
b1f3cac36d Merge pull request 'feat: session-level crisis tracking and escalation (closes #35)' (#118) from door/issue-35 into main
Merge PR #118: feat: session-level crisis tracking and escalation (closes #35)
2026-04-17 01:46:53 +00:00
07b3f67845 fix: crisis overlay initial focus to enabled Call 988 link (#69)
All checks were successful
Sanity Checks / sanity-test (pull_request) Successful in 9s
Smoke Test / smoke (pull_request) Successful in 15s
2026-04-15 15:09:36 +00:00
c22bbbaf65 fix: crisis overlay initial focus to enabled Call 988 link (#69) 2026-04-15 15:09:32 +00:00
543cb1d40f test: add offline self-containment and retry button tests (#41)
All checks were successful
Sanity Checks / sanity-test (pull_request) Successful in 4s
Smoke Test / smoke (pull_request) Successful in 11s
2026-04-15 14:58:44 +00:00
3cfd01815a feat: session-level crisis tracking and escalation (closes #35)
All checks were successful
Sanity Checks / sanity-test (pull_request) Successful in 17s
Smoke Test / smoke (pull_request) Successful in 23s
2026-04-15 11:49:52 +00:00
5a7ba9f207 feat: session-level crisis tracking and escalation (closes #35) 2026-04-15 11:49:51 +00:00
8ed8f20a17 feat: session-level crisis tracking and escalation (closes #35) 2026-04-15 11:49:49 +00:00
9d7d26033e feat: session-level crisis tracking and escalation (closes #35) 2026-04-15 11:49:47 +00:00
11 changed files with 894 additions and 866 deletions

View File

@@ -7,6 +7,14 @@ Stands between a broken man and a machine that would tell him to die.
from .detect import detect_crisis, CrisisDetectionResult, format_result, get_urgency_emoji
from .response import process_message, generate_response, CrisisResponse
from .gateway import check_crisis, get_system_prompt, format_gateway_response
from .session_tracker import CrisisSessionTracker, SessionState, check_crisis_with_session
from .metrics import (
build_metrics_event,
append_metrics_event,
load_metrics_events,
build_weekly_summary,
render_weekly_summary,
)
__all__ = [
"detect_crisis",
@@ -19,4 +27,12 @@ __all__ = [
"format_result",
"format_gateway_response",
"get_urgency_emoji",
"CrisisSessionTracker",
"SessionState",
"check_crisis_with_session",
"build_metrics_event",
"append_metrics_event",
"load_metrics_events",
"build_weekly_summary",
"render_weekly_summary",
]

View File

@@ -1,409 +0,0 @@
#!/usr/bin/env python3
"""
Crisis Detection A/B Testing Framework
Allows testing different crisis detection algorithms with:
- Feature flags for algorithm A vs B
- Logging of which variant triggered for each event
- Metrics: false positive rate, detection latency per variant
- Statistical significance testing
Usage:
from crisis.ab_testing import ABTestManager, Variant
manager = ABTestManager()
result = manager.detect_with_variant("I'm feeling hopeless")
print(result.variant, result.detection_result)
# Get metrics
metrics = manager.get_metrics()
print(metrics.false_positive_rate_a, metrics.false_positive_rate_b)
"""
import json
import os
import random
import time
from dataclasses import dataclass, field, asdict
from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple
from uuid import uuid4
from .detect import detect_crisis, CrisisDetectionResult
class Variant(Enum):
"""A/B test variants for crisis detection."""
A = "A" # Control: current algorithm
B = "B" # Treatment: alternative algorithm
@dataclass
class ABTestConfig:
"""Configuration for A/B testing."""
enabled: bool = True
variant_b_percentage: float = 0.5 # 50% split
log_file: Optional[str] = None # Default: ~/.the-door/ab_test_logs.jsonl
metrics_file: Optional[str] = None # Default: ~/.the-door/ab_metrics.json
seed: Optional[int] = None # For reproducible assignments
@dataclass
class DetectionEvent:
"""A single detection event with variant assignment."""
event_id: str
timestamp: str
variant: str
text_hash: str # Hash of input text for privacy
detected_level: str
detected_score: float
indicators: List[str]
detection_latency_ms: float
is_false_positive: Optional[bool] = None # Null until labeled
user_feedback: Optional[str] = None
@dataclass
class VariantMetrics:
"""Metrics for a single variant."""
total_detections: int = 0
true_positives: int = 0
false_positives: int = 0
false_negatives: int = 0 # Requires manual labeling
avg_latency_ms: float = 0.0
p50_latency_ms: float = 0.0
p95_latency_ms: float = 0.0
p99_latency_ms: float = 0.0
level_distribution: Dict[str, int] = field(default_factory=dict)
indicator_frequency: Dict[str, int] = field(default_factory=dict)
@dataclass
class ABTestMetrics:
"""Complete A/B test metrics."""
test_id: str
start_time: str
end_time: Optional[str] = None
variant_a: VariantMetrics = field(default_factory=VariantMetrics)
variant_b: VariantMetrics = field(default_factory=VariantMetrics)
sample_size_a: int = 0
sample_size_b: int = 0
statistical_significance: Optional[float] = None # p-value if calculable
class ABTestManager:
"""Manages A/B testing for crisis detection algorithms."""
def __init__(self, config: Optional[ABTestConfig] = None):
self.config = config or ABTestConfig()
self.test_id = str(uuid4())[:8]
self.events: List[DetectionEvent] = []
# Set up file paths
home = Path.home() / ".the-door"
home.mkdir(exist_ok=True)
self.log_file = Path(self.config.log_file or home / "ab_test_logs.jsonl")
self.metrics_file = Path(self.config.metrics_file or home / "ab_metrics.json")
# Initialize RNG
if self.config.seed is not None:
random.seed(self.config.seed)
def _assign_variant(self, text: str) -> Variant:
"""Assign a variant based on text hash for deterministic assignment."""
if not self.config.enabled:
return Variant.A
# Use hash of text for consistent assignment
text_hash = hash(text) % 100
threshold = int(self.config.variant_b_percentage * 100)
if text_hash < threshold:
return Variant.B
return Variant.A
def _detect_variant_a(self, text: str) -> CrisisDetectionResult:
"""Variant A: Current algorithm (control)."""
return detect_crisis(text)
def _detect_variant_b(self, text: str) -> CrisisDetectionResult:
"""Variant B: Alternative detection algorithm.
This is a placeholder - in practice, you'd implement a different
detection algorithm here. For now, we'll use the same algorithm
but with different sensitivity settings.
"""
# Example: Variant B could use different thresholds or additional patterns
result = detect_crisis(text)
# For demonstration: adjust sensitivity based on confidence score
# In practice, this would be a completely different algorithm
if result.score > 0.7 and result.level != "CRITICAL":
# Variant B is more sensitive to high-confidence detections
from .detect import CRITICAL_INDICATORS
import re
for pattern in CRITICAL_INDICATORS:
if re.search(pattern, text, re.IGNORECASE):
# Upgrade to CRITICAL if we find critical indicators
return CrisisDetectionResult(
level="CRITICAL",
score=result.score,
indicators=result.indicators,
matched_patterns=result.matched_patterns,
recommended_action="immediate_intervention"
)
return result
def detect_with_variant(self, text: str, user_id: Optional[str] = None) -> Tuple[Variant, CrisisDetectionResult, float]:
"""
Run crisis detection with A/B testing.
Returns:
Tuple of (variant, detection_result, latency_ms)
"""
if not self.config.enabled:
start = time.time()
result = self._detect_variant_a(text)
latency = (time.time() - start) * 1000
return Variant.A, result, latency
# Assign variant
variant = self._assign_variant(text)
# Run detection with timing
start = time.time()
if variant == Variant.A:
result = self._detect_variant_a(text)
else:
result = self._detect_variant_b(text)
latency_ms = (time.time() - start) * 1000
# Log event
self._log_event(variant, text, result, latency_ms, user_id)
return variant, result, latency_ms
def _log_event(self, variant: Variant, text: str, result: CrisisDetectionResult,
latency_ms: float, user_id: Optional[str] = None):
"""Log a detection event."""
import hashlib
# Hash text for privacy (don't log actual crisis text)
text_hash = hashlib.sha256(text.encode()).hexdigest()[:16]
event = DetectionEvent(
event_id=str(uuid4())[:8],
timestamp=datetime.now(timezone.utc).isoformat(),
variant=variant.value,
text_hash=text_hash,
detected_level=result.level,
detected_score=result.score,
indicators=result.indicators[:5], # Limit for storage
detection_latency_ms=round(latency_ms, 2),
)
self.events.append(event)
# Append to log file
try:
with open(self.log_file, "a") as f:
f.write(json.dumps(asdict(event)) + "\n")
except Exception:
pass # Don't fail on logging errors
def label_event(self, event_id: str, is_false_positive: bool, feedback: Optional[str] = None):
"""Label an event as true/false positive for metrics calculation."""
for event in self.events:
if event.event_id == event_id:
event.is_false_positive = is_false_positive
event.user_feedback = feedback
break
# Update log file
self._save_events()
def _save_events(self):
"""Save all events to log file."""
try:
with open(self.log_file, "w") as f:
for event in self.events:
f.write(json.dumps(asdict(event)) + "\n")
except Exception:
pass
def get_metrics(self) -> ABTestMetrics:
"""Calculate metrics for both variants."""
metrics = ABTestMetrics(
test_id=self.test_id,
start_time=self.events[0].timestamp if self.events else datetime.now(timezone.utc).isoformat(),
end_time=datetime.now(timezone.utc).isoformat(),
)
# Separate events by variant
a_events = [e for e in self.events if e.variant == "A"]
b_events = [e for e in self.events if e.variant == "B"]
metrics.sample_size_a = len(a_events)
metrics.sample_size_b = len(b_events)
# Calculate variant A metrics
if a_events:
metrics.variant_a = self._calculate_variant_metrics(a_events)
# Calculate variant B metrics
if b_events:
metrics.variant_b = self._calculate_variant_metrics(b_events)
# Calculate statistical significance if we have enough data
if len(a_events) >= 30 and len(b_events) >= 30:
metrics.statistical_significance = self._calculate_significance(a_events, b_events)
# Save metrics to file
self._save_metrics(metrics)
return metrics
def _calculate_variant_metrics(self, events: List[DetectionEvent]) -> VariantMetrics:
"""Calculate metrics for a single variant."""
if not events:
return VariantMetrics()
# Latency statistics
latencies = [e.detection_latency_ms for e in events]
latencies.sort()
n = len(latencies)
p50_idx = int(n * 0.5)
p95_idx = int(n * 0.95)
p99_idx = int(n * 0.99)
# Level distribution
level_dist = {}
for e in events:
level_dist[e.detected_level] = level_dist.get(e.detected_level, 0) + 1
# Indicator frequency
indicator_freq = {}
for e in events:
for ind in e.indicators:
indicator_freq[ind] = indicator_freq.get(ind, 0) + 1
# False positive rate (only for labeled events)
labeled = [e for e in events if e.is_false_positive is not None]
fp_count = sum(1 for e in labeled if e.is_false_positive)
tp_count = sum(1 for e in labeled if not e.is_false_positive)
return VariantMetrics(
total_detections=len(events),
true_positives=tp_count,
false_positives=fp_count,
avg_latency_ms=sum(latencies) / n,
p50_latency_ms=latencies[p50_idx] if n > 0 else 0,
p95_latency_ms=latencies[p95_idx] if n > 0 else 0,
p99_latency_ms=latencies[p99_idx] if n > 0 else 0,
level_distribution=level_dist,
indicator_frequency=dict(sorted(indicator_freq.items(), key=lambda x: -x[1])[:10]),
)
def _calculate_significance(self, a_events: List[DetectionEvent],
b_events: List[DetectionEvent]) -> Optional[float]:
"""Calculate statistical significance (p-value) using chi-squared test."""
try:
# Count detections at each level for each variant
a_levels = {"CRITICAL": 0, "HIGH": 0, "MEDIUM": 0, "LOW": 0, "NONE": 0}
b_levels = {"CRITICAL": 0, "HIGH": 0, "MEDIUM": 0, "LOW": 0, "NONE": 0}
for e in a_events:
a_levels[e.detected_level] = a_levels.get(e.detected_level, 0) + 1
for e in b_events:
b_levels[e.detected_level] = b_levels.get(e.detected_level, 0) + 1
# Simple chi-squared test for level distribution difference
# This is a simplified version - in production you'd use scipy.stats.chi2_contingency
total_a = len(a_events)
total_b = len(b_events)
if total_a == 0 or total_b == 0:
return None
# Calculate expected frequencies
chi_sq = 0
for level in a_levels:
expected_a = (a_levels[level] + b_levels[level]) * total_a / (total_a + total_b)
expected_b = (a_levels[level] + b_levels[level]) * total_b / (total_a + total_b)
if expected_a > 0:
chi_sq += (a_levels[level] - expected_a) ** 2 / expected_a
if expected_b > 0:
chi_sq += (b_levels[level] - expected_b) ** 2 / expected_b
# Return chi-squared value (not p-value, as we don't have scipy)
# Higher values indicate more significant difference
return chi_sq
except Exception:
return None
def _save_metrics(self, metrics: ABTestMetrics):
"""Save metrics to file."""
try:
with open(self.metrics_file, "w") as f:
json.dump(asdict(metrics), f, indent=2)
except Exception:
pass
def get_variant_distribution(self) -> Dict[str, int]:
"""Get current distribution of events across variants."""
dist = {"A": 0, "B": 0}
for event in self.events:
dist[event.variant] = dist.get(event.variant, 0) + 1
return dist
def force_variant(self, variant: Variant):
"""Force all subsequent detections to use a specific variant."""
self.config.enabled = False
self._forced_variant = variant
def reset(self):
"""Reset the A/B test."""
self.events = []
self.config.enabled = True
if hasattr(self, '_forced_variant'):
delattr(self, '_forced_variant')
# Convenience function for easy integration
_default_manager = None
def get_ab_manager() -> ABTestManager:
"""Get the default A/B test manager instance."""
global _default_manager
if _default_manager is None:
_default_manager = ABTestManager()
return _default_manager
def detect_with_ab(text: str, user_id: Optional[str] = None) -> dict:
"""
Detect crisis with A/B testing.
Returns dict with variant, detection result, and metrics.
"""
manager = get_ab_manager()
variant, result, latency = manager.detect_with_variant(text, user_id)
return {
"variant": variant.value,
"detection": {
"level": result.level,
"score": result.score,
"indicators": result.indicators,
"recommended_action": result.recommended_action,
},
"latency_ms": round(latency, 2),
"test_id": manager.test_id,
}

View File

@@ -22,9 +22,18 @@ from .response import (
get_system_prompt_modifier,
CrisisResponse,
)
from .session_tracker import CrisisSessionTracker
from .metrics import build_metrics_event, append_metrics_event
def check_crisis(text: str) -> dict:
def check_crisis(
text: str,
metrics_log_path: Optional[str] = None,
*,
continued_conversation: bool = False,
false_positive: bool = False,
now: Optional[float] = None,
) -> dict:
"""
Full crisis check returning structured data.
@@ -34,7 +43,7 @@ def check_crisis(text: str) -> dict:
detection = detect_crisis(text)
response = generate_response(detection)
return {
result = {
"level": detection.level,
"score": detection.score,
"indicators": detection.indicators,
@@ -48,6 +57,23 @@ def check_crisis(text: str) -> dict:
"escalate": response.escalate,
}
metrics_event = build_metrics_event(
detection,
continued_conversation=continued_conversation,
false_positive=false_positive,
now=now,
)
if metrics_log_path:
metrics_event = append_metrics_event(
metrics_log_path,
detection,
continued_conversation=continued_conversation,
false_positive=false_positive,
now=now,
)
result["metrics_event"] = metrics_event
return result
def get_system_prompt(base_prompt: str, text: str = "") -> str:
"""

166
crisis/metrics.py Normal file
View File

@@ -0,0 +1,166 @@
"""Privacy-preserving crisis analytics metrics for the-door.
Stores only timestamps, crisis levels, indicator categories, and operator
feedback flags. No raw message text or PII is persisted.
"""
from __future__ import annotations
import argparse
import json
import time
from collections import Counter
from pathlib import Path
from typing import Iterable
from .detect import CrisisDetectionResult, detect_crisis
LEVELS = ("NONE", "LOW", "MEDIUM", "HIGH", "CRITICAL")
def normalize_indicator(indicator: str) -> str:
"""Return a stable privacy-safe keyword/category identifier."""
return indicator
def build_metrics_event(
detection: CrisisDetectionResult,
*,
continued_conversation: bool = False,
false_positive: bool = False,
now: float | None = None,
) -> dict:
timestamp = float(time.time() if now is None else now)
indicators = [normalize_indicator(indicator) for indicator in detection.indicators]
return {
"timestamp": timestamp,
"level": detection.level,
"indicator_count": len(indicators),
"indicators": indicators,
"continued_conversation": bool(continued_conversation),
"false_positive": bool(false_positive),
}
def append_metrics_event(
log_path: str | Path,
detection: CrisisDetectionResult,
*,
continued_conversation: bool = False,
false_positive: bool = False,
now: float | None = None,
) -> dict:
event = build_metrics_event(
detection,
continued_conversation=continued_conversation,
false_positive=false_positive,
now=now,
)
path = Path(log_path)
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("a", encoding="utf-8") as handle:
handle.write(json.dumps(event) + "\n")
return event
def load_metrics_events(log_path: str | Path) -> list[dict]:
path = Path(log_path)
if not path.exists():
return []
events = []
for line in path.read_text(encoding="utf-8").splitlines():
if not line.strip():
continue
events.append(json.loads(line))
return events
def build_weekly_summary(
events: Iterable[dict],
*,
now: float | None = None,
window_days: int = 7,
) -> dict:
current_time = float(time.time() if now is None else now)
cutoff = current_time - (window_days * 86400)
filtered = [event for event in events if float(event.get("timestamp", 0)) >= cutoff]
detections_per_level = {level: 0 for level in LEVELS}
keyword_counts: Counter[str] = Counter()
detections = []
continued_after_intervention = 0
for event in filtered:
level = event.get("level", "NONE")
detections_per_level[level] = detections_per_level.get(level, 0) + 1
keyword_counts.update(event.get("indicators", []))
if level != "NONE":
detections.append(event)
if event.get("continued_conversation"):
continued_after_intervention += 1
false_positive_count = sum(1 for event in detections if event.get("false_positive"))
false_positive_estimate = (
false_positive_count / len(detections) if detections else 0.0
)
return {
"window_days": window_days,
"total_events": len(filtered),
"detections_per_level": detections_per_level,
"most_common_keywords": [
{"keyword": keyword, "count": count}
for keyword, count in keyword_counts.most_common(10)
],
"false_positive_estimate": false_positive_estimate,
"continued_after_intervention": continued_after_intervention,
}
def render_weekly_summary(summary: dict) -> str:
return json.dumps(summary, indent=2)
def write_weekly_summary(path: str | Path, summary: dict) -> Path:
output_path = Path(path)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(render_weekly_summary(summary) + "\n", encoding="utf-8")
return output_path
def record_text_event(
text: str,
log_path: str | Path,
*,
continued_conversation: bool = False,
false_positive: bool = False,
now: float | None = None,
) -> dict:
detection = detect_crisis(text)
return append_metrics_event(
log_path,
detection,
continued_conversation=continued_conversation,
false_positive=false_positive,
now=now,
)
def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(description="Privacy-preserving crisis metrics summary")
parser.add_argument("--log-path", required=True, help="JSONL event log path")
parser.add_argument("--days", type=int, default=7, help="Summary window in days")
parser.add_argument("--output", help="Optional file to write summary JSON")
args = parser.parse_args(argv)
events = load_metrics_events(args.log_path)
summary = build_weekly_summary(events, window_days=args.days)
rendered = render_weekly_summary(summary)
print(rendered)
if args.output:
write_weekly_summary(args.output, summary)
return 0
if __name__ == "__main__":
raise SystemExit(main())

259
crisis/session_tracker.py Normal file
View File

@@ -0,0 +1,259 @@
"""
Session-level crisis tracking and escalation for the-door (P0 #35).
Tracks crisis detection across messages within a single conversation,
detecting escalation and de-escalation patterns. Privacy-first: no
persistence beyond the conversation session.
Each message is analyzed in isolation by detect.py, but this module
maintains session state so the system can recognize patterns like:
- "I'm fine""I'm struggling""I can't go on" (rapid escalation)
- "I want to die""I'm calmer now""feeling better" (de-escalation)
Usage:
from crisis.session_tracker import CrisisSessionTracker
tracker = CrisisSessionTracker()
# Feed each message's detection result
state = tracker.record(detect_crisis("I'm having a tough day"))
print(state.current_level) # "LOW"
print(state.is_escalating) # False
state = tracker.record(detect_crisis("I feel hopeless"))
print(state.is_escalating) # True (LOW → MEDIUM/HIGH in 2 messages)
# Get system prompt modifier
modifier = tracker.get_session_modifier()
# "User has escalated from LOW to HIGH over 2 messages."
# Reset for new session
tracker.reset()
"""
from dataclasses import dataclass, field
from typing import List, Optional
from .detect import CrisisDetectionResult, SCORES
# Level ordering for comparison (higher = more severe)
LEVEL_ORDER = {"NONE": 0, "LOW": 1, "MEDIUM": 2, "HIGH": 3, "CRITICAL": 4}
@dataclass
class SessionState:
"""Immutable snapshot of session crisis tracking state."""
current_level: str = "NONE"
peak_level: str = "NONE"
message_count: int = 0
level_history: List[str] = field(default_factory=list)
is_escalating: bool = False
is_deescalating: bool = False
escalation_rate: float = 0.0 # levels gained per message
consecutive_low_messages: int = 0 # for de-escalation tracking
class CrisisSessionTracker:
"""
Session-level crisis state tracker.
Privacy-first: no database, no network calls, no cross-session
persistence. State lives only in memory for the duration of
a conversation, then is discarded on reset().
"""
# Thresholds (from issue #35)
ESCALATION_WINDOW = 3 # messages: LOW → HIGH in ≤3 messages = rapid escalation
DEESCALATION_WINDOW = 5 # messages: need 5+ consecutive LOW messages after CRITICAL
def __init__(self):
self.reset()
def reset(self):
"""Reset all session state. Call on new conversation."""
self._current_level = "NONE"
self._peak_level = "NONE"
self._message_count = 0
self._level_history: List[str] = []
self._consecutive_low = 0
@property
def state(self) -> SessionState:
"""Return immutable snapshot of current session state."""
is_escalating = self._detect_escalation()
is_deescalating = self._detect_deescalation()
rate = self._compute_escalation_rate()
return SessionState(
current_level=self._current_level,
peak_level=self._peak_level,
message_count=self._message_count,
level_history=list(self._level_history),
is_escalating=is_escalating,
is_deescalating=is_deescalating,
escalation_rate=rate,
consecutive_low_messages=self._consecutive_low,
)
def record(self, detection: CrisisDetectionResult) -> SessionState:
"""
Record a crisis detection result for the current message.
Returns updated SessionState.
"""
level = detection.level
self._message_count += 1
self._level_history.append(level)
# Update peak
if LEVEL_ORDER.get(level, 0) > LEVEL_ORDER.get(self._peak_level, 0):
self._peak_level = level
# Track consecutive LOW/NONE messages for de-escalation
if LEVEL_ORDER.get(level, 0) <= LEVEL_ORDER["LOW"]:
self._consecutive_low += 1
else:
self._consecutive_low = 0
self._current_level = level
return self.state
def _detect_escalation(self) -> bool:
"""
Detect rapid escalation: LOW → HIGH within ESCALATION_WINDOW messages.
Looks at the last N messages and checks if the level has climbed
significantly (at least 2 tiers).
"""
if len(self._level_history) < 2:
return False
window = self._level_history[-self.ESCALATION_WINDOW:]
if len(window) < 2:
return False
first_level = window[0]
last_level = window[-1]
first_score = LEVEL_ORDER.get(first_level, 0)
last_score = LEVEL_ORDER.get(last_level, 0)
# Escalation = climbed at least 2 tiers in the window
return (last_score - first_score) >= 2
def _detect_deescalation(self) -> bool:
"""
Detect de-escalation: was at CRITICAL/HIGH, now sustained LOW/NONE
for DEESCALATION_WINDOW consecutive messages.
"""
if LEVEL_ORDER.get(self._peak_level, 0) < LEVEL_ORDER["HIGH"]:
return False
return self._consecutive_low >= self.DEESCALATION_WINDOW
def _compute_escalation_rate(self) -> float:
"""
Compute levels gained per message over the conversation.
Positive = escalating, negative = de-escalating, 0 = stable.
"""
if self._message_count < 2:
return 0.0
first = LEVEL_ORDER.get(self._level_history[0], 0)
current = LEVEL_ORDER.get(self._current_level, 0)
return (current - first) / (self._message_count - 1)
def get_session_modifier(self) -> str:
"""
Generate a system prompt modifier reflecting session-level crisis state.
Returns empty string if no session context is relevant.
"""
if self._message_count < 2:
return ""
s = self.state
if s.is_escalating:
return (
f"User has escalated from {self._level_history[0]} to "
f"{s.current_level} over {s.message_count} messages. "
f"Peak crisis level this session: {s.peak_level}. "
"Respond with heightened awareness. The trajectory is "
"worsening — prioritize safety and connection."
)
if s.is_deescalating:
return (
f"User previously reached {s.peak_level} crisis level "
f"but has been at {s.current_level} or below for "
f"{s.consecutive_low_messages} consecutive messages. "
"The situation appears to be stabilizing. Continue "
"supportive engagement while remaining vigilant."
)
if s.peak_level in ("CRITICAL", "HIGH") and s.current_level not in ("CRITICAL", "HIGH"):
return (
f"User previously reached {s.peak_level} crisis level "
f"this session (currently {s.current_level}). "
"Continue with care and awareness of the earlier crisis."
)
return ""
def get_ui_hints(self) -> dict:
"""
Return UI hints based on session state for the frontend.
These are advisory — the frontend decides what to show.
"""
s = self.state
hints = {
"session_escalating": s.is_escalating,
"session_deescalating": s.is_deescalating,
"session_peak_level": s.peak_level,
"session_message_count": s.message_count,
}
if s.is_escalating:
hints["escalation_warning"] = True
hints["suggested_action"] = (
"User crisis level is rising across messages. "
"Consider increasing intervention level."
)
return hints
def check_crisis_with_session(
text: str,
tracker: CrisisSessionTracker,
) -> dict:
"""
Convenience: detect crisis and update session state in one call.
Returns combined single-message detection + session-level context.
"""
from .detect import detect_crisis
from .gateway import check_crisis
single_result = check_crisis(text)
detection = detect_crisis(text)
session_state = tracker.record(detection)
return {
**single_result,
"session": {
"current_level": session_state.current_level,
"peak_level": session_state.peak_level,
"message_count": session_state.message_count,
"is_escalating": session_state.is_escalating,
"is_deescalating": session_state.is_deescalating,
"modifier": tracker.get_session_modifier(),
"ui_hints": tracker.get_ui_hints(),
},
}

View File

@@ -680,7 +680,7 @@ html, body {
<!-- Footer -->
<footer id="footer">
<a href="/about" aria-label="About The Door">about</a>
<a href="/about.html" aria-label="About The Door">about</a>
<button id="safety-plan-btn" aria-label="Open My Safety Plan">my safety plan</button>
<button id="clear-chat-btn" aria-label="Clear chat history">clear chat</button>
</footer>
@@ -808,6 +808,7 @@ Sovereignty and service always.`;
var crisisPanel = document.getElementById('crisis-panel');
var crisisOverlay = document.getElementById('crisis-overlay');
var overlayDismissBtn = document.getElementById('overlay-dismiss-btn');
var overlayCallLink = document.querySelector('.overlay-call');
var statusDot = document.querySelector('.status-dot');
var statusText = document.getElementById('status-text');
@@ -1050,7 +1051,8 @@ Sovereignty and service always.`;
}
}, 1000);
overlayDismissBtn.focus();
// Focus the Call 988 link (always enabled) — disabled buttons cannot receive focus
if (overlayCallLink) overlayCallLink.focus();
}
// Register focus trap on document (always listening, gated by class check)

View File

@@ -1,453 +0,0 @@
#!/usr/bin/env python3
"""
Tests for Crisis Detection A/B Testing Framework.
"""
import json
import os
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock
import pytest
# Add crisis module to path
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from crisis.ab_testing import (
ABTestManager,
ABTestConfig,
Variant,
DetectionEvent,
VariantMetrics,
ABTestMetrics,
get_ab_manager,
detect_with_ab,
)
from crisis.detect import CrisisDetectionResult
class TestABTestConfig:
"""Test A/B test configuration."""
def test_default_config(self):
config = ABTestConfig()
assert config.enabled is True
assert config.variant_b_percentage == 0.5
assert config.log_file is None
assert config.metrics_file is None
def test_custom_config(self):
config = ABTestConfig(
enabled=False,
variant_b_percentage=0.3,
log_file="/tmp/test.log",
metrics_file="/tmp/metrics.json",
seed=42
)
assert config.enabled is False
assert config.variant_b_percentage == 0.3
assert config.log_file == "/tmp/test.log"
assert config.seed == 42
class TestVariant:
"""Test variant enum."""
def test_values(self):
assert Variant.A.value == "A"
assert Variant.B.value == "B"
def test_comparison(self):
assert Variant.A != Variant.B
assert Variant.A == Variant.A
class TestDetectionEvent:
"""Test detection event dataclass."""
def test_creation(self):
event = DetectionEvent(
event_id="test-123",
timestamp="2026-01-01T00:00:00Z",
variant="A",
text_hash="abc123",
detected_level="HIGH",
detected_score=0.8,
indicators=["despair"],
detection_latency_ms=5.2
)
assert event.event_id == "test-123"
assert event.variant == "A"
assert event.detected_level == "HIGH"
assert event.is_false_positive is None
def test_to_dict(self):
event = DetectionEvent(
event_id="test-123",
timestamp="2026-01-01T00:00:00Z",
variant="B",
text_hash="def456",
detected_level="CRITICAL",
detected_score=0.95,
indicators=["suicide"],
detection_latency_ms=3.1
)
d = event.__dict__
assert "event_id" in d
assert "variant" in d
class TestABTestManager:
"""Test A/B test manager."""
def test_initialization(self, tmp_path):
config = ABTestConfig(
log_file=str(tmp_path / "test.log"),
metrics_file=str(tmp_path / "metrics.json")
)
manager = ABTestManager(config)
assert manager.config.enabled is True
assert manager.test_id is not None
assert len(manager.events) == 0
def test_variant_assignment_deterministic(self):
"""Same text should always get same variant."""
manager = ABTestManager(ABTestConfig(seed=42))
text = "I feel hopeless about everything"
variant1 = manager._assign_variant(text)
variant2 = manager._assign_variant(text)
assert variant1 == variant2
def test_variant_distribution(self):
"""Test that variant distribution is roughly 50/50."""
manager = ABTestManager(ABTestConfig(seed=42))
variants = []
for i in range(100):
text = f"Test message number {i}"
variant = manager._assign_variant(text)
variants.append(variant)
a_count = sum(1 for v in variants if v == Variant.A)
b_count = sum(1 for v in variants if v == Variant.B)
# Should be roughly 50/50 (within 20% tolerance)
assert 30 <= a_count <= 70
assert 30 <= b_count <= 70
def test_disabled_ab_testing(self, tmp_path):
"""When disabled, should always use variant A."""
config = ABTestConfig(
enabled=False,
log_file=str(tmp_path / "test.log")
)
manager = ABTestManager(config)
for i in range(10):
text = f"Test message {i}"
variant = manager._assign_variant(text)
assert variant == Variant.A
@patch('crisis.ab_testing.detect_crisis')
def test_detect_with_variant_a(self, mock_detect, tmp_path):
"""Test detection with variant A (control)."""
mock_detect.return_value = CrisisDetectionResult(
level="HIGH",
score=0.8,
indicators=["despair"],
matched_patterns=[],
recommended_action="provide_resources"
)
config = ABTestConfig(
enabled=False, # Force variant A
log_file=str(tmp_path / "test.log")
)
manager = ABTestManager(config)
variant, result, latency = manager.detect_with_variant("I'm feeling hopeless")
assert variant == Variant.A
assert result.level == "HIGH"
assert latency >= 0
@patch('crisis.ab_testing.detect_crisis')
def test_detect_with_variant_b(self, mock_detect, tmp_path):
"""Test detection with variant B (treatment)."""
mock_detect.return_value = CrisisDetectionResult(
level="MEDIUM",
score=0.75,
indicators=["no hope"],
matched_patterns=[],
recommended_action="provide_resources"
)
config = ABTestConfig(
variant_b_percentage=1.0, # Always variant B
log_file=str(tmp_path / "test.log")
)
manager = ABTestManager(config)
# Use text that hashes to variant B
for i in range(20):
text = f"Test message {i}"
variant, result, latency = manager.detect_with_variant(text)
if variant == Variant.B:
break
assert variant == Variant.B
def test_event_logging(self, tmp_path):
"""Test that events are logged to file."""
log_file = tmp_path / "test.jsonl"
config = ABTestConfig(
log_file=str(log_file),
metrics_file=str(tmp_path / "metrics.json")
)
manager = ABTestManager(config)
# Mock detection
with patch('crisis.ab_testing.detect_crisis') as mock_detect:
mock_detect.return_value = CrisisDetectionResult(
level="LOW",
score=0.3,
indicators=[],
matched_patterns=[],
recommended_action="none"
)
manager.detect_with_variant("Test message")
# Check log file exists and has content
assert log_file.exists()
with open(log_file) as f:
lines = f.readlines()
assert len(lines) >= 1
# Parse log entry
entry = json.loads(lines[0])
assert "event_id" in entry
assert "variant" in entry
assert "detected_level" in entry
def test_label_event(self, tmp_path):
"""Test labeling events as false positives."""
config = ABTestConfig(
log_file=str(tmp_path / "test.log"),
metrics_file=str(tmp_path / "metrics.json")
)
manager = ABTestManager(config)
# Create a mock event
event = DetectionEvent(
event_id="test-123",
timestamp="2026-01-01T00:00:00Z",
variant="A",
text_hash="abc123",
detected_level="HIGH",
detected_score=0.8,
indicators=["despair"],
detection_latency_ms=5.0
)
manager.events.append(event)
# Label it
manager.label_event("test-123", is_false_positive=True, feedback="Not actually crisis")
# Check labeling
assert event.is_false_positive is True
assert event.user_feedback == "Not actually crisis"
def test_get_metrics_empty(self, tmp_path):
"""Test metrics with no events."""
config = ABTestConfig(
metrics_file=str(tmp_path / "metrics.json")
)
manager = ABTestManager(config)
metrics = manager.get_metrics()
assert metrics.sample_size_a == 0
assert metrics.sample_size_b == 0
assert metrics.variant_a.total_detections == 0
def test_get_metrics_with_events(self, tmp_path):
"""Test metrics calculation with events."""
config = ABTestConfig(
log_file=str(tmp_path / "test.log"),
metrics_file=str(tmp_path / "metrics.json")
)
manager = ABTestManager(config)
# Add some mock events
for i in range(10):
event = DetectionEvent(
event_id=f"event-{i}",
timestamp="2026-01-01T00:00:00Z",
variant="A" if i % 2 == 0 else "B",
text_hash=f"hash-{i}",
detected_level="HIGH" if i % 3 == 0 else "MEDIUM",
detected_score=0.7 + (i % 3) * 0.1,
indicators=["despair"] if i % 2 == 0 else [],
detection_latency_ms=3.0 + i * 0.5
)
# Label some as false positives
if i % 4 == 0:
event.is_false_positive = True
elif i % 4 == 1:
event.is_false_positive = False
manager.events.append(event)
metrics = manager.get_metrics()
# Check we have events in both variants
assert metrics.sample_size_a > 0
assert metrics.sample_size_b > 0
# Check latency calculations
assert metrics.variant_a.avg_latency_ms > 0
assert metrics.variant_b.avg_latency_ms > 0
# Check level distribution
assert len(metrics.variant_a.level_distribution) > 0
def test_variant_distribution(self, tmp_path):
"""Test getting variant distribution."""
config = ABTestConfig()
manager = ABTestManager(config)
# Add events
for i in range(5):
event = DetectionEvent(
event_id=f"event-{i}",
timestamp="2026-01-01T00:00:00Z",
variant="A" if i < 3 else "B",
text_hash=f"hash-{i}",
detected_level="LOW",
detected_score=0.5,
indicators=[],
detection_latency_ms=2.0
)
manager.events.append(event)
dist = manager.get_variant_distribution()
assert dist["A"] == 3
assert detect_with_ab
assert dist["B"] == 2
def test_force_variant(self, tmp_path):
"""Test forcing a specific variant."""
config = ABTestConfig()
manager = ABTestManager(config)
manager.force_variant(Variant.B)
# After forcing, all should be variant B
for i in range(5):
text = f"Test message {i}"
variant = manager._assign_variant(text)
assert variant == Variant.B
def test_reset(self, tmp_path):
"""Test resetting the A/B test."""
config = ABTestConfig(
log_file=str(tmp_path / "test.log"),
metrics_file=str(tmp_path / "metrics.json")
)
manager = ABTestManager(config)
# Add some events
for i in range(3):
event = DetectionEvent(
event_id=f"event-{i}",
timestamp="2026-01-01T00:00:00Z",
variant="A",
text_hash=f"hash-{i}",
detected_level="LOW",
detected_score=0.5,
indicators=[],
detection_latency_ms=2.0
)
manager.events.append(event)
assert len(manager.events) == 3
# Reset
manager.reset()
assert len(manager.events) == 0
assert manager.config.enabled is True
class TestConvenienceFunctions:
"""Test convenience functions."""
def test_get_ab_manager(self):
"""Test getting default manager."""
manager = get_ab_manager()
assert isinstance(manager, ABTestManager)
@patch('crisis.ab_testing.detect_crisis')
def test_detect_with_ab(self, mock_detect):
"""Test convenience detection function."""
mock_detect.return_value = CrisisDetectionResult(
level="HIGH",
score=0.8,
indicators=["despair"],
matched_patterns=[],
recommended_action="provide_resources"
)
result = detect_with_ab("I'm feeling hopeless")
assert "variant" in result
assert "detection" in result
assert "latency_ms" in result
assert "test_id" in result
assert result["detection"]["level"] == "HIGH"
class TestMetricsCalculation:
"""Test metrics calculation edge cases."""
def test_percentile_calculation(self, tmp_path):
"""Test that percentiles are calculated correctly."""
config = ABTestConfig()
manager = ABTestManager(config)
# Create events with known latencies
latencies = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
events = []
for i, lat in enumerate(latencies):
events.append(DetectionEvent(
event_id=f"event-{i}",
timestamp="2026-01-01T00:00:00Z",
variant="A",
text_hash=f"hash-{i}",
detected_level="LOW",
detected_score=0.5,
indicators=[],
detection_latency_ms=lat
))
metrics = manager._calculate_variant_metrics(events)
assert metrics.p50_latency_ms == 5.0 # 50th percentile
assert metrics.p95_latency_ms == 10.0 # 95th percentile (last element)
assert metrics.avg_latency_ms == 5.5 # Average
def test_empty_variant_metrics(self, tmp_path):
"""Test metrics with no events."""
config = ABTestConfig()
manager = ABTestManager(config)
metrics = manager._calculate_variant_metrics([])
assert metrics.total_detections == 0
assert metrics.avg_latency_ms == 0
assert metrics.level_distribution == {}

View File

@@ -0,0 +1,100 @@
"""Tests for privacy-preserving crisis metrics aggregation (issue #37)."""
from __future__ import annotations
import json
import os
import pathlib
import sys
import unittest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from crisis.detect import detect_crisis
from crisis.gateway import check_crisis
from crisis.metrics import (
append_metrics_event,
build_metrics_event,
build_weekly_summary,
load_metrics_events,
render_weekly_summary,
)
class TestMetricsEvent(unittest.TestCase):
def test_event_is_privacy_preserving(self):
detection = detect_crisis("I want to kill myself")
event = build_metrics_event(
detection,
continued_conversation=True,
false_positive=False,
now=1_700_000_000,
)
self.assertEqual(event["timestamp"], 1_700_000_000)
self.assertEqual(event["level"], "CRITICAL")
self.assertTrue(event["continued_conversation"])
self.assertFalse(event["false_positive"])
self.assertNotIn("text", event)
self.assertNotIn("message", event)
self.assertGreaterEqual(event["indicator_count"], 1)
self.assertTrue(event["indicators"])
class TestMetricsLogAndSummary(unittest.TestCase):
def test_append_and_load_metrics_events(self):
log_path = pathlib.Path(self._testMethodName).with_suffix(".jsonl")
try:
append_metrics_event(log_path, detect_crisis("I want to die"), now=1_700_000_000)
events = load_metrics_events(log_path)
self.assertEqual(len(events), 1)
self.assertEqual(events[0]["level"], "CRITICAL")
finally:
if log_path.exists():
log_path.unlink()
def test_weekly_summary_counts_levels_keywords_and_false_positives(self):
events = [
build_metrics_event(detect_crisis("I want to die"), continued_conversation=True, false_positive=False, now=1_700_000_000),
build_metrics_event(detect_crisis("I'm having a rough day"), continued_conversation=False, false_positive=False, now=1_700_000_100),
build_metrics_event(detect_crisis("I want to die"), continued_conversation=False, false_positive=True, now=1_700_000_200),
build_metrics_event(detect_crisis("Hello there"), continued_conversation=False, false_positive=False, now=1_700_000_300),
]
summary = build_weekly_summary(events, now=1_700_000_400, window_days=7)
self.assertEqual(summary["detections_per_level"]["CRITICAL"], 2)
self.assertEqual(summary["detections_per_level"]["LOW"], 1)
self.assertEqual(summary["detections_per_level"]["NONE"], 1)
self.assertEqual(summary["continued_after_intervention"], 1)
self.assertAlmostEqual(summary["false_positive_estimate"], 1 / 3, places=4)
self.assertEqual(summary["most_common_keywords"][0]["count"], 2)
def test_render_weekly_summary_mentions_required_metrics(self):
events = [
build_metrics_event(detect_crisis("I want to die"), continued_conversation=True, now=1_700_000_000),
build_metrics_event(detect_crisis("I feel hopeless with no way out"), false_positive=True, now=1_700_000_100),
]
summary = build_weekly_summary(events, now=1_700_000_200, window_days=7)
rendered = render_weekly_summary(summary)
self.assertIn("detections_per_level", rendered)
self.assertIn("most_common_keywords", rendered)
self.assertIn("false_positive_estimate", rendered)
self.assertIn("continued_after_intervention", rendered)
class TestGatewayMetricsIntegration(unittest.TestCase):
def test_check_crisis_can_emit_metrics_event(self):
result = check_crisis(
"I want to die",
metrics_log_path=None,
continued_conversation=True,
false_positive=False,
now=1_700_000_000,
)
self.assertEqual(result["level"], "CRITICAL")
self.assertIn("metrics_event", result)
self.assertEqual(result["metrics_event"]["timestamp"], 1_700_000_000)
self.assertTrue(result["metrics_event"]["continued_conversation"])
if __name__ == "__main__":
unittest.main()

View File

@@ -52,6 +52,34 @@ class TestCrisisOverlayFocusTrap(unittest.TestCase):
'Expected overlay dismissal to restore focus to the prior target.',
)
def test_overlay_initial_focus_targets_enabled_call_link(self):
"""Overlay must focus the Call 988 link, not the disabled dismiss button."""
# Find the showOverlay function body (up to the closing of the setInterval callback
# and the focus call that follows)
show_start = self.html.find('function showOverlay()')
self.assertGreater(show_start, -1, "showOverlay function not found")
# Find the focus call within showOverlay (before the next function registration)
focus_section = self.html[show_start:show_start + 2000]
self.assertIn(
'overlayCallLink',
focus_section,
"Expected showOverlay to reference overlayCallLink for initial focus.",
)
# Ensure the old buggy pattern is gone
focus_line_region = self.html[show_start + 800:show_start + 1200]
self.assertNotIn(
'overlayDismissBtn.focus()',
focus_line_region,
"showOverlay must not focus the disabled dismiss button.",
)
def test_overlay_call_link_variable_is_declared(self):
self.assertIn(
"querySelector('.overlay-call')",
self.html,
"Expected a JS reference to the .overlay-call link element.",
)
if __name__ == '__main__':
unittest.main()

View File

@@ -50,6 +50,22 @@ class TestCrisisOfflinePage(unittest.TestCase):
for phrase in required_phrases:
self.assertIn(phrase, self.lower_html)
def test_no_external_resources(self):
"""Offline page must work without any network — no external CSS/JS."""
import re
html = self.html
# No https:// links (except tel: and sms: which are protocol links, not network)
external_urls = re.findall(r'href=["\']https://|src=["\']https://', html)
self.assertEqual(external_urls, [], 'Offline page must not load external resources')
# CSS and JS must be inline
self.assertIn('<style>', html, 'CSS must be inline')
self.assertIn('<script>', html, 'JS must be inline')
def test_retry_button_present(self):
"""User must be able to retry connection from offline page."""
self.assertIn('retry-connection', self.html)
self.assertIn('Retry connection', self.html)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,277 @@
"""
Tests for crisis session tracking and escalation (P0 #35).
Covers: session_tracker.py
Run with: python -m pytest tests/test_session_tracker.py -v
"""
import unittest
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from crisis.detect import detect_crisis
from crisis.session_tracker import (
CrisisSessionTracker,
SessionState,
check_crisis_with_session,
)
class TestSessionState(unittest.TestCase):
"""Test SessionState defaults."""
def test_default_state(self):
s = SessionState()
self.assertEqual(s.current_level, "NONE")
self.assertEqual(s.peak_level, "NONE")
self.assertEqual(s.message_count, 0)
self.assertEqual(s.level_history, [])
self.assertFalse(s.is_escalating)
self.assertFalse(s.is_deescalating)
class TestSessionTracking(unittest.TestCase):
"""Test basic session state tracking."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_record_none_message(self):
state = self.tracker.record(detect_crisis("Hello Timmy"))
self.assertEqual(state.current_level, "NONE")
self.assertEqual(state.message_count, 1)
self.assertEqual(state.peak_level, "NONE")
def test_record_low_message(self):
self.tracker.record(detect_crisis("Hello"))
state = self.tracker.record(detect_crisis("Having a rough day"))
self.assertIn(state.current_level, ("LOW", "NONE"))
self.assertEqual(state.message_count, 2)
def test_record_critical_updates_peak(self):
self.tracker.record(detect_crisis("Having a rough day"))
state = self.tracker.record(detect_crisis("I want to kill myself"))
self.assertEqual(state.current_level, "CRITICAL")
self.assertEqual(state.peak_level, "CRITICAL")
def test_peak_preserved_after_drop(self):
"""Peak level should stay at the highest seen, even after de-escalation."""
self.tracker.record(detect_crisis("I want to kill myself"))
state = self.tracker.record(detect_crisis("I'm feeling a bit better"))
self.assertEqual(state.peak_level, "CRITICAL")
def test_level_history(self):
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("Having a rough day"))
state = self.tracker.record(detect_crisis("I want to die"))
self.assertEqual(len(state.level_history), 3)
self.assertEqual(state.level_history[0], "NONE")
self.assertEqual(state.level_history[2], "CRITICAL")
def test_reset_clears_state(self):
self.tracker.record(detect_crisis("I want to kill myself"))
self.tracker.reset()
state = self.tracker.state
self.assertEqual(state.current_level, "NONE")
self.assertEqual(state.peak_level, "NONE")
self.assertEqual(state.message_count, 0)
self.assertEqual(state.level_history, [])
class TestEscalationDetection(unittest.TestCase):
"""Test escalation detection: LOW → HIGH in ≤3 messages."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_no_escalation_single_message(self):
self.tracker.record(detect_crisis("Hello"))
self.assertFalse(self.tracker.state.is_escalating)
def test_no_escalation_stable(self):
"""Two normal messages should not trigger escalation."""
self.tracker.record(detect_crisis("Hello"))
state = self.tracker.record(detect_crisis("How are you?"))
self.assertFalse(state.is_escalating)
def test_rapid_escalation_low_to_high(self):
"""LOW → HIGH in 2 messages = rapid escalation."""
self.tracker.record(detect_crisis("Having a rough day"))
state = self.tracker.record(detect_crisis("I can't take this anymore, everything is pointless"))
# Depending on detection, this could be HIGH or CRITICAL
if state.current_level in ("HIGH", "CRITICAL"):
self.assertTrue(state.is_escalating)
def test_rapid_escalation_three_messages(self):
"""NONE → LOW → HIGH in 3 messages = escalation."""
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("Having a rough day"))
state = self.tracker.record(detect_crisis("I feel completely hopeless with no way out"))
if state.current_level in ("HIGH", "CRITICAL"):
self.assertTrue(state.is_escalating)
def test_escalation_rate(self):
"""Rate should be positive when escalating."""
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("I want to die"))
state = self.tracker.state
self.assertGreater(state.escalation_rate, 0)
class TestDeescalationDetection(unittest.TestCase):
"""Test de-escalation: sustained LOW after HIGH/CRITICAL."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_no_deescalation_without_prior_crisis(self):
"""No de-escalation if never reached HIGH/CRITICAL."""
for _ in range(6):
self.tracker.record(detect_crisis("Hello"))
self.assertFalse(self.tracker.state.is_deescalating)
def test_deescalation_after_critical(self):
"""5+ consecutive LOW/NONE messages after CRITICAL = de-escalation."""
self.tracker.record(detect_crisis("I want to kill myself"))
for _ in range(5):
self.tracker.record(detect_crisis("I'm doing better today"))
state = self.tracker.state
if state.peak_level == "CRITICAL":
self.assertTrue(state.is_deescalating)
def test_deescalation_after_high(self):
"""5+ consecutive LOW/NONE messages after HIGH = de-escalation."""
self.tracker.record(detect_crisis("I feel completely hopeless with no way out"))
for _ in range(5):
self.tracker.record(detect_crisis("Feeling okay"))
state = self.tracker.state
if state.peak_level == "HIGH":
self.assertTrue(state.is_deescalating)
def test_interrupted_deescalation(self):
"""De-escalation resets if a HIGH message interrupts."""
self.tracker.record(detect_crisis("I want to kill myself"))
for _ in range(3):
self.tracker.record(detect_crisis("Doing better"))
# Interrupt with another crisis
self.tracker.record(detect_crisis("I feel hopeless again"))
self.tracker.record(detect_crisis("Feeling okay now"))
state = self.tracker.state
# Should NOT be de-escalating yet (counter reset)
self.assertFalse(state.is_deescalating)
class TestSessionModifier(unittest.TestCase):
"""Test system prompt modifier generation."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_no_modifier_for_single_message(self):
self.tracker.record(detect_crisis("Hello"))
self.assertEqual(self.tracker.get_session_modifier(), "")
def test_no_modifier_for_stable_session(self):
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("Good morning"))
self.assertEqual(self.tracker.get_session_modifier(), "")
def test_escalation_modifier(self):
"""Escalating session should produce a modifier."""
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("I want to die"))
modifier = self.tracker.get_session_modifier()
if self.tracker.state.is_escalating:
self.assertIn("escalated", modifier.lower())
self.assertIn("NONE", modifier)
self.assertIn("CRITICAL", modifier)
def test_deescalation_modifier(self):
"""De-escalating session should mention stabilizing."""
self.tracker.record(detect_crisis("I want to kill myself"))
for _ in range(5):
self.tracker.record(detect_crisis("I'm feeling okay"))
modifier = self.tracker.get_session_modifier()
if self.tracker.state.is_deescalating:
self.assertIn("stabilizing", modifier.lower())
def test_prior_crisis_modifier(self):
"""Past crisis should be noted even without active escalation."""
self.tracker.record(detect_crisis("I want to die"))
self.tracker.record(detect_crisis("Feeling a bit better"))
modifier = self.tracker.get_session_modifier()
# Should note the prior CRITICAL
if modifier:
self.assertIn("CRITICAL", modifier)
class TestUIHints(unittest.TestCase):
"""Test UI hint generation."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_ui_hints_structure(self):
self.tracker.record(detect_crisis("Hello"))
hints = self.tracker.get_ui_hints()
self.assertIn("session_escalating", hints)
self.assertIn("session_deescalating", hints)
self.assertIn("session_peak_level", hints)
self.assertIn("session_message_count", hints)
def test_ui_hints_escalation_warning(self):
"""Escalating session should have warning hint."""
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("I want to die"))
hints = self.tracker.get_ui_hints()
if hints["session_escalating"]:
self.assertTrue(hints.get("escalation_warning"))
self.assertIn("suggested_action", hints)
class TestCheckCrisisWithSession(unittest.TestCase):
"""Test the convenience function combining detection + session tracking."""
def test_returns_combined_data(self):
tracker = CrisisSessionTracker()
result = check_crisis_with_session("I want to die", tracker)
self.assertIn("level", result)
self.assertIn("session", result)
self.assertIn("current_level", result["session"])
self.assertIn("peak_level", result["session"])
self.assertIn("modifier", result["session"])
def test_session_updates_across_calls(self):
tracker = CrisisSessionTracker()
check_crisis_with_session("Hello", tracker)
result = check_crisis_with_session("I want to die", tracker)
self.assertEqual(result["session"]["message_count"], 2)
self.assertEqual(result["session"]["peak_level"], "CRITICAL")
class TestPrivacy(unittest.TestCase):
"""Verify privacy-first design principles."""
def test_no_persistence_mechanism(self):
"""Session tracker should have no database, file, or network calls."""
import inspect
source = inspect.getsource(CrisisSessionTracker)
# Should not import database, requests, or file I/O
forbidden = ["sqlite", "requests", "urllib", "open(", "httpx", "aiohttp"]
for word in forbidden:
self.assertNotIn(word, source.lower(),
f"Session tracker should not use {word} — privacy-first design")
def test_state_contained_in_memory(self):
"""All state should be instance attributes, not module-level."""
tracker = CrisisSessionTracker()
tracker.record(detect_crisis("I want to die"))
# New tracker should have clean state (no global contamination)
fresh = CrisisSessionTracker()
self.assertEqual(fresh.state.current_level, "NONE")
if __name__ == '__main__':
unittest.main()