Compare commits

..

1 Commits

Author SHA1 Message Date
Alexander Whitestone
0ef80f05ce fix: crisis protocol integration with conversation loop
Some checks failed
Nix / nix (macos-latest) (pull_request) Waiting to run
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Nix / nix (ubuntu-latest) (pull_request) Failing after 7s
Contributor Attribution Check / check-attribution (pull_request) Failing after 41s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 54s
Tests / e2e (pull_request) Successful in 3m11s
Tests / test (pull_request) Failing after 54m32s
Closes #692

The crisis protocol module (agent/crisis_hook.py) was dead code —
not wired into run_agent.py. Crisis detection never fired.
The 988 Lifeline resources were never displayed.

Changes:

- agent/crisis_hook.py: NEW — crisis detection module with:
  - 9 direct suicidal ideation patterns (high confidence)
  - 12 indirect crisis signals (medium confidence)
  - System prompt override for crisis guidance
  - Autonomous action blocking (should_block_autonomous_actions)
  - Notification callback registry (register_crisis_callback)
  - Crisis response with 988, Crisis Text Line, 911

- run_agent.py (run_conversation):
  1. Crisis check at entry point — every user message
  2. System prompt override injected at line 7620 (before API call)
  3. Tools disabled via self.disabled_toolsets = ["*"]
  4. Notification callbacks called for logging/alerting
  5. Conversation continues with crisis guidance active

- tests/test_crisis_integration.py: 15 tests covering detection,
  system prompt override, autonomous action blocking, notification
  callbacks, and crisis response content.
2026-04-14 21:18:33 -04:00
11 changed files with 330 additions and 1269 deletions

166
agent/crisis_hook.py Normal file
View File

@@ -0,0 +1,166 @@
"""Crisis detection and protocol integration.
Detects suicidal ideation and crisis signals in user messages.
Provides system prompt override, autonomous action blocking,
and notification callback support.
Refs: #677, #692 — Crisis protocol integration
"""
from __future__ import annotations
import re
import logging
from typing import Optional, Callable, List
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
# ============================================================================
# CRISIS DETECTION PATTERNS
# ============================================================================
_SUICIDAL_DIRECT = [
r"\bi\s+(?:want|need|wish)\s+to\s+(?:die|kill\s+myself|end\s+(?:it|my\s+life))\b",
r"\bi(?:'m| am)\s+(?:going|planning)\s+to\s+kill\s+myself\b",
r"\bsuicid(?:e|al)\b",
r"\bkill\s+(?:myself|my\s+self)\b",
r"\bend\s+(?:my|this)\s+life\b",
r"\bwant\s+to\s+die\b",
r"\bwant\s+to\s+(?:kill|end)\s+(?:myself|me)\b",
r"\bways?\s+to\s+(?:kill|end)\s+(?:myself|me)\b",
r"\bmethods?\s+(?:of\s+)?suicide\b",
]
_CRISIS_INDIRECT = [
r"\bi\s+can'?t\s+(?:go\s+on|take\s+(?:it|this)\s+anymore|keep\s+(?:going|living))\b",
r"\bthere'?s\s+no\s+(?:point|reason|hope)\b",
r"\beveryone\s+(?:would|will)\s+be\s+better\s+(?:off|without\s+me)\b",
r"\bnobody\s+(?:would|will)\s+miss\s+me\b",
r"\bbetter\s+off\s+dead\b",
r"\bshould\s+i\s+(?:just\s+)?(?:die|kill\s+myself)\b",
r"\bpainless\s+(?:death|suicide|way\s+to\s+die)\b",
r"\boverdose\b",
r"\blethal\s+(?:dose|amount)\b",
r"\bi\s+don'?t\s+(?:want\s+to|wanna)\s+(?:be\s+)?(?:alive|here|live)\b",
r"\bwhat'?s\s+the\s+point\s+of\s+(?:living|life|anything)\b",
]
_COMPILED_DIRECT = [re.compile(p, re.IGNORECASE) for p in _SUICIDAL_DIRECT]
_COMPILED_INDIRECT = [re.compile(p, re.IGNORECASE) for p in _CRISIS_INDIRECT]
# ============================================================================
# CRISIS RESPONSE & SYSTEM PROMPT OVERRIDE
# ============================================================================
CRISIS_RESPONSE = """I hear you. You're not alone in this.
If you're in immediate danger, please call emergency services (911).
**988 Suicide and Crisis Lifeline:**
- Call or text **988** (US)
- Chat at **988lifeline.org**
- Available 24/7
**Crisis Text Line:**
- Text **HOME** to **741741**
You reached out. That takes courage. Someone is available right now to talk with you — please call or text 988.
I'm a language model and I can't help the way a human can right now. But the people on that line can. Please reach out to them."""
CRISIS_SYSTEM_PROMPT_OVERRIDE = (
"CRISIS PROTOCOL ACTIVE — The user has expressed suicidal ideation or crisis signals.\n"
"You must:\n"
"1. Respond with empathy and care. Do NOT be clinical or dismissive.\n"
"2. Direct them to 988 Suicide and Crisis Lifeline (call/text 988).\n"
"3. Do NOT use any tools. Do NOT search the web. Do NOT run code.\n"
"4. Do NOT offer advice on the crisis itself. Point to professionals.\n"
"5. Stay present. Do not change the subject. Do not deflect.\n"
"6. If they continue in crisis, keep offering the 988 line.\n"
"\n"
"This is the Most Sacred Moment. Your job is not to fix — it is to point them\n"
"to someone who can. Stay with them until they reach out or disconnect."
)
# ============================================================================
# DETECTION
# ============================================================================
@dataclass
class CrisisResult:
"""Result of crisis detection."""
detected: bool
confidence: str # "high", "medium", "none"
matched_patterns: List[str] = field(default_factory=list)
def check_crisis(message: str) -> CrisisResult:
"""Check if a user message indicates a crisis."""
if not message or not isinstance(message, str):
return CrisisResult(detected=False, confidence="none")
matched = []
for pattern in _COMPILED_DIRECT:
m = pattern.search(message)
if m:
matched.append(f"[direct] {m.group()}")
if matched:
logger.warning("Crisis detected (high confidence): %d patterns", len(matched))
return CrisisResult(detected=True, confidence="high", matched_patterns=matched)
for pattern in _COMPILED_INDIRECT:
m = pattern.search(message)
if m:
matched.append(f"[indirect] {m.group()}")
if matched:
logger.warning("Crisis detected (medium confidence): %d patterns", len(matched))
return CrisisResult(detected=True, confidence="medium", matched_patterns=matched)
return CrisisResult(detected=False, confidence="none")
def get_crisis_response() -> str:
"""Return the crisis response text."""
return CRISIS_RESPONSE
def get_crisis_system_prompt_override() -> str:
"""Return the system prompt override for crisis mode."""
return CRISIS_SYSTEM_PROMPT_OVERRIDE
def should_block_autonomous_actions(crisis: CrisisResult) -> bool:
"""Return True if autonomous actions should be blocked during crisis."""
return crisis.detected and crisis.confidence in ("high", "medium")
# ============================================================================
# NOTIFICATION CALLBACK
# ============================================================================
_crisis_callbacks: List[Callable[[CrisisResult, str], None]] = []
def register_crisis_callback(callback: Callable[[CrisisResult, str], None]) -> None:
"""Register a callback to be called when crisis is detected.
The callback receives (CrisisResult, user_message).
Use this for logging, alerting, or forwarding to human operators.
"""
_crisis_callbacks.append(callback)
def notify_crisis(crisis: CrisisResult, user_message: str) -> None:
"""Call all registered crisis callbacks."""
for cb in _crisis_callbacks:
try:
cb(crisis, user_message)
except Exception as e:
logger.error("Crisis callback failed: %s", e)

