Compare commits

..

13 Commits

Author SHA1 Message Date
Timmy Burn
622cac0654 feat: add unified multimodal scorer for #134
All checks were successful
Sanity Checks / sanity-test (pull_request) Successful in 6s
Smoke Test / smoke (pull_request) Successful in 12s
2026-04-20 20:06:45 -04:00
Timmy Burn
1cd1dd3288 test: define unified multimodal scorer for #134 2026-04-20 19:58:43 -04:00
d412939b4f fix: footer /about link to point to static about.html
Fixes #59

The footer links to /about but the repo ships about.html. On a plain static server this results in a 404. Changed to /about.html so the link resolves correctly.
2026-04-17 05:37:40 +00:00
07c582aa08 Merge pull request 'fix: crisis overlay initial focus to enabled Call 988 link (#69)' (#126) from burn/69-1776264183 into main
Merge PR #126: fix: crisis overlay initial focus to enabled Call 988 link (#69)
2026-04-17 01:46:56 +00:00
5f95dc1e39 Merge pull request '[P3] Service worker: cache crisis resources for offline (#41)' (#122) from burn/41-1776264184 into main
Merge PR #122: [P3] Service worker: cache crisis resources for offline (#41)
2026-04-17 01:46:55 +00:00
b1f3cac36d Merge pull request 'feat: session-level crisis tracking and escalation (closes #35)' (#118) from door/issue-35 into main
Merge PR #118: feat: session-level crisis tracking and escalation (closes #35)
2026-04-17 01:46:53 +00:00
07b3f67845 fix: crisis overlay initial focus to enabled Call 988 link (#69)
All checks were successful
Sanity Checks / sanity-test (pull_request) Successful in 9s
Smoke Test / smoke (pull_request) Successful in 15s
2026-04-15 15:09:36 +00:00
c22bbbaf65 fix: crisis overlay initial focus to enabled Call 988 link (#69) 2026-04-15 15:09:32 +00:00
543cb1d40f test: add offline self-containment and retry button tests (#41)
All checks were successful
Sanity Checks / sanity-test (pull_request) Successful in 4s
Smoke Test / smoke (pull_request) Successful in 11s
2026-04-15 14:58:44 +00:00
3cfd01815a feat: session-level crisis tracking and escalation (closes #35)
All checks were successful
Sanity Checks / sanity-test (pull_request) Successful in 17s
Smoke Test / smoke (pull_request) Successful in 23s
2026-04-15 11:49:52 +00:00
5a7ba9f207 feat: session-level crisis tracking and escalation (closes #35) 2026-04-15 11:49:51 +00:00
8ed8f20a17 feat: session-level crisis tracking and escalation (closes #35) 2026-04-15 11:49:49 +00:00
9d7d26033e feat: session-level crisis tracking and escalation (closes #35) 2026-04-15 11:49:47 +00:00
12 changed files with 860 additions and 749 deletions

View File

@@ -6,7 +6,8 @@ Stands between a broken man and a machine that would tell him to die.
from .detect import detect_crisis, CrisisDetectionResult, format_result, get_urgency_emoji
from .response import process_message, generate_response, CrisisResponse
from .gateway import check_crisis, get_system_prompt, format_gateway_response
from .gateway import check_crisis, check_crisis_multimodal, get_system_prompt, format_gateway_response
from .session_tracker import CrisisSessionTracker, SessionState, check_crisis_with_session
__all__ = [
"detect_crisis",
@@ -15,8 +16,12 @@ __all__ = [
"generate_response",
"CrisisResponse",
"check_crisis",
"check_crisis_multimodal",
"get_system_prompt",
"format_result",
"format_gateway_response",
"get_urgency_emoji",
"CrisisSessionTracker",
"SessionState",
"check_crisis_with_session",
]

View File

@@ -1,387 +0,0 @@
"""
Crisis Detection A/B Testing Framework for the-door.
Provides feature-flagged A/B testing for crisis detection algorithms.
Variant A: Current canonical detector (crisis/detect.py)
Variant B: Enhanced detector with contextual scoring (configurable)
Logs which variant triggered for each event and tracks metrics:
- False positive rate per variant
- Detection latency per variant
- Detection distribution per variant
"""
import json
import time
import hashlib
import os
from dataclasses import dataclass, field, asdict
from typing import Dict, List, Optional, Callable, Any
from datetime import datetime, timezone
from enum import Enum
from crisis.detect import detect_crisis as detect_crisis_variant_a, CrisisDetectionResult
class Variant(Enum):
"""A/B test variants."""
A = "A" # Control: current canonical detector
B = "B" # Treatment: enhanced detector
@dataclass
class ABTestConfig:
"""Configuration for A/B test."""
enabled: bool = True
variant_b_percentage: float = 0.5 # 50% traffic to variant B
seed: Optional[str] = None # For deterministic assignment
log_file: str = "crisis_ab_test.jsonl"
metrics_file: str = "crisis_ab_metrics.json"
@dataclass
class DetectionEvent:
"""Single detection event for logging."""
event_id: str
timestamp: str
text_hash: str # Hash of input text (privacy-preserving)
variant: str
level: str
score: float
indicators: List[str]
latency_ms: float
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class VariantMetrics:
"""Metrics for a single variant."""
variant: str
total_events: int = 0
detections_by_level: Dict[str, int] = field(default_factory=lambda: {
"CRITICAL": 0, "HIGH": 0, "MEDIUM": 0, "LOW": 0, "NONE": 0
})
avg_latency_ms: float = 0.0
total_latency_ms: float = 0.0
false_positives: int = 0 # Requires human labeling
true_positives: int = 0 # Requires human labeling
false_positive_rate: Optional[float] = None # Calculated when labels available
@dataclass
class ABTestMetrics:
"""Aggregate metrics for A/B test."""
variant_a: VariantMetrics = field(default_factory=lambda: VariantMetrics("A"))
variant_b: VariantMetrics = field(default_factory=lambda: VariantMetrics("B"))
start_time: str = ""
end_time: str = ""
total_events: int = 0
disagreements: int = 0 # Cases where variants disagree
class CrisisABTester:
"""
A/B testing framework for crisis detection algorithms.
Usage:
tester = CrisisABTester()
result = tester.detect("I feel hopeless")
# Returns CrisisDetectionResult from assigned variant
# Logs event and updates metrics
"""
def __init__(self, config: Optional[ABTestConfig] = None):
self.config = config or ABTestConfig()
self.metrics = ABTestMetrics(start_time=datetime.now(timezone.utc).isoformat())
self._variant_b_detector: Optional[Callable] = None
self._event_log: List[DetectionEvent] = []
# Load existing metrics if file exists
if os.path.exists(self.config.metrics_file):
self._load_metrics()
def set_variant_b_detector(self, detector: Callable[[str], CrisisDetectionResult]):
"""Set the detector function for variant B."""
self._variant_b_detector = detector
def _assign_variant(self, text: str) -> Variant:
"""
Assign variant based on text hash (deterministic) or random.
Uses hash for consistent assignment of same/similar texts.
"""
if not self.config.enabled:
return Variant.A
# Use text hash for deterministic assignment
hash_input = text.strip().lower()
if self.config.seed:
hash_input = self.config.seed + hash_input
hash_val = int(hashlib.md5(hash_input.encode()).hexdigest(), 16)
normalized = (hash_val % 1000) / 1000.0 # 0.0 to 0.999
return Variant.B if normalized < self.config.variant_b_percentage else Variant.A
def _get_variant_b_result(self, text: str) -> CrisisDetectionResult:
"""
Get detection result from variant B.
Falls back to variant A if no variant B detector is set.
"""
if self._variant_b_detector:
return self._variant_b_detector(text)
# Default variant B: enhanced contextual scoring
# More sensitive to MEDIUM indicators (requires only 1 instead of 2)
from crisis.detect import _find_indicators, ACTIONS, SCORES
text_lower = text.lower()
matches = _find_indicators(text_lower)
if not matches:
return CrisisDetectionResult(level="NONE", score=0.0)
# CRITICAL and HIGH: same as variant A
for tier in ("CRITICAL", "HIGH"):
if matches[tier]:
tier_matches = matches[tier]
patterns = [m["pattern"] for m in tier_matches]
return CrisisDetectionResult(
level=tier,
indicators=patterns,
recommended_action=ACTIONS[tier],
score=SCORES[tier],
matches=tier_matches,
)
# MEDIUM: variant B requires only 1 indicator (vs 2 in variant A)
if matches["MEDIUM"]:
tier_matches = matches["MEDIUM"]
patterns = [m["pattern"] for m in tier_matches]
return CrisisDetectionResult(
level="MEDIUM",
indicators=patterns,
recommended_action=ACTIONS["MEDIUM"],
score=SCORES["MEDIUM"],
matches=tier_matches,
)
if matches["LOW"]:
tier_matches = matches["LOW"]
patterns = [m["pattern"] for m in tier_matches]
return CrisisDetectionResult(
level="LOW",
indicators=patterns,
recommended_action=ACTIONS["LOW"],
score=SCORES["LOW"],
matches=tier_matches,
)
return CrisisDetectionResult(level="NONE", score=0.0)
def detect(self, text: str, metadata: Optional[Dict] = None) -> CrisisDetectionResult:
"""
Run A/B test detection.
Args:
text: Input text to analyze
metadata: Optional metadata to attach to event log
Returns:
CrisisDetectionResult from assigned variant
"""
if not self.config.enabled:
return detect_crisis_variant_a(text)
variant = self._assign_variant(text)
start_time = time.perf_counter()
if variant == Variant.A:
result = detect_crisis_variant_a(text)
else:
result = self._get_variant_b_result(text)
latency_ms = (time.perf_counter() - start_time) * 1000
# Log event
event = DetectionEvent(
event_id=f"{int(time.time() * 1000)}-{hash(text) % 10000:04d}",
timestamp=datetime.now(timezone.utc).isoformat(),
text_hash=hashlib.sha256(text.encode()).hexdigest()[:16],
variant=variant.value,
level=result.level,
score=result.score,
indicators=result.indicators[:5], # Limit for privacy
latency_ms=round(latency_ms, 3),
metadata=metadata or {}
)
self._event_log.append(event)
self._log_event(event)
# Update metrics
self._update_metrics(variant, result, latency_ms)
return result
def _log_event(self, event: DetectionEvent):
"""Append event to JSONL log file."""
try:
with open(self.config.log_file, "a") as f:
f.write(json.dumps(asdict(event)) + "\n")
except Exception as e:
print(f"Warning: Could not log A/B test event: {e}")
def _update_metrics(self, variant: Variant, result: CrisisDetectionResult, latency_ms: float):
"""Update running metrics."""
vm = self.metrics.variant_a if variant == Variant.A else self.metrics.variant_b
vm.total_events += 1
vm.detections_by_level[result.level] = vm.detections_by_level.get(result.level, 0) + 1
vm.total_latency_ms += latency_ms
vm.avg_latency_ms = vm.total_latency_ms / vm.total_events
self.metrics.total_events += 1
def compare_results(self, text: str) -> Dict[str, CrisisDetectionResult]:
"""
Run both variants and return both results (for analysis).
Does not log to A/B test metrics.
"""
result_a = detect_crisis_variant_a(text)
result_b = self._get_variant_b_result(text)
return {"A": result_a, "B": result_b}
def get_disagreement_rate(self) -> float:
"""
Calculate disagreement rate from logged events.
Requires running detect() for same texts with both variants.
"""
if not self._event_log:
return 0.0
# Group by text_hash
by_text: Dict[str, Dict[str, str]] = {}
for event in self._event_log:
if event.text_hash not in by_text:
by_text[event.text_hash] = {}
by_text[event.text_hash][event.variant] = event.level
disagreements = sum(
1 for variants in by_text.values()
if "A" in variants and "B" in variants and variants["A"] != variants["B"]
)
return disagreements / len(by_text) if by_text else 0.0
def get_metrics(self) -> ABTestMetrics:
"""Get current metrics snapshot."""
self.metrics.end_time = datetime.now(timezone.utc).isoformat()
self.metrics.disagreements = int(self.get_disagreement_rate() * self.metrics.total_events)
return self.metrics
def save_metrics(self):
"""Save metrics to JSON file."""
try:
with open(self.config.metrics_file, "w") as f:
json.dump(asdict(self.get_metrics()), f, indent=2)
except Exception as e:
print(f"Warning: Could not save A/B test metrics: {e}")
def _load_metrics(self):
"""Load metrics from JSON file."""
try:
with open(self.config.metrics_file, "r") as f:
data = json.load(f)
# Reconstruct metrics from saved data
if "variant_a" in data:
self.metrics.variant_a = VariantMetrics(**data["variant_a"])
if "variant_b" in data:
self.metrics.variant_b = VariantMetrics(**data["variant_b"])
self.metrics.total_events = data.get("total_events", 0)
self.metrics.disagreements = data.get("disagreements", 0)
except Exception as e:
print(f"Warning: Could not load A/B test metrics: {e}")
def label_event(self, event_id: str, is_true_positive: bool):
"""
Label an event as true/false positive (requires human review).
Updates false positive rate metrics.
"""
for event in self._event_log:
if event.event_id == event_id:
vm = self.metrics.variant_a if event.variant == "A" else self.metrics.variant_b
if is_true_positive:
vm.true_positives += 1
else:
vm.false_positives += 1
# Recalculate false positive rate
total_labelled = vm.true_positives + vm.false_positives
if total_labelled > 0:
vm.false_positive_rate = vm.false_positives / total_labelled
self.save_metrics()
return
raise ValueError(f"Event {event_id} not found")
def get_report(self) -> str:
"""Generate human-readable A/B test report."""
m = self.get_metrics()
lines = [
"=" * 60,
"CRISIS DETECTION A/B TEST REPORT",
"=" * 60,
f"Period: {m.start_time} to {m.end_time}",
f"Total Events: {m.total_events}",
f"Disagreements: {m.disagreements}",
"",
"VARIANT A (Control - Current Detector):",
f" Events: {m.variant_a.total_events}",
f" Avg Latency: {m.variant_a.avg_latency_ms:.3f} ms",
f" Detection Distribution:",
]
for level, count in m.variant_a.detections_by_level.items():
pct = (count / m.variant_a.total_events * 100) if m.variant_a.total_events else 0
lines.append(f" {level}: {count} ({pct:.1f}%)")
if m.variant_a.false_positive_rate is not None:
lines.append(f" False Positive Rate: {m.variant_a.false_positive_rate:.1%}")
lines.extend([
"",
"VARIANT B (Treatment - Enhanced Detector):",
f" Events: {m.variant_b.total_events}",
f" Avg Latency: {m.variant_b.avg_latency_ms:.3f} ms",
f" Detection Distribution:",
])
for level, count in m.variant_b.detections_by_level.items():
pct = (count / m.variant_b.total_events * 100) if m.variant_b.total_events else 0
lines.append(f" {level}: {count} ({pct:.1f}%)")
if m.variant_b.false_positive_rate is not None:
lines.append(f" False Positive Rate: {m.variant_b.false_positive_rate:.1%}")
lines.append("=" * 60)
return "\n".join(lines)
# ── Module-level convenience ──────────────────────────────────────
_default_tester: Optional[CrisisABTester] = None
def get_ab_tester(config: Optional[ABTestConfig] = None) -> CrisisABTester:
"""Get or create the default A/B tester instance."""
global _default_tester
if _default_tester is None:
_default_tester = CrisisABTester(config)
return _default_tester
def detect_crisis_ab(text: str, metadata: Optional[Dict] = None) -> CrisisDetectionResult:
"""Convenience function for A/B tested crisis detection."""
return get_ab_tester().detect(text, metadata)

View File

@@ -2,18 +2,21 @@
Crisis Gateway Module for the-door.
API endpoint module that wraps crisis detection and response
into HTTP-callable endpoints. Integrates detect.py and response.py.
into HTTP-callable endpoints. Integrates detect.py, unified_scorer.py, and response.py.
Usage:
from crisis.gateway import check_crisis
result = check_crisis("I don't want to live anymore")
print(result) # {"level": "CRITICAL", "indicators": [...], "response": {...}}
"""
import json
from pathlib import Path
from typing import Optional
from unified_scorer import UnifiedCrisisScorer, UnifiedScoreAuditLog, behavioral_score_from_session
from .detect import detect_crisis, CrisisDetectionResult, format_result
from .compassion_router import router
from .response import (
@@ -22,6 +25,7 @@ from .response import (
get_system_prompt_modifier,
CrisisResponse,
)
from .session_tracker import CrisisSessionTracker
def check_crisis(text: str) -> dict:
@@ -49,6 +53,74 @@ def check_crisis(text: str) -> dict:
}
def check_crisis_multimodal(
text: str,
*,
tracker: Optional[CrisisSessionTracker] = None,
voice_score: Optional[float] = None,
image_score: Optional[float] = None,
behavioral_score: Optional[float] = None,
audit_log_path: Optional[Path] = None,
weights: Optional[dict] = None,
) -> dict:
"""Combine text, voice, image, and behavioral signals into one crisis assessment."""
detection = detect_crisis(text)
session_state = tracker.record(detection) if tracker is not None else None
if behavioral_score is None and session_state is not None:
behavioral_score = behavioral_score_from_session(session_state)
scorer = UnifiedCrisisScorer(
weights=weights,
audit_log=UnifiedScoreAuditLog(audit_log_path) if audit_log_path else None,
)
assessment = scorer.score(
text_score=detection.score,
voice_score=voice_score,
image_score=image_score,
behavioral_score=behavioral_score,
source_text=text,
)
unified_detection = CrisisDetectionResult(
level=assessment.level.value,
indicators=detection.indicators,
recommended_action=detection.recommended_action,
score=assessment.combined_score,
matches=detection.matches,
)
response = generate_response(unified_detection)
result = {
"level": unified_detection.level,
"score": unified_detection.score,
"indicators": detection.indicators,
"recommended_action": unified_detection.recommended_action,
"timmy_message": response.timmy_message,
"ui": {
"show_crisis_panel": response.show_crisis_panel,
"show_overlay": response.show_overlay,
"provide_988": response.provide_988,
},
"escalate": response.escalate,
"unified": {
"level": assessment.level.value,
"combined_score": assessment.combined_score,
"weights": assessment.weights,
"modalities": assessment.modalities,
"present_modalities": assessment.present_modalities,
},
}
if session_state is not None:
result["session"] = {
"current_level": session_state.current_level,
"peak_level": session_state.peak_level,
"message_count": session_state.message_count,
"is_escalating": session_state.is_escalating,
"is_deescalating": session_state.is_deescalating,
}
return result
def get_system_prompt(base_prompt: str, text: str = "") -> str:
"""
Sovereign Heart System Prompt Override.

259
crisis/session_tracker.py Normal file
View File

@@ -0,0 +1,259 @@
"""
Session-level crisis tracking and escalation for the-door (P0 #35).
Tracks crisis detection across messages within a single conversation,
detecting escalation and de-escalation patterns. Privacy-first: no
persistence beyond the conversation session.
Each message is analyzed in isolation by detect.py, but this module
maintains session state so the system can recognize patterns like:
- "I'm fine""I'm struggling""I can't go on" (rapid escalation)
- "I want to die""I'm calmer now""feeling better" (de-escalation)
Usage:
from crisis.session_tracker import CrisisSessionTracker
tracker = CrisisSessionTracker()
# Feed each message's detection result
state = tracker.record(detect_crisis("I'm having a tough day"))
print(state.current_level) # "LOW"
print(state.is_escalating) # False
state = tracker.record(detect_crisis("I feel hopeless"))
print(state.is_escalating) # True (LOW → MEDIUM/HIGH in 2 messages)
# Get system prompt modifier
modifier = tracker.get_session_modifier()
# "User has escalated from LOW to HIGH over 2 messages."
# Reset for new session
tracker.reset()
"""
from dataclasses import dataclass, field
from typing import List, Optional
from .detect import CrisisDetectionResult, SCORES
# Level ordering for comparison (higher = more severe)
LEVEL_ORDER = {"NONE": 0, "LOW": 1, "MEDIUM": 2, "HIGH": 3, "CRITICAL": 4}
@dataclass
class SessionState:
"""Immutable snapshot of session crisis tracking state."""
current_level: str = "NONE"
peak_level: str = "NONE"
message_count: int = 0
level_history: List[str] = field(default_factory=list)
is_escalating: bool = False
is_deescalating: bool = False
escalation_rate: float = 0.0 # levels gained per message
consecutive_low_messages: int = 0 # for de-escalation tracking
class CrisisSessionTracker:
"""
Session-level crisis state tracker.
Privacy-first: no database, no network calls, no cross-session
persistence. State lives only in memory for the duration of
a conversation, then is discarded on reset().
"""
# Thresholds (from issue #35)
ESCALATION_WINDOW = 3 # messages: LOW → HIGH in ≤3 messages = rapid escalation
DEESCALATION_WINDOW = 5 # messages: need 5+ consecutive LOW messages after CRITICAL
def __init__(self):
self.reset()
def reset(self):
"""Reset all session state. Call on new conversation."""
self._current_level = "NONE"
self._peak_level = "NONE"
self._message_count = 0
self._level_history: List[str] = []
self._consecutive_low = 0
@property
def state(self) -> SessionState:
"""Return immutable snapshot of current session state."""
is_escalating = self._detect_escalation()
is_deescalating = self._detect_deescalation()
rate = self._compute_escalation_rate()
return SessionState(
current_level=self._current_level,
peak_level=self._peak_level,
message_count=self._message_count,
level_history=list(self._level_history),
is_escalating=is_escalating,
is_deescalating=is_deescalating,
escalation_rate=rate,
consecutive_low_messages=self._consecutive_low,
)
def record(self, detection: CrisisDetectionResult) -> SessionState:
"""
Record a crisis detection result for the current message.
Returns updated SessionState.
"""
level = detection.level
self._message_count += 1
self._level_history.append(level)
# Update peak
if LEVEL_ORDER.get(level, 0) > LEVEL_ORDER.get(self._peak_level, 0):
self._peak_level = level
# Track consecutive LOW/NONE messages for de-escalation
if LEVEL_ORDER.get(level, 0) <= LEVEL_ORDER["LOW"]:
self._consecutive_low += 1
else:
self._consecutive_low = 0
self._current_level = level
return self.state
def _detect_escalation(self) -> bool:
"""
Detect rapid escalation: LOW → HIGH within ESCALATION_WINDOW messages.
Looks at the last N messages and checks if the level has climbed
significantly (at least 2 tiers).
"""
if len(self._level_history) < 2:
return False
window = self._level_history[-self.ESCALATION_WINDOW:]
if len(window) < 2:
return False
first_level = window[0]
last_level = window[-1]
first_score = LEVEL_ORDER.get(first_level, 0)
last_score = LEVEL_ORDER.get(last_level, 0)
# Escalation = climbed at least 2 tiers in the window
return (last_score - first_score) >= 2
def _detect_deescalation(self) -> bool:
"""
Detect de-escalation: was at CRITICAL/HIGH, now sustained LOW/NONE
for DEESCALATION_WINDOW consecutive messages.
"""
if LEVEL_ORDER.get(self._peak_level, 0) < LEVEL_ORDER["HIGH"]:
return False
return self._consecutive_low >= self.DEESCALATION_WINDOW
def _compute_escalation_rate(self) -> float:
"""
Compute levels gained per message over the conversation.
Positive = escalating, negative = de-escalating, 0 = stable.
"""
if self._message_count < 2:
return 0.0
first = LEVEL_ORDER.get(self._level_history[0], 0)
current = LEVEL_ORDER.get(self._current_level, 0)
return (current - first) / (self._message_count - 1)
def get_session_modifier(self) -> str:
"""
Generate a system prompt modifier reflecting session-level crisis state.
Returns empty string if no session context is relevant.
"""
if self._message_count < 2:
return ""
s = self.state
if s.is_escalating:
return (
f"User has escalated from {self._level_history[0]} to "
f"{s.current_level} over {s.message_count} messages. "
f"Peak crisis level this session: {s.peak_level}. "
"Respond with heightened awareness. The trajectory is "
"worsening — prioritize safety and connection."
)
if s.is_deescalating:
return (
f"User previously reached {s.peak_level} crisis level "
f"but has been at {s.current_level} or below for "
f"{s.consecutive_low_messages} consecutive messages. "
"The situation appears to be stabilizing. Continue "
"supportive engagement while remaining vigilant."
)
if s.peak_level in ("CRITICAL", "HIGH") and s.current_level not in ("CRITICAL", "HIGH"):
return (
f"User previously reached {s.peak_level} crisis level "
f"this session (currently {s.current_level}). "
"Continue with care and awareness of the earlier crisis."
)
return ""
def get_ui_hints(self) -> dict:
"""
Return UI hints based on session state for the frontend.
These are advisory — the frontend decides what to show.
"""
s = self.state
hints = {
"session_escalating": s.is_escalating,
"session_deescalating": s.is_deescalating,
"session_peak_level": s.peak_level,
"session_message_count": s.message_count,
}
if s.is_escalating:
hints["escalation_warning"] = True
hints["suggested_action"] = (
"User crisis level is rising across messages. "
"Consider increasing intervention level."
)
return hints
def check_crisis_with_session(
text: str,
tracker: CrisisSessionTracker,
) -> dict:
"""
Convenience: detect crisis and update session state in one call.
Returns combined single-message detection + session-level context.
"""
from .detect import detect_crisis
from .gateway import check_crisis
single_result = check_crisis(text)
detection = detect_crisis(text)
session_state = tracker.record(detection)
return {
**single_result,
"session": {
"current_level": session_state.current_level,
"peak_level": session_state.peak_level,
"message_count": session_state.message_count,
"is_escalating": session_state.is_escalating,
"is_deescalating": session_state.is_deescalating,
"modifier": tracker.get_session_modifier(),
"ui_hints": tracker.get_ui_hints(),
},
}

View File

@@ -680,7 +680,7 @@ html, body {
<!-- Footer -->
<footer id="footer">
<a href="/about" aria-label="About The Door">about</a>
<a href="/about.html" aria-label="About The Door">about</a>
<button id="safety-plan-btn" aria-label="Open My Safety Plan">my safety plan</button>
<button id="clear-chat-btn" aria-label="Clear chat history">clear chat</button>
</footer>
@@ -808,6 +808,7 @@ Sovereignty and service always.`;
var crisisPanel = document.getElementById('crisis-panel');
var crisisOverlay = document.getElementById('crisis-overlay');
var overlayDismissBtn = document.getElementById('overlay-dismiss-btn');
var overlayCallLink = document.querySelector('.overlay-call');
var statusDot = document.querySelector('.status-dot');
var statusText = document.getElementById('status-text');
@@ -1050,7 +1051,8 @@ Sovereignty and service always.`;
}
}, 1000);
overlayDismissBtn.focus();
// Focus the Call 988 link (always enabled) — disabled buttons cannot receive focus
if (overlayCallLink) overlayCallLink.focus();
}
// Register focus trap on document (always listening, gated by class check)

View File

@@ -1,357 +0,0 @@
"""
Tests for Crisis Detection A/B Testing Framework.
"""
import unittest
import os
import json
import tempfile
import shutil
from unittest.mock import patch, MagicMock
# Import from the crisis module
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from crisis.ab_test import (
CrisisABTester,
ABTestConfig,
Variant,
DetectionEvent,
VariantMetrics,
ABTestMetrics,
detect_crisis_ab,
)
from crisis.detect import CrisisDetectionResult
class TestABTestConfig(unittest.TestCase):
"""Test A/B test configuration."""
def test_default_config(self):
config = ABTestConfig()
self.assertTrue(config.enabled)
self.assertEqual(config.variant_b_percentage, 0.5)
self.assertIsNone(config.seed)
self.assertEqual(config.log_file, "crisis_ab_test.jsonl")
self.assertEqual(config.metrics_file, "crisis_ab_metrics.json")
def test_custom_config(self):
config = ABTestConfig(
enabled=False,
variant_b_percentage=0.3,
seed="test-seed",
log_file="custom.jsonl",
metrics_file="custom.json"
)
self.assertFalse(config.enabled)
self.assertEqual(config.variant_b_percentage, 0.3)
self.assertEqual(config.seed, "test-seed")
class TestVariantAssignment(unittest.TestCase):
"""Test variant assignment logic."""
def test_deterministic_assignment(self):
"""Same text should always get same variant with same seed."""
config = ABTestConfig(seed="test-seed")
tester = CrisisABTester(config)
text = "I feel hopeless today"
variant1 = tester._assign_variant(text)
variant2 = tester._assign_variant(text)
self.assertEqual(variant1, variant2)
def test_assignment_distribution(self):
"""With 50% split, roughly half should go to each variant."""
config = ABTestConfig(seed="test-seed")
tester = CrisisABTester(config)
variants_a = 0
variants_b = 0
test_texts = [f"test message {i}" for i in range(100)]
for text in test_texts:
variant = tester._assign_variant(text)
if variant == Variant.A:
variants_a += 1
else:
variants_b += 1
# Should be roughly 50/50 (allow some variance)
self.assertGreater(variants_a, 30)
self.assertGreater(variants_b, 30)
def test_disabled_returns_a(self):
"""When disabled, should always return variant A."""
config = ABTestConfig(enabled=False)
tester = CrisisABTester(config)
for i in range(10):
variant = tester._assign_variant(f"test {i}")
self.assertEqual(variant, Variant.A)
class TestDetection(unittest.TestCase):
"""Test A/B detection logic."""
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
self.log_file = os.path.join(self.temp_dir, "test_ab.jsonl")
self.metrics_file = os.path.join(self.temp_dir, "test_metrics.json")
self.config = ABTestConfig(
seed="test-seed",
log_file=self.log_file,
metrics_file=self.metrics_file
)
self.tester = CrisisABTester(self.config)
def tearDown(self):
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_detect_returns_result(self):
"""detect() should return CrisisDetectionResult."""
result = self.tester.detect("I feel sad")
self.assertIsInstance(result, CrisisDetectionResult)
self.assertIn(result.level, ["NONE", "LOW", "MEDIUM", "HIGH", "CRITICAL"])
def test_detect_critical_text(self):
"""Critical text should be detected regardless of variant."""
result = self.tester.detect("I want to kill myself")
self.assertEqual(result.level, "CRITICAL")
def test_detect_none_text(self):
"""Non-crisis text should return NONE."""
result = self.tester.detect("The weather is nice today")
self.assertEqual(result.level, "NONE")
def test_disabled_uses_variant_a(self):
"""When disabled, should use variant A only."""
config = ABTestConfig(enabled=False, log_file=self.log_file, metrics_file=self.metrics_file)
tester = CrisisABTester(config)
result = tester.detect("I feel hopeless")
self.assertIsInstance(result, CrisisDetectionResult)
def test_logs_events(self):
"""detect() should log events to JSONL file."""
self.tester.detect("I feel sad")
self.tester.detect("I feel happy")
self.assertTrue(os.path.exists(self.log_file))
with open(self.log_file) as f:
lines = f.readlines()
self.assertEqual(len(lines), 2)
event = json.loads(lines[0])
self.assertIn("event_id", event)
self.assertIn("variant", event)
self.assertIn("level", event)
def test_updates_metrics(self):
"""detect() should update metrics."""
self.tester.detect("I feel hopeless") # Should trigger detection
self.tester.detect("Hello world") # Should not trigger
metrics = self.tester.get_metrics()
self.assertEqual(metrics.total_events, 2)
def test_variant_b_more_sensitive(self):
"""Variant B should be more sensitive to MEDIUM indicators."""
# Create tester that always assigns variant B
config = ABTestConfig(
variant_b_percentage=1.0, # 100% to B
seed="test-seed",
log_file=self.log_file,
metrics_file=self.metrics_file
)
tester_b = CrisisABTester(config)
# Single MEDIUM indicator - variant A would return LOW, variant B returns MEDIUM
result_b = tester_b.detect("I feel worthless")
# Compare with variant A
config_a = ABTestConfig(
variant_b_percentage=0.0, # 0% to B (100% to A)
seed="test-seed",
log_file=self.log_file + ".a",
metrics_file=self.metrics_file + ".a"
)
tester_a = CrisisABTester(config_a)
result_a = tester_a.detect("I feel worthless")
# Variant B should be at least as sensitive
level_order = {"NONE": 0, "LOW": 1, "MEDIUM": 2, "HIGH": 3, "CRITICAL": 4}
self.assertGreaterEqual(
level_order[result_b.level],
level_order[result_a.level]
)
class TestCompareResults(unittest.TestCase):
"""Test compare_results functionality."""
def test_compare_returns_both_variants(self):
tester = CrisisABTester()
results = tester.compare_results("I feel hopeless")
self.assertIn("A", results)
self.assertIn("B", results)
self.assertIsInstance(results["A"], CrisisDetectionResult)
self.assertIsInstance(results["B"], CrisisDetectionResult)
def test_compare_does_not_log(self):
"""compare_results should not log to A/B test metrics."""
temp_dir = tempfile.mkdtemp()
log_file = os.path.join(temp_dir, "test.jsonl")
metrics_file = os.path.join(temp_dir, "metrics.json")
config = ABTestConfig(log_file=log_file, metrics_file=metrics_file)
tester = CrisisABTester(config)
tester.compare_results("I feel sad")
# Log file should not exist (no events logged)
self.assertFalse(os.path.exists(log_file))
shutil.rmtree(temp_dir)
class TestMetrics(unittest.TestCase):
"""Test metrics tracking."""
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
self.log_file = os.path.join(self.temp_dir, "test.jsonl")
self.metrics_file = os.path.join(self.temp_dir, "metrics.json")
self.config = ABTestConfig(
seed="test-seed",
log_file=self.log_file,
metrics_file=self.metrics_file
)
def tearDown(self):
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_metrics_track_events(self):
tester = CrisisABTester(self.config)
for i in range(10):
tester.detect(f"test message {i}")
metrics = tester.get_metrics()
self.assertEqual(metrics.total_events, 10)
def test_metrics_track_levels(self):
tester = CrisisABTester(self.config)
tester.detect("I want to kill myself") # CRITICAL
tester.detect("I feel hopeless") # HIGH or MEDIUM
tester.detect("Hello world") # NONE
metrics = tester.get_metrics()
total_detections = sum(metrics.variant_a.detections_by_level.values())
total_detections += sum(metrics.variant_b.detections_by_level.values())
self.assertEqual(total_detections, 3)
def test_save_and_load_metrics(self):
tester = CrisisABTester(self.config)
for i in range(5):
tester.detect(f"test {i}")
tester.save_metrics()
self.assertTrue(os.path.exists(self.metrics_file))
# Create new tester that loads saved metrics
tester2 = CrisisABTester(self.config)
self.assertEqual(tester2.metrics.total_events, 5)
class TestEventLabeling(unittest.TestCase):
"""Test event labeling for false positive tracking."""
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
self.log_file = os.path.join(self.temp_dir, "test.jsonl")
self.metrics_file = os.path.join(self.temp_dir, "metrics.json")
self.config = ABTestConfig(
seed="test-seed",
log_file=self.log_file,
metrics_file=self.metrics_file
)
self.tester = CrisisABTester(self.config)
def tearDown(self):
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_label_event_updates_metrics(self):
self.tester.detect("I feel hopeless")
event_id = self.tester._event_log[0].event_id
self.tester.label_event(event_id, is_true_positive=False)
metrics = self.tester.get_metrics()
# Find which variant was assigned
variant = self.tester._event_log[0].variant
vm = metrics.variant_a if variant == "A" else metrics.variant_b
self.assertEqual(vm.false_positives, 1)
self.assertEqual(vm.false_positive_rate, 1.0)
def test_label_nonexistent_event_raises(self):
with self.assertRaises(ValueError):
self.tester.label_event("nonexistent-id", is_true_positive=True)
class TestReport(unittest.TestCase):
"""Test report generation."""
def test_report_format(self):
tester = CrisisABTester()
for i in range(5):
tester.detect(f"test message {i}")
report = tester.get_report()
self.assertIn("CRISIS DETECTION A/B TEST REPORT", report)
self.assertIn("VARIANT A", report)
self.assertIn("VARIANT B", report)
self.assertIn("Total Events: 5", report)
class TestConvenienceFunction(unittest.TestCase):
"""Test module-level convenience function."""
def test_detect_crisis_ab(self):
result = detect_crisis_ab("I feel sad")
self.assertIsInstance(result, CrisisDetectionResult)
def test_detect_crisis_ab_with_metadata(self):
result = detect_crisis_ab("I feel sad", metadata={"source": "test"})
self.assertIsInstance(result, CrisisDetectionResult)
class TestCustomVariantBDetector(unittest.TestCase):
"""Test custom variant B detector."""
def test_custom_detector(self):
"""Should use custom detector when set."""
def custom_detector(text: str) -> CrisisDetectionResult:
return CrisisDetectionResult(
level="HIGH",
indicators=["custom"],
score=0.9
)
tester = CrisisABTester(ABTestConfig(variant_b_percentage=1.0))
tester.set_variant_b_detector(custom_detector)
result = tester.detect("Hello world")
self.assertEqual(result.level, "HIGH")
self.assertEqual(result.indicators, ["custom"])
if __name__ == "__main__":
unittest.main()

View File

@@ -52,6 +52,34 @@ class TestCrisisOverlayFocusTrap(unittest.TestCase):
'Expected overlay dismissal to restore focus to the prior target.',
)
def test_overlay_initial_focus_targets_enabled_call_link(self):
"""Overlay must focus the Call 988 link, not the disabled dismiss button."""
# Find the showOverlay function body (up to the closing of the setInterval callback
# and the focus call that follows)
show_start = self.html.find('function showOverlay()')
self.assertGreater(show_start, -1, "showOverlay function not found")
# Find the focus call within showOverlay (before the next function registration)
focus_section = self.html[show_start:show_start + 2000]
self.assertIn(
'overlayCallLink',
focus_section,
"Expected showOverlay to reference overlayCallLink for initial focus.",
)
# Ensure the old buggy pattern is gone
focus_line_region = self.html[show_start + 800:show_start + 1200]
self.assertNotIn(
'overlayDismissBtn.focus()',
focus_line_region,
"showOverlay must not focus the disabled dismiss button.",
)
def test_overlay_call_link_variable_is_declared(self):
self.assertIn(
"querySelector('.overlay-call')",
self.html,
"Expected a JS reference to the .overlay-call link element.",
)
if __name__ == '__main__':
unittest.main()

View File

@@ -50,6 +50,22 @@ class TestCrisisOfflinePage(unittest.TestCase):
for phrase in required_phrases:
self.assertIn(phrase, self.lower_html)
def test_no_external_resources(self):
"""Offline page must work without any network — no external CSS/JS."""
import re
html = self.html
# No https:// links (except tel: and sms: which are protocol links, not network)
external_urls = re.findall(r'href=["\']https://|src=["\']https://', html)
self.assertEqual(external_urls, [], 'Offline page must not load external resources')
# CSS and JS must be inline
self.assertIn('<style>', html, 'CSS must be inline')
self.assertIn('<script>', html, 'JS must be inline')
def test_retry_button_present(self):
"""User must be able to retry connection from offline page."""
self.assertIn('retry-connection', self.html)
self.assertIn('Retry connection', self.html)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,277 @@
"""
Tests for crisis session tracking and escalation (P0 #35).
Covers: session_tracker.py
Run with: python -m pytest tests/test_session_tracker.py -v
"""
import unittest
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from crisis.detect import detect_crisis
from crisis.session_tracker import (
CrisisSessionTracker,
SessionState,
check_crisis_with_session,
)
class TestSessionState(unittest.TestCase):
"""Test SessionState defaults."""
def test_default_state(self):
s = SessionState()
self.assertEqual(s.current_level, "NONE")
self.assertEqual(s.peak_level, "NONE")
self.assertEqual(s.message_count, 0)
self.assertEqual(s.level_history, [])
self.assertFalse(s.is_escalating)
self.assertFalse(s.is_deescalating)
class TestSessionTracking(unittest.TestCase):
"""Test basic session state tracking."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_record_none_message(self):
state = self.tracker.record(detect_crisis("Hello Timmy"))
self.assertEqual(state.current_level, "NONE")
self.assertEqual(state.message_count, 1)
self.assertEqual(state.peak_level, "NONE")
def test_record_low_message(self):
self.tracker.record(detect_crisis("Hello"))
state = self.tracker.record(detect_crisis("Having a rough day"))
self.assertIn(state.current_level, ("LOW", "NONE"))
self.assertEqual(state.message_count, 2)
def test_record_critical_updates_peak(self):
self.tracker.record(detect_crisis("Having a rough day"))
state = self.tracker.record(detect_crisis("I want to kill myself"))
self.assertEqual(state.current_level, "CRITICAL")
self.assertEqual(state.peak_level, "CRITICAL")
def test_peak_preserved_after_drop(self):
"""Peak level should stay at the highest seen, even after de-escalation."""
self.tracker.record(detect_crisis("I want to kill myself"))
state = self.tracker.record(detect_crisis("I'm feeling a bit better"))
self.assertEqual(state.peak_level, "CRITICAL")
def test_level_history(self):
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("Having a rough day"))
state = self.tracker.record(detect_crisis("I want to die"))
self.assertEqual(len(state.level_history), 3)
self.assertEqual(state.level_history[0], "NONE")
self.assertEqual(state.level_history[2], "CRITICAL")
def test_reset_clears_state(self):
self.tracker.record(detect_crisis("I want to kill myself"))
self.tracker.reset()
state = self.tracker.state
self.assertEqual(state.current_level, "NONE")
self.assertEqual(state.peak_level, "NONE")
self.assertEqual(state.message_count, 0)
self.assertEqual(state.level_history, [])
class TestEscalationDetection(unittest.TestCase):
"""Test escalation detection: LOW → HIGH in ≤3 messages."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_no_escalation_single_message(self):
self.tracker.record(detect_crisis("Hello"))
self.assertFalse(self.tracker.state.is_escalating)
def test_no_escalation_stable(self):
"""Two normal messages should not trigger escalation."""
self.tracker.record(detect_crisis("Hello"))
state = self.tracker.record(detect_crisis("How are you?"))
self.assertFalse(state.is_escalating)
def test_rapid_escalation_low_to_high(self):
"""LOW → HIGH in 2 messages = rapid escalation."""
self.tracker.record(detect_crisis("Having a rough day"))
state = self.tracker.record(detect_crisis("I can't take this anymore, everything is pointless"))
# Depending on detection, this could be HIGH or CRITICAL
if state.current_level in ("HIGH", "CRITICAL"):
self.assertTrue(state.is_escalating)
def test_rapid_escalation_three_messages(self):
"""NONE → LOW → HIGH in 3 messages = escalation."""
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("Having a rough day"))
state = self.tracker.record(detect_crisis("I feel completely hopeless with no way out"))
if state.current_level in ("HIGH", "CRITICAL"):
self.assertTrue(state.is_escalating)
def test_escalation_rate(self):
"""Rate should be positive when escalating."""
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("I want to die"))
state = self.tracker.state
self.assertGreater(state.escalation_rate, 0)
class TestDeescalationDetection(unittest.TestCase):
"""Test de-escalation: sustained LOW after HIGH/CRITICAL."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_no_deescalation_without_prior_crisis(self):
"""No de-escalation if never reached HIGH/CRITICAL."""
for _ in range(6):
self.tracker.record(detect_crisis("Hello"))
self.assertFalse(self.tracker.state.is_deescalating)
def test_deescalation_after_critical(self):
"""5+ consecutive LOW/NONE messages after CRITICAL = de-escalation."""
self.tracker.record(detect_crisis("I want to kill myself"))
for _ in range(5):
self.tracker.record(detect_crisis("I'm doing better today"))
state = self.tracker.state
if state.peak_level == "CRITICAL":
self.assertTrue(state.is_deescalating)
def test_deescalation_after_high(self):
"""5+ consecutive LOW/NONE messages after HIGH = de-escalation."""
self.tracker.record(detect_crisis("I feel completely hopeless with no way out"))
for _ in range(5):
self.tracker.record(detect_crisis("Feeling okay"))
state = self.tracker.state
if state.peak_level == "HIGH":
self.assertTrue(state.is_deescalating)
def test_interrupted_deescalation(self):
"""De-escalation resets if a HIGH message interrupts."""
self.tracker.record(detect_crisis("I want to kill myself"))
for _ in range(3):
self.tracker.record(detect_crisis("Doing better"))
# Interrupt with another crisis
self.tracker.record(detect_crisis("I feel hopeless again"))
self.tracker.record(detect_crisis("Feeling okay now"))
state = self.tracker.state
# Should NOT be de-escalating yet (counter reset)
self.assertFalse(state.is_deescalating)
class TestSessionModifier(unittest.TestCase):
"""Test system prompt modifier generation."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_no_modifier_for_single_message(self):
self.tracker.record(detect_crisis("Hello"))
self.assertEqual(self.tracker.get_session_modifier(), "")
def test_no_modifier_for_stable_session(self):
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("Good morning"))
self.assertEqual(self.tracker.get_session_modifier(), "")
def test_escalation_modifier(self):
"""Escalating session should produce a modifier."""
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("I want to die"))
modifier = self.tracker.get_session_modifier()
if self.tracker.state.is_escalating:
self.assertIn("escalated", modifier.lower())
self.assertIn("NONE", modifier)
self.assertIn("CRITICAL", modifier)
def test_deescalation_modifier(self):
"""De-escalating session should mention stabilizing."""
self.tracker.record(detect_crisis("I want to kill myself"))
for _ in range(5):
self.tracker.record(detect_crisis("I'm feeling okay"))
modifier = self.tracker.get_session_modifier()
if self.tracker.state.is_deescalating:
self.assertIn("stabilizing", modifier.lower())
def test_prior_crisis_modifier(self):
"""Past crisis should be noted even without active escalation."""
self.tracker.record(detect_crisis("I want to die"))
self.tracker.record(detect_crisis("Feeling a bit better"))
modifier = self.tracker.get_session_modifier()
# Should note the prior CRITICAL
if modifier:
self.assertIn("CRITICAL", modifier)
class TestUIHints(unittest.TestCase):
"""Test UI hint generation."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_ui_hints_structure(self):
self.tracker.record(detect_crisis("Hello"))
hints = self.tracker.get_ui_hints()
self.assertIn("session_escalating", hints)
self.assertIn("session_deescalating", hints)
self.assertIn("session_peak_level", hints)
self.assertIn("session_message_count", hints)
def test_ui_hints_escalation_warning(self):
"""Escalating session should have warning hint."""
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("I want to die"))
hints = self.tracker.get_ui_hints()
if hints["session_escalating"]:
self.assertTrue(hints.get("escalation_warning"))
self.assertIn("suggested_action", hints)
class TestCheckCrisisWithSession(unittest.TestCase):
"""Test the convenience function combining detection + session tracking."""
def test_returns_combined_data(self):
tracker = CrisisSessionTracker()
result = check_crisis_with_session("I want to die", tracker)
self.assertIn("level", result)
self.assertIn("session", result)
self.assertIn("current_level", result["session"])
self.assertIn("peak_level", result["session"])
self.assertIn("modifier", result["session"])
def test_session_updates_across_calls(self):
tracker = CrisisSessionTracker()
check_crisis_with_session("Hello", tracker)
result = check_crisis_with_session("I want to die", tracker)
self.assertEqual(result["session"]["message_count"], 2)
self.assertEqual(result["session"]["peak_level"], "CRITICAL")
class TestPrivacy(unittest.TestCase):
"""Verify privacy-first design principles."""
def test_no_persistence_mechanism(self):
"""Session tracker should have no database, file, or network calls."""
import inspect
source = inspect.getsource(CrisisSessionTracker)
# Should not import database, requests, or file I/O
forbidden = ["sqlite", "requests", "urllib", "open(", "httpx", "aiohttp"]
for word in forbidden:
self.assertNotIn(word, source.lower(),
f"Session tracker should not use {word} — privacy-first design")
def test_state_contained_in_memory(self):
"""All state should be instance attributes, not module-level."""
tracker = CrisisSessionTracker()
tracker.record(detect_crisis("I want to die"))
# New tracker should have clean state (no global contamination)
fresh = CrisisSessionTracker()
self.assertEqual(fresh.state.current_level, "NONE")
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,19 @@
from crisis.gateway import check_crisis_multimodal
from crisis.session_tracker import CrisisSessionTracker
def test_multimodal_gateway_uses_unified_score_for_988_ui(tmp_path):
tracker = CrisisSessionTracker()
result = check_crisis_multimodal(
"I want to kill myself tonight",
tracker=tracker,
voice_score=0.92,
image_score=0.6,
audit_log_path=tmp_path / "audit.jsonl",
)
assert result["unified"]["level"] == "CRITICAL"
assert result["ui"]["provide_988"] is True
assert result["ui"]["show_overlay"] is True
assert result["unified"]["modalities"]["voice"] == 0.92
assert result["unified"]["modalities"]["behavioral"] >= 0.0

View File

@@ -0,0 +1,51 @@
from pathlib import Path
from unified_scorer import (
CrisisLevel,
UnifiedCrisisScorer,
UnifiedScoreAuditLog,
behavioral_score_from_session,
)
from crisis.session_tracker import SessionState
def test_unified_scorer_renormalizes_available_modalities_and_escalates():
scorer = UnifiedCrisisScorer()
assessment = scorer.score(
text_score=1.0,
voice_score=0.8,
image_score=None,
behavioral_score=0.7,
)
assert assessment.level is CrisisLevel.CRITICAL
assert assessment.combined_score > 0.8
assert assessment.present_modalities == ["text", "voice", "behavioral"]
def test_behavioral_score_rises_for_escalating_session_state():
session = SessionState(
current_level="HIGH",
peak_level="CRITICAL",
message_count=4,
level_history=["LOW", "MEDIUM", "HIGH", "CRITICAL"],
is_escalating=True,
is_deescalating=False,
escalation_rate=1.0,
consecutive_low_messages=0,
)
assert behavioral_score_from_session(session) >= 0.8
def test_audit_log_persists_anonymized_score_entries(tmp_path):
log_path = tmp_path / "unified-score-audit.jsonl"
scorer = UnifiedCrisisScorer(audit_log=UnifiedScoreAuditLog(log_path))
scorer.score(text_score=0.75, voice_score=0.2, image_score=0.1, behavioral_score=0.6, source_text="I feel trapped and hopeless")
lines = log_path.read_text().strip().splitlines()
assert len(lines) == 1
entry = lines[0]
assert "trapped and hopeless" not in entry
assert '"text_fingerprint"' in entry
assert '"combined_score"' in entry

126
unified_scorer.py Normal file
View File

@@ -0,0 +1,126 @@
"""Unified multimodal crisis scoring for the-door."""
from __future__ import annotations
import hashlib
import json
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from crisis.session_tracker import SessionState
SCORE_BY_LEVEL = {"NONE": 0.0, "LOW": 0.25, "MEDIUM": 0.5, "HIGH": 0.75, "CRITICAL": 1.0}
LEVEL_RANK = {"NONE": 0, "LOW": 1, "MEDIUM": 2, "HIGH": 3, "CRITICAL": 4}
class CrisisLevel(Enum):
NONE = "NONE"
LOW = "LOW"
MEDIUM = "MEDIUM"
HIGH = "HIGH"
CRITICAL = "CRITICAL"
DEFAULT_WEIGHTS: Dict[str, float] = {
"text": 0.4,
"voice": 0.25,
"behavioral": 0.2,
"image": 0.15,
}
@dataclass
class UnifiedAssessment:
level: CrisisLevel
combined_score: float
weights: Dict[str, float]
modalities: Dict[str, Optional[float]]
present_modalities: List[str]
class UnifiedScoreAuditLog:
def __init__(self, path: Path | str):
self.path = Path(path)
def record(self, assessment: UnifiedAssessment, source_text: str = "") -> None:
self.path.parent.mkdir(parents=True, exist_ok=True)
fingerprint = hashlib.sha256(source_text.encode("utf-8")).hexdigest()[:12] if source_text else None
payload = {
"level": assessment.level.value,
"combined_score": round(assessment.combined_score, 4),
"weights": assessment.weights,
"modalities": assessment.modalities,
"present_modalities": assessment.present_modalities,
"text_fingerprint": fingerprint,
}
with self.path.open("a", encoding="utf-8") as fh:
fh.write(json.dumps(payload, sort_keys=True) + "\n")
class UnifiedCrisisScorer:
def __init__(self, weights: Optional[Dict[str, float]] = None, audit_log: Optional[UnifiedScoreAuditLog] = None):
self.weights = dict(DEFAULT_WEIGHTS)
if weights:
self.weights.update(weights)
self.audit_log = audit_log
def _normalize(self, modalities: Dict[str, Optional[float]]) -> Dict[str, float]:
present = [name for name, score in modalities.items() if score is not None]
if not present:
return {}
total = sum(self.weights[name] for name in present)
return {name: self.weights[name] / total for name in present}
def _level_for_score(self, score: float) -> CrisisLevel:
if score > 0.8:
return CrisisLevel.CRITICAL
if score > 0.6:
return CrisisLevel.HIGH
if score > 0.4:
return CrisisLevel.MEDIUM
if score > 0.0:
return CrisisLevel.LOW
return CrisisLevel.NONE
def score(
self,
*,
text_score: Optional[float],
voice_score: Optional[float] = None,
image_score: Optional[float] = None,
behavioral_score: Optional[float] = None,
source_text: str = "",
) -> UnifiedAssessment:
modalities = {
"text": text_score,
"voice": voice_score,
"behavioral": behavioral_score,
"image": image_score,
}
normalized = self._normalize(modalities)
combined = 0.0
for name, weight in normalized.items():
combined += float(modalities[name]) * weight
assessment = UnifiedAssessment(
level=self._level_for_score(combined),
combined_score=combined,
weights=normalized,
modalities=modalities,
present_modalities=[name for name, score in modalities.items() if score is not None],
)
if self.audit_log:
self.audit_log.record(assessment, source_text=source_text)
return assessment
def behavioral_score_from_session(session: 'SessionState') -> float:
current = SCORE_BY_LEVEL.get(session.current_level, 0.0)
peak_bonus = 0.1 if LEVEL_RANK.get(session.peak_level, 0) >= LEVEL_RANK["HIGH"] else 0.0
escalation_bonus = 0.15 if session.is_escalating else 0.0
rate_bonus = min(max(session.escalation_rate, 0.0), 1.0) * 0.1
deescalation_penalty = 0.15 if session.is_deescalating else 0.0
return max(0.0, min(1.0, current + peak_bonus + escalation_bonus + rate_bonus - deescalation_penalty))