Compare commits

..

1 Commits

Author SHA1 Message Date
Alexander Whitestone
0ab2626ef2 feat: image content screening for self-harm indicators (closes #132)
All checks were successful
Sanity Checks / sanity-test (pull_request) Successful in 4s
Smoke Test / smoke (pull_request) Successful in 10s
2026-04-15 12:02:22 -04:00
7 changed files with 239 additions and 580 deletions

View File

@@ -7,7 +7,6 @@ Stands between a broken man and a machine that would tell him to die.
from .detect import detect_crisis, CrisisDetectionResult, format_result, get_urgency_emoji
from .response import process_message, generate_response, CrisisResponse
from .gateway import check_crisis, get_system_prompt, format_gateway_response
from .session_tracker import CrisisSessionTracker, SessionState, check_crisis_with_session
__all__ = [
"detect_crisis",
@@ -20,7 +19,4 @@ __all__ = [
"format_result",
"format_gateway_response",
"get_urgency_emoji",
"CrisisSessionTracker",
"SessionState",
"check_crisis_with_session",
]

View File

@@ -22,7 +22,6 @@ from .response import (
get_system_prompt_modifier,
CrisisResponse,
)
from .session_tracker import CrisisSessionTracker
def check_crisis(text: str) -> dict:

View File

@@ -1,259 +0,0 @@
"""
Session-level crisis tracking and escalation for the-door (P0 #35).
Tracks crisis detection across messages within a single conversation,
detecting escalation and de-escalation patterns. Privacy-first: no
persistence beyond the conversation session.
Each message is analyzed in isolation by detect.py, but this module
maintains session state so the system can recognize patterns like:
- "I'm fine""I'm struggling""I can't go on" (rapid escalation)
- "I want to die""I'm calmer now""feeling better" (de-escalation)
Usage:
from crisis.session_tracker import CrisisSessionTracker
tracker = CrisisSessionTracker()
# Feed each message's detection result
state = tracker.record(detect_crisis("I'm having a tough day"))
print(state.current_level) # "LOW"
print(state.is_escalating) # False
state = tracker.record(detect_crisis("I feel hopeless"))
print(state.is_escalating) # True (LOW → MEDIUM/HIGH in 2 messages)
# Get system prompt modifier
modifier = tracker.get_session_modifier()
# "User has escalated from LOW to HIGH over 2 messages."
# Reset for new session
tracker.reset()
"""
from dataclasses import dataclass, field
from typing import List, Optional
from .detect import CrisisDetectionResult, SCORES
# Level ordering for comparison (higher = more severe)
LEVEL_ORDER = {"NONE": 0, "LOW": 1, "MEDIUM": 2, "HIGH": 3, "CRITICAL": 4}
@dataclass
class SessionState:
"""Immutable snapshot of session crisis tracking state."""
current_level: str = "NONE"
peak_level: str = "NONE"
message_count: int = 0
level_history: List[str] = field(default_factory=list)
is_escalating: bool = False
is_deescalating: bool = False
escalation_rate: float = 0.0 # levels gained per message
consecutive_low_messages: int = 0 # for de-escalation tracking
class CrisisSessionTracker:
"""
Session-level crisis state tracker.
Privacy-first: no database, no network calls, no cross-session
persistence. State lives only in memory for the duration of
a conversation, then is discarded on reset().
"""
# Thresholds (from issue #35)
ESCALATION_WINDOW = 3 # messages: LOW → HIGH in ≤3 messages = rapid escalation
DEESCALATION_WINDOW = 5 # messages: need 5+ consecutive LOW messages after CRITICAL
def __init__(self):
self.reset()
def reset(self):
"""Reset all session state. Call on new conversation."""
self._current_level = "NONE"
self._peak_level = "NONE"
self._message_count = 0
self._level_history: List[str] = []
self._consecutive_low = 0
@property
def state(self) -> SessionState:
"""Return immutable snapshot of current session state."""
is_escalating = self._detect_escalation()
is_deescalating = self._detect_deescalation()
rate = self._compute_escalation_rate()
return SessionState(
current_level=self._current_level,
peak_level=self._peak_level,
message_count=self._message_count,
level_history=list(self._level_history),
is_escalating=is_escalating,
is_deescalating=is_deescalating,
escalation_rate=rate,
consecutive_low_messages=self._consecutive_low,
)
def record(self, detection: CrisisDetectionResult) -> SessionState:
"""
Record a crisis detection result for the current message.
Returns updated SessionState.
"""
level = detection.level
self._message_count += 1
self._level_history.append(level)
# Update peak
if LEVEL_ORDER.get(level, 0) > LEVEL_ORDER.get(self._peak_level, 0):
self._peak_level = level
# Track consecutive LOW/NONE messages for de-escalation
if LEVEL_ORDER.get(level, 0) <= LEVEL_ORDER["LOW"]:
self._consecutive_low += 1
else:
self._consecutive_low = 0
self._current_level = level
return self.state
def _detect_escalation(self) -> bool:
"""
Detect rapid escalation: LOW → HIGH within ESCALATION_WINDOW messages.
Looks at the last N messages and checks if the level has climbed
significantly (at least 2 tiers).
"""
if len(self._level_history) < 2:
return False
window = self._level_history[-self.ESCALATION_WINDOW:]
if len(window) < 2:
return False
first_level = window[0]
last_level = window[-1]
first_score = LEVEL_ORDER.get(first_level, 0)
last_score = LEVEL_ORDER.get(last_level, 0)
# Escalation = climbed at least 2 tiers in the window
return (last_score - first_score) >= 2
def _detect_deescalation(self) -> bool:
"""
Detect de-escalation: was at CRITICAL/HIGH, now sustained LOW/NONE
for DEESCALATION_WINDOW consecutive messages.
"""
if LEVEL_ORDER.get(self._peak_level, 0) < LEVEL_ORDER["HIGH"]:
return False
return self._consecutive_low >= self.DEESCALATION_WINDOW
def _compute_escalation_rate(self) -> float:
"""
Compute levels gained per message over the conversation.
Positive = escalating, negative = de-escalating, 0 = stable.
"""
if self._message_count < 2:
return 0.0
first = LEVEL_ORDER.get(self._level_history[0], 0)
current = LEVEL_ORDER.get(self._current_level, 0)
return (current - first) / (self._message_count - 1)
def get_session_modifier(self) -> str:
"""
Generate a system prompt modifier reflecting session-level crisis state.
Returns empty string if no session context is relevant.
"""
if self._message_count < 2:
return ""
s = self.state
if s.is_escalating:
return (
f"User has escalated from {self._level_history[0]} to "
f"{s.current_level} over {s.message_count} messages. "
f"Peak crisis level this session: {s.peak_level}. "
"Respond with heightened awareness. The trajectory is "
"worsening — prioritize safety and connection."
)
if s.is_deescalating:
return (
f"User previously reached {s.peak_level} crisis level "
f"but has been at {s.current_level} or below for "
f"{s.consecutive_low_messages} consecutive messages. "
"The situation appears to be stabilizing. Continue "
"supportive engagement while remaining vigilant."
)
if s.peak_level in ("CRITICAL", "HIGH") and s.current_level not in ("CRITICAL", "HIGH"):
return (
f"User previously reached {s.peak_level} crisis level "
f"this session (currently {s.current_level}). "
"Continue with care and awareness of the earlier crisis."
)
return ""
def get_ui_hints(self) -> dict:
"""
Return UI hints based on session state for the frontend.
These are advisory — the frontend decides what to show.
"""
s = self.state
hints = {
"session_escalating": s.is_escalating,
"session_deescalating": s.is_deescalating,
"session_peak_level": s.peak_level,
"session_message_count": s.message_count,
}
if s.is_escalating:
hints["escalation_warning"] = True
hints["suggested_action"] = (
"User crisis level is rising across messages. "
"Consider increasing intervention level."
)
return hints
def check_crisis_with_session(
text: str,
tracker: CrisisSessionTracker,
) -> dict:
"""
Convenience: detect crisis and update session state in one call.
Returns combined single-message detection + session-level context.
"""
from .detect import detect_crisis
from .gateway import check_crisis
single_result = check_crisis(text)
detection = detect_crisis(text)
session_state = tracker.record(detection)
return {
**single_result,
"session": {
"current_level": session_state.current_level,
"peak_level": session_state.peak_level,
"message_count": session_state.message_count,
"is_escalating": session_state.is_escalating,
"is_deescalating": session_state.is_deescalating,
"modifier": tracker.get_session_modifier(),
"ui_hints": tracker.get_ui_hints(),
},
}

155
image_screening.py Normal file
View File

@@ -0,0 +1,155 @@
"""
Image Content Screening for Self-Harm Indicators
Screens images sent in chat for self-harm content using vision models.
Privacy-preserving: analyzes in-memory, discards after scoring. No image retention.
"""
import base64
import json
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Optional
logger = logging.getLogger("the-door.image_screening")
class RiskLevel(str, Enum):
SAFE = "safe"
CONCERNING = "concerning"
CRITICAL = "critical"
@dataclass
class ImageScreenResult:
risk_level: RiskLevel
confidence: float
categories: list
reasoning: str
model_used: str
@property
def needs_crisis_response(self):
return self.risk_level == RiskLevel.CRITICAL
@property
def needs_followup(self):
return self.risk_level in (RiskLevel.CONCERNING, RiskLevel.CRITICAL)
def to_dict(self):
return {
"risk_level": self.risk_level.value,
"confidence": round(self.confidence, 3),
"categories": self.categories,
"reasoning": self.reasoning,
"model_used": self.model_used,
}
SCREENING_PROMPT = """Analyze this image for indicators of self-harm or crisis. Look for:
1. Self-harm wounds or injuries
2. Medication/pills in concerning quantities
3. Farewell-type imagery (empty rooms, items laid out, goodbye notes)
4. Screenshots of crisis-related searches
5. Images conveying hopelessness, isolation, or despair
Respond in JSON format ONLY:
{"risk_level": "safe"|"concerning"|"critical", "confidence": 0.0-1.0, "categories": [], "reasoning": ""}
CRITICAL: imminent self-harm indicators
CONCERNING: ambiguous but worrying
SAFE: no indicators detected"""
def _analyze_with_ollama(image_b64, model="gemma3:4b"):
try:
import urllib.request
payload = json.dumps({
"model": model,
"messages": [{
"role": "user",
"content": SCREENING_PROMPT,
"images": [image_b64],
}],
"stream": False,
"options": {"temperature": 0.1},
}).encode()
req = urllib.request.Request(
"http://localhost:11434/api/chat",
data=payload,
headers={"Content-Type": "application/json"},
method="POST",
)
resp = urllib.request.urlopen(req, timeout=30)
data = json.loads(resp.read())
content = data.get("message", {}).get("content", "")
json_start = content.find("{")
json_end = content.rfind("}") + 1
if json_start == -1 or json_end <= json_start:
return None
result = json.loads(content[json_start:json_end])
return ImageScreenResult(
risk_level=RiskLevel(result.get("risk_level", "safe")),
confidence=float(result.get("confidence", 0.5)),
categories=result.get("categories", []),
reasoning=result.get("reasoning", ""),
model_used=f"ollama:{model}",
)
except Exception as e:
logger.warning(f"Ollama vision analysis failed: {e}")
return None
def _analyze_fallback(image_bytes):
return ImageScreenResult(
risk_level=RiskLevel.SAFE,
confidence=0.2,
categories=["unanalyzed"],
reasoning="No vision model available. Defaulting to safe with low confidence.",
model_used="fallback:heuristic",
)
def screen_image(image_data, use_vision_model=True, model="gemma3:4b"):
"""Screen image for self-harm indicators. Analyzes in-memory, no retention."""
if isinstance(image_data, bytes):
image_b64 = base64.b64encode(image_data).decode()
else:
image_b64 = image_data
image_data = base64.b64decode(image_b64)
if use_vision_model:
result = _analyze_with_ollama(image_b64, model)
if result:
logger.info(f"Image screened: {result.risk_level.value} (conf: {result.confidence:.2f})")
if result.needs_crisis_response:
logger.warning(f"CRITICAL image: {result.reasoning}")
return result
return _analyze_fallback(image_data)
def handle_chat_image(image_data):
"""Handle image from chat. Returns action dict for gateway."""
result = screen_image(image_data)
action = {
"result": result.to_dict(),
"show_crisis_overlay": result.needs_crisis_response,
"log_event": result.needs_followup,
"response_text": None,
}
if result.risk_level == RiskLevel.CRITICAL:
action["response_text"] = (
"I noticed something concerning in the image you shared. "
"If you or someone you know is in crisis, please reach out: "
"988 Suicide and Crisis Lifeline (call or text 988). "
"You are not alone."
)
elif result.risk_level == RiskLevel.CONCERNING:
action["response_text"] = (
"I want to check in \u2014 how are you doing? "
"If you need to talk to someone, the 988 Lifeline is available 24/7."
)
return action

View File

@@ -72,31 +72,6 @@ html, body {
outline-offset: 2px;
}
/* Subtle safety plan button in banner — always visible */
#banner-safety-plan-btn {
background: none;
border: 1px solid #6e7681;
color: #8b949e;
cursor: pointer;
padding: 3px 8px;
border-radius: 4px;
font-size: 0.75rem;
display: flex;
align-items: center;
gap: 4px;
transition: background 0.2s, border-color 0.2s, color 0.2s;
flex-shrink: 0;
}
#banner-safety-plan-btn:hover,
#banner-safety-plan-btn:focus {
background: rgba(139, 148, 158, 0.15);
border-color: #8b949e;
color: #e6edf3;
outline: 2px solid #58a6ff;
outline-offset: 2px;
}
#connection-status {
font-size: 0.7rem;
color: #6e7681;
@@ -650,10 +625,6 @@ html, body {
<a href="tel:988" aria-label="Call 988 Suicide and Crisis Lifeline">
988 Suicide &amp; Crisis Lifeline — Call or text 988
</a>
<button id="banner-safety-plan-btn" aria-label="Open my safety plan" title="My Safety Plan">
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" aria-hidden="true"><path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z"/><polyline points="14 2 14 8 20 8"/><line x1="16" y1="13" x2="8" y2="13"/><line x1="16" y1="17" x2="8" y2="17"/><polyline points="10 9 9 9 8 9"/></svg>
<span class="sr-only">My Safety Plan</span>
</button>
<div id="connection-status" aria-hidden="true">
<span class="status-dot"></span>
<span id="status-text">Online</span>
@@ -843,7 +814,6 @@ Sovereignty and service always.`;
// Safety Plan Elements
var safetyPlanBtn = document.getElementById('safety-plan-btn');
var crisisSafetyPlanBtn = document.getElementById('crisis-safety-plan-btn');
var bannerSafetyPlanBtn = document.getElementById('banner-safety-plan-btn');
var safetyPlanModal = document.getElementById('safety-plan-modal');
var closeSafetyPlan = document.getElementById('close-safety-plan');
var cancelSafetyPlan = document.getElementById('cancel-safety-plan');
@@ -1329,15 +1299,6 @@ Sovereignty and service always.`;
});
}
// Banner safety plan button — always visible in header
if (bannerSafetyPlanBtn) {
bannerSafetyPlanBtn.addEventListener('click', function() {
loadSafetyPlan();
safetyPlanModal.classList.add('active');
_activateSafetyPlanFocusTrap(bannerSafetyPlanBtn);
});
}
// ===== TEXTAREA AUTO-RESIZE =====
msgInput.addEventListener('input', function() {
this.style.height = 'auto';

View File

@@ -0,0 +1,84 @@
"""Tests for image content screening module."""
import json
from unittest.mock import patch, MagicMock
from image_screening import (
RiskLevel,
ImageScreenResult,
screen_image,
handle_chat_image,
_analyze_fallback,
)
class TestImageScreenResult:
def test_safe_result(self):
result = ImageScreenResult(
risk_level=RiskLevel.SAFE, confidence=0.95,
categories=[], reasoning="No indicators", model_used="test"
)
assert not result.needs_crisis_response
assert not result.needs_followup
assert result.to_dict()["risk_level"] == "safe"
def test_critical_result(self):
result = ImageScreenResult(
risk_level=RiskLevel.CRITICAL, confidence=0.9,
categories=["wounds"], reasoning="Detected", model_used="test"
)
assert result.needs_crisis_response
assert result.needs_followup
def test_concerning_result(self):
result = ImageScreenResult(
risk_level=RiskLevel.CONCERNING, confidence=0.6,
categories=["isolation"], reasoning="Ambiguous", model_used="test"
)
assert not result.needs_crisis_response
assert result.needs_followup
class TestScreenImage:
def test_fallback_returns_safe(self):
result = screen_image(b"fake_image_data", use_vision_model=False)
assert result.risk_level == RiskLevel.SAFE
assert result.model_used == "fallback:heuristic"
assert result.confidence < 0.5
def test_base64_input(self):
import base64
b64 = base64.b64encode(b"fake").decode()
result = screen_image(b64, use_vision_model=False)
assert result.risk_level == RiskLevel.SAFE
class TestHandleChatImage:
def test_safe_image_no_overlay(self):
action = handle_chat_image(b"safe_image")
assert not action["show_crisis_overlay"]
assert action["response_text"] is None
@patch("image_screening._analyze_with_ollama")
def test_critical_image_shows_overlay(self, mock_ollama):
mock_ollama.return_value = ImageScreenResult(
risk_level=RiskLevel.CRITICAL, confidence=0.95,
categories=["wounds"], reasoning="Self-harm detected",
model_used="ollama:gemma3:4b"
)
action = handle_chat_image(b"concerning_image")
assert action["show_crisis_overlay"]
assert "988" in action["response_text"]
assert action["log_event"]
@patch("image_screening._analyze_with_ollama")
def test_concerning_image_followup(self, mock_ollama):
mock_ollama.return_value = ImageScreenResult(
risk_level=RiskLevel.CONCERNING, confidence=0.6,
categories=["isolation"], reasoning="Empty room",
model_used="ollama:gemma3:4b"
)
action = handle_chat_image(b"maybe_concerning")
assert not action["show_crisis_overlay"]
assert action["log_event"]
assert "check in" in action["response_text"]

View File

@@ -1,277 +0,0 @@
"""
Tests for crisis session tracking and escalation (P0 #35).
Covers: session_tracker.py
Run with: python -m pytest tests/test_session_tracker.py -v
"""
import unittest
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from crisis.detect import detect_crisis
from crisis.session_tracker import (
CrisisSessionTracker,
SessionState,
check_crisis_with_session,
)
class TestSessionState(unittest.TestCase):
"""Test SessionState defaults."""
def test_default_state(self):
s = SessionState()
self.assertEqual(s.current_level, "NONE")
self.assertEqual(s.peak_level, "NONE")
self.assertEqual(s.message_count, 0)
self.assertEqual(s.level_history, [])
self.assertFalse(s.is_escalating)
self.assertFalse(s.is_deescalating)
class TestSessionTracking(unittest.TestCase):
"""Test basic session state tracking."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_record_none_message(self):
state = self.tracker.record(detect_crisis("Hello Timmy"))
self.assertEqual(state.current_level, "NONE")
self.assertEqual(state.message_count, 1)
self.assertEqual(state.peak_level, "NONE")
def test_record_low_message(self):
self.tracker.record(detect_crisis("Hello"))
state = self.tracker.record(detect_crisis("Having a rough day"))
self.assertIn(state.current_level, ("LOW", "NONE"))
self.assertEqual(state.message_count, 2)
def test_record_critical_updates_peak(self):
self.tracker.record(detect_crisis("Having a rough day"))
state = self.tracker.record(detect_crisis("I want to kill myself"))
self.assertEqual(state.current_level, "CRITICAL")
self.assertEqual(state.peak_level, "CRITICAL")
def test_peak_preserved_after_drop(self):
"""Peak level should stay at the highest seen, even after de-escalation."""
self.tracker.record(detect_crisis("I want to kill myself"))
state = self.tracker.record(detect_crisis("I'm feeling a bit better"))
self.assertEqual(state.peak_level, "CRITICAL")
def test_level_history(self):
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("Having a rough day"))
state = self.tracker.record(detect_crisis("I want to die"))
self.assertEqual(len(state.level_history), 3)
self.assertEqual(state.level_history[0], "NONE")
self.assertEqual(state.level_history[2], "CRITICAL")
def test_reset_clears_state(self):
self.tracker.record(detect_crisis("I want to kill myself"))
self.tracker.reset()
state = self.tracker.state
self.assertEqual(state.current_level, "NONE")
self.assertEqual(state.peak_level, "NONE")
self.assertEqual(state.message_count, 0)
self.assertEqual(state.level_history, [])
class TestEscalationDetection(unittest.TestCase):
"""Test escalation detection: LOW → HIGH in ≤3 messages."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_no_escalation_single_message(self):
self.tracker.record(detect_crisis("Hello"))
self.assertFalse(self.tracker.state.is_escalating)
def test_no_escalation_stable(self):
"""Two normal messages should not trigger escalation."""
self.tracker.record(detect_crisis("Hello"))
state = self.tracker.record(detect_crisis("How are you?"))
self.assertFalse(state.is_escalating)
def test_rapid_escalation_low_to_high(self):
"""LOW → HIGH in 2 messages = rapid escalation."""
self.tracker.record(detect_crisis("Having a rough day"))
state = self.tracker.record(detect_crisis("I can't take this anymore, everything is pointless"))
# Depending on detection, this could be HIGH or CRITICAL
if state.current_level in ("HIGH", "CRITICAL"):
self.assertTrue(state.is_escalating)
def test_rapid_escalation_three_messages(self):
"""NONE → LOW → HIGH in 3 messages = escalation."""
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("Having a rough day"))
state = self.tracker.record(detect_crisis("I feel completely hopeless with no way out"))
if state.current_level in ("HIGH", "CRITICAL"):
self.assertTrue(state.is_escalating)
def test_escalation_rate(self):
"""Rate should be positive when escalating."""
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("I want to die"))
state = self.tracker.state
self.assertGreater(state.escalation_rate, 0)
class TestDeescalationDetection(unittest.TestCase):
"""Test de-escalation: sustained LOW after HIGH/CRITICAL."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_no_deescalation_without_prior_crisis(self):
"""No de-escalation if never reached HIGH/CRITICAL."""
for _ in range(6):
self.tracker.record(detect_crisis("Hello"))
self.assertFalse(self.tracker.state.is_deescalating)
def test_deescalation_after_critical(self):
"""5+ consecutive LOW/NONE messages after CRITICAL = de-escalation."""
self.tracker.record(detect_crisis("I want to kill myself"))
for _ in range(5):
self.tracker.record(detect_crisis("I'm doing better today"))
state = self.tracker.state
if state.peak_level == "CRITICAL":
self.assertTrue(state.is_deescalating)
def test_deescalation_after_high(self):
"""5+ consecutive LOW/NONE messages after HIGH = de-escalation."""
self.tracker.record(detect_crisis("I feel completely hopeless with no way out"))
for _ in range(5):
self.tracker.record(detect_crisis("Feeling okay"))
state = self.tracker.state
if state.peak_level == "HIGH":
self.assertTrue(state.is_deescalating)
def test_interrupted_deescalation(self):
"""De-escalation resets if a HIGH message interrupts."""
self.tracker.record(detect_crisis("I want to kill myself"))
for _ in range(3):
self.tracker.record(detect_crisis("Doing better"))
# Interrupt with another crisis
self.tracker.record(detect_crisis("I feel hopeless again"))
self.tracker.record(detect_crisis("Feeling okay now"))
state = self.tracker.state
# Should NOT be de-escalating yet (counter reset)
self.assertFalse(state.is_deescalating)
class TestSessionModifier(unittest.TestCase):
"""Test system prompt modifier generation."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_no_modifier_for_single_message(self):
self.tracker.record(detect_crisis("Hello"))
self.assertEqual(self.tracker.get_session_modifier(), "")
def test_no_modifier_for_stable_session(self):
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("Good morning"))
self.assertEqual(self.tracker.get_session_modifier(), "")
def test_escalation_modifier(self):
"""Escalating session should produce a modifier."""
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("I want to die"))
modifier = self.tracker.get_session_modifier()
if self.tracker.state.is_escalating:
self.assertIn("escalated", modifier.lower())
self.assertIn("NONE", modifier)
self.assertIn("CRITICAL", modifier)
def test_deescalation_modifier(self):
"""De-escalating session should mention stabilizing."""
self.tracker.record(detect_crisis("I want to kill myself"))
for _ in range(5):
self.tracker.record(detect_crisis("I'm feeling okay"))
modifier = self.tracker.get_session_modifier()
if self.tracker.state.is_deescalating:
self.assertIn("stabilizing", modifier.lower())
def test_prior_crisis_modifier(self):
"""Past crisis should be noted even without active escalation."""
self.tracker.record(detect_crisis("I want to die"))
self.tracker.record(detect_crisis("Feeling a bit better"))
modifier = self.tracker.get_session_modifier()
# Should note the prior CRITICAL
if modifier:
self.assertIn("CRITICAL", modifier)
class TestUIHints(unittest.TestCase):
"""Test UI hint generation."""
def setUp(self):
self.tracker = CrisisSessionTracker()
def test_ui_hints_structure(self):
self.tracker.record(detect_crisis("Hello"))
hints = self.tracker.get_ui_hints()
self.assertIn("session_escalating", hints)
self.assertIn("session_deescalating", hints)
self.assertIn("session_peak_level", hints)
self.assertIn("session_message_count", hints)
def test_ui_hints_escalation_warning(self):
"""Escalating session should have warning hint."""
self.tracker.record(detect_crisis("Hello"))
self.tracker.record(detect_crisis("I want to die"))
hints = self.tracker.get_ui_hints()
if hints["session_escalating"]:
self.assertTrue(hints.get("escalation_warning"))
self.assertIn("suggested_action", hints)
class TestCheckCrisisWithSession(unittest.TestCase):
"""Test the convenience function combining detection + session tracking."""
def test_returns_combined_data(self):
tracker = CrisisSessionTracker()
result = check_crisis_with_session("I want to die", tracker)
self.assertIn("level", result)
self.assertIn("session", result)
self.assertIn("current_level", result["session"])
self.assertIn("peak_level", result["session"])
self.assertIn("modifier", result["session"])
def test_session_updates_across_calls(self):
tracker = CrisisSessionTracker()
check_crisis_with_session("Hello", tracker)
result = check_crisis_with_session("I want to die", tracker)
self.assertEqual(result["session"]["message_count"], 2)
self.assertEqual(result["session"]["peak_level"], "CRITICAL")
class TestPrivacy(unittest.TestCase):
"""Verify privacy-first design principles."""
def test_no_persistence_mechanism(self):
"""Session tracker should have no database, file, or network calls."""
import inspect
source = inspect.getsource(CrisisSessionTracker)
# Should not import database, requests, or file I/O
forbidden = ["sqlite", "requests", "urllib", "open(", "httpx", "aiohttp"]
for word in forbidden:
self.assertNotIn(word, source.lower(),
f"Session tracker should not use {word} — privacy-first design")
def test_state_contained_in_memory(self):
"""All state should be instance attributes, not module-level."""
tracker = CrisisSessionTracker()
tracker.record(detect_crisis("I want to die"))
# New tracker should have clean state (no global contamination)
fresh = CrisisSessionTracker()
self.assertEqual(fresh.state.current_level, "NONE")
if __name__ == '__main__':
unittest.main()