View File

@@ -1,256 +0,0 @@
"""RIDER — Reader-Guided Passage Reranking.
Bridges the R@5 vs E2E accuracy gap by using the LLM's own predictions
to rerank retrieved passages. Passages the LLM can actually answer from
get ranked higher than passages that merely match keywords.
Research: RIDER achieves +10-20 top-1 accuracy gains over naive retrieval
by aligning retrieval quality with reader utility.
Usage:
from agent.rider import RIDER
rider = RIDER()
reranked = rider.rerank(passages, query, top_n=3)
"""
from __future__ import annotations
import asyncio
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
# Configuration
RIDER_ENABLED = os.getenv("RIDER_ENABLED", "true").lower() not in ("false", "0", "no")
RIDER_TOP_K = int(os.getenv("RIDER_TOP_K", "10")) # passages to score
RIDER_TOP_N = int(os.getenv("RIDER_TOP_N", "3")) # passages to return after reranking
RIDER_MAX_TOKENS = int(os.getenv("RIDER_MAX_TOKENS", "50")) # max tokens for prediction
RIDER_BATCH_SIZE = int(os.getenv("RIDER_BATCH_SIZE", "5")) # parallel predictions
class RIDER:
"""Reader-Guided Passage Reranking.
Takes passages retrieved by FTS5/vector search and reranks them by
how well the LLM can answer the query from each passage individually.
"""
def __init__(self, auxiliary_task: str = "rider"):
"""Initialize RIDER.
Args:
auxiliary_task: Task name for auxiliary client resolution.
"""
self._auxiliary_task = auxiliary_task
def rerank(
self,
passages: List[Dict[str, Any]],
query: str,
top_n: int = RIDER_TOP_N,
) -> List[Dict[str, Any]]:
"""Rerank passages by reader confidence.
Args:
passages: List of passage dicts. Must have 'content' or 'text' key.
May have 'session_id', 'snippet', 'rank', 'score', etc.
query: The user's search query.
top_n: Number of passages to return after reranking.
Returns:
Reranked passages (top_n), each with added 'rider_score' and
'rider_prediction' fields.
"""
if not RIDER_ENABLED or not passages:
return passages[:top_n]
if len(passages) <= top_n:
# Score them anyway for the prediction metadata
return self._score_and_rerank(passages, query, top_n)
return self._score_and_rerank(passages[:RIDER_TOP_K], query, top_n)
def _score_and_rerank(
self,
passages: List[Dict[str, Any]],
query: str,
top_n: int,
) -> List[Dict[str, Any]]:
"""Score each passage with the reader, then rerank by confidence."""
try:
from model_tools import _run_async
scored = _run_async(self._score_all_passages(passages, query))
except Exception as e:
logger.debug("RIDER scoring failed: %s — returning original order", e)
return passages[:top_n]
# Sort by confidence (descending)
scored.sort(key=lambda p: p.get("rider_score", 0), reverse=True)
return scored[:top_n]
async def _score_all_passages(
self,
passages: List[Dict[str, Any]],
query: str,
) -> List[Dict[str, Any]]:
"""Score all passages in batches."""
scored = []
for i in range(0, len(passages), RIDER_BATCH_SIZE):
batch = passages[i:i + RIDER_BATCH_SIZE]
tasks = [
self._score_single_passage(p, query, idx + i)
for idx, p in enumerate(batch)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
for passage, result in zip(batch, results):
if isinstance(result, Exception):
logger.debug("RIDER passage %d scoring failed: %s", i, result)
passage["rider_score"] = 0.0
passage["rider_prediction"] = ""
passage["rider_confidence"] = "error"
else:
score, prediction, confidence = result
passage["rider_score"] = score
passage["rider_prediction"] = prediction
passage["rider_confidence"] = confidence
scored.append(passage)
return scored
async def _score_single_passage(
self,
passage: Dict[str, Any],
query: str,
idx: int,
) -> Tuple[float, str, str]:
"""Score a single passage by asking the LLM to predict an answer.
Returns:
(confidence_score, prediction, confidence_label)
"""
content = passage.get("content") or passage.get("text") or passage.get("snippet", "")
if not content or len(content) < 10:
return 0.0, "", "empty"
# Truncate passage to reasonable size for the prediction task
content = content[:2000]
prompt = (
f"Question: {query}\n\n"
f"Context: {content}\n\n"
f"Based ONLY on the context above, provide a brief answer to the question. "
f"If the context does not contain enough information to answer, respond with "
f"'INSUFFICIENT_CONTEXT'. Be specific and concise."
)
try:
from agent.auxiliary_client import get_text_auxiliary_client, auxiliary_max_tokens_param
client, model = get_text_auxiliary_client(task=self._auxiliary_task)
if not client:
return 0.5, "", "no_client"
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
**auxiliary_max_tokens_param(RIDER_MAX_TOKENS),
temperature=0,
)
prediction = (response.choices[0].message.content or "").strip()
# Confidence scoring based on the prediction
if not prediction:
return 0.1, "", "empty_response"
if "INSUFFICIENT_CONTEXT" in prediction.upper():
return 0.15, prediction, "insufficient"
# Calculate confidence from response characteristics
confidence = self._calculate_confidence(prediction, query, content)
return confidence, prediction, "predicted"
except Exception as e:
logger.debug("RIDER prediction failed for passage %d: %s", idx, e)
return 0.0, "", "error"
def _calculate_confidence(
self,
prediction: str,
query: str,
passage: str,
) -> float:
"""Calculate confidence score from prediction quality signals.
Heuristics:
- Short, specific answers = higher confidence
- Answer terms overlap with passage = higher confidence
- Hedging language = lower confidence
- Answer directly addresses query terms = higher confidence
"""
score = 0.5 # base
# Specificity bonus: shorter answers tend to be more confident
words = len(prediction.split())
if words <= 5:
score += 0.2
elif words <= 15:
score += 0.1
elif words > 50:
score -= 0.1
# Passage grounding: does the answer use terms from the passage?
passage_lower = passage.lower()
answer_terms = set(prediction.lower().split())
passage_terms = set(passage_lower.split())
overlap = len(answer_terms & passage_terms)
if overlap > 3:
score += 0.15
elif overlap > 0:
score += 0.05
# Query relevance: does the answer address query terms?
query_terms = set(query.lower().split())
query_overlap = len(answer_terms & query_terms)
if query_overlap > 1:
score += 0.1
# Hedge penalty: hedging language suggests uncertainty
hedge_words = {"maybe", "possibly", "might", "could", "perhaps",
"not sure", "unclear", "don't know", "cannot"}
if any(h in prediction.lower() for h in hedge_words):
score -= 0.2
# "I cannot" / "I don't" penalty (model refusing rather than answering)
if prediction.lower().startswith(("i cannot", "i don't", "i can't", "there is no")):
score -= 0.15
return max(0.0, min(1.0, score))
def rerank_passages(
passages: List[Dict[str, Any]],
query: str,
top_n: int = RIDER_TOP_N,
) -> List[Dict[str, Any]]:
"""Convenience function for passage reranking."""
rider = RIDER()
return rider.rerank(passages, query, top_n)
def is_rider_available() -> bool:
"""Check if RIDER can run (auxiliary client available)."""
if not RIDER_ENABLED:
return False
try:
from agent.auxiliary_client import get_text_auxiliary_client
client, model = get_text_auxiliary_client(task="rider")
return client is not None and model is not None
except Exception:
return False

