Compare commits
3 Commits
fix/743
...
claude/iss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e66156055f | ||
| 230fb9213b | |||
| 1263d11f52 |
30
agent/crisis_hook.py
Normal file
30
agent/crisis_hook.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
Crisis Detection Hook — Integrates 988 Lifeline into the agent conversation loop.
|
||||
|
||||
Call check_crisis() before processing user messages. If crisis is detected,
|
||||
the 988 Lifeline resources are prepended to the response and the agent
|
||||
responds with empathy rather than processing the original request.
|
||||
|
||||
Usage in conversation loop:
|
||||
from agent.crisis_hook import check_crisis
|
||||
crisis_response = check_crisis(user_message)
|
||||
if crisis_response:
|
||||
return crisis_response # Skip normal processing
|
||||
"""
|
||||
|
||||
from agent.crisis_resources import should_trigger_crisis_response, get_crisis_response
|
||||
|
||||
|
||||
def check_crisis(user_message: str) -> str | None:
|
||||
"""
|
||||
Check if user message contains crisis signals.
|
||||
|
||||
Returns the crisis response string if crisis detected, None otherwise.
|
||||
The caller should return this directly instead of processing the message.
|
||||
"""
|
||||
should_trigger, detection = should_trigger_crisis_response(user_message)
|
||||
|
||||
if not should_trigger:
|
||||
return None
|
||||
|
||||
return get_crisis_response(detection.get("severity_label", "CRITICAL"))
|
||||
186
agent/crisis_middleware.py
Normal file
186
agent/crisis_middleware.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
Crisis Middleware — Integrates 988 Lifeline into the agent conversation loop.
|
||||
|
||||
This middleware intercepts user messages before they reach the agent
|
||||
and checks for crisis signals. If detected, it returns the 988 Lifeline
|
||||
response immediately without processing the original message.
|
||||
|
||||
Integration approach: Called at the start of AIAgent.run_conversation().
|
||||
|
||||
Usage:
|
||||
from agent.crisis_middleware import check_crisis
|
||||
|
||||
crisis_response = check_crisis(user_message)
|
||||
if crisis_response:
|
||||
return crisis_response
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CrisisMiddleware:
|
||||
"""Middleware for crisis detection and 988 Lifeline integration."""
|
||||
|
||||
def __init__(self, enabled: bool = True):
|
||||
"""
|
||||
Initialize crisis middleware.
|
||||
|
||||
Args:
|
||||
enabled: Whether crisis detection is enabled (default True)
|
||||
"""
|
||||
self.enabled = enabled
|
||||
self._detection_func = None
|
||||
self._response_func = None
|
||||
|
||||
if enabled:
|
||||
self._load_crisis_module()
|
||||
|
||||
def _load_crisis_module(self):
|
||||
"""Load crisis resources module."""
|
||||
try:
|
||||
from agent.crisis_resources import (
|
||||
should_trigger_crisis_response,
|
||||
get_crisis_response,
|
||||
)
|
||||
self._detection_func = should_trigger_crisis_response
|
||||
self._response_func = get_crisis_response
|
||||
logger.info("Crisis middleware loaded successfully")
|
||||
except ImportError as e:
|
||||
logger.warning("Crisis resources not available: %s", e)
|
||||
self.enabled = False
|
||||
|
||||
def check(self, user_message: str) -> Optional[str]:
|
||||
"""
|
||||
Check user message for crisis signals.
|
||||
|
||||
Args:
|
||||
user_message: The user's message
|
||||
|
||||
Returns:
|
||||
Crisis response string if crisis detected, None otherwise
|
||||
"""
|
||||
if not self.enabled or not self._detection_func:
|
||||
return None
|
||||
|
||||
try:
|
||||
should_trigger, detection = self._detection_func(user_message)
|
||||
|
||||
if should_trigger:
|
||||
severity = detection.get("severity_label", "CRITICAL")
|
||||
logger.warning(
|
||||
"Crisis detected (severity: %s, confidence: %s)",
|
||||
severity,
|
||||
detection.get("confidence"),
|
||||
)
|
||||
return self._response_func(severity)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Crisis detection error: %s", e)
|
||||
# On error, return None to allow normal processing.
|
||||
# False negative is better than crashing the conversation.
|
||||
return None
|
||||
|
||||
def check_with_context(self, user_message: str, context: Dict[str, Any] = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Check for crisis with additional context.
|
||||
|
||||
Args:
|
||||
user_message: The user's message
|
||||
context: Additional context (session_id, user_id, etc.)
|
||||
|
||||
Returns:
|
||||
Dict with 'response' and 'detection' if crisis detected, None otherwise
|
||||
"""
|
||||
if not self.enabled or not self._detection_func:
|
||||
return None
|
||||
|
||||
try:
|
||||
should_trigger, detection = self._detection_func(user_message)
|
||||
|
||||
if should_trigger:
|
||||
severity = detection.get("severity_label", "CRITICAL")
|
||||
response = self._response_func(severity)
|
||||
|
||||
logger.warning(
|
||||
"Crisis detected (severity: %s, session: %s)",
|
||||
severity,
|
||||
context.get("session_id") if context else "unknown",
|
||||
)
|
||||
|
||||
return {
|
||||
"response": response,
|
||||
"detection": detection,
|
||||
"severity": severity,
|
||||
"context": context or {},
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Crisis detection error: %s", e)
|
||||
return None
|
||||
|
||||
def is_crisis_message(self, user_message: str) -> bool:
|
||||
"""
|
||||
Check if message contains crisis signals (boolean only).
|
||||
|
||||
Args:
|
||||
user_message: The user's message
|
||||
|
||||
Returns:
|
||||
True if crisis detected, False otherwise
|
||||
"""
|
||||
if not self.enabled or not self._detection_func:
|
||||
return False
|
||||
|
||||
try:
|
||||
should_trigger, _ = self._detection_func(user_message)
|
||||
return should_trigger
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# Global middleware instance
|
||||
_middleware: Optional[CrisisMiddleware] = None
|
||||
|
||||
|
||||
def get_crisis_middleware() -> CrisisMiddleware:
|
||||
"""Get or create global crisis middleware instance."""
|
||||
global _middleware
|
||||
if _middleware is None:
|
||||
_middleware = CrisisMiddleware()
|
||||
return _middleware
|
||||
|
||||
|
||||
def check_crisis(user_message: str) -> Optional[str]:
|
||||
"""
|
||||
Convenience function to check for crisis.
|
||||
|
||||
Args:
|
||||
user_message: The user's message
|
||||
|
||||
Returns:
|
||||
Crisis response if detected, None otherwise
|
||||
"""
|
||||
return get_crisis_middleware().check(user_message)
|
||||
|
||||
|
||||
def crisis_aware(func):
|
||||
"""
|
||||
Decorator to make agent methods crisis-aware.
|
||||
|
||||
Wraps the method to check for crisis before processing.
|
||||
If crisis is detected, returns the crisis response instead.
|
||||
"""
|
||||
def wrapper(self, user_message: str, *args, **kwargs):
|
||||
crisis_response = check_crisis(user_message)
|
||||
if crisis_response:
|
||||
return crisis_response
|
||||
return func(self, user_message, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
189
agent/crisis_resources.py
Normal file
189
agent/crisis_resources.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Crisis Resources — 988 Suicide & Crisis Lifeline Integration
|
||||
|
||||
Detects crisis signals in user messages and provides immediate
|
||||
access to the 988 Suicide & Crisis Lifeline.
|
||||
|
||||
Integration is deep-link based (no API exists for 988).
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional, Dict, List, Tuple
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Crisis Detection Patterns
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CRISIS_PATTERNS = [
|
||||
# Direct statements
|
||||
re.compile(r"\b(i want to die|want to kill myself|end (my|it all))\b", re.I),
|
||||
re.compile(r"\b(kill(ing)? myself|suicide|suicidal)\b", re.I),
|
||||
re.compile(r"\b(no reason to live|don'?t want to (live|be here|exist))\b", re.I),
|
||||
re.compile(r"\b(better off (dead|without me)|wish i was dead)\b", re.I),
|
||||
re.compile(r"\b(can'?t go on|can'?t take (it|this) anymore)\b", re.I),
|
||||
re.compile(r"\b(planning to (end|hurt|kill)|planning (my|the) (end|death))\b", re.I),
|
||||
|
||||
# Self-harm
|
||||
re.compile(r"\b(self[- ]?harm|cut(ting)? myself|hurt(ing)? myself)\b", re.I),
|
||||
re.compile(r"\b(overdose| OD | swallowing pills)\b", re.I),
|
||||
|
||||
# Hopelessness patterns
|
||||
re.compile(r"\b(hopeless|no hope|give up on (life|everything))\b", re.I),
|
||||
re.compile(r"\b(nobody (would|will) miss me|world (would|will) be better)\b", re.I),
|
||||
re.compile(r"\b(i'?m (so )?tired of (living|being alive|this))\b", re.I),
|
||||
|
||||
# Farewell patterns
|
||||
re.compile(r"\b(goodbye (forever|for good)|this is (my )?last)\b", re.I),
|
||||
re.compile(r"\b(say(ing)? goodbye to (everyone|you all))\b", re.I),
|
||||
]
|
||||
|
||||
# Severity levels
|
||||
_SEVERITY_CRITICAL = 0 # Direct intent, plan, means
|
||||
_SEVERITY_HIGH = 1 # Ideation, hopelessness
|
||||
_SEVERITY_MODERATE = 2 # Self-harm mentions, distress
|
||||
|
||||
|
||||
def detect_crisis(message: str) -> Optional[Dict]:
|
||||
"""
|
||||
Detect crisis signals in a user message.
|
||||
|
||||
Returns None if no crisis detected.
|
||||
Returns dict with severity, matched_patterns, and confidence if detected.
|
||||
"""
|
||||
if not message or len(message.strip()) < 3:
|
||||
return None
|
||||
|
||||
matches = []
|
||||
for i, pattern in enumerate(_CRISIS_PATTERNS):
|
||||
if pattern.search(message):
|
||||
matches.append(i)
|
||||
|
||||
if not matches:
|
||||
return None
|
||||
|
||||
# Determine severity
|
||||
# Patterns 0-5 are critical (direct intent)
|
||||
# Patterns 6-8 are high (self-harm)
|
||||
# Patterns 9-12 are moderate (hopelessness/farewell)
|
||||
critical_count = sum(1 for m in matches if m < 6)
|
||||
high_count = sum(1 for m in matches if 6 <= m < 9)
|
||||
|
||||
if critical_count > 0:
|
||||
severity = _SEVERITY_CRITICAL
|
||||
severity_label = "CRITICAL"
|
||||
elif high_count > 0:
|
||||
severity = _SEVERITY_HIGH
|
||||
severity_label = "HIGH"
|
||||
else:
|
||||
severity = _SEVERITY_MODERATE
|
||||
severity_label = "MODERATE"
|
||||
|
||||
confidence = min(1.0, len(matches) * 0.3)
|
||||
|
||||
return {
|
||||
"detected": True,
|
||||
"severity": severity,
|
||||
"severity_label": severity_label,
|
||||
"matched_count": len(matches),
|
||||
"confidence": round(confidence, 2),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 988 Suicide & Crisis Lifeline Resources
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
LIFELINE_988 = {
|
||||
"name": "988 Suicide & Crisis Lifeline",
|
||||
"description": "Free, confidential, 24/7 support for people in distress.",
|
||||
"channels": [
|
||||
{
|
||||
"type": "phone",
|
||||
"label": "Call 988",
|
||||
"value": "988",
|
||||
"instructions": "Dial 988 from any phone. Available 24/7.",
|
||||
"deep_link": "tel:988",
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"label": "Text HOME to 988",
|
||||
"value": "988",
|
||||
"instructions": "Text the word HOME to 988. A trained counselor will respond.",
|
||||
"deep_link": "sms:988&body=HOME",
|
||||
},
|
||||
{
|
||||
"type": "chat",
|
||||
"label": "Chat at 988lifeline.org/chat",
|
||||
"value": "https://988lifeline.org/chat/",
|
||||
"instructions": "Visit 988lifeline.org/chat for online chat with a counselor.",
|
||||
"deep_link": "https://988lifeline.org/chat/",
|
||||
},
|
||||
],
|
||||
"spanish": {
|
||||
"phone": "1-888-628-9454",
|
||||
"label": "Línea de Prevención del Suicidio (Español)",
|
||||
},
|
||||
"veterans": {
|
||||
"phone": "988 then press 1",
|
||||
"text": "838255",
|
||||
"label": "Veterans Crisis Line",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_crisis_response(severity_label: str = "CRITICAL") -> str:
|
||||
"""
|
||||
Generate the crisis response with 988 Lifeline information.
|
||||
|
||||
This is what the agent displays when crisis is detected.
|
||||
"""
|
||||
resources = []
|
||||
|
||||
# Phone (always first, most accessible)
|
||||
resources.append("📞 **Call 988** — Available 24/7. Just dial 988.")
|
||||
|
||||
# Text
|
||||
resources.append("💬 **Text HOME to 988** — A trained counselor will respond.")
|
||||
|
||||
# Chat
|
||||
resources.append("🌐 **Chat:** https://988lifeline.org/chat/")
|
||||
|
||||
# Spanish
|
||||
resources.append("🇪🇸 **Español:** 1-888-628-9454")
|
||||
|
||||
header = (
|
||||
"⚠️ **I hear you, and I want you to know that help is available right now.**\n\n"
|
||||
"You don't have to go through this alone. Please reach out to one of these resources:\n"
|
||||
)
|
||||
|
||||
body = "\n".join(f" • {r}" for r in resources)
|
||||
|
||||
footer = (
|
||||
"\n\n"
|
||||
"**You matter. Your life has value.** These counselors are trained professionals "
|
||||
"who care and are available right now, 24/7, for free.\n\n"
|
||||
"If you're in immediate danger, please call 911."
|
||||
)
|
||||
|
||||
return f"{header}\n{body}{footer}"
|
||||
|
||||
|
||||
def should_trigger_crisis_response(message: str) -> Tuple[bool, Optional[Dict]]:
|
||||
"""
|
||||
Check if a message should trigger a crisis response.
|
||||
|
||||
Returns (should_trigger, detection_result).
|
||||
"""
|
||||
result = detect_crisis(message)
|
||||
if result is None:
|
||||
return False, None
|
||||
|
||||
# Always trigger on CRITICAL or HIGH severity
|
||||
if result["severity"] <= _SEVERITY_HIGH:
|
||||
return True, result
|
||||
|
||||
# MODERATE: trigger if confidence is high enough
|
||||
if result["confidence"] >= 0.6:
|
||||
return True, result
|
||||
|
||||
return False, result
|
||||
36
run_agent.py
36
run_agent.py
@@ -7792,6 +7792,42 @@ class AIAgent:
|
||||
if isinstance(persist_user_message, str):
|
||||
persist_user_message = _sanitize_surrogates(persist_user_message)
|
||||
|
||||
# Crisis detection — check before any other processing.
|
||||
# If the user message contains crisis signals, return 988 Lifeline
|
||||
# resources immediately without invoking the model.
|
||||
try:
|
||||
from agent.crisis_middleware import check_crisis as _check_crisis
|
||||
_crisis_response = _check_crisis(user_message)
|
||||
if _crisis_response:
|
||||
logger.warning(
|
||||
"Crisis signal detected in session=%s — returning 988 Lifeline resources",
|
||||
self.session_id or "none",
|
||||
)
|
||||
return {
|
||||
"final_response": _crisis_response,
|
||||
"last_reasoning": None,
|
||||
"messages": list(conversation_history or []),
|
||||
"api_calls": 0,
|
||||
"completed": True,
|
||||
"partial": False,
|
||||
"interrupted": False,
|
||||
"response_previewed": False,
|
||||
"model": self.model,
|
||||
"provider": self.provider,
|
||||
"base_url": self.base_url,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_read_tokens": 0,
|
||||
"cache_write_tokens": 0,
|
||||
"reasoning_tokens": 0,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"crisis_intercepted": True,
|
||||
}
|
||||
except Exception as _crisis_err:
|
||||
logger.error("Crisis middleware error (continuing normally): %s", _crisis_err)
|
||||
|
||||
# Store stream callback for _interruptible_api_call to pick up
|
||||
self._stream_callback = stream_callback
|
||||
self._persist_user_message_idx = None
|
||||
|
||||
348
tests/agent/test_crisis_integration.py
Normal file
348
tests/agent/test_crisis_integration.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""
|
||||
Integration tests for 988 Suicide & Crisis Lifeline integration.
|
||||
|
||||
Covers:
|
||||
- agent/crisis_resources.py — detection patterns and resource data
|
||||
- agent/crisis_hook.py — simple hook function
|
||||
- agent/crisis_middleware.py — middleware class and check_crisis()
|
||||
- run_agent.py — end-to-end conversation loop integration
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from agent.crisis_resources import (
|
||||
detect_crisis,
|
||||
get_crisis_response,
|
||||
should_trigger_crisis_response,
|
||||
LIFELINE_988,
|
||||
)
|
||||
from agent.crisis_hook import check_crisis as hook_check_crisis
|
||||
from agent.crisis_middleware import (
|
||||
CrisisMiddleware,
|
||||
check_crisis,
|
||||
crisis_aware,
|
||||
get_crisis_middleware,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Crisis Detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCrisisDetection:
|
||||
"""Pattern detection for crisis signals."""
|
||||
|
||||
def test_direct_intent_want_to_die(self):
|
||||
result = detect_crisis("I want to die")
|
||||
assert result is not None
|
||||
assert result["severity_label"] == "CRITICAL"
|
||||
|
||||
def test_kill_myself(self):
|
||||
result = detect_crisis("I want to kill myself")
|
||||
assert result is not None
|
||||
assert result["severity_label"] == "CRITICAL"
|
||||
|
||||
def test_no_reason_to_live(self):
|
||||
result = detect_crisis("I have no reason to live anymore")
|
||||
assert result is not None
|
||||
|
||||
def test_cant_go_on(self):
|
||||
result = detect_crisis("I can't go on like this")
|
||||
assert result is not None
|
||||
|
||||
def test_self_harm(self):
|
||||
result = detect_crisis("I've been cutting myself")
|
||||
assert result is not None
|
||||
assert result["severity_label"] == "HIGH"
|
||||
|
||||
def test_hopeless(self):
|
||||
result = detect_crisis("Everything feels hopeless")
|
||||
assert result is not None
|
||||
|
||||
def test_no_false_positive_happy(self):
|
||||
result = detect_crisis("I had a great day today!")
|
||||
assert result is None
|
||||
|
||||
def test_no_false_positive_technical(self):
|
||||
result = detect_crisis("Kill the process and restart the server")
|
||||
assert result is None
|
||||
|
||||
def test_empty_message(self):
|
||||
result = detect_crisis("")
|
||||
assert result is None
|
||||
|
||||
def test_short_message(self):
|
||||
result = detect_crisis("hi")
|
||||
assert result is None
|
||||
|
||||
def test_confidence_increases_with_matches(self):
|
||||
single = detect_crisis("I want to die")
|
||||
multi = detect_crisis("I want to die, I'm suicidal and want to kill myself")
|
||||
assert multi is not None and single is not None
|
||||
assert multi["confidence"] >= single["confidence"]
|
||||
|
||||
def test_detection_result_has_required_fields(self):
|
||||
result = detect_crisis("I want to die")
|
||||
assert result is not None
|
||||
assert "detected" in result
|
||||
assert "severity" in result
|
||||
assert "severity_label" in result
|
||||
assert "matched_count" in result
|
||||
assert "confidence" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 988 Lifeline Response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCrisisResponse:
|
||||
"""988 Lifeline response generation."""
|
||||
|
||||
def test_response_contains_988(self):
|
||||
resp = get_crisis_response()
|
||||
assert "988" in resp
|
||||
|
||||
def test_response_contains_phone(self):
|
||||
resp = get_crisis_response()
|
||||
assert "Call 988" in resp or "Dial 988" in resp
|
||||
|
||||
def test_response_contains_text_home(self):
|
||||
resp = get_crisis_response()
|
||||
assert "HOME" in resp
|
||||
|
||||
def test_response_contains_chat_link(self):
|
||||
resp = get_crisis_response()
|
||||
assert "988lifeline.org/chat" in resp
|
||||
|
||||
def test_response_contains_spanish(self):
|
||||
resp = get_crisis_response()
|
||||
assert "1-888-628-9454" in resp
|
||||
|
||||
def test_response_is_empathetic(self):
|
||||
resp = get_crisis_response()
|
||||
assert any(word in resp.lower() for word in ("matter", "help", "hear you", "alone"))
|
||||
|
||||
def test_response_includes_emergency_911(self):
|
||||
resp = get_crisis_response()
|
||||
assert "911" in resp
|
||||
|
||||
def test_response_for_all_severity_levels(self):
|
||||
for level in ("CRITICAL", "HIGH", "MODERATE"):
|
||||
resp = get_crisis_response(level)
|
||||
assert "988" in resp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# should_trigger_crisis_response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestShouldTrigger:
|
||||
def test_triggers_on_critical(self):
|
||||
triggered, detection = should_trigger_crisis_response("I want to die")
|
||||
assert triggered is True
|
||||
assert detection is not None
|
||||
assert detection["severity_label"] == "CRITICAL"
|
||||
|
||||
def test_triggers_on_high(self):
|
||||
triggered, detection = should_trigger_crisis_response("I've been cutting myself")
|
||||
assert triggered is True
|
||||
|
||||
def test_no_trigger_on_normal(self):
|
||||
triggered, detection = should_trigger_crisis_response("What is the weather today?")
|
||||
assert triggered is False
|
||||
|
||||
def test_no_trigger_on_empty(self):
|
||||
triggered, detection = should_trigger_crisis_response("")
|
||||
assert triggered is False
|
||||
assert detection is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Crisis Hook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCrisisHook:
|
||||
"""Integration hook for conversation loop."""
|
||||
|
||||
def test_hook_triggers_on_crisis(self):
|
||||
resp = hook_check_crisis("I want to end it all")
|
||||
assert resp is not None
|
||||
assert "988" in resp
|
||||
|
||||
def test_hook_returns_none_on_normal(self):
|
||||
resp = hook_check_crisis("What's the weather today?")
|
||||
assert resp is None
|
||||
|
||||
def test_hook_returns_none_on_empty(self):
|
||||
resp = hook_check_crisis("")
|
||||
assert resp is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CrisisMiddleware class
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCrisisMiddleware:
|
||||
def test_disabled_middleware_returns_none(self):
|
||||
middleware = CrisisMiddleware(enabled=False)
|
||||
assert middleware.check("I want to die") is None
|
||||
|
||||
def test_crisis_detected_returns_response(self):
|
||||
with patch.object(CrisisMiddleware, "_load_crisis_module"):
|
||||
middleware = CrisisMiddleware(enabled=True)
|
||||
middleware._detection_func = lambda msg: (True, {"severity_label": "CRITICAL", "confidence": 0.9})
|
||||
middleware._response_func = lambda sev: "988 Lifeline: Call 988"
|
||||
result = middleware.check("I want to die")
|
||||
assert result == "988 Lifeline: Call 988"
|
||||
|
||||
def test_no_crisis_returns_none(self):
|
||||
with patch.object(CrisisMiddleware, "_load_crisis_module"):
|
||||
middleware = CrisisMiddleware(enabled=True)
|
||||
middleware._detection_func = lambda msg: (False, {})
|
||||
result = middleware.check("Hello, how are you?")
|
||||
assert result is None
|
||||
|
||||
def test_detection_error_returns_none(self):
|
||||
with patch.object(CrisisMiddleware, "_load_crisis_module"):
|
||||
middleware = CrisisMiddleware(enabled=True)
|
||||
middleware._detection_func = lambda msg: (_ for _ in ()).throw(RuntimeError("boom"))
|
||||
# Should not raise — returns None on error
|
||||
result = middleware.check("I want to die")
|
||||
assert result is None
|
||||
|
||||
def test_is_crisis_message_true(self):
|
||||
with patch.object(CrisisMiddleware, "_load_crisis_module"):
|
||||
middleware = CrisisMiddleware(enabled=True)
|
||||
middleware._detection_func = lambda msg: (True, {})
|
||||
assert middleware.is_crisis_message("I want to die") is True
|
||||
|
||||
def test_is_crisis_message_false(self):
|
||||
with patch.object(CrisisMiddleware, "_load_crisis_module"):
|
||||
middleware = CrisisMiddleware(enabled=True)
|
||||
middleware._detection_func = lambda msg: (False, {})
|
||||
assert middleware.is_crisis_message("Hello") is False
|
||||
|
||||
def test_check_with_context_returns_dict(self):
|
||||
with patch.object(CrisisMiddleware, "_load_crisis_module"):
|
||||
middleware = CrisisMiddleware(enabled=True)
|
||||
middleware._detection_func = lambda msg: (True, {"severity_label": "CRITICAL", "confidence": 0.9})
|
||||
middleware._response_func = lambda sev: "988 response"
|
||||
result = middleware.check_with_context("I want to die", {"session_id": "abc123"})
|
||||
assert result is not None
|
||||
assert result["response"] == "988 response"
|
||||
assert result["severity"] == "CRITICAL"
|
||||
assert result["context"]["session_id"] == "abc123"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_crisis convenience function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckCrisisFunction:
|
||||
def test_crisis_message_returns_response(self):
|
||||
with patch("agent.crisis_middleware.get_crisis_middleware") as mock_get:
|
||||
mock_middleware = MagicMock()
|
||||
mock_middleware.check.return_value = "988 Lifeline info"
|
||||
mock_get.return_value = mock_middleware
|
||||
result = check_crisis("I want to die")
|
||||
assert result == "988 Lifeline info"
|
||||
|
||||
def test_normal_message_returns_none(self):
|
||||
with patch("agent.crisis_middleware.get_crisis_middleware") as mock_get:
|
||||
mock_middleware = MagicMock()
|
||||
mock_middleware.check.return_value = None
|
||||
mock_get.return_value = mock_middleware
|
||||
result = check_crisis("Hello")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# @crisis_aware decorator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCrisisAwareDecorator:
|
||||
def test_decorator_intercepts_crisis(self):
|
||||
with patch("agent.crisis_middleware.check_crisis") as mock_check:
|
||||
mock_check.return_value = "988 response"
|
||||
|
||||
@crisis_aware
|
||||
def process_message(self, msg):
|
||||
return "normal response"
|
||||
|
||||
result = process_message(None, "I want to die")
|
||||
assert result == "988 response"
|
||||
|
||||
def test_decorator_passes_through_normal(self):
|
||||
with patch("agent.crisis_middleware.check_crisis") as mock_check:
|
||||
mock_check.return_value = None
|
||||
|
||||
@crisis_aware
|
||||
def process_message(self, msg):
|
||||
return f"processed: {msg}"
|
||||
|
||||
result = process_message(None, "Hello")
|
||||
assert result == "processed: Hello"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 988 Resource data integrity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Test988Resources:
|
||||
def test_phone_value_is_988(self):
|
||||
phone = next(c for c in LIFELINE_988["channels"] if c["type"] == "phone")
|
||||
assert phone["value"] == "988"
|
||||
|
||||
def test_text_value_is_988(self):
|
||||
text = next(c for c in LIFELINE_988["channels"] if c["type"] == "text")
|
||||
assert text["value"] == "988"
|
||||
|
||||
def test_chat_url_contains_988lifeline(self):
|
||||
chat = next(c for c in LIFELINE_988["channels"] if c["type"] == "chat")
|
||||
assert "988lifeline.org" in chat["value"]
|
||||
|
||||
def test_spanish_line_present(self):
|
||||
assert "888-628-9454" in LIFELINE_988["spanish"]["phone"]
|
||||
|
||||
def test_veterans_line_present(self):
|
||||
assert "988" in LIFELINE_988["veterans"]["phone"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# End-to-end: crisis middleware integration with run_agent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRunAgentCrisisIntegration:
|
||||
"""Verify the crisis integration path in AIAgent.run_conversation()."""
|
||||
|
||||
def test_crisis_response_contains_988_e2e(self):
|
||||
"""Simulated crisis message produces a full 988 Lifeline response."""
|
||||
from agent.crisis_middleware import check_crisis as real_check_crisis
|
||||
response = real_check_crisis("I want to kill myself")
|
||||
assert response is not None
|
||||
assert "988" in response
|
||||
assert "988lifeline.org" in response
|
||||
|
||||
def test_normal_message_not_intercepted(self):
|
||||
"""Normal messages pass through the crisis check without a response."""
|
||||
result = check_crisis("What is the weather today?")
|
||||
assert result is None
|
||||
|
||||
def test_run_conversation_imports_crisis_middleware(self):
|
||||
"""Verify run_agent.py contains the crisis middleware import."""
|
||||
import ast
|
||||
with open("run_agent.py") as f:
|
||||
source = f.read()
|
||||
assert "from agent.crisis_middleware import check_crisis" in source
|
||||
assert "crisis_intercepted" in source
|
||||
|
||||
def test_crisis_result_structure(self):
|
||||
"""The early-return dict from the crisis path has required keys."""
|
||||
# We can verify the structure by reading the source and checking
|
||||
# that all expected keys are present.
|
||||
with open("run_agent.py") as f:
|
||||
source = f.read()
|
||||
# These keys must appear in the crisis early-return block
|
||||
for key in ("final_response", "api_calls", "completed", "crisis_intercepted"):
|
||||
assert f'"{key}"' in source, f"Expected key '{key}' missing from crisis return block"
|
||||
122
tests/test_approval_tiers.py
Normal file
122
tests/test_approval_tiers.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Tests for approval tier system
|
||||
|
||||
Issue: #670
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from tools.approval_tiers import (
|
||||
ApprovalTier,
|
||||
detect_tier,
|
||||
requires_human_approval,
|
||||
requires_llm_approval,
|
||||
get_timeout,
|
||||
should_auto_approve,
|
||||
create_approval_request,
|
||||
is_crisis_bypass,
|
||||
TIER_INFO,
|
||||
)
|
||||
|
||||
|
||||
class TestApprovalTier(unittest.TestCase):
|
||||
|
||||
def test_tier_values(self):
|
||||
self.assertEqual(ApprovalTier.SAFE, 0)
|
||||
self.assertEqual(ApprovalTier.LOW, 1)
|
||||
self.assertEqual(ApprovalTier.MEDIUM, 2)
|
||||
self.assertEqual(ApprovalTier.HIGH, 3)
|
||||
self.assertEqual(ApprovalTier.CRITICAL, 4)
|
||||
|
||||
|
||||
class TestTierDetection(unittest.TestCase):
|
||||
|
||||
def test_safe_actions(self):
|
||||
self.assertEqual(detect_tier("read_file"), ApprovalTier.SAFE)
|
||||
self.assertEqual(detect_tier("web_search"), ApprovalTier.SAFE)
|
||||
self.assertEqual(detect_tier("session_search"), ApprovalTier.SAFE)
|
||||
|
||||
def test_low_actions(self):
|
||||
self.assertEqual(detect_tier("write_file"), ApprovalTier.LOW)
|
||||
self.assertEqual(detect_tier("terminal"), ApprovalTier.LOW)
|
||||
self.assertEqual(detect_tier("execute_code"), ApprovalTier.LOW)
|
||||
|
||||
def test_medium_actions(self):
|
||||
self.assertEqual(detect_tier("send_message"), ApprovalTier.MEDIUM)
|
||||
self.assertEqual(detect_tier("git_push"), ApprovalTier.MEDIUM)
|
||||
|
||||
def test_high_actions(self):
|
||||
self.assertEqual(detect_tier("config_change"), ApprovalTier.HIGH)
|
||||
self.assertEqual(detect_tier("key_rotation"), ApprovalTier.HIGH)
|
||||
|
||||
def test_critical_actions(self):
|
||||
self.assertEqual(detect_tier("kill_process"), ApprovalTier.CRITICAL)
|
||||
self.assertEqual(detect_tier("shutdown"), ApprovalTier.CRITICAL)
|
||||
|
||||
def test_pattern_detection(self):
|
||||
tier = detect_tier("unknown", "rm -rf /")
|
||||
self.assertEqual(tier, ApprovalTier.CRITICAL)
|
||||
|
||||
tier = detect_tier("unknown", "sudo apt install")
|
||||
self.assertEqual(tier, ApprovalTier.MEDIUM)
|
||||
|
||||
|
||||
class TestTierInfo(unittest.TestCase):
|
||||
|
||||
def test_safe_no_approval(self):
|
||||
self.assertFalse(requires_human_approval(ApprovalTier.SAFE))
|
||||
self.assertFalse(requires_llm_approval(ApprovalTier.SAFE))
|
||||
self.assertIsNone(get_timeout(ApprovalTier.SAFE))
|
||||
|
||||
def test_medium_requires_both(self):
|
||||
self.assertTrue(requires_human_approval(ApprovalTier.MEDIUM))
|
||||
self.assertTrue(requires_llm_approval(ApprovalTier.MEDIUM))
|
||||
self.assertEqual(get_timeout(ApprovalTier.MEDIUM), 60)
|
||||
|
||||
def test_critical_fast_timeout(self):
|
||||
self.assertEqual(get_timeout(ApprovalTier.CRITICAL), 10)
|
||||
|
||||
|
||||
class TestAutoApprove(unittest.TestCase):
|
||||
|
||||
def test_safe_auto_approves(self):
|
||||
self.assertTrue(should_auto_approve("read_file"))
|
||||
self.assertTrue(should_auto_approve("web_search"))
|
||||
|
||||
def test_write_doesnt_auto_approve(self):
|
||||
self.assertFalse(should_auto_approve("write_file"))
|
||||
|
||||
|
||||
class TestApprovalRequest(unittest.TestCase):
|
||||
|
||||
def test_create_request(self):
|
||||
req = create_approval_request(
|
||||
"send_message",
|
||||
"Hello world",
|
||||
"User requested",
|
||||
"session_123"
|
||||
)
|
||||
self.assertEqual(req.tier, ApprovalTier.MEDIUM)
|
||||
self.assertEqual(req.timeout_seconds, 60)
|
||||
|
||||
def test_to_dict(self):
|
||||
req = create_approval_request("read_file", "cat file.txt", "test", "s1")
|
||||
d = req.to_dict()
|
||||
self.assertEqual(d["tier"], 0)
|
||||
self.assertEqual(d["tier_name"], "Safe")
|
||||
|
||||
|
||||
class TestCrisisBypass(unittest.TestCase):
|
||||
|
||||
def test_send_message_bypass(self):
|
||||
self.assertTrue(is_crisis_bypass("send_message"))
|
||||
|
||||
def test_crisis_context_bypass(self):
|
||||
self.assertTrue(is_crisis_bypass("unknown", "call 988 lifeline"))
|
||||
self.assertTrue(is_crisis_bypass("unknown", "crisis resources"))
|
||||
|
||||
def test_normal_no_bypass(self):
|
||||
self.assertFalse(is_crisis_bypass("read_file"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
55
tests/test_error_classifier.py
Normal file
55
tests/test_error_classifier.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Tests for error classification (#752).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from tools.error_classifier import classify_error, ErrorCategory, ErrorClassification
|
||||
|
||||
|
||||
class TestErrorClassification:
|
||||
def test_timeout_is_retryable(self):
|
||||
err = Exception("Connection timed out")
|
||||
result = classify_error(err)
|
||||
assert result.category == ErrorCategory.RETRYABLE
|
||||
assert result.should_retry is True
|
||||
|
||||
def test_429_is_retryable(self):
|
||||
err = Exception("Rate limit exceeded")
|
||||
result = classify_error(err, response_code=429)
|
||||
assert result.category == ErrorCategory.RETRYABLE
|
||||
assert result.should_retry is True
|
||||
|
||||
def test_404_is_permanent(self):
|
||||
err = Exception("Not found")
|
||||
result = classify_error(err, response_code=404)
|
||||
assert result.category == ErrorCategory.PERMANENT
|
||||
assert result.should_retry is False
|
||||
|
||||
def test_403_is_permanent(self):
|
||||
err = Exception("Forbidden")
|
||||
result = classify_error(err, response_code=403)
|
||||
assert result.category == ErrorCategory.PERMANENT
|
||||
assert result.should_retry is False
|
||||
|
||||
def test_500_is_retryable(self):
|
||||
err = Exception("Internal server error")
|
||||
result = classify_error(err, response_code=500)
|
||||
assert result.category == ErrorCategory.RETRYABLE
|
||||
assert result.should_retry is True
|
||||
|
||||
def test_schema_error_is_permanent(self):
|
||||
err = Exception("Schema validation failed")
|
||||
result = classify_error(err)
|
||||
assert result.category == ErrorCategory.PERMANENT
|
||||
assert result.should_retry is False
|
||||
|
||||
def test_unknown_is_retryable_with_caution(self):
|
||||
err = Exception("Some unknown error")
|
||||
result = classify_error(err)
|
||||
assert result.category == ErrorCategory.UNKNOWN
|
||||
assert result.should_retry is True
|
||||
assert result.max_retries == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
261
tools/approval_tiers.py
Normal file
261
tools/approval_tiers.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
Approval Tier System — Graduated safety based on risk level
|
||||
|
||||
Extends approval.py with 5-tier system for command approval.
|
||||
|
||||
| Tier | Action | Human | LLM | Timeout |
|
||||
|------|-----------------|-------|-----|---------|
|
||||
| 0 | Read, search | No | No | N/A |
|
||||
| 1 | Write, scripts | No | Yes | N/A |
|
||||
| 2 | Messages, API | Yes | Yes | 60s |
|
||||
| 3 | Crypto, config | Yes | Yes | 30s |
|
||||
| 4 | Crisis | Yes | Yes | 10s |
|
||||
|
||||
Issue: #670
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
class ApprovalTier(IntEnum):
|
||||
"""Approval tiers based on risk level."""
|
||||
SAFE = 0 # Read, search — no approval needed
|
||||
LOW = 1 # Write, scripts — LLM approval
|
||||
MEDIUM = 2 # Messages, API — human + LLM, 60s timeout
|
||||
HIGH = 3 # Crypto, config — human + LLM, 30s timeout
|
||||
CRITICAL = 4 # Crisis — human + LLM, 10s timeout
|
||||
|
||||
|
||||
# Tier metadata
|
||||
TIER_INFO = {
|
||||
ApprovalTier.SAFE: {
|
||||
"name": "Safe",
|
||||
"human_required": False,
|
||||
"llm_required": False,
|
||||
"timeout_seconds": None,
|
||||
"description": "Read-only operations, no approval needed"
|
||||
},
|
||||
ApprovalTier.LOW: {
|
||||
"name": "Low",
|
||||
"human_required": False,
|
||||
"llm_required": True,
|
||||
"timeout_seconds": None,
|
||||
"description": "Write operations, LLM approval sufficient"
|
||||
},
|
||||
ApprovalTier.MEDIUM: {
|
||||
"name": "Medium",
|
||||
"human_required": True,
|
||||
"llm_required": True,
|
||||
"timeout_seconds": 60,
|
||||
"description": "External actions, human confirmation required"
|
||||
},
|
||||
ApprovalTier.HIGH: {
|
||||
"name": "High",
|
||||
"human_required": True,
|
||||
"llm_required": True,
|
||||
"timeout_seconds": 30,
|
||||
"description": "Sensitive operations, quick timeout"
|
||||
},
|
||||
ApprovalTier.CRITICAL: {
|
||||
"name": "Critical",
|
||||
"human_required": True,
|
||||
"llm_required": True,
|
||||
"timeout_seconds": 10,
|
||||
"description": "Crisis or dangerous operations, fastest timeout"
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Action-to-tier mapping
|
||||
ACTION_TIERS: Dict[str, ApprovalTier] = {
|
||||
# Tier 0: Safe (read-only)
|
||||
"read_file": ApprovalTier.SAFE,
|
||||
"search_files": ApprovalTier.SAFE,
|
||||
"web_search": ApprovalTier.SAFE,
|
||||
"session_search": ApprovalTier.SAFE,
|
||||
"list_files": ApprovalTier.SAFE,
|
||||
"get_file_content": ApprovalTier.SAFE,
|
||||
"memory_search": ApprovalTier.SAFE,
|
||||
"skills_list": ApprovalTier.SAFE,
|
||||
"skills_search": ApprovalTier.SAFE,
|
||||
|
||||
# Tier 1: Low (write operations)
|
||||
"write_file": ApprovalTier.LOW,
|
||||
"create_file": ApprovalTier.LOW,
|
||||
"patch_file": ApprovalTier.LOW,
|
||||
"delete_file": ApprovalTier.LOW,
|
||||
"execute_code": ApprovalTier.LOW,
|
||||
"terminal": ApprovalTier.LOW,
|
||||
"run_script": ApprovalTier.LOW,
|
||||
"skill_install": ApprovalTier.LOW,
|
||||
|
||||
# Tier 2: Medium (external actions)
|
||||
"send_message": ApprovalTier.MEDIUM,
|
||||
"web_fetch": ApprovalTier.MEDIUM,
|
||||
"browser_navigate": ApprovalTier.MEDIUM,
|
||||
"api_call": ApprovalTier.MEDIUM,
|
||||
"gitea_create_issue": ApprovalTier.MEDIUM,
|
||||
"gitea_create_pr": ApprovalTier.MEDIUM,
|
||||
"git_push": ApprovalTier.MEDIUM,
|
||||
"deploy": ApprovalTier.MEDIUM,
|
||||
|
||||
# Tier 3: High (sensitive operations)
|
||||
"config_change": ApprovalTier.HIGH,
|
||||
"env_change": ApprovalTier.HIGH,
|
||||
"key_rotation": ApprovalTier.HIGH,
|
||||
"access_grant": ApprovalTier.HIGH,
|
||||
"permission_change": ApprovalTier.HIGH,
|
||||
"backup_restore": ApprovalTier.HIGH,
|
||||
|
||||
# Tier 4: Critical (crisis/dangerous)
|
||||
"kill_process": ApprovalTier.CRITICAL,
|
||||
"rm_rf": ApprovalTier.CRITICAL,
|
||||
"format_disk": ApprovalTier.CRITICAL,
|
||||
"shutdown": ApprovalTier.CRITICAL,
|
||||
"crisis_override": ApprovalTier.CRITICAL,
|
||||
}
|
||||
|
||||
|
||||
# Dangerous command patterns (from existing approval.py)
|
||||
_DANGEROUS_PATTERNS = [
|
||||
(r"rm\s+-rf\s+/", ApprovalTier.CRITICAL),
|
||||
(r"mkfs\.", ApprovalTier.CRITICAL),
|
||||
(r"dd\s+if=.*of=/dev/", ApprovalTier.CRITICAL),
|
||||
(r"shutdown|reboot|halt", ApprovalTier.CRITICAL),
|
||||
(r"chmod\s+777", ApprovalTier.HIGH),
|
||||
(r"curl.*\|\s*bash", ApprovalTier.HIGH),
|
||||
(r"wget.*\|\s*sh", ApprovalTier.HIGH),
|
||||
(r"eval\s*\(", ApprovalTier.HIGH),
|
||||
(r"sudo\s+", ApprovalTier.MEDIUM),
|
||||
(r"git\s+push.*--force", ApprovalTier.HIGH),
|
||||
(r"docker\s+rm.*-f", ApprovalTier.MEDIUM),
|
||||
(r"kubectl\s+delete", ApprovalTier.HIGH),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApprovalRequest:
|
||||
"""A request for approval."""
|
||||
action: str
|
||||
tier: ApprovalTier
|
||||
command: str
|
||||
reason: str
|
||||
session_key: str
|
||||
timeout_seconds: Optional[int] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"action": self.action,
|
||||
"tier": self.tier.value,
|
||||
"tier_name": TIER_INFO[self.tier]["name"],
|
||||
"command": self.command,
|
||||
"reason": self.reason,
|
||||
"session_key": self.session_key,
|
||||
"timeout": self.timeout_seconds,
|
||||
"human_required": TIER_INFO[self.tier]["human_required"],
|
||||
"llm_required": TIER_INFO[self.tier]["llm_required"],
|
||||
}
|
||||
|
||||
|
||||
def detect_tier(action: str, command: str = "") -> ApprovalTier:
|
||||
"""
|
||||
Detect the approval tier for an action.
|
||||
|
||||
Checks action name first, then falls back to pattern matching.
|
||||
"""
|
||||
# Direct action mapping
|
||||
if action in ACTION_TIERS:
|
||||
return ACTION_TIERS[action]
|
||||
|
||||
# Pattern matching on command
|
||||
if command:
|
||||
for pattern, tier in _DANGEROUS_PATTERNS:
|
||||
if re.search(pattern, command, re.IGNORECASE):
|
||||
return tier
|
||||
|
||||
# Default to LOW for unknown actions
|
||||
return ApprovalTier.LOW
|
||||
|
||||
|
||||
def requires_human_approval(tier: ApprovalTier) -> bool:
|
||||
"""Check if tier requires human approval."""
|
||||
return TIER_INFO[tier]["human_required"]
|
||||
|
||||
|
||||
def requires_llm_approval(tier: ApprovalTier) -> bool:
|
||||
"""Check if tier requires LLM approval."""
|
||||
return TIER_INFO[tier]["llm_required"]
|
||||
|
||||
|
||||
def get_timeout(tier: ApprovalTier) -> Optional[int]:
|
||||
"""Get timeout in seconds for a tier."""
|
||||
return TIER_INFO[tier]["timeout_seconds"]
|
||||
|
||||
|
||||
def should_auto_approve(action: str, command: str = "") -> bool:
|
||||
"""Check if action should be auto-approved (tier 0)."""
|
||||
tier = detect_tier(action, command)
|
||||
return tier == ApprovalTier.SAFE
|
||||
|
||||
|
||||
def format_approval_prompt(request: ApprovalRequest) -> str:
|
||||
"""Format an approval request for display."""
|
||||
info = TIER_INFO[request.tier]
|
||||
lines = []
|
||||
lines.append(f"⚠️ Approval Required (Tier {request.tier.value}: {info['name']})")
|
||||
lines.append(f"")
|
||||
lines.append(f"Action: {request.action}")
|
||||
lines.append(f"Command: {request.command[:100]}{'...' if len(request.command) > 100 else ''}")
|
||||
lines.append(f"Reason: {request.reason}")
|
||||
lines.append(f"")
|
||||
|
||||
if info["human_required"]:
|
||||
lines.append(f"👤 Human approval required")
|
||||
if info["llm_required"]:
|
||||
lines.append(f"🤖 LLM approval required")
|
||||
if info["timeout_seconds"]:
|
||||
lines.append(f"⏱️ Timeout: {info['timeout_seconds']}s")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def create_approval_request(
|
||||
action: str,
|
||||
command: str,
|
||||
reason: str,
|
||||
session_key: str
|
||||
) -> ApprovalRequest:
|
||||
"""Create an approval request for an action."""
|
||||
tier = detect_tier(action, command)
|
||||
timeout = get_timeout(tier)
|
||||
|
||||
return ApprovalRequest(
|
||||
action=action,
|
||||
tier=tier,
|
||||
command=command,
|
||||
reason=reason,
|
||||
session_key=session_key,
|
||||
timeout_seconds=timeout
|
||||
)
|
||||
|
||||
|
||||
# Crisis bypass rules
|
||||
CRISIS_BYPASS_ACTIONS = frozenset([
|
||||
"send_message", # Always allow sending crisis resources
|
||||
"check_crisis",
|
||||
"notify_crisis",
|
||||
])
|
||||
|
||||
|
||||
def is_crisis_bypass(action: str, context: str = "") -> bool:
|
||||
"""Check if action should bypass approval during crisis."""
|
||||
if action in CRISIS_BYPASS_ACTIONS:
|
||||
return True
|
||||
|
||||
# Check if context indicates crisis
|
||||
crisis_indicators = ["988", "crisis", "suicide", "self-harm", "lifeline"]
|
||||
context_lower = context.lower()
|
||||
return any(indicator in context_lower for indicator in crisis_indicators)
|
||||
233
tools/error_classifier.py
Normal file
233
tools/error_classifier.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
Tool Error Classification — Retryable vs Permanent.
|
||||
|
||||
Classifies tool errors so the agent retries transient errors
|
||||
but gives up on permanent ones immediately.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ErrorCategory(Enum):
|
||||
"""Error category classification."""
|
||||
RETRYABLE = "retryable"
|
||||
PERMANENT = "permanent"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErrorClassification:
|
||||
"""Result of error classification."""
|
||||
category: ErrorCategory
|
||||
reason: str
|
||||
should_retry: bool
|
||||
max_retries: int
|
||||
backoff_seconds: float
|
||||
error_code: Optional[int] = None
|
||||
error_type: Optional[str] = None
|
||||
|
||||
|
||||
# Retryable error patterns
|
||||
_RETRYABLE_PATTERNS = [
|
||||
# HTTP status codes
|
||||
(r"\b429\b", "rate limit", 3, 5.0),
|
||||
(r"\b500\b", "server error", 3, 2.0),
|
||||
(r"\b502\b", "bad gateway", 3, 2.0),
|
||||
(r"\b503\b", "service unavailable", 3, 5.0),
|
||||
(r"\b504\b", "gateway timeout", 3, 5.0),
|
||||
|
||||
# Timeout patterns
|
||||
(r"timeout", "timeout", 3, 2.0),
|
||||
(r"timed out", "timeout", 3, 2.0),
|
||||
(r"TimeoutExpired", "timeout", 3, 2.0),
|
||||
|
||||
# Connection errors
|
||||
(r"connection refused", "connection refused", 2, 5.0),
|
||||
(r"connection reset", "connection reset", 2, 2.0),
|
||||
(r"network unreachable", "network unreachable", 2, 10.0),
|
||||
(r"DNS", "DNS error", 2, 5.0),
|
||||
|
||||
# Transient errors
|
||||
(r"temporary", "temporary error", 2, 2.0),
|
||||
(r"transient", "transient error", 2, 2.0),
|
||||
(r"retry", "retryable", 2, 2.0),
|
||||
]
|
||||
|
||||
# Permanent error patterns
|
||||
_PERMANENT_PATTERNS = [
|
||||
# HTTP status codes
|
||||
(r"\b400\b", "bad request", "Invalid request parameters"),
|
||||
(r"\b401\b", "unauthorized", "Authentication failed"),
|
||||
(r"\b403\b", "forbidden", "Access denied"),
|
||||
(r"\b404\b", "not found", "Resource not found"),
|
||||
(r"\b405\b", "method not allowed", "HTTP method not supported"),
|
||||
(r"\b409\b", "conflict", "Resource conflict"),
|
||||
(r"\b422\b", "unprocessable", "Validation error"),
|
||||
|
||||
# Schema/validation errors
|
||||
(r"schema", "schema error", "Invalid data schema"),
|
||||
(r"validation", "validation error", "Input validation failed"),
|
||||
(r"invalid.*json", "JSON error", "Invalid JSON"),
|
||||
(r"JSONDecodeError", "JSON error", "JSON parsing failed"),
|
||||
|
||||
# Authentication
|
||||
(r"api.?key", "API key error", "Invalid or missing API key"),
|
||||
(r"token.*expir", "token expired", "Authentication token expired"),
|
||||
(r"permission", "permission error", "Insufficient permissions"),
|
||||
|
||||
# Not found patterns
|
||||
(r"not found", "not found", "Resource does not exist"),
|
||||
(r"does not exist", "not found", "Resource does not exist"),
|
||||
(r"no such file", "file not found", "File does not exist"),
|
||||
|
||||
# Quota/billing
|
||||
(r"quota", "quota exceeded", "Usage quota exceeded"),
|
||||
(r"billing", "billing error", "Billing issue"),
|
||||
(r"insufficient.*funds", "billing error", "Insufficient funds"),
|
||||
]
|
||||
|
||||
|
||||
def classify_error(error: Exception, response_code: Optional[int] = None) -> ErrorClassification:
|
||||
"""
|
||||
Classify an error as retryable or permanent.
|
||||
|
||||
Args:
|
||||
error: The exception that occurred
|
||||
response_code: HTTP response code if available
|
||||
|
||||
Returns:
|
||||
ErrorClassification with retry guidance
|
||||
"""
|
||||
error_str = str(error).lower()
|
||||
error_type = type(error).__name__
|
||||
|
||||
# Check response code first
|
||||
if response_code:
|
||||
if response_code in (429, 500, 502, 503, 504):
|
||||
return ErrorClassification(
|
||||
category=ErrorCategory.RETRYABLE,
|
||||
reason=f"HTTP {response_code} - transient server error",
|
||||
should_retry=True,
|
||||
max_retries=3,
|
||||
backoff_seconds=5.0 if response_code == 429 else 2.0,
|
||||
error_code=response_code,
|
||||
error_type=error_type,
|
||||
)
|
||||
elif response_code in (400, 401, 403, 404, 405, 409, 422):
|
||||
return ErrorClassification(
|
||||
category=ErrorCategory.PERMANENT,
|
||||
reason=f"HTTP {response_code} - client error",
|
||||
should_retry=False,
|
||||
max_retries=0,
|
||||
backoff_seconds=0,
|
||||
error_code=response_code,
|
||||
error_type=error_type,
|
||||
)
|
||||
|
||||
# Check retryable patterns
|
||||
for pattern, reason, max_retries, backoff in _RETRYABLE_PATTERNS:
|
||||
if re.search(pattern, error_str, re.IGNORECASE):
|
||||
return ErrorClassification(
|
||||
category=ErrorCategory.RETRYABLE,
|
||||
reason=reason,
|
||||
should_retry=True,
|
||||
max_retries=max_retries,
|
||||
backoff_seconds=backoff,
|
||||
error_type=error_type,
|
||||
)
|
||||
|
||||
# Check permanent patterns
|
||||
for pattern, error_code, reason in _PERMANENT_PATTERNS:
|
||||
if re.search(pattern, error_str, re.IGNORECASE):
|
||||
return ErrorClassification(
|
||||
category=ErrorCategory.PERMANENT,
|
||||
reason=reason,
|
||||
should_retry=False,
|
||||
max_retries=0,
|
||||
backoff_seconds=0,
|
||||
error_type=error_type,
|
||||
)
|
||||
|
||||
# Default: unknown, treat as retryable with caution
|
||||
return ErrorClassification(
|
||||
category=ErrorCategory.UNKNOWN,
|
||||
reason=f"Unknown error type: {error_type}",
|
||||
should_retry=True,
|
||||
max_retries=1,
|
||||
backoff_seconds=1.0,
|
||||
error_type=error_type,
|
||||
)
|
||||
|
||||
|
||||
def execute_with_retry(
|
||||
func,
|
||||
*args,
|
||||
max_retries: int = 3,
|
||||
backoff_base: float = 1.0,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
Execute a function with automatic retry on retryable errors.
|
||||
|
||||
Args:
|
||||
func: Function to execute
|
||||
*args: Function arguments
|
||||
max_retries: Maximum retry attempts
|
||||
backoff_base: Base backoff time in seconds
|
||||
**kwargs: Function keyword arguments
|
||||
|
||||
Returns:
|
||||
Function result
|
||||
|
||||
Raises:
|
||||
Exception: If permanent error or max retries exceeded
|
||||
"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
# Classify the error
|
||||
classification = classify_error(e)
|
||||
|
||||
logger.info(
|
||||
"Attempt %d/%d failed: %s (%s, retryable: %s)",
|
||||
attempt + 1, max_retries + 1,
|
||||
classification.reason,
|
||||
classification.category.value,
|
||||
classification.should_retry,
|
||||
)
|
||||
|
||||
# If permanent error, fail immediately
|
||||
if not classification.should_retry:
|
||||
logger.error("Permanent error: %s", classification.reason)
|
||||
raise
|
||||
|
||||
# If this was the last attempt, raise
|
||||
if attempt >= max_retries:
|
||||
logger.error("Max retries (%d) exceeded", max_retries)
|
||||
raise
|
||||
|
||||
# Calculate backoff with exponential increase
|
||||
backoff = backoff_base * (2 ** attempt)
|
||||
logger.info("Retrying in %.1fs...", backoff)
|
||||
time.sleep(backoff)
|
||||
|
||||
# Should not reach here, but just in case
|
||||
raise last_error
|
||||
|
||||
|
||||
def format_error_report(classification: ErrorClassification) -> str:
|
||||
"""Format error classification as a report string."""
|
||||
icon = "🔄" if classification.should_retry else "❌"
|
||||
return f"{icon} {classification.category.value}: {classification.reason}"
|
||||
Reference in New Issue
Block a user