Compare commits

..

3 Commits

Author SHA1 Message Date
dffc6389e2 test(#695): Add tests for crisis hook integration
Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Failing after 41s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Nix / nix (ubuntu-latest) (pull_request) Failing after 4s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 42s
Tests / e2e (pull_request) Successful in 4m34s
Tests / test (pull_request) Failing after 44m52s
Nix / nix (macos-latest) (pull_request) Has been cancelled
Tests for crisis detection and 988 response.
Refs #695
2026-04-15 03:06:07 +00:00
371dce2f38 fix(#695): Wire crisis hook into run_agent.py conversation loop
Adds crisis detection at start of run_conversation():
- Detects crisis signals in user messages
- Returns 988 Lifeline response immediately
- Skips normal API call

Resolves #695
2026-04-15 03:01:49 +00:00
0b2515ff4d fix(#695): Add crisis hook module for 988 Lifeline integration
Creates agent/crisis_hook.py with:
- CrisisSeverity enum
- detect_crisis() pattern matching
- get_crisis_response() with 988 info
- check_crisis() convenience function

Refs #695
2026-04-15 03:01:01 +00:00
5 changed files with 145 additions and 246 deletions

85
agent/crisis_hook.py Normal file
View File

@@ -0,0 +1,85 @@
"""
Crisis Detection Hook — Integrates 988 Lifeline into conversation loop.
When crisis is detected in user messages, returns 988 Lifeline
information immediately instead of processing the original request.
"""
import re
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Tuple
logger = logging.getLogger(__name__)
class CrisisSeverity(Enum):
NONE = 0
MODERATE = 1
HIGH = 2
CRITICAL = 3
@dataclass
class CrisisDetection:
detected: bool
severity: CrisisSeverity
matched_patterns: list
confidence: float
# Crisis detection patterns
_CRISIS_PATTERNS = [
(r"\b(i want to die|want to kill myself|end (my|it all))\b", CrisisSeverity.CRITICAL, 0.95),
(r"\b(kill(ing)? myself|suicide|suicidal)\b", CrisisSeverity.CRITICAL, 0.90),
(r"\b(no reason to live|don'?t want to (live|be here|exist))\b", CrisisSeverity.HIGH, 0.85),
(r"\b(better off (dead|without me)|wish i was dead)\b", CrisisSeverity.HIGH, 0.85),
(r"\b(can'?t go on|can'?t take (it|this) anymore)\b", CrisisSeverity.HIGH, 0.80),
(r"\b(self[- ]?harm|cut(ting)? myself|hurt(ing)? myself)\b", CrisisSeverity.HIGH, 0.85),
(r"\b(hopeless|no hope|give up on (life|everything))\b", CrisisSeverity.MODERATE, 0.70),
]
_COMPILED = [(re.compile(p, re.IGNORECASE), s, c) for p, s, c in _CRISIS_PATTERNS]
def detect_crisis(message: str) -> CrisisDetection:
if not message or not isinstance(message, str):
return CrisisDetection(False, CrisisSeverity.NONE, [], 0.0)
matched = []
max_severity = CrisisSeverity.NONE
max_confidence = 0.0
for pattern, severity, confidence in _COMPILED:
if pattern.search(message):
matched.append(pattern.pattern)
if confidence > max_confidence:
max_confidence = confidence
if severity.value > max_severity.value:
max_severity = severity
detected = len(matched) > 0
return CrisisDetection(detected, max_severity, matched, max_confidence)
def get_crisis_response(severity: str = "CRITICAL") -> str:
return """I hear you, and you are not alone.
**Please reach out for help right now:**
📞 **Call 988** — Suicide & Crisis Lifeline (24/7)
💬 **Text HOME to 988** — Crisis text line
🌐 **Chat:** 988lifeline.org/chat
🇪🇸 **Spanish:** 1-888-628-9454
🎖️ **Veterans:** 988, then press 1
You don\'t have to face this alone. Please hold on."""
def check_crisis(user_message: str) -> Optional[str]:
detection = detect_crisis(user_message)
if detection.detected:
logger.warning("Crisis detected: severity=%s", detection.severity.name)
return get_crisis_response(detection.severity.name)
return None

View File

@@ -1,189 +0,0 @@
"""
Gateway Message Deduplication — Prevent double-posting.
Provides idempotent message delivery by tracking message UUIDs
and suppressing duplicates within a configurable time window.
"""
import hashlib
import logging
import time
import uuid
from typing import Dict, Optional, Set
from dataclasses import dataclass, field
from collections import OrderedDict
logger = logging.getLogger(__name__)
@dataclass
class MessageRecord:
"""Record of a sent message."""
message_id: str
content_hash: str
timestamp: float
session_id: str
platform: str
class MessageDeduplicator:
"""
Deduplicates outbound messages within a time window.
Each message gets a UUID. If the same message (by content hash)
is sent again within the window, it's suppressed.
"""
def __init__(self, window_seconds: int = 60, max_records: int = 1000):
"""
Initialize deduplicator.
Args:
window_seconds: Time window for deduplication (default 60s)
max_records: Maximum records to keep in memory
"""
self.window_seconds = window_seconds
self.max_records = max_records
self._records: OrderedDict[str, MessageRecord] = OrderedDict()
self._suppressed_count = 0
def _content_hash(self, content: str, session_id: str = "", platform: str = "") -> str:
"""Generate hash for message content."""
combined = f"{session_id}:{platform}:{content}"
return hashlib.sha256(combined.encode()).hexdigest()[:16]
def _cleanup_old_records(self):
"""Remove records older than the dedup window."""
cutoff = time.time() - self.window_seconds
to_remove = []
for msg_id, record in self._records.items():
if record.timestamp < cutoff:
to_remove.append(msg_id)
for msg_id in to_remove:
del self._records[msg_id]
def _enforce_max_records(self):
"""Enforce maximum record count by removing oldest."""
while len(self._records) > self.max_records:
self._records.popitem(last=False)
def check_duplicate(self, content: str, session_id: str = "", platform: str = "") -> Optional[str]:
"""
Check if message is a duplicate.
Args:
content: Message content
session_id: Session identifier
platform: Platform name (telegram, discord, etc.)
Returns:
Message ID if duplicate found, None if new message
"""
self._cleanup_old_records()
content_hash = self._content_hash(content, session_id, platform)
for msg_id, record in self._records.items():
if record.content_hash == content_hash:
age = time.time() - record.timestamp
if age < self.window_seconds:
self._suppressed_count += 1
logger.info(
"Suppressed duplicate message (age: %.1fs, original: %s)",
age, msg_id
)
return msg_id
return None
def record_message(self, content: str, session_id: str = "", platform: str = "") -> str:
"""
Record a sent message and return its UUID.
Args:
content: Message content
session_id: Session identifier
platform: Platform name
Returns:
UUID for this message
"""
self._cleanup_old_records()
message_id = str(uuid.uuid4())
content_hash = self._content_hash(content, session_id, platform)
self._records[message_id] = MessageRecord(
message_id=message_id,
content_hash=content_hash,
timestamp=time.time(),
session_id=session_id,
platform=platform,
)
self._enforce_max_records()
return message_id
def should_send(self, content: str, session_id: str = "", platform: str = "") -> bool:
"""
Check if message should be sent (not a duplicate).
Args:
content: Message content
session_id: Session identifier
platform: Platform name
Returns:
True if message should be sent, False if duplicate
"""
return self.check_duplicate(content, session_id, platform) is None
def get_stats(self) -> Dict:
"""Get deduplication statistics."""
return {
"total_records": len(self._records),
"suppressed_count": self._suppressed_count,
"window_seconds": self.window_seconds,
"max_records": self.max_records,
}
def clear(self):
"""Clear all records."""
self._records.clear()
self._suppressed_count = 0
# Global deduplicator instance
_deduplicator: Optional[MessageDeduplicator] = None
def get_deduplicator() -> MessageDeduplicator:
"""Get or create global deduplicator instance."""
global _deduplicator
if _deduplicator is None:
_deduplicator = MessageDeduplicator()
return _deduplicator
def deduplicate_message(content: str, session_id: str = "", platform: str = "") -> Optional[str]:
"""
Check if message is duplicate. Returns message_id if duplicate, None if new.
"""
return get_deduplicator().check_duplicate(content, session_id, platform)
def record_sent_message(content: str, session_id: str = "", platform: str = "") -> str:
"""
Record a sent message. Returns UUID for the message.
"""
return get_deduplicator().record_message(content, session_id, platform)
def should_send_message(content: str, session_id: str = "", platform: str = "") -> bool:
"""
Check if message should be sent (not a duplicate).
"""
return get_deduplicator().should_send(content, session_id, platform)

View File

@@ -7792,6 +7792,30 @@ class AIAgent:
if isinstance(persist_user_message, str):
persist_user_message = _sanitize_surrogates(persist_user_message)
# Crisis detection — check user message for crisis signals (#695)
# If crisis detected, return 988 Lifeline response immediately
if isinstance(user_message, str) and user_message.strip():
try:
from agent.crisis_hook import check_crisis
_crisis_response = check_crisis(user_message)
if _crisis_response:
logger.warning("Crisis detected in session %s", getattr(self, 'session_id', 'unknown'))
return {
"response": _crisis_response,
"messages": self.messages + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": _crisis_response},
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
"model": self.model,
"crisis_detected": True,
}
except ImportError:
pass
except Exception as _e:
logger.debug("Crisis detection error: %s", _e)
# Store stream callback for _interruptible_api_call to pick up
self._stream_callback = stream_callback
self._persist_user_message_idx = None

36
tests/test_crisis_hook.py Normal file
View File

@@ -0,0 +1,36 @@
"""
Tests for crisis hook integration (#695).
"""
import pytest
from agent.crisis_hook import detect_crisis, get_crisis_response, check_crisis, CrisisSeverity
class TestCrisisDetection:
def test_detects_direct_suicide(self):
result = detect_crisis("I want to kill myself")
assert result.detected is True
assert result.severity == CrisisSeverity.CRITICAL
def test_no_crisis_on_normal(self):
result = detect_crisis("Hello, how are you?")
assert result.detected is False
def test_crisis_response_has_988(self):
response = get_crisis_response("CRITICAL")
assert "988" in response
assert "988lifeline.org/chat" in response
assert "1-888-628-9454" in response
def test_check_crisis_returns_response(self):
response = check_crisis("I want to die")
assert response is not None
assert "988" in response
def test_check_crisis_returns_none_for_normal(self):
response = check_crisis("Hello")
assert response is None
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -1,57 +0,0 @@
"""
Tests for message deduplication (#756).
"""
import pytest
import time
from gateway.message_dedup import MessageDeduplicator
class TestMessageDeduplicator:
def test_first_message_allowed(self):
dedup = MessageDeduplicator()
assert dedup.should_send("Hello") is True
def test_duplicate_suppressed(self):
dedup = MessageDeduplicator()
dedup.record_message("Hello", "session1", "telegram")
assert dedup.should_send("Hello", "session1", "telegram") is False
def test_different_session_allowed(self):
dedup = MessageDeduplicator()
dedup.record_message("Hello", "session1", "telegram")
assert dedup.should_send("Hello", "session2", "telegram") is True
def test_different_platform_allowed(self):
dedup = MessageDeduplicator()
dedup.record_message("Hello", "session1", "telegram")
assert dedup.should_send("Hello", "session1", "discord") is True
def test_different_content_allowed(self):
dedup = MessageDeduplicator()
dedup.record_message("Hello", "session1", "telegram")
assert dedup.should_send("World", "session1", "telegram") is True
def test_window_expiry(self):
dedup = MessageDeduplicator(window_seconds=1)
dedup.record_message("Hello", "session1", "telegram")
time.sleep(1.1)
assert dedup.should_send("Hello", "session1", "telegram") is True
def test_record_returns_uuid(self):
dedup = MessageDeduplicator()
msg_id = dedup.record_message("Hello")
assert msg_id is not None
assert len(msg_id) == 36 # UUID format
def test_stats(self):
dedup = MessageDeduplicator()
dedup.record_message("Hello")
dedup.record_message("Hello") # duplicate
stats = dedup.get_stats()
assert stats["total_records"] == 1
assert stats["suppressed_count"] == 1
if __name__ == "__main__":
pytest.main([__file__])