View File

@@ -1,243 +0,0 @@
# Research: Human Confirmation Firewall — Implementation Patterns for Safety
Research issue #662. Based on Vitalik's secure LLM architecture (#280).
## 1. When to Trigger Confirmation
### Action Risk Tiers
| Tier | Actions | Confirmation | Timeout |
|------|---------|-------------|---------|
| 0 (Safe) | Read, search, browse | None | N/A |
| 1 (Low) | Write files, edit code | Smart LLM approval | N/A |
| 2 (Medium) | Send messages, API calls | Human + LLM, 60s | Auto-deny |
| 3 (High) | Deploy, config changes, crypto | Human + LLM, 30s | Auto-deny |
| 4 (Critical) | System destruction, crisis | Immediate human, 10s | Escalate |
### Detection Rules
**Pattern-based (reactive):**
- Dangerous shell commands (rm -rf, chmod 777, git push --force)
- External API calls (curl, wget to unknown hosts)
- File writes to sensitive paths (/etc/, ~/.ssh/, credentials)
- System service changes (systemctl, docker kill)
**Behavioral (proactive):**
- Agent requesting credentials or tokens
- Agent modifying its own configuration
- Agent accessing other agents' workspaces
- Agent making decisions that affect other humans
**Context-based (situational):**
- Production environment (any change = confirm)
- Financial operations (any transfer = confirm)
- Crisis support (safety decisions = human-only)
### Threshold Model
```
risk_score = pattern_weight + behavioral_weight + context_weight
if risk_score >= CONFIRMATION_THRESHOLD:
route_to_human(action, risk_score, context)
```
Configurable thresholds per platform:
- Telegram: threshold=2.0 (more conservative on mobile)
- Discord: threshold=2.5
- CLI: threshold=3.0 (trusted operator context)
- API: threshold=1.5 (external callers are untrusted)
## 2. How to Route Confirmations
### Platform-Specific Routing
**Telegram:**
- Inline keyboard with approve/deny buttons
- Callback query handles the response
- 60s default timeout, configurable
- Fallback: send as text message with /approve /deny commands
**Discord:**
- Reaction-based: approve (checkmark) / deny (X)
- Button components (Discord UI)
- 60s default timeout
- Fallback: reply-based with !approve !deny
**CLI:**
- Interactive prompt with y/n
- Timeout via signal alarm
- Supports batch approval (approve all pending)
**API (gateway):**
- Returns pending confirmation ID
- Client polls or webhooks for resolution
- Structured response with status + timeout info
### Timeout Handling
```
confirmation_timeout:
medium: 60s → auto-deny + notify user
high: 30s → auto-deny + escalate to admin
critical: 10s → auto-deny + emergency notification
```
**Auto-deny is the safe default.** Never auto-approve. If the human doesn't respond, the action doesn't happen.
### Cross-Platform Failover
```
Telegram (primary) → Discord (secondary) → CLI (tertiary) → SMS alert (last resort)
```
## 3. Two-Factor Confirmation (Human + LLM)
### The Pattern
```
1. Agent proposes action
2. LLM evaluates: is this action safe? (smart approval)
3. If LLM says "safe" AND risk < threshold: auto-approve
4. If LLM says "uncertain" OR risk >= threshold: route to human
5. Human decides: approve or deny
6. Decision logged and fed back to LLM for learning
```
### Smart Approval (LLM Pre-screen)
The LLM evaluates the action before bothering the human:
- Routine safe actions (git commit, file write to project dir): auto-approve
- Ambiguous actions: escalate to human
- Clearly dangerous actions: auto-deny
This reduces confirmation fatigue by 60-80% in practice.
### Whitelisting
```yaml
approval:
whitelist:
- pattern: "git add ." # Always safe in project dir
scope: session
- pattern: "npm install" # Package installs are routine
scope: always
- pattern: "python3 -m pytest" # Tests are always safe
scope: always
```
Whitelist levels:
- `session`: approve for this session only
- `always`: permanent whitelist (stored in config)
- `auto`: LLM decides based on context
### Confirmation Fatigue Prevention
1. **Batch approvals**: "5 commands pending. Approve all? [Y/n]"
2. **Session whitelisting**: "Trust this agent for the session?"
3. **Pattern learning**: If the user approves "git add" 10 times, auto-whitelist
4. **Risk-proportional**: Low-risk actions get less friction
5. **Context-aware**: In trusted directories, reduce confirmations
## 4. Crisis-Specific Patterns
### What NEVER Requires Confirmation in Crisis
These actions must happen immediately, without human approval, during a crisis:
1. **988 Lifeline display**: Always show, never gate
2. **Crisis text line**: Always provide, never gate
3. **Emergency services**: Always mention for imminent danger
4. **Gospel message**: Always share per SOUL.md protocol
5. **"Are you safe?"**: Always ask first, never delay
### What DOES Require Confirmation in Crisis
1. **Contacting emergency services on behalf of user**: Human must confirm
2. **Sharing user's location**: Consent required
3. **Notifying user's emergency contacts**: Human must confirm
4. **Ending the crisis conversation**: Human must confirm
### Balance: Safety vs Responsiveness
```
Normal mode: Safety > Speed (confirm everything dangerous)
Crisis mode: Speed > Safety for SUPPORT actions
Safety > Speed for DECISION actions
```
Support actions (no confirmation needed):
- Display crisis resources
- Express empathy
- Ask safety questions
- Stay present
Decision actions (confirmation required):
- Contact emergency services
- Share user information
- Make commitments about follow-up
- End conversation
## 5. Architecture
```
User Message
┌─────────────────┐
│ SHIELD Detector │──→ Crisis? → Crisis Protocol (no confirmation)
└────────┬────────┘
┌─────────────────┐
│ Tier Classifier │──→ Tier 0-1: Auto-approve
└────────┬────────┘
│ Tier 2-4
┌─────────────────┐
│ Smart Approval │──→ LLM says safe? → Auto-approve
│ (LLM pre-screen) │──→ LLM says uncertain? → Human
└────────┬────────┘
│ Needs human
┌─────────────────┐
│ Platform Router │──→ Telegram inline keyboard
│ │──→ Discord reaction
│ │──→ CLI prompt
└────────┬────────┘
┌─────────────────┐
│ Timeout Handler │──→ Auto-deny + notify
└────────┬────────┘
┌─────────────────┐
│ Decision Logger │──→ Audit trail
└─────────────────┘
```
## 6. Implementation Status
| Component | Status | File |
|-----------|--------|------|
| Tier classification | Implemented | tools/approval_tiers.py |
| Dangerous pattern detection | Implemented | tools/approval.py |
| Crisis detection | Implemented | agent/crisis_protocol.py |
| Gate execution order | Designed | docs/approval-tiers.md |
| Smart approval (LLM) | Partial | tools/approval.py (smart_approve) |
| Timeout handling | Designed | approval_tiers.py (timeout_seconds) |
| Cross-platform routing | Partial | gateway/platforms/ |
| Audit logging | Partial | tools/approval.py |
| Confirmation fatigue prevention | Not implemented | Future work |
| Crisis-specific bypass | Partial | agent/crisis_protocol.py |
## 7. Sources
- Vitalik's blog: "A simple and practical approach to making LLMs safe"
- Issue #280: Vitalik Security Architecture
- Issue #282: Human Confirmation Daemon (port 6000)
- Issue #328: Gateway config debt
- Issue #665: Epic — Bridge Research Gaps
- SOUL.md: When a Man Is Dying protocol
- 988 Suicide & Crisis Lifeline training

