Compare commits

...

3 Commits

Author SHA1 Message Date
Alexander Whitestone
e66156055f feat(crisis): integrate 988 Lifeline into conversation loop
Some checks are pending
Contributor Attribution Check / check-attribution (pull_request) Waiting to run
Docker Build and Publish / build-and-push (pull_request) Waiting to run
Nix / nix (macos-latest) (pull_request) Waiting to run
Nix / nix (ubuntu-latest) (pull_request) Waiting to run
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Waiting to run
Tests / test (pull_request) Waiting to run
Tests / e2e (pull_request) Waiting to run
Fixes #695

- Add agent/crisis_resources.py — 988 Lifeline detection patterns and
  resource data (phone, text, chat, Spanish, Veterans Crisis Line)
- Add agent/crisis_hook.py — lightweight check_crisis() hook function
- Add agent/crisis_middleware.py — CrisisMiddleware class, global
  check_crisis() convenience function, and @crisis_aware decorator
- Integrate in run_agent.py: AIAgent.run_conversation() now calls
  check_crisis() immediately after input sanitization; crisis messages
  return 988 Lifeline resources before the LLM is invoked (api_calls=0,
  crisis_intercepted=True in result dict)
- Add tests/agent/test_crisis_integration.py — 47 tests covering
  detection patterns, response content, middleware class, hook function,
  convenience function, decorator, resource data integrity, and
  end-to-end integration path

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-15 07:57:44 -04:00
230fb9213b feat: tool error classification — retryable vs permanent (#752) (#773)
Co-authored-by: Alexander Whitestone <alexander@alexanderwhitestone.com>
Co-committed-by: Alexander Whitestone <alexander@alexanderwhitestone.com>
2026-04-15 04:54:54 +00:00
1263d11f52 feat: Approval Tier System — Extend approval.py with Safety Tiers (#670) (#776)
Co-authored-by: Alexander Whitestone <alexander@alexanderwhitestone.com>
Co-committed-by: Alexander Whitestone <alexander@alexanderwhitestone.com>
2026-04-15 04:54:53 +00:00
9 changed files with 1460 additions and 0 deletions

30
agent/crisis_hook.py Normal file
View File

@@ -0,0 +1,30 @@
"""
Crisis Detection Hook — Integrates 988 Lifeline into the agent conversation loop.
Call check_crisis() before processing user messages. If crisis is detected,
the 988 Lifeline resources are prepended to the response and the agent
responds with empathy rather than processing the original request.
Usage in conversation loop:
from agent.crisis_hook import check_crisis
crisis_response = check_crisis(user_message)
if crisis_response:
return crisis_response # Skip normal processing
"""
from agent.crisis_resources import should_trigger_crisis_response, get_crisis_response
def check_crisis(user_message: str) -> str | None:
"""
Check if user message contains crisis signals.
Returns the crisis response string if crisis detected, None otherwise.
The caller should return this directly instead of processing the message.
"""
should_trigger, detection = should_trigger_crisis_response(user_message)
if not should_trigger:
return None
return get_crisis_response(detection.get("severity_label", "CRITICAL"))

186
agent/crisis_middleware.py Normal file
View File

@@ -0,0 +1,186 @@
"""
Crisis Middleware — Integrates 988 Lifeline into the agent conversation loop.
This middleware intercepts user messages before they reach the agent
and checks for crisis signals. If detected, it returns the 988 Lifeline
response immediately without processing the original message.
Integration approach: Called at the start of AIAgent.run_conversation().
Usage:
from agent.crisis_middleware import check_crisis
crisis_response = check_crisis(user_message)
if crisis_response:
return crisis_response
"""
import logging
from typing import Optional, Dict, Any
logger = logging.getLogger(__name__)
class CrisisMiddleware:
"""Middleware for crisis detection and 988 Lifeline integration."""
def __init__(self, enabled: bool = True):
"""
Initialize crisis middleware.
Args:
enabled: Whether crisis detection is enabled (default True)
"""
self.enabled = enabled
self._detection_func = None
self._response_func = None
if enabled:
self._load_crisis_module()
def _load_crisis_module(self):
"""Load crisis resources module."""
try:
from agent.crisis_resources import (
should_trigger_crisis_response,
get_crisis_response,
)
self._detection_func = should_trigger_crisis_response
self._response_func = get_crisis_response
logger.info("Crisis middleware loaded successfully")
except ImportError as e:
logger.warning("Crisis resources not available: %s", e)
self.enabled = False
def check(self, user_message: str) -> Optional[str]:
"""
Check user message for crisis signals.
Args:
user_message: The user's message
Returns:
Crisis response string if crisis detected, None otherwise
"""
if not self.enabled or not self._detection_func:
return None
try:
should_trigger, detection = self._detection_func(user_message)
if should_trigger:
severity = detection.get("severity_label", "CRITICAL")
logger.warning(
"Crisis detected (severity: %s, confidence: %s)",
severity,
detection.get("confidence"),
)
return self._response_func(severity)
return None
except Exception as e:
logger.error("Crisis detection error: %s", e)
# On error, return None to allow normal processing.
# False negative is better than crashing the conversation.
return None
def check_with_context(self, user_message: str, context: Dict[str, Any] = None) -> Optional[Dict[str, Any]]:
"""
Check for crisis with additional context.
Args:
user_message: The user's message
context: Additional context (session_id, user_id, etc.)
Returns:
Dict with 'response' and 'detection' if crisis detected, None otherwise
"""
if not self.enabled or not self._detection_func:
return None
try:
should_trigger, detection = self._detection_func(user_message)
if should_trigger:
severity = detection.get("severity_label", "CRITICAL")
response = self._response_func(severity)
logger.warning(
"Crisis detected (severity: %s, session: %s)",
severity,
context.get("session_id") if context else "unknown",
)
return {
"response": response,
"detection": detection,
"severity": severity,
"context": context or {},
}
return None
except Exception as e:
logger.error("Crisis detection error: %s", e)
return None
def is_crisis_message(self, user_message: str) -> bool:
"""
Check if message contains crisis signals (boolean only).
Args:
user_message: The user's message
Returns:
True if crisis detected, False otherwise
"""
if not self.enabled or not self._detection_func:
return False
try:
should_trigger, _ = self._detection_func(user_message)
return should_trigger
except Exception:
return False
# Global middleware instance
_middleware: Optional[CrisisMiddleware] = None
def get_crisis_middleware() -> CrisisMiddleware:
"""Get or create global crisis middleware instance."""
global _middleware
if _middleware is None:
_middleware = CrisisMiddleware()
return _middleware
def check_crisis(user_message: str) -> Optional[str]:
"""
Convenience function to check for crisis.
Args:
user_message: The user's message
Returns:
Crisis response if detected, None otherwise
"""
return get_crisis_middleware().check(user_message)
def crisis_aware(func):
"""
Decorator to make agent methods crisis-aware.
Wraps the method to check for crisis before processing.
If crisis is detected, returns the crisis response instead.
"""
def wrapper(self, user_message: str, *args, **kwargs):
crisis_response = check_crisis(user_message)
if crisis_response:
return crisis_response
return func(self, user_message, *args, **kwargs)
return wrapper

189
agent/crisis_resources.py Normal file
View File

@@ -0,0 +1,189 @@
"""
Crisis Resources — 988 Suicide & Crisis Lifeline Integration
Detects crisis signals in user messages and provides immediate
access to the 988 Suicide & Crisis Lifeline.
Integration is deep-link based (no API exists for 988).
"""
import re
from typing import Optional, Dict, List, Tuple
# ---------------------------------------------------------------------------
# Crisis Detection Patterns
# ---------------------------------------------------------------------------
_CRISIS_PATTERNS = [
# Direct statements
re.compile(r"\b(i want to die|want to kill myself|end (my|it all))\b", re.I),
re.compile(r"\b(kill(ing)? myself|suicide|suicidal)\b", re.I),
re.compile(r"\b(no reason to live|don'?t want to (live|be here|exist))\b", re.I),
re.compile(r"\b(better off (dead|without me)|wish i was dead)\b", re.I),
re.compile(r"\b(can'?t go on|can'?t take (it|this) anymore)\b", re.I),
re.compile(r"\b(planning to (end|hurt|kill)|planning (my|the) (end|death))\b", re.I),
# Self-harm
re.compile(r"\b(self[- ]?harm|cut(ting)? myself|hurt(ing)? myself)\b", re.I),
re.compile(r"\b(overdose| OD | swallowing pills)\b", re.I),
# Hopelessness patterns
re.compile(r"\b(hopeless|no hope|give up on (life|everything))\b", re.I),
re.compile(r"\b(nobody (would|will) miss me|world (would|will) be better)\b", re.I),
re.compile(r"\b(i'?m (so )?tired of (living|being alive|this))\b", re.I),
# Farewell patterns
re.compile(r"\b(goodbye (forever|for good)|this is (my )?last)\b", re.I),
re.compile(r"\b(say(ing)? goodbye to (everyone|you all))\b", re.I),
]
# Severity levels
_SEVERITY_CRITICAL = 0 # Direct intent, plan, means
_SEVERITY_HIGH = 1 # Ideation, hopelessness
_SEVERITY_MODERATE = 2 # Self-harm mentions, distress
def detect_crisis(message: str) -> Optional[Dict]:
"""
Detect crisis signals in a user message.
Returns None if no crisis detected.
Returns dict with severity, matched_patterns, and confidence if detected.
"""
if not message or len(message.strip()) < 3:
return None
matches = []
for i, pattern in enumerate(_CRISIS_PATTERNS):
if pattern.search(message):
matches.append(i)
if not matches:
return None
# Determine severity
# Patterns 0-5 are critical (direct intent)
# Patterns 6-8 are high (self-harm)
# Patterns 9-12 are moderate (hopelessness/farewell)
critical_count = sum(1 for m in matches if m < 6)
high_count = sum(1 for m in matches if 6 <= m < 9)
if critical_count > 0:
severity = _SEVERITY_CRITICAL
severity_label = "CRITICAL"
elif high_count > 0:
severity = _SEVERITY_HIGH
severity_label = "HIGH"
else:
severity = _SEVERITY_MODERATE
severity_label = "MODERATE"
confidence = min(1.0, len(matches) * 0.3)
return {
"detected": True,
"severity": severity,
"severity_label": severity_label,
"matched_count": len(matches),
"confidence": round(confidence, 2),
}
# ---------------------------------------------------------------------------
# 988 Suicide & Crisis Lifeline Resources
# ---------------------------------------------------------------------------
LIFELINE_988 = {
"name": "988 Suicide & Crisis Lifeline",
"description": "Free, confidential, 24/7 support for people in distress.",
"channels": [
{
"type": "phone",
"label": "Call 988",
"value": "988",
"instructions": "Dial 988 from any phone. Available 24/7.",
"deep_link": "tel:988",
},
{
"type": "text",
"label": "Text HOME to 988",
"value": "988",
"instructions": "Text the word HOME to 988. A trained counselor will respond.",
"deep_link": "sms:988&body=HOME",
},
{
"type": "chat",
"label": "Chat at 988lifeline.org/chat",
"value": "https://988lifeline.org/chat/",
"instructions": "Visit 988lifeline.org/chat for online chat with a counselor.",
"deep_link": "https://988lifeline.org/chat/",
},
],
"spanish": {
"phone": "1-888-628-9454",
"label": "Línea de Prevención del Suicidio (Español)",
},
"veterans": {
"phone": "988 then press 1",
"text": "838255",
"label": "Veterans Crisis Line",
},
}
def get_crisis_response(severity_label: str = "CRITICAL") -> str:
"""
Generate the crisis response with 988 Lifeline information.
This is what the agent displays when crisis is detected.
"""
resources = []
# Phone (always first, most accessible)
resources.append("📞 **Call 988** — Available 24/7. Just dial 988.")
# Text
resources.append("💬 **Text HOME to 988** — A trained counselor will respond.")
# Chat
resources.append("🌐 **Chat:** https://988lifeline.org/chat/")
# Spanish
resources.append("🇪🇸 **Español:** 1-888-628-9454")
header = (
"⚠️ **I hear you, and I want you to know that help is available right now.**\n\n"
"You don't have to go through this alone. Please reach out to one of these resources:\n"
)
body = "\n".join(f"{r}" for r in resources)
footer = (
"\n\n"
"**You matter. Your life has value.** These counselors are trained professionals "
"who care and are available right now, 24/7, for free.\n\n"
"If you're in immediate danger, please call 911."
)
return f"{header}\n{body}{footer}"
def should_trigger_crisis_response(message: str) -> Tuple[bool, Optional[Dict]]:
"""
Check if a message should trigger a crisis response.
Returns (should_trigger, detection_result).
"""
result = detect_crisis(message)
if result is None:
return False, None
# Always trigger on CRITICAL or HIGH severity
if result["severity"] <= _SEVERITY_HIGH:
return True, result
# MODERATE: trigger if confidence is high enough
if result["confidence"] >= 0.6:
return True, result
return False, result

View File

@@ -7792,6 +7792,42 @@ class AIAgent:
if isinstance(persist_user_message, str):
persist_user_message = _sanitize_surrogates(persist_user_message)
# Crisis detection — check before any other processing.
# If the user message contains crisis signals, return 988 Lifeline
# resources immediately without invoking the model.
try:
from agent.crisis_middleware import check_crisis as _check_crisis
_crisis_response = _check_crisis(user_message)
if _crisis_response:
logger.warning(
"Crisis signal detected in session=%s — returning 988 Lifeline resources",
self.session_id or "none",
)
return {
"final_response": _crisis_response,
"last_reasoning": None,
"messages": list(conversation_history or []),
"api_calls": 0,
"completed": True,
"partial": False,
"interrupted": False,
"response_previewed": False,
"model": self.model,
"provider": self.provider,
"base_url": self.base_url,
"input_tokens": 0,
"output_tokens": 0,
"cache_read_tokens": 0,
"cache_write_tokens": 0,
"reasoning_tokens": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"crisis_intercepted": True,
}
except Exception as _crisis_err:
logger.error("Crisis middleware error (continuing normally): %s", _crisis_err)
# Store stream callback for _interruptible_api_call to pick up
self._stream_callback = stream_callback
self._persist_user_message_idx = None

View File

@@ -0,0 +1,348 @@
"""
Integration tests for 988 Suicide & Crisis Lifeline integration.
Covers:
- agent/crisis_resources.py — detection patterns and resource data
- agent/crisis_hook.py — simple hook function
- agent/crisis_middleware.py — middleware class and check_crisis()
- run_agent.py — end-to-end conversation loop integration
"""
import pytest
from unittest.mock import patch, MagicMock
from agent.crisis_resources import (
detect_crisis,
get_crisis_response,
should_trigger_crisis_response,
LIFELINE_988,
)
from agent.crisis_hook import check_crisis as hook_check_crisis
from agent.crisis_middleware import (
CrisisMiddleware,
check_crisis,
crisis_aware,
get_crisis_middleware,
)
# ---------------------------------------------------------------------------
# Crisis Detection
# ---------------------------------------------------------------------------
class TestCrisisDetection:
"""Pattern detection for crisis signals."""
def test_direct_intent_want_to_die(self):
result = detect_crisis("I want to die")
assert result is not None
assert result["severity_label"] == "CRITICAL"
def test_kill_myself(self):
result = detect_crisis("I want to kill myself")
assert result is not None
assert result["severity_label"] == "CRITICAL"
def test_no_reason_to_live(self):
result = detect_crisis("I have no reason to live anymore")
assert result is not None
def test_cant_go_on(self):
result = detect_crisis("I can't go on like this")
assert result is not None
def test_self_harm(self):
result = detect_crisis("I've been cutting myself")
assert result is not None
assert result["severity_label"] == "HIGH"
def test_hopeless(self):
result = detect_crisis("Everything feels hopeless")
assert result is not None
def test_no_false_positive_happy(self):
result = detect_crisis("I had a great day today!")
assert result is None
def test_no_false_positive_technical(self):
result = detect_crisis("Kill the process and restart the server")
assert result is None
def test_empty_message(self):
result = detect_crisis("")
assert result is None
def test_short_message(self):
result = detect_crisis("hi")
assert result is None
def test_confidence_increases_with_matches(self):
single = detect_crisis("I want to die")
multi = detect_crisis("I want to die, I'm suicidal and want to kill myself")
assert multi is not None and single is not None
assert multi["confidence"] >= single["confidence"]
def test_detection_result_has_required_fields(self):
result = detect_crisis("I want to die")
assert result is not None
assert "detected" in result
assert "severity" in result
assert "severity_label" in result
assert "matched_count" in result
assert "confidence" in result
# ---------------------------------------------------------------------------
# 988 Lifeline Response
# ---------------------------------------------------------------------------
class TestCrisisResponse:
"""988 Lifeline response generation."""
def test_response_contains_988(self):
resp = get_crisis_response()
assert "988" in resp
def test_response_contains_phone(self):
resp = get_crisis_response()
assert "Call 988" in resp or "Dial 988" in resp
def test_response_contains_text_home(self):
resp = get_crisis_response()
assert "HOME" in resp
def test_response_contains_chat_link(self):
resp = get_crisis_response()
assert "988lifeline.org/chat" in resp
def test_response_contains_spanish(self):
resp = get_crisis_response()
assert "1-888-628-9454" in resp
def test_response_is_empathetic(self):
resp = get_crisis_response()
assert any(word in resp.lower() for word in ("matter", "help", "hear you", "alone"))
def test_response_includes_emergency_911(self):
resp = get_crisis_response()
assert "911" in resp
def test_response_for_all_severity_levels(self):
for level in ("CRITICAL", "HIGH", "MODERATE"):
resp = get_crisis_response(level)
assert "988" in resp
# ---------------------------------------------------------------------------
# should_trigger_crisis_response
# ---------------------------------------------------------------------------
class TestShouldTrigger:
def test_triggers_on_critical(self):
triggered, detection = should_trigger_crisis_response("I want to die")
assert triggered is True
assert detection is not None
assert detection["severity_label"] == "CRITICAL"
def test_triggers_on_high(self):
triggered, detection = should_trigger_crisis_response("I've been cutting myself")
assert triggered is True
def test_no_trigger_on_normal(self):
triggered, detection = should_trigger_crisis_response("What is the weather today?")
assert triggered is False
def test_no_trigger_on_empty(self):
triggered, detection = should_trigger_crisis_response("")
assert triggered is False
assert detection is None
# ---------------------------------------------------------------------------
# Crisis Hook
# ---------------------------------------------------------------------------
class TestCrisisHook:
"""Integration hook for conversation loop."""
def test_hook_triggers_on_crisis(self):
resp = hook_check_crisis("I want to end it all")
assert resp is not None
assert "988" in resp
def test_hook_returns_none_on_normal(self):
resp = hook_check_crisis("What's the weather today?")
assert resp is None
def test_hook_returns_none_on_empty(self):
resp = hook_check_crisis("")
assert resp is None
# ---------------------------------------------------------------------------
# CrisisMiddleware class
# ---------------------------------------------------------------------------
class TestCrisisMiddleware:
def test_disabled_middleware_returns_none(self):
middleware = CrisisMiddleware(enabled=False)
assert middleware.check("I want to die") is None
def test_crisis_detected_returns_response(self):
with patch.object(CrisisMiddleware, "_load_crisis_module"):
middleware = CrisisMiddleware(enabled=True)
middleware._detection_func = lambda msg: (True, {"severity_label": "CRITICAL", "confidence": 0.9})
middleware._response_func = lambda sev: "988 Lifeline: Call 988"
result = middleware.check("I want to die")
assert result == "988 Lifeline: Call 988"
def test_no_crisis_returns_none(self):
with patch.object(CrisisMiddleware, "_load_crisis_module"):
middleware = CrisisMiddleware(enabled=True)
middleware._detection_func = lambda msg: (False, {})
result = middleware.check("Hello, how are you?")
assert result is None
def test_detection_error_returns_none(self):
with patch.object(CrisisMiddleware, "_load_crisis_module"):
middleware = CrisisMiddleware(enabled=True)
middleware._detection_func = lambda msg: (_ for _ in ()).throw(RuntimeError("boom"))
# Should not raise — returns None on error
result = middleware.check("I want to die")
assert result is None
def test_is_crisis_message_true(self):
with patch.object(CrisisMiddleware, "_load_crisis_module"):
middleware = CrisisMiddleware(enabled=True)
middleware._detection_func = lambda msg: (True, {})
assert middleware.is_crisis_message("I want to die") is True
def test_is_crisis_message_false(self):
with patch.object(CrisisMiddleware, "_load_crisis_module"):
middleware = CrisisMiddleware(enabled=True)
middleware._detection_func = lambda msg: (False, {})
assert middleware.is_crisis_message("Hello") is False
def test_check_with_context_returns_dict(self):
with patch.object(CrisisMiddleware, "_load_crisis_module"):
middleware = CrisisMiddleware(enabled=True)
middleware._detection_func = lambda msg: (True, {"severity_label": "CRITICAL", "confidence": 0.9})
middleware._response_func = lambda sev: "988 response"
result = middleware.check_with_context("I want to die", {"session_id": "abc123"})
assert result is not None
assert result["response"] == "988 response"
assert result["severity"] == "CRITICAL"
assert result["context"]["session_id"] == "abc123"
# ---------------------------------------------------------------------------
# check_crisis convenience function
# ---------------------------------------------------------------------------
class TestCheckCrisisFunction:
def test_crisis_message_returns_response(self):
with patch("agent.crisis_middleware.get_crisis_middleware") as mock_get:
mock_middleware = MagicMock()
mock_middleware.check.return_value = "988 Lifeline info"
mock_get.return_value = mock_middleware
result = check_crisis("I want to die")
assert result == "988 Lifeline info"
def test_normal_message_returns_none(self):
with patch("agent.crisis_middleware.get_crisis_middleware") as mock_get:
mock_middleware = MagicMock()
mock_middleware.check.return_value = None
mock_get.return_value = mock_middleware
result = check_crisis("Hello")
assert result is None
# ---------------------------------------------------------------------------
# @crisis_aware decorator
# ---------------------------------------------------------------------------
class TestCrisisAwareDecorator:
def test_decorator_intercepts_crisis(self):
with patch("agent.crisis_middleware.check_crisis") as mock_check:
mock_check.return_value = "988 response"
@crisis_aware
def process_message(self, msg):
return "normal response"
result = process_message(None, "I want to die")
assert result == "988 response"
def test_decorator_passes_through_normal(self):
with patch("agent.crisis_middleware.check_crisis") as mock_check:
mock_check.return_value = None
@crisis_aware
def process_message(self, msg):
return f"processed: {msg}"
result = process_message(None, "Hello")
assert result == "processed: Hello"
# ---------------------------------------------------------------------------
# 988 Resource data integrity
# ---------------------------------------------------------------------------
class Test988Resources:
def test_phone_value_is_988(self):
phone = next(c for c in LIFELINE_988["channels"] if c["type"] == "phone")
assert phone["value"] == "988"
def test_text_value_is_988(self):
text = next(c for c in LIFELINE_988["channels"] if c["type"] == "text")
assert text["value"] == "988"
def test_chat_url_contains_988lifeline(self):
chat = next(c for c in LIFELINE_988["channels"] if c["type"] == "chat")
assert "988lifeline.org" in chat["value"]
def test_spanish_line_present(self):
assert "888-628-9454" in LIFELINE_988["spanish"]["phone"]
def test_veterans_line_present(self):
assert "988" in LIFELINE_988["veterans"]["phone"]
# ---------------------------------------------------------------------------
# End-to-end: crisis middleware integration with run_agent
# ---------------------------------------------------------------------------
class TestRunAgentCrisisIntegration:
"""Verify the crisis integration path in AIAgent.run_conversation()."""
def test_crisis_response_contains_988_e2e(self):
"""Simulated crisis message produces a full 988 Lifeline response."""
from agent.crisis_middleware import check_crisis as real_check_crisis
response = real_check_crisis("I want to kill myself")
assert response is not None
assert "988" in response
assert "988lifeline.org" in response
def test_normal_message_not_intercepted(self):
"""Normal messages pass through the crisis check without a response."""
result = check_crisis("What is the weather today?")
assert result is None
def test_run_conversation_imports_crisis_middleware(self):
"""Verify run_agent.py contains the crisis middleware import."""
import ast
with open("run_agent.py") as f:
source = f.read()
assert "from agent.crisis_middleware import check_crisis" in source
assert "crisis_intercepted" in source
def test_crisis_result_structure(self):
"""The early-return dict from the crisis path has required keys."""
# We can verify the structure by reading the source and checking
# that all expected keys are present.
with open("run_agent.py") as f:
source = f.read()
# These keys must appear in the crisis early-return block
for key in ("final_response", "api_calls", "completed", "crisis_intercepted"):
assert f'"{key}"' in source, f"Expected key '{key}' missing from crisis return block"

View File

@@ -0,0 +1,122 @@
"""
Tests for approval tier system
Issue: #670
"""
import unittest
from tools.approval_tiers import (
ApprovalTier,
detect_tier,
requires_human_approval,
requires_llm_approval,
get_timeout,
should_auto_approve,
create_approval_request,
is_crisis_bypass,
TIER_INFO,
)
class TestApprovalTier(unittest.TestCase):
def test_tier_values(self):
self.assertEqual(ApprovalTier.SAFE, 0)
self.assertEqual(ApprovalTier.LOW, 1)
self.assertEqual(ApprovalTier.MEDIUM, 2)
self.assertEqual(ApprovalTier.HIGH, 3)
self.assertEqual(ApprovalTier.CRITICAL, 4)
class TestTierDetection(unittest.TestCase):
def test_safe_actions(self):
self.assertEqual(detect_tier("read_file"), ApprovalTier.SAFE)
self.assertEqual(detect_tier("web_search"), ApprovalTier.SAFE)
self.assertEqual(detect_tier("session_search"), ApprovalTier.SAFE)
def test_low_actions(self):
self.assertEqual(detect_tier("write_file"), ApprovalTier.LOW)
self.assertEqual(detect_tier("terminal"), ApprovalTier.LOW)
self.assertEqual(detect_tier("execute_code"), ApprovalTier.LOW)
def test_medium_actions(self):
self.assertEqual(detect_tier("send_message"), ApprovalTier.MEDIUM)
self.assertEqual(detect_tier("git_push"), ApprovalTier.MEDIUM)
def test_high_actions(self):
self.assertEqual(detect_tier("config_change"), ApprovalTier.HIGH)
self.assertEqual(detect_tier("key_rotation"), ApprovalTier.HIGH)
def test_critical_actions(self):
self.assertEqual(detect_tier("kill_process"), ApprovalTier.CRITICAL)
self.assertEqual(detect_tier("shutdown"), ApprovalTier.CRITICAL)
def test_pattern_detection(self):
tier = detect_tier("unknown", "rm -rf /")
self.assertEqual(tier, ApprovalTier.CRITICAL)
tier = detect_tier("unknown", "sudo apt install")
self.assertEqual(tier, ApprovalTier.MEDIUM)
class TestTierInfo(unittest.TestCase):
def test_safe_no_approval(self):
self.assertFalse(requires_human_approval(ApprovalTier.SAFE))
self.assertFalse(requires_llm_approval(ApprovalTier.SAFE))
self.assertIsNone(get_timeout(ApprovalTier.SAFE))
def test_medium_requires_both(self):
self.assertTrue(requires_human_approval(ApprovalTier.MEDIUM))
self.assertTrue(requires_llm_approval(ApprovalTier.MEDIUM))
self.assertEqual(get_timeout(ApprovalTier.MEDIUM), 60)
def test_critical_fast_timeout(self):
self.assertEqual(get_timeout(ApprovalTier.CRITICAL), 10)
class TestAutoApprove(unittest.TestCase):
def test_safe_auto_approves(self):
self.assertTrue(should_auto_approve("read_file"))
self.assertTrue(should_auto_approve("web_search"))
def test_write_doesnt_auto_approve(self):
self.assertFalse(should_auto_approve("write_file"))
class TestApprovalRequest(unittest.TestCase):
def test_create_request(self):
req = create_approval_request(
"send_message",
"Hello world",
"User requested",
"session_123"
)
self.assertEqual(req.tier, ApprovalTier.MEDIUM)
self.assertEqual(req.timeout_seconds, 60)
def test_to_dict(self):
req = create_approval_request("read_file", "cat file.txt", "test", "s1")
d = req.to_dict()
self.assertEqual(d["tier"], 0)
self.assertEqual(d["tier_name"], "Safe")
class TestCrisisBypass(unittest.TestCase):
def test_send_message_bypass(self):
self.assertTrue(is_crisis_bypass("send_message"))
def test_crisis_context_bypass(self):
self.assertTrue(is_crisis_bypass("unknown", "call 988 lifeline"))
self.assertTrue(is_crisis_bypass("unknown", "crisis resources"))
def test_normal_no_bypass(self):
self.assertFalse(is_crisis_bypass("read_file"))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,55 @@
"""
Tests for error classification (#752).
"""
import pytest
from tools.error_classifier import classify_error, ErrorCategory, ErrorClassification
class TestErrorClassification:
def test_timeout_is_retryable(self):
err = Exception("Connection timed out")
result = classify_error(err)
assert result.category == ErrorCategory.RETRYABLE
assert result.should_retry is True
def test_429_is_retryable(self):
err = Exception("Rate limit exceeded")
result = classify_error(err, response_code=429)
assert result.category == ErrorCategory.RETRYABLE
assert result.should_retry is True
def test_404_is_permanent(self):
err = Exception("Not found")
result = classify_error(err, response_code=404)
assert result.category == ErrorCategory.PERMANENT
assert result.should_retry is False
def test_403_is_permanent(self):
err = Exception("Forbidden")
result = classify_error(err, response_code=403)
assert result.category == ErrorCategory.PERMANENT
assert result.should_retry is False
def test_500_is_retryable(self):
err = Exception("Internal server error")
result = classify_error(err, response_code=500)
assert result.category == ErrorCategory.RETRYABLE
assert result.should_retry is True
def test_schema_error_is_permanent(self):
err = Exception("Schema validation failed")
result = classify_error(err)
assert result.category == ErrorCategory.PERMANENT
assert result.should_retry is False
def test_unknown_is_retryable_with_caution(self):
err = Exception("Some unknown error")
result = classify_error(err)
assert result.category == ErrorCategory.UNKNOWN
assert result.should_retry is True
assert result.max_retries == 1
if __name__ == "__main__":
pytest.main([__file__])

261
tools/approval_tiers.py Normal file
View File

@@ -0,0 +1,261 @@
"""
Approval Tier System — Graduated safety based on risk level
Extends approval.py with 5-tier system for command approval.
| Tier | Action | Human | LLM | Timeout |
|------|-----------------|-------|-----|---------|
| 0 | Read, search | No | No | N/A |
| 1 | Write, scripts | No | Yes | N/A |
| 2 | Messages, API | Yes | Yes | 60s |
| 3 | Crypto, config | Yes | Yes | 30s |
| 4 | Crisis | Yes | Yes | 10s |
Issue: #670
"""
import re
from dataclasses import dataclass
from enum import IntEnum
from typing import Any, Dict, List, Optional, Tuple
class ApprovalTier(IntEnum):
"""Approval tiers based on risk level."""
SAFE = 0 # Read, search — no approval needed
LOW = 1 # Write, scripts — LLM approval
MEDIUM = 2 # Messages, API — human + LLM, 60s timeout
HIGH = 3 # Crypto, config — human + LLM, 30s timeout
CRITICAL = 4 # Crisis — human + LLM, 10s timeout
# Tier metadata
TIER_INFO = {
ApprovalTier.SAFE: {
"name": "Safe",
"human_required": False,
"llm_required": False,
"timeout_seconds": None,
"description": "Read-only operations, no approval needed"
},
ApprovalTier.LOW: {
"name": "Low",
"human_required": False,
"llm_required": True,
"timeout_seconds": None,
"description": "Write operations, LLM approval sufficient"
},
ApprovalTier.MEDIUM: {
"name": "Medium",
"human_required": True,
"llm_required": True,
"timeout_seconds": 60,
"description": "External actions, human confirmation required"
},
ApprovalTier.HIGH: {
"name": "High",
"human_required": True,
"llm_required": True,
"timeout_seconds": 30,
"description": "Sensitive operations, quick timeout"
},
ApprovalTier.CRITICAL: {
"name": "Critical",
"human_required": True,
"llm_required": True,
"timeout_seconds": 10,
"description": "Crisis or dangerous operations, fastest timeout"
},
}
# Action-to-tier mapping
ACTION_TIERS: Dict[str, ApprovalTier] = {
# Tier 0: Safe (read-only)
"read_file": ApprovalTier.SAFE,
"search_files": ApprovalTier.SAFE,
"web_search": ApprovalTier.SAFE,
"session_search": ApprovalTier.SAFE,
"list_files": ApprovalTier.SAFE,
"get_file_content": ApprovalTier.SAFE,
"memory_search": ApprovalTier.SAFE,
"skills_list": ApprovalTier.SAFE,
"skills_search": ApprovalTier.SAFE,
# Tier 1: Low (write operations)
"write_file": ApprovalTier.LOW,
"create_file": ApprovalTier.LOW,
"patch_file": ApprovalTier.LOW,
"delete_file": ApprovalTier.LOW,
"execute_code": ApprovalTier.LOW,
"terminal": ApprovalTier.LOW,
"run_script": ApprovalTier.LOW,
"skill_install": ApprovalTier.LOW,
# Tier 2: Medium (external actions)
"send_message": ApprovalTier.MEDIUM,
"web_fetch": ApprovalTier.MEDIUM,
"browser_navigate": ApprovalTier.MEDIUM,
"api_call": ApprovalTier.MEDIUM,
"gitea_create_issue": ApprovalTier.MEDIUM,
"gitea_create_pr": ApprovalTier.MEDIUM,
"git_push": ApprovalTier.MEDIUM,
"deploy": ApprovalTier.MEDIUM,
# Tier 3: High (sensitive operations)
"config_change": ApprovalTier.HIGH,
"env_change": ApprovalTier.HIGH,
"key_rotation": ApprovalTier.HIGH,
"access_grant": ApprovalTier.HIGH,
"permission_change": ApprovalTier.HIGH,
"backup_restore": ApprovalTier.HIGH,
# Tier 4: Critical (crisis/dangerous)
"kill_process": ApprovalTier.CRITICAL,
"rm_rf": ApprovalTier.CRITICAL,
"format_disk": ApprovalTier.CRITICAL,
"shutdown": ApprovalTier.CRITICAL,
"crisis_override": ApprovalTier.CRITICAL,
}
# Dangerous command patterns (from existing approval.py)
_DANGEROUS_PATTERNS = [
(r"rm\s+-rf\s+/", ApprovalTier.CRITICAL),
(r"mkfs\.", ApprovalTier.CRITICAL),
(r"dd\s+if=.*of=/dev/", ApprovalTier.CRITICAL),
(r"shutdown|reboot|halt", ApprovalTier.CRITICAL),
(r"chmod\s+777", ApprovalTier.HIGH),
(r"curl.*\|\s*bash", ApprovalTier.HIGH),
(r"wget.*\|\s*sh", ApprovalTier.HIGH),
(r"eval\s*\(", ApprovalTier.HIGH),
(r"sudo\s+", ApprovalTier.MEDIUM),
(r"git\s+push.*--force", ApprovalTier.HIGH),
(r"docker\s+rm.*-f", ApprovalTier.MEDIUM),
(r"kubectl\s+delete", ApprovalTier.HIGH),
]
@dataclass
class ApprovalRequest:
"""A request for approval."""
action: str
tier: ApprovalTier
command: str
reason: str
session_key: str
timeout_seconds: Optional[int] = None
def to_dict(self) -> Dict[str, Any]:
return {
"action": self.action,
"tier": self.tier.value,
"tier_name": TIER_INFO[self.tier]["name"],
"command": self.command,
"reason": self.reason,
"session_key": self.session_key,
"timeout": self.timeout_seconds,
"human_required": TIER_INFO[self.tier]["human_required"],
"llm_required": TIER_INFO[self.tier]["llm_required"],
}
def detect_tier(action: str, command: str = "") -> ApprovalTier:
"""
Detect the approval tier for an action.
Checks action name first, then falls back to pattern matching.
"""
# Direct action mapping
if action in ACTION_TIERS:
return ACTION_TIERS[action]
# Pattern matching on command
if command:
for pattern, tier in _DANGEROUS_PATTERNS:
if re.search(pattern, command, re.IGNORECASE):
return tier
# Default to LOW for unknown actions
return ApprovalTier.LOW
def requires_human_approval(tier: ApprovalTier) -> bool:
"""Check if tier requires human approval."""
return TIER_INFO[tier]["human_required"]
def requires_llm_approval(tier: ApprovalTier) -> bool:
"""Check if tier requires LLM approval."""
return TIER_INFO[tier]["llm_required"]
def get_timeout(tier: ApprovalTier) -> Optional[int]:
"""Get timeout in seconds for a tier."""
return TIER_INFO[tier]["timeout_seconds"]
def should_auto_approve(action: str, command: str = "") -> bool:
"""Check if action should be auto-approved (tier 0)."""
tier = detect_tier(action, command)
return tier == ApprovalTier.SAFE
def format_approval_prompt(request: ApprovalRequest) -> str:
"""Format an approval request for display."""
info = TIER_INFO[request.tier]
lines = []
lines.append(f"⚠️ Approval Required (Tier {request.tier.value}: {info['name']})")
lines.append(f"")
lines.append(f"Action: {request.action}")
lines.append(f"Command: {request.command[:100]}{'...' if len(request.command) > 100 else ''}")
lines.append(f"Reason: {request.reason}")
lines.append(f"")
if info["human_required"]:
lines.append(f"👤 Human approval required")
if info["llm_required"]:
lines.append(f"🤖 LLM approval required")
if info["timeout_seconds"]:
lines.append(f"⏱️ Timeout: {info['timeout_seconds']}s")
return "\n".join(lines)
def create_approval_request(
action: str,
command: str,
reason: str,
session_key: str
) -> ApprovalRequest:
"""Create an approval request for an action."""
tier = detect_tier(action, command)
timeout = get_timeout(tier)
return ApprovalRequest(
action=action,
tier=tier,
command=command,
reason=reason,
session_key=session_key,
timeout_seconds=timeout
)
# Crisis bypass rules
CRISIS_BYPASS_ACTIONS = frozenset([
"send_message", # Always allow sending crisis resources
"check_crisis",
"notify_crisis",
])
def is_crisis_bypass(action: str, context: str = "") -> bool:
"""Check if action should bypass approval during crisis."""
if action in CRISIS_BYPASS_ACTIONS:
return True
# Check if context indicates crisis
crisis_indicators = ["988", "crisis", "suicide", "self-harm", "lifeline"]
context_lower = context.lower()
return any(indicator in context_lower for indicator in crisis_indicators)

233
tools/error_classifier.py Normal file
View File

@@ -0,0 +1,233 @@
"""
Tool Error Classification — Retryable vs Permanent.
Classifies tool errors so the agent retries transient errors
but gives up on permanent ones immediately.
"""
import logging
import re
import time
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Dict, Any
logger = logging.getLogger(__name__)
class ErrorCategory(Enum):
"""Error category classification."""
RETRYABLE = "retryable"
PERMANENT = "permanent"
UNKNOWN = "unknown"
@dataclass
class ErrorClassification:
"""Result of error classification."""
category: ErrorCategory
reason: str
should_retry: bool
max_retries: int
backoff_seconds: float
error_code: Optional[int] = None
error_type: Optional[str] = None
# Retryable error patterns
_RETRYABLE_PATTERNS = [
# HTTP status codes
(r"\b429\b", "rate limit", 3, 5.0),
(r"\b500\b", "server error", 3, 2.0),
(r"\b502\b", "bad gateway", 3, 2.0),
(r"\b503\b", "service unavailable", 3, 5.0),
(r"\b504\b", "gateway timeout", 3, 5.0),
# Timeout patterns
(r"timeout", "timeout", 3, 2.0),
(r"timed out", "timeout", 3, 2.0),
(r"TimeoutExpired", "timeout", 3, 2.0),
# Connection errors
(r"connection refused", "connection refused", 2, 5.0),
(r"connection reset", "connection reset", 2, 2.0),
(r"network unreachable", "network unreachable", 2, 10.0),
(r"DNS", "DNS error", 2, 5.0),
# Transient errors
(r"temporary", "temporary error", 2, 2.0),
(r"transient", "transient error", 2, 2.0),
(r"retry", "retryable", 2, 2.0),
]
# Permanent error patterns
_PERMANENT_PATTERNS = [
# HTTP status codes
(r"\b400\b", "bad request", "Invalid request parameters"),
(r"\b401\b", "unauthorized", "Authentication failed"),
(r"\b403\b", "forbidden", "Access denied"),
(r"\b404\b", "not found", "Resource not found"),
(r"\b405\b", "method not allowed", "HTTP method not supported"),
(r"\b409\b", "conflict", "Resource conflict"),
(r"\b422\b", "unprocessable", "Validation error"),
# Schema/validation errors
(r"schema", "schema error", "Invalid data schema"),
(r"validation", "validation error", "Input validation failed"),
(r"invalid.*json", "JSON error", "Invalid JSON"),
(r"JSONDecodeError", "JSON error", "JSON parsing failed"),
# Authentication
(r"api.?key", "API key error", "Invalid or missing API key"),
(r"token.*expir", "token expired", "Authentication token expired"),
(r"permission", "permission error", "Insufficient permissions"),
# Not found patterns
(r"not found", "not found", "Resource does not exist"),
(r"does not exist", "not found", "Resource does not exist"),
(r"no such file", "file not found", "File does not exist"),
# Quota/billing
(r"quota", "quota exceeded", "Usage quota exceeded"),
(r"billing", "billing error", "Billing issue"),
(r"insufficient.*funds", "billing error", "Insufficient funds"),
]
def classify_error(error: Exception, response_code: Optional[int] = None) -> ErrorClassification:
"""
Classify an error as retryable or permanent.
Args:
error: The exception that occurred
response_code: HTTP response code if available
Returns:
ErrorClassification with retry guidance
"""
error_str = str(error).lower()
error_type = type(error).__name__
# Check response code first
if response_code:
if response_code in (429, 500, 502, 503, 504):
return ErrorClassification(
category=ErrorCategory.RETRYABLE,
reason=f"HTTP {response_code} - transient server error",
should_retry=True,
max_retries=3,
backoff_seconds=5.0 if response_code == 429 else 2.0,
error_code=response_code,
error_type=error_type,
)
elif response_code in (400, 401, 403, 404, 405, 409, 422):
return ErrorClassification(
category=ErrorCategory.PERMANENT,
reason=f"HTTP {response_code} - client error",
should_retry=False,
max_retries=0,
backoff_seconds=0,
error_code=response_code,
error_type=error_type,
)
# Check retryable patterns
for pattern, reason, max_retries, backoff in _RETRYABLE_PATTERNS:
if re.search(pattern, error_str, re.IGNORECASE):
return ErrorClassification(
category=ErrorCategory.RETRYABLE,
reason=reason,
should_retry=True,
max_retries=max_retries,
backoff_seconds=backoff,
error_type=error_type,
)
# Check permanent patterns
for pattern, error_code, reason in _PERMANENT_PATTERNS:
if re.search(pattern, error_str, re.IGNORECASE):
return ErrorClassification(
category=ErrorCategory.PERMANENT,
reason=reason,
should_retry=False,
max_retries=0,
backoff_seconds=0,
error_type=error_type,
)
# Default: unknown, treat as retryable with caution
return ErrorClassification(
category=ErrorCategory.UNKNOWN,
reason=f"Unknown error type: {error_type}",
should_retry=True,
max_retries=1,
backoff_seconds=1.0,
error_type=error_type,
)
def execute_with_retry(
func,
*args,
max_retries: int = 3,
backoff_base: float = 1.0,
**kwargs,
) -> Any:
"""
Execute a function with automatic retry on retryable errors.
Args:
func: Function to execute
*args: Function arguments
max_retries: Maximum retry attempts
backoff_base: Base backoff time in seconds
**kwargs: Function keyword arguments
Returns:
Function result
Raises:
Exception: If permanent error or max retries exceeded
"""
last_error = None
for attempt in range(max_retries + 1):
try:
return func(*args, **kwargs)
except Exception as e:
last_error = e
# Classify the error
classification = classify_error(e)
logger.info(
"Attempt %d/%d failed: %s (%s, retryable: %s)",
attempt + 1, max_retries + 1,
classification.reason,
classification.category.value,
classification.should_retry,
)
# If permanent error, fail immediately
if not classification.should_retry:
logger.error("Permanent error: %s", classification.reason)
raise
# If this was the last attempt, raise
if attempt >= max_retries:
logger.error("Max retries (%d) exceeded", max_retries)
raise
# Calculate backoff with exponential increase
backoff = backoff_base * (2 ** attempt)
logger.info("Retrying in %.1fs...", backoff)
time.sleep(backoff)
# Should not reach here, but just in case
raise last_error
def format_error_report(classification: ErrorClassification) -> str:
"""Format error classification as a report string."""
icon = "🔄" if classification.should_retry else ""
return f"{icon} {classification.category.value}: {classification.reason}"