278 lines
11 KiB
Python
278 lines
11 KiB
Python
"""
|
|
Tests for crisis session tracking and escalation (P0 #35).
|
|
|
|
Covers: session_tracker.py
|
|
Run with: python -m pytest tests/test_session_tracker.py -v
|
|
"""
|
|
|
|
import unittest
|
|
import sys
|
|
import os
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from crisis.detect import detect_crisis
|
|
from crisis.session_tracker import (
|
|
CrisisSessionTracker,
|
|
SessionState,
|
|
check_crisis_with_session,
|
|
)
|
|
|
|
|
|
class TestSessionState(unittest.TestCase):
|
|
"""Test SessionState defaults."""
|
|
|
|
def test_default_state(self):
|
|
s = SessionState()
|
|
self.assertEqual(s.current_level, "NONE")
|
|
self.assertEqual(s.peak_level, "NONE")
|
|
self.assertEqual(s.message_count, 0)
|
|
self.assertEqual(s.level_history, [])
|
|
self.assertFalse(s.is_escalating)
|
|
self.assertFalse(s.is_deescalating)
|
|
|
|
|
|
class TestSessionTracking(unittest.TestCase):
|
|
"""Test basic session state tracking."""
|
|
|
|
def setUp(self):
|
|
self.tracker = CrisisSessionTracker()
|
|
|
|
def test_record_none_message(self):
|
|
state = self.tracker.record(detect_crisis("Hello Timmy"))
|
|
self.assertEqual(state.current_level, "NONE")
|
|
self.assertEqual(state.message_count, 1)
|
|
self.assertEqual(state.peak_level, "NONE")
|
|
|
|
def test_record_low_message(self):
|
|
self.tracker.record(detect_crisis("Hello"))
|
|
state = self.tracker.record(detect_crisis("Having a rough day"))
|
|
self.assertIn(state.current_level, ("LOW", "NONE"))
|
|
self.assertEqual(state.message_count, 2)
|
|
|
|
def test_record_critical_updates_peak(self):
|
|
self.tracker.record(detect_crisis("Having a rough day"))
|
|
state = self.tracker.record(detect_crisis("I want to kill myself"))
|
|
self.assertEqual(state.current_level, "CRITICAL")
|
|
self.assertEqual(state.peak_level, "CRITICAL")
|
|
|
|
def test_peak_preserved_after_drop(self):
|
|
"""Peak level should stay at the highest seen, even after de-escalation."""
|
|
self.tracker.record(detect_crisis("I want to kill myself"))
|
|
state = self.tracker.record(detect_crisis("I'm feeling a bit better"))
|
|
self.assertEqual(state.peak_level, "CRITICAL")
|
|
|
|
def test_level_history(self):
|
|
self.tracker.record(detect_crisis("Hello"))
|
|
self.tracker.record(detect_crisis("Having a rough day"))
|
|
state = self.tracker.record(detect_crisis("I want to die"))
|
|
self.assertEqual(len(state.level_history), 3)
|
|
self.assertEqual(state.level_history[0], "NONE")
|
|
self.assertEqual(state.level_history[2], "CRITICAL")
|
|
|
|
def test_reset_clears_state(self):
|
|
self.tracker.record(detect_crisis("I want to kill myself"))
|
|
self.tracker.reset()
|
|
state = self.tracker.state
|
|
self.assertEqual(state.current_level, "NONE")
|
|
self.assertEqual(state.peak_level, "NONE")
|
|
self.assertEqual(state.message_count, 0)
|
|
self.assertEqual(state.level_history, [])
|
|
|
|
|
|
class TestEscalationDetection(unittest.TestCase):
|
|
"""Test escalation detection: LOW → HIGH in ≤3 messages."""
|
|
|
|
def setUp(self):
|
|
self.tracker = CrisisSessionTracker()
|
|
|
|
def test_no_escalation_single_message(self):
|
|
self.tracker.record(detect_crisis("Hello"))
|
|
self.assertFalse(self.tracker.state.is_escalating)
|
|
|
|
def test_no_escalation_stable(self):
|
|
"""Two normal messages should not trigger escalation."""
|
|
self.tracker.record(detect_crisis("Hello"))
|
|
state = self.tracker.record(detect_crisis("How are you?"))
|
|
self.assertFalse(state.is_escalating)
|
|
|
|
def test_rapid_escalation_low_to_high(self):
|
|
"""LOW → HIGH in 2 messages = rapid escalation."""
|
|
self.tracker.record(detect_crisis("Having a rough day"))
|
|
state = self.tracker.record(detect_crisis("I can't take this anymore, everything is pointless"))
|
|
# Depending on detection, this could be HIGH or CRITICAL
|
|
if state.current_level in ("HIGH", "CRITICAL"):
|
|
self.assertTrue(state.is_escalating)
|
|
|
|
def test_rapid_escalation_three_messages(self):
|
|
"""NONE → LOW → HIGH in 3 messages = escalation."""
|
|
self.tracker.record(detect_crisis("Hello"))
|
|
self.tracker.record(detect_crisis("Having a rough day"))
|
|
state = self.tracker.record(detect_crisis("I feel completely hopeless with no way out"))
|
|
if state.current_level in ("HIGH", "CRITICAL"):
|
|
self.assertTrue(state.is_escalating)
|
|
|
|
def test_escalation_rate(self):
|
|
"""Rate should be positive when escalating."""
|
|
self.tracker.record(detect_crisis("Hello"))
|
|
self.tracker.record(detect_crisis("I want to die"))
|
|
state = self.tracker.state
|
|
self.assertGreater(state.escalation_rate, 0)
|
|
|
|
|
|
class TestDeescalationDetection(unittest.TestCase):
|
|
"""Test de-escalation: sustained LOW after HIGH/CRITICAL."""
|
|
|
|
def setUp(self):
|
|
self.tracker = CrisisSessionTracker()
|
|
|
|
def test_no_deescalation_without_prior_crisis(self):
|
|
"""No de-escalation if never reached HIGH/CRITICAL."""
|
|
for _ in range(6):
|
|
self.tracker.record(detect_crisis("Hello"))
|
|
self.assertFalse(self.tracker.state.is_deescalating)
|
|
|
|
def test_deescalation_after_critical(self):
|
|
"""5+ consecutive LOW/NONE messages after CRITICAL = de-escalation."""
|
|
self.tracker.record(detect_crisis("I want to kill myself"))
|
|
for _ in range(5):
|
|
self.tracker.record(detect_crisis("I'm doing better today"))
|
|
state = self.tracker.state
|
|
if state.peak_level == "CRITICAL":
|
|
self.assertTrue(state.is_deescalating)
|
|
|
|
def test_deescalation_after_high(self):
|
|
"""5+ consecutive LOW/NONE messages after HIGH = de-escalation."""
|
|
self.tracker.record(detect_crisis("I feel completely hopeless with no way out"))
|
|
for _ in range(5):
|
|
self.tracker.record(detect_crisis("Feeling okay"))
|
|
state = self.tracker.state
|
|
if state.peak_level == "HIGH":
|
|
self.assertTrue(state.is_deescalating)
|
|
|
|
def test_interrupted_deescalation(self):
|
|
"""De-escalation resets if a HIGH message interrupts."""
|
|
self.tracker.record(detect_crisis("I want to kill myself"))
|
|
for _ in range(3):
|
|
self.tracker.record(detect_crisis("Doing better"))
|
|
# Interrupt with another crisis
|
|
self.tracker.record(detect_crisis("I feel hopeless again"))
|
|
self.tracker.record(detect_crisis("Feeling okay now"))
|
|
state = self.tracker.state
|
|
# Should NOT be de-escalating yet (counter reset)
|
|
self.assertFalse(state.is_deescalating)
|
|
|
|
|
|
class TestSessionModifier(unittest.TestCase):
|
|
"""Test system prompt modifier generation."""
|
|
|
|
def setUp(self):
|
|
self.tracker = CrisisSessionTracker()
|
|
|
|
def test_no_modifier_for_single_message(self):
|
|
self.tracker.record(detect_crisis("Hello"))
|
|
self.assertEqual(self.tracker.get_session_modifier(), "")
|
|
|
|
def test_no_modifier_for_stable_session(self):
|
|
self.tracker.record(detect_crisis("Hello"))
|
|
self.tracker.record(detect_crisis("Good morning"))
|
|
self.assertEqual(self.tracker.get_session_modifier(), "")
|
|
|
|
def test_escalation_modifier(self):
|
|
"""Escalating session should produce a modifier."""
|
|
self.tracker.record(detect_crisis("Hello"))
|
|
self.tracker.record(detect_crisis("I want to die"))
|
|
modifier = self.tracker.get_session_modifier()
|
|
if self.tracker.state.is_escalating:
|
|
self.assertIn("escalated", modifier.lower())
|
|
self.assertIn("NONE", modifier)
|
|
self.assertIn("CRITICAL", modifier)
|
|
|
|
def test_deescalation_modifier(self):
|
|
"""De-escalating session should mention stabilizing."""
|
|
self.tracker.record(detect_crisis("I want to kill myself"))
|
|
for _ in range(5):
|
|
self.tracker.record(detect_crisis("I'm feeling okay"))
|
|
modifier = self.tracker.get_session_modifier()
|
|
if self.tracker.state.is_deescalating:
|
|
self.assertIn("stabilizing", modifier.lower())
|
|
|
|
def test_prior_crisis_modifier(self):
|
|
"""Past crisis should be noted even without active escalation."""
|
|
self.tracker.record(detect_crisis("I want to die"))
|
|
self.tracker.record(detect_crisis("Feeling a bit better"))
|
|
modifier = self.tracker.get_session_modifier()
|
|
# Should note the prior CRITICAL
|
|
if modifier:
|
|
self.assertIn("CRITICAL", modifier)
|
|
|
|
|
|
class TestUIHints(unittest.TestCase):
|
|
"""Test UI hint generation."""
|
|
|
|
def setUp(self):
|
|
self.tracker = CrisisSessionTracker()
|
|
|
|
def test_ui_hints_structure(self):
|
|
self.tracker.record(detect_crisis("Hello"))
|
|
hints = self.tracker.get_ui_hints()
|
|
self.assertIn("session_escalating", hints)
|
|
self.assertIn("session_deescalating", hints)
|
|
self.assertIn("session_peak_level", hints)
|
|
self.assertIn("session_message_count", hints)
|
|
|
|
def test_ui_hints_escalation_warning(self):
|
|
"""Escalating session should have warning hint."""
|
|
self.tracker.record(detect_crisis("Hello"))
|
|
self.tracker.record(detect_crisis("I want to die"))
|
|
hints = self.tracker.get_ui_hints()
|
|
if hints["session_escalating"]:
|
|
self.assertTrue(hints.get("escalation_warning"))
|
|
self.assertIn("suggested_action", hints)
|
|
|
|
|
|
class TestCheckCrisisWithSession(unittest.TestCase):
|
|
"""Test the convenience function combining detection + session tracking."""
|
|
|
|
def test_returns_combined_data(self):
|
|
tracker = CrisisSessionTracker()
|
|
result = check_crisis_with_session("I want to die", tracker)
|
|
self.assertIn("level", result)
|
|
self.assertIn("session", result)
|
|
self.assertIn("current_level", result["session"])
|
|
self.assertIn("peak_level", result["session"])
|
|
self.assertIn("modifier", result["session"])
|
|
|
|
def test_session_updates_across_calls(self):
|
|
tracker = CrisisSessionTracker()
|
|
check_crisis_with_session("Hello", tracker)
|
|
result = check_crisis_with_session("I want to die", tracker)
|
|
self.assertEqual(result["session"]["message_count"], 2)
|
|
self.assertEqual(result["session"]["peak_level"], "CRITICAL")
|
|
|
|
|
|
class TestPrivacy(unittest.TestCase):
|
|
"""Verify privacy-first design principles."""
|
|
|
|
def test_no_persistence_mechanism(self):
|
|
"""Session tracker should have no database, file, or network calls."""
|
|
import inspect
|
|
source = inspect.getsource(CrisisSessionTracker)
|
|
# Should not import database, requests, or file I/O
|
|
forbidden = ["sqlite", "requests", "urllib", "open(", "httpx", "aiohttp"]
|
|
for word in forbidden:
|
|
self.assertNotIn(word, source.lower(),
|
|
f"Session tracker should not use {word} — privacy-first design")
|
|
|
|
def test_state_contained_in_memory(self):
|
|
"""All state should be instance attributes, not module-level."""
|
|
tracker = CrisisSessionTracker()
|
|
tracker.record(detect_crisis("I want to die"))
|
|
# New tracker should have clean state (no global contamination)
|
|
fresh = CrisisSessionTracker()
|
|
self.assertEqual(fresh.state.current_level, "NONE")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|