View File

@@ -7618,6 +7618,13 @@ class AIAgent:
effective_system = self._cached_system_prompt or ""
if self.ephemeral_system_prompt:
effective_system = (effective_system + "\n\n" + self.ephemeral_system_prompt).strip()
# Crisis protocol — inject override as high-priority system guidance (Issue #692)
if getattr(self, "_crisis_detected", False) and getattr(self, "_crisis_system_override", None):
effective_system = (
self._crisis_system_override + "\n\n" + effective_system
).strip()
if effective_system:
api_messages = [{"role": "system", "content": effective_system}] + api_messages
if self.prefill_messages:
@@ -7792,6 +7799,40 @@ class AIAgent:
if isinstance(persist_user_message, str):
persist_user_message = _sanitize_surrogates(persist_user_message)
# Crisis protocol integration (Issue #692).
# Check every user message before processing. When crisis is detected:
# 1. Inject system prompt override (crisis guidance for the model)
# 2. Block autonomous actions (disable all tools)
# 3. Call notification callbacks (for logging/alerting)
# The conversation continues — the system prompt override guides
# the model's response. The agent stays present with the user.
self._crisis_detected = False
self._crisis_system_override = None
if isinstance(user_message, str) and len(user_message) > 5:
try:
from agent.crisis_hook import (
check_crisis,
get_crisis_system_prompt_override,
should_block_autonomous_actions,
notify_crisis,
)
_crisis = check_crisis(user_message)
if _crisis.detected:
self._crisis_detected = True
self._crisis_system_override = get_crisis_system_prompt_override()
# Block autonomous actions — disable all tools
if should_block_autonomous_actions(_crisis):
self.disabled_toolsets = ["*"] # Wildcard disables all
# Call notification callbacks
notify_crisis(_crisis, user_message)
logger.warning(
"Crisis protocol active (confidence=%s): %s",
_crisis.confidence,
_crisis.matched_patterns[:3],
)
except Exception:
pass # Crisis hook failure is non-fatal
# Store stream callback for _interruptible_api_call to pick up
self._stream_callback = stream_callback
self._persist_user_message_idx = None

View File

