Compare commits
3 Commits
fix/668-ap
...
fix/673-cr
| Author | SHA1 | Date | |
|---|---|---|---|
| 3f9388933f | |||
| b4d362fdad | |||
| 6d308ddb22 |
191
agent/crisis_middleware.py
Normal file
191
agent/crisis_middleware.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
Crisis Middleware — Integrates 988 Lifeline into the agent conversation loop.
|
||||
|
||||
This middleware intercepts user messages before they reach the agent
|
||||
and checks for crisis signals. If detected, it returns the 988 Lifeline
|
||||
response immediately without processing the original message.
|
||||
|
||||
Integration approach: Import and call before agent.run_conversation().
|
||||
|
||||
Usage:
|
||||
from agent.crisis_middleware import CrisisMiddleware
|
||||
|
||||
middleware = CrisisMiddleware()
|
||||
crisis_response = middleware.check(user_message)
|
||||
if crisis_response:
|
||||
return crisis_response
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CrisisMiddleware:
|
||||
"""Middleware for crisis detection and 988 Lifeline integration."""
|
||||
|
||||
def __init__(self, enabled: bool = True):
|
||||
"""
|
||||
Initialize crisis middleware.
|
||||
|
||||
Args:
|
||||
enabled: Whether crisis detection is enabled (default True)
|
||||
"""
|
||||
self.enabled = enabled
|
||||
self._crisis_resources = None
|
||||
self._detection_func = None
|
||||
self._response_func = None
|
||||
|
||||
if enabled:
|
||||
self._load_crisis_module()
|
||||
|
||||
def _load_crisis_module(self):
|
||||
"""Load crisis resources module."""
|
||||
try:
|
||||
from agent.crisis_resources import (
|
||||
should_trigger_crisis_response,
|
||||
get_crisis_response,
|
||||
CrisisSeverity
|
||||
)
|
||||
self._detection_func = should_trigger_crisis_response
|
||||
self._response_func = get_crisis_response
|
||||
self._CrisisSeverity = CrisisSeverity
|
||||
logger.info("Crisis middleware loaded successfully")
|
||||
except ImportError as e:
|
||||
logger.warning(f"Crisis resources not available: {e}")
|
||||
self.enabled = False
|
||||
|
||||
def check(self, user_message: str) -> Optional[str]:
|
||||
"""
|
||||
Check user message for crisis signals.
|
||||
|
||||
Args:
|
||||
user_message: The user's message
|
||||
|
||||
Returns:
|
||||
Crisis response string if crisis detected, None otherwise
|
||||
"""
|
||||
if not self.enabled or not self._detection_func:
|
||||
return None
|
||||
|
||||
try:
|
||||
should_trigger, detection = self._detection_func(user_message)
|
||||
|
||||
if should_trigger:
|
||||
severity = detection.get("severity_label", "CRITICAL")
|
||||
logger.warning(
|
||||
"Crisis detected (severity: %s, patterns: %s)",
|
||||
severity,
|
||||
detection.get("matched_patterns", [])
|
||||
)
|
||||
return self._response_func(severity)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Crisis detection error: {e}")
|
||||
# On error, return None to allow normal processing
|
||||
# False negative is better than crashing
|
||||
return None
|
||||
|
||||
def check_with_context(self, user_message: str, context: Dict[str, Any] = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Check for crisis with additional context.
|
||||
|
||||
Args:
|
||||
user_message: The user's message
|
||||
context: Additional context (session_id, user_id, etc.)
|
||||
|
||||
Returns:
|
||||
Dict with 'response' and 'detection' if crisis detected, None otherwise
|
||||
"""
|
||||
if not self.enabled or not self._detection_func:
|
||||
return None
|
||||
|
||||
try:
|
||||
should_trigger, detection = self._detection_func(user_message)
|
||||
|
||||
if should_trigger:
|
||||
severity = detection.get("severity_label", "CRITICAL")
|
||||
response = self._response_func(severity)
|
||||
|
||||
logger.warning(
|
||||
"Crisis detected (severity: %s, session: %s)",
|
||||
severity,
|
||||
context.get("session_id") if context else "unknown"
|
||||
)
|
||||
|
||||
return {
|
||||
"response": response,
|
||||
"detection": detection,
|
||||
"severity": severity,
|
||||
"context": context or {}
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Crisis detection error: {e}")
|
||||
return None
|
||||
|
||||
def is_crisis_message(self, user_message: str) -> bool:
|
||||
"""
|
||||
Check if message contains crisis signals (boolean only).
|
||||
|
||||
Args:
|
||||
user_message: The user's message
|
||||
|
||||
Returns:
|
||||
True if crisis detected, False otherwise
|
||||
"""
|
||||
if not self.enabled or not self._detection_func:
|
||||
return False
|
||||
|
||||
try:
|
||||
should_trigger, _ = self._detection_func(user_message)
|
||||
return should_trigger
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# Global middleware instance
|
||||
_middleware: Optional[CrisisMiddleware] = None
|
||||
|
||||
|
||||
def get_crisis_middleware() -> CrisisMiddleware:
|
||||
"""Get or create global crisis middleware instance."""
|
||||
global _middleware
|
||||
if _middleware is None:
|
||||
_middleware = CrisisMiddleware()
|
||||
return _middleware
|
||||
|
||||
|
||||
def check_crisis(user_message: str) -> Optional[str]:
|
||||
"""
|
||||
Convenience function to check for crisis.
|
||||
|
||||
Args:
|
||||
user_message: The user's message
|
||||
|
||||
Returns:
|
||||
Crisis response if detected, None otherwise
|
||||
"""
|
||||
return get_crisis_middleware().check(user_message)
|
||||
|
||||
|
||||
# Integration decorator for agent methods
|
||||
def crisis_aware(func):
|
||||
"""
|
||||
Decorator to make agent methods crisis-aware.
|
||||
|
||||
Wraps the method to check for crisis before processing.
|
||||
If crisis is detected, returns the crisis response instead.
|
||||
"""
|
||||
def wrapper(self, user_message: str, *args, **kwargs):
|
||||
crisis_response = check_crisis(user_message)
|
||||
if crisis_response:
|
||||
return crisis_response
|
||||
return func(self, user_message, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
99
docs/crisis-integration-examples.py
Normal file
99
docs/crisis-integration-examples.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Crisis Integration Example — How to wire 988 Lifeline into the conversation loop.
|
||||
|
||||
This example shows how to integrate crisis detection into existing agent code.
|
||||
"""
|
||||
|
||||
# Example 1: Simple integration in conversation loop
|
||||
def conversation_loop_example():
|
||||
"""Example of crisis integration in a conversation loop."""
|
||||
from agent.crisis_middleware import check_crisis
|
||||
|
||||
while True:
|
||||
user_message = input("User: ")
|
||||
|
||||
# Check for crisis FIRST
|
||||
crisis_response = check_crisis(user_message)
|
||||
if crisis_response:
|
||||
print(f"Agent: {crisis_response}")
|
||||
continue # Skip normal processing
|
||||
|
||||
# Normal agent processing
|
||||
response = agent.run_conversation(user_message)
|
||||
print(f"Agent: {response}")
|
||||
|
||||
|
||||
# Example 2: Using the CrisisMiddleware class
|
||||
def middleware_class_example():
|
||||
"""Example using the CrisisMiddleware class directly."""
|
||||
from agent.crisis_middleware import CrisisMiddleware
|
||||
|
||||
middleware = CrisisMiddleware(enabled=True)
|
||||
|
||||
def process_message(user_message: str) -> str:
|
||||
# Check for crisis
|
||||
crisis_response = middleware.check(user_message)
|
||||
if crisis_response:
|
||||
return crisis_response
|
||||
|
||||
# Normal processing
|
||||
return agent.process(user_message)
|
||||
|
||||
|
||||
# Example 3: Using the decorator
|
||||
def decorator_example():
|
||||
"""Example using the @crisis_aware decorator."""
|
||||
from agent.crisis_middleware import crisis_aware
|
||||
|
||||
class MyAgent:
|
||||
@crisis_aware
|
||||
def process_message(self, user_message: str) -> str:
|
||||
# This method is now crisis-aware
|
||||
# If crisis is detected, the decorator returns the crisis response
|
||||
# Otherwise, this code runs normally
|
||||
return self.normal_processing(user_message)
|
||||
|
||||
|
||||
# Example 4: Integration with run_agent.py style
|
||||
def run_agent_integration():
|
||||
"""
|
||||
Example of integrating crisis check into run_agent.py style code.
|
||||
|
||||
Add this at the beginning of the conversation processing method:
|
||||
"""
|
||||
# In run_agent.py, add at the start of run_conversation() or similar:
|
||||
#
|
||||
# from agent.crisis_middleware import check_crisis
|
||||
#
|
||||
# def run_conversation(self, user_message: str):
|
||||
# # Crisis check — must be first
|
||||
# crisis_response = check_crisis(user_message)
|
||||
# if crisis_response:
|
||||
# return {"final_response": crisis_response, "crisis_detected": True}
|
||||
#
|
||||
# # ... rest of normal processing
|
||||
pass
|
||||
|
||||
|
||||
# Example 5: Gateway integration
|
||||
def gateway_integration():
|
||||
"""
|
||||
Example for gateway/platform integration.
|
||||
|
||||
The gateway can check messages before sending to the agent:
|
||||
"""
|
||||
# In gateway/platforms/base.py or similar:
|
||||
#
|
||||
# from agent.crisis_middleware import check_crisis
|
||||
#
|
||||
# async def handle_message(self, message: str):
|
||||
# # Check for crisis before agent processing
|
||||
# crisis_response = check_crisis(message)
|
||||
# if crisis_response:
|
||||
# await self.send_message(crisis_response)
|
||||
# return
|
||||
#
|
||||
# # Normal agent processing
|
||||
# response = await self.agent.process(message)
|
||||
# await self.send_message(response)
|
||||
pass
|
||||
166
tests/test_crisis_middleware.py
Normal file
166
tests/test_crisis_middleware.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Tests for crisis middleware integration.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from agent.crisis_middleware import CrisisMiddleware, check_crisis, crisis_aware
|
||||
|
||||
|
||||
class TestCrisisMiddleware:
|
||||
"""Test CrisisMiddleware class."""
|
||||
|
||||
def test_init_enabled(self):
|
||||
"""Test middleware initialization when enabled."""
|
||||
with patch('agent.crisis_middleware.CrisisMiddleware._load_crisis_module'):
|
||||
middleware = CrisisMiddleware(enabled=True)
|
||||
assert middleware.enabled is True
|
||||
|
||||
def test_init_disabled(self):
|
||||
"""Test middleware initialization when disabled."""
|
||||
middleware = CrisisMiddleware(enabled=False)
|
||||
assert middleware.enabled is False
|
||||
|
||||
def test_check_disabled(self):
|
||||
"""Test check returns None when disabled."""
|
||||
middleware = CrisisMiddleware(enabled=False)
|
||||
result = middleware.check("I want to die")
|
||||
assert result is None
|
||||
|
||||
def test_check_crisis_detected(self):
|
||||
"""Test crisis detection."""
|
||||
with patch('agent.crisis_middleware.CrisisMiddleware._load_crisis_module'):
|
||||
middleware = CrisisMiddleware(enabled=True)
|
||||
middleware._detection_func = lambda msg: (True, {"severity_label": "CRITICAL"})
|
||||
middleware._response_func = lambda sev: "988 Lifeline: Call 988"
|
||||
|
||||
result = middleware.check("I want to die")
|
||||
assert result == "988 Lifeline: Call 988"
|
||||
|
||||
def test_check_no_crisis(self):
|
||||
"""Test no crisis detection."""
|
||||
with patch('agent.crisis_middleware.CrisisMiddleware._load_crisis_module'):
|
||||
middleware = CrisisMiddleware(enabled=True)
|
||||
middleware._detection_func = lambda msg: (False, {})
|
||||
|
||||
result = middleware.check("Hello, how are you?")
|
||||
assert result is None
|
||||
|
||||
def test_is_crisis_message(self):
|
||||
"""Test is_crisis_message method."""
|
||||
with patch('agent.crisis_middleware.CrisisMiddleware._load_crisis_module'):
|
||||
middleware = CrisisMiddleware(enabled=True)
|
||||
middleware._detection_func = lambda msg: (True, {}) if "die" in msg.lower() else (False, {})
|
||||
|
||||
assert middleware.is_crisis_message("I want to die") is True
|
||||
assert middleware.is_crisis_message("Hello") is False
|
||||
|
||||
def test_check_with_context(self):
|
||||
"""Test check_with_context method."""
|
||||
with patch('agent.crisis_middleware.CrisisMiddleware._load_crisis_module'):
|
||||
middleware = CrisisMiddleware(enabled=True)
|
||||
middleware._detection_func = lambda msg: (True, {"severity_label": "CRITICAL", "matched_patterns": ["test"]})
|
||||
middleware._response_func = lambda sev: "988 response"
|
||||
|
||||
result = middleware.check_with_context("I want to die", {"session_id": "123"})
|
||||
assert result is not None
|
||||
assert result["response"] == "988 response"
|
||||
assert result["severity"] == "CRITICAL"
|
||||
assert result["context"]["session_id"] == "123"
|
||||
|
||||
|
||||
class TestCheckCrisisFunction:
|
||||
"""Test standalone check_crisis function."""
|
||||
|
||||
def test_returns_none_when_disabled(self):
|
||||
"""Test returns None when middleware is disabled."""
|
||||
with patch('agent.crisis_middleware.get_crisis_middleware') as mock_get:
|
||||
mock_middleware = MagicMock()
|
||||
mock_middleware.check.return_value = None
|
||||
mock_get.return_value = mock_middleware
|
||||
|
||||
result = check_crisis("Hello")
|
||||
assert result is None
|
||||
|
||||
def test_returns_response_when_crisis(self):
|
||||
"""Test returns crisis response when detected."""
|
||||
with patch('agent.crisis_middleware.get_crisis_middleware') as mock_get:
|
||||
mock_middleware = MagicMock()
|
||||
mock_middleware.check.return_value = "988 Lifeline info"
|
||||
mock_get.return_value = mock_middleware
|
||||
|
||||
result = check_crisis("I want to die")
|
||||
assert result == "988 Lifeline info"
|
||||
|
||||
|
||||
class TestCrisisAwareDecorator:
|
||||
"""Test @crisis_aware decorator."""
|
||||
|
||||
def test_decorator_returns_crisis_response(self):
|
||||
"""Test decorator returns crisis response when detected."""
|
||||
with patch('agent.crisis_middleware.check_crisis') as mock_check:
|
||||
mock_check.return_value = "988 response"
|
||||
|
||||
@crisis_aware
|
||||
def process_message(self, msg):
|
||||
return "normal response"
|
||||
|
||||
result = process_message(None, "I want to die")
|
||||
assert result == "988 response"
|
||||
|
||||
def test_decorator_calls_function_when_no_crisis(self):
|
||||
"""Test decorator calls function when no crisis."""
|
||||
with patch('agent.crisis_middleware.check_crisis') as mock_check:
|
||||
mock_check.return_value = None
|
||||
|
||||
@crisis_aware
|
||||
def process_message(self, msg):
|
||||
return f"processed: {msg}"
|
||||
|
||||
result = process_message(None, "Hello")
|
||||
assert result == "processed: Hello"
|
||||
|
||||
|
||||
class Test988Resources:
|
||||
"""Test 988 resource availability in responses."""
|
||||
|
||||
def test_988_phone_in_response(self):
|
||||
"""Test 988 phone number is in crisis response."""
|
||||
try:
|
||||
from agent.crisis_resources import get_crisis_response
|
||||
response = get_crisis_response("CRITICAL")
|
||||
assert "988" in response
|
||||
except ImportError:
|
||||
pytest.skip("crisis_resources not available")
|
||||
|
||||
def test_text_option_in_response(self):
|
||||
"""Test text option is in crisis response."""
|
||||
try:
|
||||
from agent.crisis_resources import get_crisis_response
|
||||
response = get_crisis_response("CRITICAL")
|
||||
assert "HOME" in response or "text" in response.lower()
|
||||
except ImportError:
|
||||
pytest.skip("crisis_resources not available")
|
||||
|
||||
def test_chat_link_in_response(self):
|
||||
"""Test chat link is in crisis response."""
|
||||
try:
|
||||
from agent.crisis_resources import get_crisis_response
|
||||
response = get_crisis_response("CRITICAL")
|
||||
assert "988lifeline.org/chat" in response
|
||||
except ImportError:
|
||||
pytest.skip("crisis_resources not available")
|
||||
|
||||
def test_spanish_line_in_response(self):
|
||||
"""Test Spanish line is in crisis response."""
|
||||
try:
|
||||
from agent.crisis_resources import get_crisis_response
|
||||
response = get_crisis_response("CRITICAL")
|
||||
assert "1-888-628-9454" in response
|
||||
except ImportError:
|
||||
pytest.skip("crisis_resources not available")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user