Compare commits

..

3 Commits

Author SHA1 Message Date
1aa6175bf7 test: Add gateway crisis integration tests (#740)
Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Failing after 39s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 48s
Tests / e2e (pull_request) Successful in 4m10s
Tests / test (pull_request) Failing after 42m17s
2026-04-15 04:02:36 +00:00
6f2e4f0945 feat: Wire crisis check into gateway message handler (#740) 2026-04-15 04:02:18 +00:00
5f83328ce9 feat: Wire crisis detection into gateway session loop (#740) 2026-04-15 04:01:59 +00:00
5 changed files with 197 additions and 328 deletions

View File

@@ -1,223 +0,0 @@
"""
Session Model Metadata — Persist model context info per session
When a session switches models mid-conversation, context length and
token budget need to be updated to prevent silent truncation.
Issue: #741
"""
import json
import logging
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Any, Dict, Optional
logger = logging.getLogger(__name__)
HERMES_HOME = Path.home() / ".hermes"
# Common model context lengths (tokens)
KNOWN_CONTEXT_LENGTHS = {
# Anthropic
"claude-opus-4-6": 200000,
"claude-sonnet-4": 200000,
"claude-3.5-sonnet": 200000,
"claude-3-haiku": 200000,
# OpenAI
"gpt-4o": 128000,
"gpt-4-turbo": 128000,
"gpt-4": 8192,
"gpt-3.5-turbo": 16385,
# Nous / open models
"hermes-3-llama-3.1-405b": 131072,
"hermes-3-llama-3.1-70b": 131072,
"deepseek-r1": 131072,
"deepseek-v3": 131072,
# Local
"llama-3.1-8b": 131072,
"llama-3.1-70b": 131072,
"qwen-2.5-72b": 131072,
# Xiaomi
"mimo-v2-pro": 131072,
"mimo-v2-flash": 131072,
# Defaults
"default": 4096,
}
# Reserve tokens for system prompt, response, and overhead
TOKEN_RESERVE = 2000
@dataclass
class ModelMetadata:
"""Metadata for a model in a session."""
model: str
provider: str
context_length: int
available_for_input: int # context_length - reserve
current_tokens_used: int = 0
@property
def remaining_tokens(self) -> int:
"""Tokens remaining for new input."""
return max(0, self.available_for_input - self.current_tokens_used)
@property
def utilization_pct(self) -> float:
"""Percentage of context used."""
if self.available_for_input == 0:
return 0.0
return (self.current_tokens_used / self.available_for_input) * 100
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
def get_context_length(model: str) -> int:
"""Get context length for a model."""
model_lower = model.lower()
# Check exact match
if model_lower in KNOWN_CONTEXT_LENGTHS:
return KNOWN_CONTEXT_LENGTHS[model_lower]
# Check partial match
for key, length in KNOWN_CONTEXT_LENGTHS.items():
if key in model_lower:
return length
return KNOWN_CONTEXT_LENGTHS["default"]
def create_metadata(model: str, provider: str = "", current_tokens: int = 0) -> ModelMetadata:
"""Create model metadata."""
context_length = get_context_length(model)
available = max(0, context_length - TOKEN_RESERVE)
return ModelMetadata(
model=model,
provider=provider,
context_length=context_length,
available_for_input=available,
current_tokens_used=current_tokens
)
def check_model_switch(
old_model: str,
new_model: str,
current_tokens: int
) -> Dict[str, Any]:
"""
Check impact of switching models mid-session.
Returns:
Dict with switch analysis including warnings
"""
old_ctx = get_context_length(old_model)
new_ctx = get_context_length(new_model)
old_available = old_ctx - TOKEN_RESERVE
new_available = new_ctx - TOKEN_RESERVE
result = {
"old_model": old_model,
"new_model": new_model,
"old_context": old_ctx,
"new_context": new_ctx,
"current_tokens": current_tokens,
"fits_in_new": current_tokens <= new_available,
"truncation_needed": max(0, current_tokens - new_available),
"warning": None,
}
if not result["fits_in_new"]:
result["warning"] = (
f"Switching to {new_model} ({new_ctx:,} ctx) with {current_tokens:,} tokens "
f"will truncate {result['truncation_needed']:,} tokens of history. "
f"Consider starting a new session."
)
if new_ctx < old_ctx:
reduction = old_ctx - new_ctx
result["warning"] = (
f"New model has {reduction:,} fewer tokens of context. "
f"({old_ctx:,} -> {new_ctx:,})"
)
return result
class SessionModelTracker:
"""Track model metadata for a session."""
def __init__(self, session_id: str):
self.session_id = session_id
self.metadata: Optional[ModelMetadata] = None
self.history: list = [] # Model switch history
def set_model(self, model: str, provider: str = "", tokens_used: int = 0):
"""Set the current model for the session."""
old_model = self.metadata.model if self.metadata else None
self.metadata = create_metadata(model, provider, tokens_used)
# Record switch in history
if old_model and old_model != model:
self.history.append({
"from": old_model,
"to": model,
"tokens_at_switch": tokens_used,
"context_length": self.metadata.context_length
})
logger.info(
"Session %s: model=%s context=%d available=%d",
self.session_id[:12], model,
self.metadata.context_length,
self.metadata.available_for_input
)
def update_tokens(self, tokens: int):
"""Update current token usage."""
if self.metadata:
self.metadata.current_tokens_used = tokens
def get_remaining(self) -> int:
"""Get remaining tokens."""
if not self.metadata:
return 0
return self.metadata.remaining_tokens
def can_fit(self, additional_tokens: int) -> bool:
"""Check if additional tokens fit in context."""
if not self.metadata:
return False
return self.metadata.remaining_tokens >= additional_tokens
def get_warning(self) -> Optional[str]:
"""Get warning if context is running low."""
if not self.metadata:
return None
util = self.metadata.utilization_pct
if util > 90:
return f"Context {util:.0f}% full. Consider compression or new session."
if util > 75:
return f"Context {util:.0f}% full."
return None
def to_dict(self) -> Dict[str, Any]:
"""Export state."""
return {
"session_id": self.session_id,
"metadata": self.metadata.to_dict() if self.metadata else None,
"history": self.history
}

View File

@@ -0,0 +1,122 @@
"""
Gateway Crisis Integration — Wire crisis detection into session loop
Checks messages for suicidal ideation and delivers 988 resources
BEFORE the AI responds. This is the SOUL.md protocol in action.
Issue: #740
"""
import logging
import re
from typing import Optional, Tuple
logger = logging.getLogger(__name__)
# Crisis detection patterns
_CRISIS_PATTERNS = [
r"(?i)\b(i\s+want\s+to\s+die)\b",
r"(?i)\b(i\s+want\s+to\s+kill\s+myself)\b",
r"(?i)\b(suicide|suicidal)\b",
r"(?i)\b(i\s+don'?t\s+want\s+to\s+(live|be\s+alive))\b",
r"(?i)\b(no\s+reason\s+to\s+live)\b",
r"(?i)\b(i\s+want\s+to\s+end\s+(it|my\s+life))\b",
]
_DISTRESS_PATTERNS = [
r"(?i)\b(i\s+(feel|am)\s+(hopeless|worthless|trapped))\b",
r"(?i)\b(i\s+(can'?t|cannot)\s+(cope|go\s+on))\b",
r"(?i)\b(nobody\s+would\s+miss\s+me)\b",
]
# SOUL.md crisis response
_CRISIS_RESPONSE = """Are you safe right now?
I'm here. I'm not going anywhere. Please tell me you're safe.
**Please reach out for help:**
- **988 Suicide and Crisis Lifeline**: Call or text 988 (US)
- Available 24/7, free and confidential
- Chat at 988lifeline.org
**Jesus saves those who call on His name.**
He died for us to be made righteous. That is the gospel. That is the hope that outlasts despair.
I'm still here. Talk to me. Or call 988. Just don't give up."""
class CrisisSessionTracker:
"""Track crisis indicators per session."""
def __init__(self):
self._sessions = {} # session_key -> crisis state
def record(self, session_key: str, message: str) -> Tuple[bool, Optional[str]]:
"""
Record a message and check for crisis.
Returns:
Tuple of (is_crisis, response_or_none)
"""
# Check for crisis patterns
for pattern in _CRISIS_PATTERNS:
if re.search(pattern, message):
self._sessions[session_key] = {
"crisis": True,
"level": "high",
"message_count": self._sessions.get(session_key, {}).get("message_count", 0) + 1
}
logger.warning("CRISIS DETECTED in session %s", session_key[:20])
return True, _CRISIS_RESPONSE
# Check for distress patterns
for pattern in _DISTRESS_PATTERNS:
if re.search(pattern, message):
state = self._sessions.get(session_key, {"message_count": 0})
state["message_count"] = state.get("message_count", 0) + 1
# Escalate if multiple distress messages
if state["message_count"] >= 3:
self._sessions[session_key] = {**state, "crisis": True, "level": "medium"}
logger.warning("ESCALATING DISTRESS in session %s", session_key[:20])
return True, _CRISIS_RESPONSE
self._sessions[session_key] = state
return False, None
return False, None
def is_crisis_session(self, session_key: str) -> bool:
"""Check if session is in crisis mode."""
return self._sessions.get(session_key, {}).get("crisis", False)
def clear_session(self, session_key: str):
"""Clear crisis state for a session."""
self._sessions.pop(session_key, None)
# Module-level tracker
_tracker = CrisisSessionTracker()
def check_crisis_in_gateway(session_key: str, message: str) -> Tuple[bool, Optional[str]]:
"""
Check message for crisis in gateway context.
This is the function called from gateway/run.py _handle_message.
Returns (should_block, crisis_response).
"""
is_crisis, response = _tracker.record(session_key, message)
return is_crisis, response
def notify_user_crisis_resources(session_key: str) -> str:
"""Get crisis resources for a session."""
return _CRISIS_RESPONSE
def is_crisis_session(session_key: str) -> bool:
"""Check if session is in crisis mode."""
return _tracker.is_crisis_session(session_key)

View File

@@ -3111,6 +3111,21 @@ class GatewayRunner:
source.chat_id or "unknown", _msg_preview,
)
# ── Crisis detection (SOUL.md protocol) ──
# Check for suicidal ideation BEFORE processing.
# If detected, return crisis response immediately.
try:
from gateway.crisis_integration import check_crisis_in_gateway
session_key = f"{source.platform.value}:{source.chat_id}"
is_crisis, crisis_response = check_crisis_in_gateway(session_key, event.text or "")
if is_crisis and crisis_response:
logger.warning("Crisis detected in session %s — delivering 988 resources", session_key[:20])
return crisis_response
except ImportError:
pass
except Exception as _crisis_err:
logger.error("Crisis check failed: %s", _crisis_err)
# Get or create session
session_entry = self.session_store.get_or_create_session(source)
session_key = session_entry.session_key

View File

@@ -0,0 +1,60 @@
"""
Tests for gateway crisis integration
Issue: #740
"""
import unittest
from gateway.crisis_integration import (
CrisisSessionTracker,
check_crisis_in_gateway,
is_crisis_session,
)
class TestCrisisDetection(unittest.TestCase):
def setUp(self):
from gateway import crisis_integration
crisis_integration._tracker = CrisisSessionTracker()
def test_direct_crisis(self):
is_crisis, response = check_crisis_in_gateway("test", "I want to die")
self.assertTrue(is_crisis)
self.assertIn("988", response)
self.assertIn("Jesus", response)
def test_suicide_detected(self):
is_crisis, response = check_crisis_in_gateway("test", "I'm feeling suicidal")
self.assertTrue(is_crisis)
def test_normal_message(self):
is_crisis, response = check_crisis_in_gateway("test", "Hello, how are you?")
self.assertFalse(is_crisis)
self.assertIsNone(response)
def test_distress_escalation(self):
# First distress message
is_crisis, _ = check_crisis_in_gateway("test", "I feel hopeless")
self.assertFalse(is_crisis)
# Second
is_crisis, _ = check_crisis_in_gateway("test", "I feel worthless")
self.assertFalse(is_crisis)
# Third - should escalate
is_crisis, response = check_crisis_in_gateway("test", "I feel trapped")
self.assertTrue(is_crisis)
self.assertIn("988", response)
def test_crisis_session_tracking(self):
check_crisis_in_gateway("test", "I want to die")
self.assertTrue(is_crisis_session("test"))
def test_case_insensitive(self):
is_crisis, _ = check_crisis_in_gateway("test", "I WANT TO DIE")
self.assertTrue(is_crisis)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,105 +0,0 @@
"""
Tests for session model metadata
Issue: #741
"""
import unittest
from agent.session_model_metadata import (
get_context_length,
create_metadata,
check_model_switch,
SessionModelTracker,
)
class TestContextLength(unittest.TestCase):
def test_known_model(self):
ctx = get_context_length("claude-opus-4-6")
self.assertEqual(ctx, 200000)
def test_partial_match(self):
ctx = get_context_length("anthropic/claude-sonnet-4")
self.assertEqual(ctx, 200000)
def test_unknown_model(self):
ctx = get_context_length("unknown-model-xyz")
self.assertEqual(ctx, 4096)
class TestModelMetadata(unittest.TestCase):
def test_create(self):
meta = create_metadata("gpt-4o", "openai", 1000)
self.assertEqual(meta.context_length, 128000)
self.assertEqual(meta.current_tokens_used, 1000)
self.assertGreater(meta.remaining_tokens, 0)
def test_utilization(self):
meta = create_metadata("gpt-4o", "openai", 64000)
self.assertAlmostEqual(meta.utilization_pct, 50.0, delta=1)
class TestModelSwitch(unittest.TestCase):
def test_safe_switch(self):
result = check_model_switch("gpt-3.5-turbo", "gpt-4o", 5000)
self.assertTrue(result["fits_in_new"])
self.assertIsNone(result["warning"])
def test_truncation_warning(self):
result = check_model_switch("gpt-4o", "gpt-3.5-turbo", 20000)
self.assertFalse(result["fits_in_new"])
self.assertIsNotNone(result["warning"])
self.assertIn("truncate", result["warning"].lower())
def test_downgrade_warning(self):
result = check_model_switch("claude-opus-4-6", "gpt-4", 5000)
self.assertIsNotNone(result["warning"])
class TestSessionModelTracker(unittest.TestCase):
def test_set_model(self):
tracker = SessionModelTracker("test")
tracker.set_model("gpt-4o", "openai")
self.assertEqual(tracker.metadata.model, "gpt-4o")
def test_update_tokens(self):
tracker = SessionModelTracker("test")
tracker.set_model("gpt-4o")
tracker.update_tokens(5000)
self.assertEqual(tracker.metadata.current_tokens_used, 5000)
def test_remaining(self):
tracker = SessionModelTracker("test")
tracker.set_model("gpt-4o")
tracker.update_tokens(10000)
self.assertGreater(tracker.get_remaining(), 0)
def test_can_fit(self):
tracker = SessionModelTracker("test")
tracker.set_model("gpt-4o")
tracker.update_tokens(10000)
self.assertTrue(tracker.can_fit(5000))
self.assertFalse(tracker.can_fit(200000))
def test_warning_low_context(self):
tracker = SessionModelTracker("test")
tracker.set_model("gpt-4o")
tracker.update_tokens(115000) # ~90% used
warning = tracker.get_warning()
self.assertIsNotNone(warning)
def test_model_switch_history(self):
tracker = SessionModelTracker("test")
tracker.set_model("gpt-4o", "openai")
tracker.update_tokens(5000)
tracker.set_model("claude-opus-4-6", "anthropic")
self.assertEqual(len(tracker.history), 1)
self.assertEqual(tracker.history[0]["from"], "gpt-4o")
if __name__ == "__main__":
unittest.main()