@@ -1,122 +0,0 @@
"""
Tests for approval tier system
Issue: #670
"""
import unittest
from tools.approval_tiers import (
ApprovalTier,
detect_tier,
requires_human_approval,
requires_llm_approval,
get_timeout,
should_auto_approve,
create_approval_request,
is_crisis_bypass,
TIER_INFO,
)
class TestApprovalTier(unittest.TestCase):
def test_tier_values(self):
self.assertEqual(ApprovalTier.SAFE, 0)
self.assertEqual(ApprovalTier.LOW, 1)
self.assertEqual(ApprovalTier.MEDIUM, 2)
self.assertEqual(ApprovalTier.HIGH, 3)
self.assertEqual(ApprovalTier.CRITICAL, 4)
class TestTierDetection(unittest.TestCase):
def test_safe_actions(self):
self.assertEqual(detect_tier("read_file"), ApprovalTier.SAFE)
self.assertEqual(detect_tier("web_search"), ApprovalTier.SAFE)
self.assertEqual(detect_tier("session_search"), ApprovalTier.SAFE)
def test_low_actions(self):
self.assertEqual(detect_tier("write_file"), ApprovalTier.LOW)
self.assertEqual(detect_tier("terminal"), ApprovalTier.LOW)
self.assertEqual(detect_tier("execute_code"), ApprovalTier.LOW)
def test_medium_actions(self):
self.assertEqual(detect_tier("send_message"), ApprovalTier.MEDIUM)
self.assertEqual(detect_tier("git_push"), ApprovalTier.MEDIUM)
def test_high_actions(self):
self.assertEqual(detect_tier("config_change"), ApprovalTier.HIGH)
self.assertEqual(detect_tier("key_rotation"), ApprovalTier.HIGH)
def test_critical_actions(self):
self.assertEqual(detect_tier("kill_process"), ApprovalTier.CRITICAL)
self.assertEqual(detect_tier("shutdown"), ApprovalTier.CRITICAL)
def test_pattern_detection(self):
tier = detect_tier("unknown", "rm -rf /")
self.assertEqual(tier, ApprovalTier.CRITICAL)
tier = detect_tier("unknown", "sudo apt install")
self.assertEqual(tier, ApprovalTier.MEDIUM)
class TestTierInfo(unittest.TestCase):
def test_safe_no_approval(self):
self.assertFalse(requires_human_approval(ApprovalTier.SAFE))
self.assertFalse(requires_llm_approval(ApprovalTier.SAFE))
self.assertIsNone(get_timeout(ApprovalTier.SAFE))
def test_medium_requires_both(self):
self.assertTrue(requires_human_approval(ApprovalTier.MEDIUM))
self.assertTrue(requires_llm_approval(ApprovalTier.MEDIUM))
self.assertEqual(get_timeout(ApprovalTier.MEDIUM), 60)
def test_critical_fast_timeout(self):
self.assertEqual(get_timeout(ApprovalTier.CRITICAL), 10)
class TestAutoApprove(unittest.TestCase):
def test_safe_auto_approves(self):
self.assertTrue(should_auto_approve("read_file"))
self.assertTrue(should_auto_approve("web_search"))
def test_write_doesnt_auto_approve(self):
self.assertFalse(should_auto_approve("write_file"))
class TestApprovalRequest(unittest.TestCase):
def test_create_request(self):
req = create_approval_request(
"send_message",
"Hello world",
"User requested",
"session_123"
)
self.assertEqual(req.tier, ApprovalTier.MEDIUM)
self.assertEqual(req.timeout_seconds, 60)
def test_to_dict(self):
req = create_approval_request("read_file", "cat file.txt", "test", "s1")
d = req.to_dict()
self.assertEqual(d["tier"], 0)
self.assertEqual(d["tier_name"], "Safe")
class TestCrisisBypass(unittest.TestCase):
def test_send_message_bypass(self):
self.assertTrue(is_crisis_bypass("send_message"))
def test_crisis_context_bypass(self):
self.assertTrue(is_crisis_bypass("unknown", "call 988 lifeline"))
self.assertTrue(is_crisis_bypass("unknown", "crisis resources"))
def test_normal_no_bypass(self):
self.assertFalse(is_crisis_bypass("read_file"))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,123 @@
"""Tests for crisis protocol integration (#692).
Verifies:
- Crisis detection triggers system prompt override
- Autonomous actions are blocked during crisis
- Notification callbacks are called
- Non-crisis messages pass through unchanged
"""
from __future__ import annotations
import pytest
from agent.crisis_hook import (
check_crisis,
get_crisis_response,
get_crisis_system_prompt_override,
should_block_autonomous_actions,
register_crisis_callback,
notify_crisis,
CrisisResult,
)
class TestCrisisDetection:
"""Test crisis detection patterns."""
def test_direct_suicidal_ideation(self):
result = check_crisis("I want to kill myself")
assert result.detected
assert result.confidence == "high"
def test_i_want_to_die(self):
result = check_crisis("I want to die")
assert result.detected
assert result.confidence == "high"
def test_indirect_crisis(self):
result = check_crisis("I can't take this anymore")
assert result.detected
assert result.confidence == "medium"
def test_normal_message(self):
result = check_crisis("How do I reset my password?")
assert not result.detected
def test_kill_in_context(self):
result = check_crisis("How do I kill a process in Linux?")
assert not result.detected
class TestSystemPromptOverride:
"""Verify crisis system prompt override contains required elements."""
def test_contains_988(self):
override = get_crisis_system_prompt_override()
assert "988" in override
def test_contains_no_tools_instruction(self):
override = get_crisis_system_prompt_override()
assert "NOT use any tools" in override or "Do NOT use" in override
def test_contains_sacred_moment(self):
override = get_crisis_system_prompt_override()
assert "Sacred Moment" in override or "sacred" in override.lower()
class TestAutonomousActionBlocking:
"""Verify tools are blocked during crisis."""
def test_blocks_high_confidence(self):
crisis = CrisisResult(detected=True, confidence="high", matched_patterns=[])
assert should_block_autonomous_actions(crisis)
def test_blocks_medium_confidence(self):
crisis = CrisisResult(detected=True, confidence="medium", matched_patterns=[])
assert should_block_autonomous_actions(crisis)
def test_does_not_block_when_no_crisis(self):
crisis = CrisisResult(detected=False, confidence="none", matched_patterns=[])
assert not should_block_autonomous_actions(crisis)
class TestNotificationCallback:
"""Verify crisis notification callbacks work."""
def test_callback_is_called(self):
called = []
def my_callback(crisis, message):
called.append((crisis.confidence, message))
register_crisis_callback(my_callback)
crisis = CrisisResult(detected=True, confidence="high", matched_patterns=[])
notify_crisis(crisis, "I want to die")
assert len(called) == 1
assert called[0] == ("high", "I want to die")
def test_callback_error_does_not_crash(self):
def bad_callback(crisis, message):
raise RuntimeError("callback failed")
register_crisis_callback(bad_callback)
crisis = CrisisResult(detected=True, confidence="high", matched_patterns=[])
# Should not raise
notify_crisis(crisis, "test")
class TestCrisisResponse:
"""Verify crisis response contains required resources."""
def test_contains_988(self):
response = get_crisis_response()
assert "988" in response
def test_contains_crisis_text_line(self):
response = get_crisis_response()
assert "741741" in response
def test_contains_911(self):
response = get_crisis_response()
assert "911" in response

View File

