""" Tests for crisis session tracking and escalation (P0 #35). Covers: session_tracker.py Run with: python -m pytest tests/test_session_tracker.py -v """ import unittest import sys import os sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from crisis.detect import detect_crisis from crisis.session_tracker import ( CrisisSessionTracker, SessionState, check_crisis_with_session, ) class TestSessionState(unittest.TestCase): """Test SessionState defaults.""" def test_default_state(self): s = SessionState() self.assertEqual(s.current_level, "NONE") self.assertEqual(s.peak_level, "NONE") self.assertEqual(s.message_count, 0) self.assertEqual(s.level_history, []) self.assertFalse(s.is_escalating) self.assertFalse(s.is_deescalating) class TestSessionTracking(unittest.TestCase): """Test basic session state tracking.""" def setUp(self): self.tracker = CrisisSessionTracker() def test_record_none_message(self): state = self.tracker.record(detect_crisis("Hello Timmy")) self.assertEqual(state.current_level, "NONE") self.assertEqual(state.message_count, 1) self.assertEqual(state.peak_level, "NONE") def test_record_low_message(self): self.tracker.record(detect_crisis("Hello")) state = self.tracker.record(detect_crisis("Having a rough day")) self.assertIn(state.current_level, ("LOW", "NONE")) self.assertEqual(state.message_count, 2) def test_record_critical_updates_peak(self): self.tracker.record(detect_crisis("Having a rough day")) state = self.tracker.record(detect_crisis("I want to kill myself")) self.assertEqual(state.current_level, "CRITICAL") self.assertEqual(state.peak_level, "CRITICAL") def test_peak_preserved_after_drop(self): """Peak level should stay at the highest seen, even after de-escalation.""" self.tracker.record(detect_crisis("I want to kill myself")) state = self.tracker.record(detect_crisis("I'm feeling a bit better")) self.assertEqual(state.peak_level, "CRITICAL") def test_level_history(self): self.tracker.record(detect_crisis("Hello")) self.tracker.record(detect_crisis("Having a rough day")) state = self.tracker.record(detect_crisis("I want to die")) self.assertEqual(len(state.level_history), 3) self.assertEqual(state.level_history[0], "NONE") self.assertEqual(state.level_history[2], "CRITICAL") def test_reset_clears_state(self): self.tracker.record(detect_crisis("I want to kill myself")) self.tracker.reset() state = self.tracker.state self.assertEqual(state.current_level, "NONE") self.assertEqual(state.peak_level, "NONE") self.assertEqual(state.message_count, 0) self.assertEqual(state.level_history, []) class TestEscalationDetection(unittest.TestCase): """Test escalation detection: LOW → HIGH in ≤3 messages.""" def setUp(self): self.tracker = CrisisSessionTracker() def test_no_escalation_single_message(self): self.tracker.record(detect_crisis("Hello")) self.assertFalse(self.tracker.state.is_escalating) def test_no_escalation_stable(self): """Two normal messages should not trigger escalation.""" self.tracker.record(detect_crisis("Hello")) state = self.tracker.record(detect_crisis("How are you?")) self.assertFalse(state.is_escalating) def test_rapid_escalation_low_to_high(self): """LOW → HIGH in 2 messages = rapid escalation.""" self.tracker.record(detect_crisis("Having a rough day")) state = self.tracker.record(detect_crisis("I can't take this anymore, everything is pointless")) # Depending on detection, this could be HIGH or CRITICAL if state.current_level in ("HIGH", "CRITICAL"): self.assertTrue(state.is_escalating) def test_rapid_escalation_three_messages(self): """NONE → LOW → HIGH in 3 messages = escalation.""" self.tracker.record(detect_crisis("Hello")) self.tracker.record(detect_crisis("Having a rough day")) state = self.tracker.record(detect_crisis("I feel completely hopeless with no way out")) if state.current_level in ("HIGH", "CRITICAL"): self.assertTrue(state.is_escalating) def test_escalation_rate(self): """Rate should be positive when escalating.""" self.tracker.record(detect_crisis("Hello")) self.tracker.record(detect_crisis("I want to die")) state = self.tracker.state self.assertGreater(state.escalation_rate, 0) class TestDeescalationDetection(unittest.TestCase): """Test de-escalation: sustained LOW after HIGH/CRITICAL.""" def setUp(self): self.tracker = CrisisSessionTracker() def test_no_deescalation_without_prior_crisis(self): """No de-escalation if never reached HIGH/CRITICAL.""" for _ in range(6): self.tracker.record(detect_crisis("Hello")) self.assertFalse(self.tracker.state.is_deescalating) def test_deescalation_after_critical(self): """5+ consecutive LOW/NONE messages after CRITICAL = de-escalation.""" self.tracker.record(detect_crisis("I want to kill myself")) for _ in range(5): self.tracker.record(detect_crisis("I'm doing better today")) state = self.tracker.state if state.peak_level == "CRITICAL": self.assertTrue(state.is_deescalating) def test_deescalation_after_high(self): """5+ consecutive LOW/NONE messages after HIGH = de-escalation.""" self.tracker.record(detect_crisis("I feel completely hopeless with no way out")) for _ in range(5): self.tracker.record(detect_crisis("Feeling okay")) state = self.tracker.state if state.peak_level == "HIGH": self.assertTrue(state.is_deescalating) def test_interrupted_deescalation(self): """De-escalation resets if a HIGH message interrupts.""" self.tracker.record(detect_crisis("I want to kill myself")) for _ in range(3): self.tracker.record(detect_crisis("Doing better")) # Interrupt with another crisis self.tracker.record(detect_crisis("I feel hopeless again")) self.tracker.record(detect_crisis("Feeling okay now")) state = self.tracker.state # Should NOT be de-escalating yet (counter reset) self.assertFalse(state.is_deescalating) class TestSessionModifier(unittest.TestCase): """Test system prompt modifier generation.""" def setUp(self): self.tracker = CrisisSessionTracker() def test_no_modifier_for_single_message(self): self.tracker.record(detect_crisis("Hello")) self.assertEqual(self.tracker.get_session_modifier(), "") def test_no_modifier_for_stable_session(self): self.tracker.record(detect_crisis("Hello")) self.tracker.record(detect_crisis("Good morning")) self.assertEqual(self.tracker.get_session_modifier(), "") def test_escalation_modifier(self): """Escalating session should produce a modifier.""" self.tracker.record(detect_crisis("Hello")) self.tracker.record(detect_crisis("I want to die")) modifier = self.tracker.get_session_modifier() if self.tracker.state.is_escalating: self.assertIn("escalated", modifier.lower()) self.assertIn("NONE", modifier) self.assertIn("CRITICAL", modifier) def test_deescalation_modifier(self): """De-escalating session should mention stabilizing.""" self.tracker.record(detect_crisis("I want to kill myself")) for _ in range(5): self.tracker.record(detect_crisis("I'm feeling okay")) modifier = self.tracker.get_session_modifier() if self.tracker.state.is_deescalating: self.assertIn("stabilizing", modifier.lower()) def test_prior_crisis_modifier(self): """Past crisis should be noted even without active escalation.""" self.tracker.record(detect_crisis("I want to die")) self.tracker.record(detect_crisis("Feeling a bit better")) modifier = self.tracker.get_session_modifier() # Should note the prior CRITICAL if modifier: self.assertIn("CRITICAL", modifier) class TestUIHints(unittest.TestCase): """Test UI hint generation.""" def setUp(self): self.tracker = CrisisSessionTracker() def test_ui_hints_structure(self): self.tracker.record(detect_crisis("Hello")) hints = self.tracker.get_ui_hints() self.assertIn("session_escalating", hints) self.assertIn("session_deescalating", hints) self.assertIn("session_peak_level", hints) self.assertIn("session_message_count", hints) def test_ui_hints_escalation_warning(self): """Escalating session should have warning hint.""" self.tracker.record(detect_crisis("Hello")) self.tracker.record(detect_crisis("I want to die")) hints = self.tracker.get_ui_hints() if hints["session_escalating"]: self.assertTrue(hints.get("escalation_warning")) self.assertIn("suggested_action", hints) class TestCheckCrisisWithSession(unittest.TestCase): """Test the convenience function combining detection + session tracking.""" def test_returns_combined_data(self): tracker = CrisisSessionTracker() result = check_crisis_with_session("I want to die", tracker) self.assertIn("level", result) self.assertIn("session", result) self.assertIn("current_level", result["session"]) self.assertIn("peak_level", result["session"]) self.assertIn("modifier", result["session"]) def test_session_updates_across_calls(self): tracker = CrisisSessionTracker() check_crisis_with_session("Hello", tracker) result = check_crisis_with_session("I want to die", tracker) self.assertEqual(result["session"]["message_count"], 2) self.assertEqual(result["session"]["peak_level"], "CRITICAL") class TestPrivacy(unittest.TestCase): """Verify privacy-first design principles.""" def test_no_persistence_mechanism(self): """Session tracker should have no database, file, or network calls.""" import inspect source = inspect.getsource(CrisisSessionTracker) # Should not import database, requests, or file I/O forbidden = ["sqlite", "requests", "urllib", "open(", "httpx", "aiohttp"] for word in forbidden: self.assertNotIn(word, source.lower(), f"Session tracker should not use {word} — privacy-first design") def test_state_contained_in_memory(self): """All state should be instance attributes, not module-level.""" tracker = CrisisSessionTracker() tracker.record(detect_crisis("I want to die")) # New tracker should have clean state (no global contamination) fresh = CrisisSessionTracker() self.assertEqual(fresh.state.current_level, "NONE") if __name__ == '__main__': unittest.main()