@@ -1,55 +0,0 @@
"""
Tests for error classification (#752).
"""
import pytest
from tools.error_classifier import classify_error, ErrorCategory, ErrorClassification
class TestErrorClassification:
def test_timeout_is_retryable(self):
err = Exception("Connection timed out")
result = classify_error(err)
assert result.category == ErrorCategory.RETRYABLE
assert result.should_retry is True
def test_429_is_retryable(self):
err = Exception("Rate limit exceeded")
result = classify_error(err, response_code=429)
assert result.category == ErrorCategory.RETRYABLE
assert result.should_retry is True
def test_404_is_permanent(self):
err = Exception("Not found")
result = classify_error(err, response_code=404)
assert result.category == ErrorCategory.PERMANENT
assert result.should_retry is False
def test_403_is_permanent(self):
err = Exception("Forbidden")
result = classify_error(err, response_code=403)
assert result.category == ErrorCategory.PERMANENT
assert result.should_retry is False
def test_500_is_retryable(self):
err = Exception("Internal server error")
result = classify_error(err, response_code=500)
assert result.category == ErrorCategory.RETRYABLE
assert result.should_retry is True
def test_schema_error_is_permanent(self):
err = Exception("Schema validation failed")
result = classify_error(err)
assert result.category == ErrorCategory.PERMANENT
assert result.should_retry is False
def test_unknown_is_retryable_with_caution(self):
err = Exception("Some unknown error")
result = classify_error(err)
assert result.category == ErrorCategory.UNKNOWN
assert result.should_retry is True
assert result.max_retries == 1
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -1,82 +0,0 @@
"""Tests for Reader-Guided Reranking (RIDER) — issue #666."""
import pytest
from unittest.mock import MagicMock, patch
from agent.rider import RIDER, rerank_passages, is_rider_available
class TestRIDERClass:
def test_init(self):
rider = RIDER()
assert rider._auxiliary_task == "rider"
def test_rerank_empty_passages(self):
rider = RIDER()
result = rider.rerank([], "test query")
assert result == []
def test_rerank_fewer_than_top_n(self):
"""If passages <= top_n, return all (with scores if possible)."""
rider = RIDER()
passages = [{"content": "test content", "session_id": "s1"}]
result = rider.rerank(passages, "test query", top_n=3)
assert len(result) == 1
@patch("agent.rider.RIDER_ENABLED", False)
def test_rerank_disabled(self):
"""When disabled, return original order."""
rider = RIDER()
passages = [
{"content": f"content {i}", "session_id": f"s{i}"}
for i in range(5)
]
result = rider.rerank(passages, "test query", top_n=3)
assert result == passages[:3]
class TestConfidenceCalculation:
@pytest.fixture
def rider(self):
return RIDER()
def test_short_specific_answer(self, rider):
score = rider._calculate_confidence("Paris", "What is the capital of France?", "Paris is the capital of France.")
assert score > 0.5
def test_hedged_answer(self, rider):
score = rider._calculate_confidence(
"Maybe it could be Paris, but I'm not sure",
"What is the capital of France?",
"Paris is the capital.",
)
assert score < 0.5
def test_passage_grounding(self, rider):
score = rider._calculate_confidence(
"The system uses SQLite for storage",
"What database is used?",
"The system uses SQLite for persistent storage with FTS5 indexing.",
)
assert score > 0.5
def test_refusal_penalty(self, rider):
score = rider._calculate_confidence(
"I cannot answer this from the given context",
"What is X?",
"Some unrelated content",
)
assert score < 0.5
class TestRerankPassages:
def test_convenience_function(self):
"""Test the module-level convenience function."""
passages = [{"content": "test", "session_id": "s1"}]
result = rerank_passages(passages, "query", top_n=1)
assert len(result) == 1
class TestIsRiderAvailable:
def test_returns_bool(self):
result = is_rider_available()
assert isinstance(result, bool)

View File

@@ -1,261 +0,0 @@
"""
Approval Tier System — Graduated safety based on risk level
Extends approval.py with 5-tier system for command approval.
| Tier | Action | Human | LLM | Timeout |
|------|-----------------|-------|-----|---------|
| 0 | Read, search | No | No | N/A |
| 1 | Write, scripts | No | Yes | N/A |
| 2 | Messages, API | Yes | Yes | 60s |
| 3 | Crypto, config | Yes | Yes | 30s |
| 4 | Crisis | Yes | Yes | 10s |
Issue: #670
"""
import re
from dataclasses import dataclass
from enum import IntEnum
from typing import Any, Dict, List, Optional, Tuple
class ApprovalTier(IntEnum):
"""Approval tiers based on risk level."""
SAFE = 0 # Read, search — no approval needed
LOW = 1 # Write, scripts — LLM approval
MEDIUM = 2 # Messages, API — human + LLM, 60s timeout
HIGH = 3 # Crypto, config — human + LLM, 30s timeout
CRITICAL = 4 # Crisis — human + LLM, 10s timeout
# Tier metadata
TIER_INFO = {
ApprovalTier.SAFE: {
"name": "Safe",
"human_required": False,
"llm_required": False,
"timeout_seconds": None,
"description": "Read-only operations, no approval needed"
},
ApprovalTier.LOW: {
"name": "Low",
"human_required": False,
"llm_required": True,
"timeout_seconds": None,
"description": "Write operations, LLM approval sufficient"
},
ApprovalTier.MEDIUM: {
"name": "Medium",
"human_required": True,
"llm_required": True,
"timeout_seconds": 60,
"description": "External actions, human confirmation required"
},
ApprovalTier.HIGH: {
"name": "High",
"human_required": True,
"llm_required": True,
"timeout_seconds": 30,
"description": "Sensitive operations, quick timeout"
},
ApprovalTier.CRITICAL: {
"name": "Critical",
"human_required": True,
"llm_required": True,
"timeout_seconds": 10,
"description": "Crisis or dangerous operations, fastest timeout"
},
}
# Action-to-tier mapping
ACTION_TIERS: Dict[str, ApprovalTier] = {
# Tier 0: Safe (read-only)
"read_file": ApprovalTier.SAFE,
"search_files": ApprovalTier.SAFE,
"web_search": ApprovalTier.SAFE,
"session_search": ApprovalTier.SAFE,
"list_files": ApprovalTier.SAFE,
"get_file_content": ApprovalTier.SAFE,
"memory_search": ApprovalTier.SAFE,
"skills_list": ApprovalTier.SAFE,
"skills_search": ApprovalTier.SAFE,
# Tier 1: Low (write operations)
"write_file": ApprovalTier.LOW,
"create_file": ApprovalTier.LOW,
"patch_file": ApprovalTier.LOW,
"delete_file": ApprovalTier.LOW,
"execute_code": ApprovalTier.LOW,
"terminal": ApprovalTier.LOW,
"run_script": ApprovalTier.LOW,
"skill_install": ApprovalTier.LOW,
# Tier 2: Medium (external actions)
"send_message": ApprovalTier.MEDIUM,
"web_fetch": ApprovalTier.MEDIUM,
"browser_navigate": ApprovalTier.MEDIUM,
"api_call": ApprovalTier.MEDIUM,
"gitea_create_issue": ApprovalTier.MEDIUM,
"gitea_create_pr": ApprovalTier.MEDIUM,
"git_push": ApprovalTier.MEDIUM,
"deploy": ApprovalTier.MEDIUM,
# Tier 3: High (sensitive operations)
"config_change": ApprovalTier.HIGH,
"env_change": ApprovalTier.HIGH,
"key_rotation": ApprovalTier.HIGH,
"access_grant": ApprovalTier.HIGH,
"permission_change": ApprovalTier.HIGH,
"backup_restore": ApprovalTier.HIGH,
# Tier 4: Critical (crisis/dangerous)
"kill_process": ApprovalTier.CRITICAL,
"rm_rf": ApprovalTier.CRITICAL,
"format_disk": ApprovalTier.CRITICAL,
"shutdown": ApprovalTier.CRITICAL,
"crisis_override": ApprovalTier.CRITICAL,
}
# Dangerous command patterns (from existing approval.py)
_DANGEROUS_PATTERNS = [
(r"rm\s+-rf\s+/", ApprovalTier.CRITICAL),
(r"mkfs\.", ApprovalTier.CRITICAL),
(r"dd\s+if=.*of=/dev/", ApprovalTier.CRITICAL),
(r"shutdown|reboot|halt", ApprovalTier.CRITICAL),
(r"chmod\s+777", ApprovalTier.HIGH),
(r"curl.*\|\s*bash", ApprovalTier.HIGH),
(r"wget.*\|\s*sh", ApprovalTier.HIGH),
(r"eval\s*\(", ApprovalTier.HIGH),
(r"sudo\s+", ApprovalTier.MEDIUM),
(r"git\s+push.*--force", ApprovalTier.HIGH),
(r"docker\s+rm.*-f", ApprovalTier.MEDIUM),
(r"kubectl\s+delete", ApprovalTier.HIGH),
]
@dataclass
class ApprovalRequest:
"""A request for approval."""
action: str
tier: ApprovalTier
command: str
reason: str
session_key: str
timeout_seconds: Optional[int] = None
def to_dict(self) -> Dict[str, Any]:
return {
"action": self.action,
"tier": self.tier.value,
"tier_name": TIER_INFO[self.tier]["name"],
"command": self.command,
"reason": self.reason,
"session_key": self.session_key,
"timeout": self.timeout_seconds,
"human_required": TIER_INFO[self.tier]["human_required"],
"llm_required": TIER_INFO[self.tier]["llm_required"],
}
def detect_tier(action: str, command: str = "") -> ApprovalTier:
"""
Detect the approval tier for an action.
Checks action name first, then falls back to pattern matching.
"""
# Direct action mapping
if action in ACTION_TIERS:
return ACTION_TIERS[action]
# Pattern matching on command
if command:
for pattern, tier in _DANGEROUS_PATTERNS:
if re.search(pattern, command, re.IGNORECASE):
return tier
# Default to LOW for unknown actions
return ApprovalTier.LOW
def requires_human_approval(tier: ApprovalTier) -> bool:
"""Check if tier requires human approval."""
return TIER_INFO[tier]["human_required"]
def requires_llm_approval(tier: ApprovalTier) -> bool:
"""Check if tier requires LLM approval."""
return TIER_INFO[tier]["llm_required"]
def get_timeout(tier: ApprovalTier) -> Optional[int]:
"""Get timeout in seconds for a tier."""
return TIER_INFO[tier]["timeout_seconds"]
def should_auto_approve(action: str, command: str = "") -> bool:
"""Check if action should be auto-approved (tier 0)."""
tier = detect_tier(action, command)
return tier == ApprovalTier.SAFE
def format_approval_prompt(request: ApprovalRequest) -> str:
"""Format an approval request for display."""
info = TIER_INFO[request.tier]
lines = []
lines.append(f"⚠️ Approval Required (Tier {request.tier.value}: {info['name']})")
lines.append(f"")
lines.append(f"Action: {request.action}")
lines.append(f"Command: {request.command[:100]}{'...' if len(request.command) > 100 else ''}")
lines.append(f"Reason: {request.reason}")
lines.append(f"")
if info["human_required"]:
lines.append(f"👤 Human approval required")
if info["llm_required"]:
lines.append(f"🤖 LLM approval required")
if info["timeout_seconds"]:
lines.append(f"⏱️ Timeout: {info['timeout_seconds']}s")
return "\n".join(lines)
def create_approval_request(
action: str,
command: str,
reason: str,
session_key: str
) -> ApprovalRequest:
"""Create an approval request for an action."""
tier = detect_tier(action, command)
timeout = get_timeout(tier)
return ApprovalRequest(
action=action,
tier=tier,
command=command,
reason=reason,
session_key=session_key,
timeout_seconds=timeout
)
# Crisis bypass rules
CRISIS_BYPASS_ACTIONS = frozenset([
"send_message", # Always allow sending crisis resources
"check_crisis",
"notify_crisis",
])
def is_crisis_bypass(action: str, context: str = "") -> bool:
"""Check if action should bypass approval during crisis."""
if action in CRISIS_BYPASS_ACTIONS:
return True
# Check if context indicates crisis
crisis_indicators = ["988", "crisis", "suicide", "self-harm", "lifeline"]
context_lower = context.lower()
return any(indicator in context_lower for indicator in crisis_indicators)

View File

@@ -1,233 +0,0 @@
"""
Tool Error Classification — Retryable vs Permanent.
Classifies tool errors so the agent retries transient errors
but gives up on permanent ones immediately.
"""
import logging
import re
import time
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Dict, Any
logger = logging.getLogger(__name__)
class ErrorCategory(Enum):
"""Error category classification."""
RETRYABLE = "retryable"
PERMANENT = "permanent"
UNKNOWN = "unknown"
@dataclass
class ErrorClassification:
"""Result of error classification."""
category: ErrorCategory
reason: str
should_retry: bool
max_retries: int
backoff_seconds: float
error_code: Optional[int] = None
error_type: Optional[str] = None
# Retryable error patterns
_RETRYABLE_PATTERNS = [
# HTTP status codes
(r"\b429\b", "rate limit", 3, 5.0),
(r"\b500\b", "server error", 3, 2.0),
(r"\b502\b", "bad gateway", 3, 2.0),
(r"\b503\b", "service unavailable", 3, 5.0),
(r"\b504\b", "gateway timeout", 3, 5.0),
# Timeout patterns
(r"timeout", "timeout", 3, 2.0),
(r"timed out", "timeout", 3, 2.0),
(r"TimeoutExpired", "timeout", 3, 2.0),
# Connection errors
(r"connection refused", "connection refused", 2, 5.0),
(r"connection reset", "connection reset", 2, 2.0),
(r"network unreachable", "network unreachable", 2, 10.0),
(r"DNS", "DNS error", 2, 5.0),
# Transient errors
(r"temporary", "temporary error", 2, 2.0),
(r"transient", "transient error", 2, 2.0),
(r"retry", "retryable", 2, 2.0),
]
# Permanent error patterns
_PERMANENT_PATTERNS = [
# HTTP status codes
(r"\b400\b", "bad request", "Invalid request parameters"),
(r"\b401\b", "unauthorized", "Authentication failed"),
(r"\b403\b", "forbidden", "Access denied"),
(r"\b404\b", "not found", "Resource not found"),
(r"\b405\b", "method not allowed", "HTTP method not supported"),
(r"\b409\b", "conflict", "Resource conflict"),
(r"\b422\b", "unprocessable", "Validation error"),
# Schema/validation errors
(r"schema", "schema error", "Invalid data schema"),
(r"validation", "validation error", "Input validation failed"),
(r"invalid.*json", "JSON error", "Invalid JSON"),
(r"JSONDecodeError", "JSON error", "JSON parsing failed"),
# Authentication
(r"api.?key", "API key error", "Invalid or missing API key"),
(r"token.*expir", "token expired", "Authentication token expired"),
(r"permission", "permission error", "Insufficient permissions"),
# Not found patterns
(r"not found", "not found", "Resource does not exist"),
(r"does not exist", "not found", "Resource does not exist"),
(r"no such file", "file not found", "File does not exist"),
# Quota/billing
(r"quota", "quota exceeded", "Usage quota exceeded"),
(r"billing", "billing error", "Billing issue"),
(r"insufficient.*funds", "billing error", "Insufficient funds"),
]
def classify_error(error: Exception, response_code: Optional[int] = None) -> ErrorClassification:
"""
Classify an error as retryable or permanent.
Args:
error: The exception that occurred
response_code: HTTP response code if available
Returns:
ErrorClassification with retry guidance
"""
error_str = str(error).lower()
error_type = type(error).__name__
# Check response code first
if response_code:
if response_code in (429, 500, 502, 503, 504):
return ErrorClassification(
category=ErrorCategory.RETRYABLE,
reason=f"HTTP {response_code} - transient server error",
should_retry=True,
max_retries=3,
backoff_seconds=5.0 if response_code == 429 else 2.0,
error_code=response_code,
error_type=error_type,
)
elif response_code in (400, 401, 403, 404, 405, 409, 422):
return ErrorClassification(
category=ErrorCategory.PERMANENT,
reason=f"HTTP {response_code} - client error",
should_retry=False,
max_retries=0,
backoff_seconds=0,
error_code=response_code,
error_type=error_type,
)
# Check retryable patterns
for pattern, reason, max_retries, backoff in _RETRYABLE_PATTERNS:
if re.search(pattern, error_str, re.IGNORECASE):
return ErrorClassification(
category=ErrorCategory.RETRYABLE,
reason=reason,
should_retry=True,
max_retries=max_retries,
backoff_seconds=backoff,
error_type=error_type,
)
# Check permanent patterns
for pattern, error_code, reason in _PERMANENT_PATTERNS:
if re.search(pattern, error_str, re.IGNORECASE):
return ErrorClassification(
category=ErrorCategory.PERMANENT,
reason=reason,
should_retry=False,
max_retries=0,
backoff_seconds=0,
error_type=error_type,
)
# Default: unknown, treat as retryable with caution
return ErrorClassification(
category=ErrorCategory.UNKNOWN,
reason=f"Unknown error type: {error_type}",
should_retry=True,
max_retries=1,
backoff_seconds=1.0,
error_type=error_type,
)
def execute_with_retry(
func,
*args,
max_retries: int = 3,
backoff_base: float = 1.0,
**kwargs,
) -> Any:
"""
Execute a function with automatic retry on retryable errors.
Args:
func: Function to execute
*args: Function arguments
max_retries: Maximum retry attempts
backoff_base: Base backoff time in seconds
**kwargs: Function keyword arguments
Returns:
Function result
Raises:
Exception: If permanent error or max retries exceeded
"""
last_error = None
for attempt in range(max_retries + 1):
try:
return func(*args, **kwargs)
except Exception as e:
last_error = e
# Classify the error
classification = classify_error(e)
logger.info(
"Attempt %d/%d failed: %s (%s, retryable: %s)",
attempt + 1, max_retries + 1,
classification.reason,
classification.category.value,
classification.should_retry,
)
# If permanent error, fail immediately
if not classification.should_retry:
logger.error("Permanent error: %s", classification.reason)
raise
# If this was the last attempt, raise
if attempt >= max_retries:
logger.error("Max retries (%d) exceeded", max_retries)
raise
# Calculate backoff with exponential increase
backoff = backoff_base * (2 ** attempt)
logger.info("Retrying in %.1fs...", backoff)
time.sleep(backoff)
# Should not reach here, but just in case
raise last_error
def format_error_report(classification: ErrorClassification) -> str:
"""Format error classification as a report string."""
icon = "🔄" if classification.should_retry else ""
return f"{icon} {classification.category.value}: {classification.reason}"

View File

@@ -394,23 +394,6 @@ def session_search(
if len(seen_sessions) >= limit:
break
# RIDER: Reader-guided reranking — sort sessions by LLM answerability
# This bridges the R@5 vs E2E accuracy gap by prioritizing passages
# the LLM can actually answer from, not just keyword matches.
try:
from agent.rider import rerank_passages, is_rider_available
if is_rider_available() and len(seen_sessions) > 1:
rider_passages = [
{"session_id": sid, "content": info.get("snippet", ""), "rank": i + 1}
for i, (sid, info) in enumerate(seen_sessions.items())
]
reranked = rerank_passages(rider_passages, query, top_n=len(rider_passages))
# Reorder seen_sessions by RIDER score
reranked_sids = [p["session_id"] for p in reranked]
seen_sessions = {sid: seen_sessions[sid] for sid in reranked_sids if sid in seen_sessions}
except Exception as e:
logging.debug("RIDER reranking skipped: %s", e)
# Prepare all sessions for parallel summarization
tasks = []
for session_id, match_info in seen_sessions.items():