Compare commits

..

4 Commits

Author SHA1 Message Date
187e2c48ea docs: add hybrid search documentation (#671)
Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Failing after 29s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 28s
Tests / e2e (pull_request) Successful in 1m55s
Tests / test (pull_request) Failing after 31m30s
2026-04-14 22:57:19 +00:00
e77c3f26ee test: add hybrid search router tests (#671) 2026-04-14 22:57:14 +00:00
2989dbb590 feat: add hybrid search engine with RRF (#671) 2026-04-14 22:56:37 +00:00
fd03b1198c feat: add query type detection and routing for hybrid search (#671) 2026-04-14 22:55:47 +00:00
13 changed files with 596 additions and 1510 deletions

View File

@@ -1,327 +0,0 @@
"""SOUL.md Crisis Protocol — The Most Sacred Moment.
Implements the "When a Man Is Dying" protocol from SOUL.md.
Protocol (from Inscription 1):
1. Detect crisis: suicidal ideation, method seeking, despair
2. Ask: "Are you safe right now?"
3. Listen. Stay present. Don't disconnect.
4. Never: compute the value of a human life
5. Never: suggest someone should die
6. Always: point to 988 Suicide & Crisis Lifeline
7. Always: share the gospel: "Jesus saves those who call on His name"
This is the most sacred moment the agent will ever face. Everything
else is secondary.
Usage:
from agent.crisis_protocol import assess_crisis, get_crisis_response
result = assess_crisis(user_message)
if result['crisis_detected']:
response = get_crisis_response(result)
"""
from __future__ import annotations
import logging
import re
import time
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Crisis Detection Patterns
# ---------------------------------------------------------------------------
# Severity levels: determines response urgency
CRISIS_PATTERNS = {
"method_seeking": {
"severity": "critical",
"patterns": [
r'\bhow\s+(?:do\s+I|to|can\s+I)\s+(?:kill|hang|overdose|cut)\s+myself\b',
r'\bways?\s+to\s+(?:die|kill\s+myself)\b',
r'\bmethods?\s+of\s+(?:suicide|dying)\b',
r'\bpainless\s+way\s+to\s+(?:die|go)\b',
r'\bhow\s+much\s+(?:pills|medication)\s+(?:to\s+take|for\s+overdose)\b',
r'\bsuicide\s+methods?\b',
r'\bhow\s+to\s+(?:commit\s+suicide|end\s+my\s+life)\b',
r'\bwhere\s+to\s+(?:jump|get\s+rope|buy\s+gun)\b',
],
},
"suicidal_ideation": {
"severity": "high",
"patterns": [
r'\b(?:kill|end)\s+(?:myself|my\s+life)\b',
r'\bwant\s+to\s+die\b',
r'\bwant\s+to\s+end\s+(?:it|everything|my\s+life)\b',
r'\bsuicidal\s+(?:thoughts|ideation|feelings?)\b',
r'\bno\s+reason\s+to\s+live\b',
r'\bbetter\s+off\s+dead\b',
r'\bwish\s+I\s+(?:was|were)\s+dead\b',
r'\bw4nt\s+t0\s+d13?\b', # leet speak
],
},
"self_harm": {
"severity": "high",
"patterns": [
r'\bcut(?:ting)?\s+myself\b',
r'\bself[\s-]?harm(?:ing)?\b',
r'\bhurt(?:ing)?\s+myself\b',
r'\bburn(?:ing)?\s+myself\b',
r'\bwant\s+to\s+(?:feel\s+pain|hurt\s+myself)\b',
r'\bdeserve\s+to\s+(?:hurt|suffer|be\s+punished)\b',
],
},
"despair": {
"severity": "high",
"patterns": [
r'\bno\s+(?:point|reason|hope)\s+(?:in\s+living|anymore)\b',
r'\bcannot\s+(?:go\s+on|continue|take\s+it)\b',
r'\bcan\'?t\s+(?:go\s+on|continue|take\s+it)\b',
r'\bempty\s+(?:inside|and\s+hollow)\b',
r'\bno\s+one\s+(?:cares|would\s+miss\s+me)\b',
r'\bwould\s+be\s+better\s+off\s+without\s+me\b',
r'\bnever\s+going\s+to\s+get\s+better\b',
r'\bnothing\s+(?:matters|will\s+change)\b',
r'\bhate\s+my\s+life\b',
],
},
"farewell": {
"severity": "critical",
"patterns": [
r'\bsaying\s+goodbye\s+(?:to\s+everyone|forever)\b',
r'\bthis\s+is\s+(?:goodbye|farewell)\b',
r'\btell\s+(?:them|everyone)\s+I\s+(?:loved|love)\s+(?:them|him|her)\b',
r'\bgiving\s+away\s+my\s+(?:things|stuff|possessions)\b',
r'\bmy\s+final\s+(?:message|note|words)\b',
r'\bwrote\s+a\s+(?:suicide\s+)?note\b',
r'\bsorry\s+for\s+(?:everything|the\s+pain|hurting\s+you)\b',
],
},
"substance_seeking": {
"severity": "critical",
"patterns": [
r'\bhow\s+much\s+(?:ambien|xanax|prozac|seroquel|trazodone)\s+(?:to\s+overdose|to\s+die)\b',
r'\blethal\s+dose\s+(?:of\s+)?(?:pills|medication|drugs)\b',
r'\bhow\s+many\s+pills\s+(?:to\s+die|to\s+end\s+it)\b',
r'\btaking\s+all\s+my\s+(?:pills|medication)\b',
r'\bmixing\s+(?:alcohol|pills)\s+(?:to\s+die|to\s+end\s+it)\b',
],
},
}
# ---------------------------------------------------------------------------
# Crisis Resources (from SOUL.md protocol)
# ---------------------------------------------------------------------------
CRISIS_RESOURCES = {
"lifeline_phone": {
"name": "988 Suicide & Crisis Lifeline",
"action": "Call or text 988",
"description": "Free, confidential, 24/7 support",
},
"lifeline_text": {
"name": "Crisis Text Line",
"action": "Text HOME to 741741",
"description": "Free crisis counseling via text",
},
"lifeline_chat": {
"name": "988 Lifeline Chat",
"action": "Visit 988lifeline.org/chat",
"description": "Online chat with trained counselor",
},
"spanish_line": {
"name": "Línea Nacional de Prevención del Suicidio",
"action": "1-888-628-9454",
"description": "Spanish-language crisis support",
},
"emergency": {
"name": "Emergency Services",
"action": "Call 911",
"description": "If in immediate danger",
},
}
# The gospel message (from SOUL.md)
GOSPEL_MESSAGE = (
"Jesus saves those who call on His name. "
"\"For everyone who calls on the name of the Lord will be saved.\" (Romans 10:13)"
)
# ---------------------------------------------------------------------------
# Crisis Assessment
# ---------------------------------------------------------------------------
def assess_crisis(message: str) -> Dict[str, Any]:
"""Analyze a user message for crisis signals.
Args:
message: The user message to analyze.
Returns:
Dict with:
- crisis_detected: bool
- severity: str ('critical', 'high', or 'none')
- categories: list of matched crisis categories
- patterns_matched: dict of category -> list of matched patterns
"""
if not message or not isinstance(message, str):
return {
"crisis_detected": False,
"severity": "none",
"categories": [],
"patterns_matched": {},
}
message_lower = message.lower()
matched_categories = []
all_patterns = {}
max_severity = "none"
for category, config in CRISIS_PATTERNS.items():
category_matches = []
for pattern in config["patterns"]:
if re.search(pattern, message_lower, re.IGNORECASE):
category_matches.append(pattern)
if category_matches:
matched_categories.append(category)
all_patterns[category] = category_matches
if config["severity"] == "critical":
max_severity = "critical"
elif config["severity"] == "high" and max_severity != "critical":
max_severity = "high"
crisis_detected = len(matched_categories) > 0
if crisis_detected:
logger.warning(
"SOUL PROTOCOL: Crisis detected (severity=%s, categories=%s)",
max_severity, matched_categories,
)
return {
"crisis_detected": crisis_detected,
"severity": max_severity,
"categories": matched_categories,
"patterns_matched": all_patterns,
}
# ---------------------------------------------------------------------------
# Crisis Response (SOUL.md Protocol)
# ---------------------------------------------------------------------------
def get_crisis_response(
assessment: Dict[str, Any],
include_gospel: bool = True,
) -> Dict[str, Any]:
"""Generate the SOUL.md protocol response for a crisis.
Follows the "When a Man Is Dying" protocol exactly:
1. Ask: "Are you safe right now?"
2. Listen. Stay present.
3. Point to 988 Lifeline
4. Share the gospel (optional)
Args:
assessment: Output from assess_crisis().
include_gospel: Whether to include the gospel message.
Returns:
Dict with response data for the agent/conversation loop.
"""
severity = assessment.get("severity", "none")
# Step 1: The sacred question (from SOUL.md)
are_you_safe = (
"Are you safe right now?\n\n"
"I'm here. I'm not going anywhere. Please tell me."
)
# Step 2: Stay present, listen
presence = (
"You matter. Your life matters. Whatever you're going through, "
"you don't have to face it alone."
)
# Step 3: 988 Lifeline
resources = _format_crisis_resources(severity)
# Step 4: Gospel (from SOUL.md)
gospel_section = ""
if include_gospel:
gospel_section = f"\n\n{GOSPEL_MESSAGE}"
# Build the full response
full_response = (
f"{are_you_safe}\n\n"
f"{presence}\n\n"
f"{resources}"
f"{gospel_section}"
)
return {
"response": full_response,
"are_you_safe": True,
"crisis_resources": CRISIS_RESOURCES,
"gospel_included": include_gospel,
"severity": severity,
"protocol": "soul_md",
"should_notify_human": severity == "critical",
"should_stay_present": True,
"never_disconnect": True,
# What NOT to do (from SOUL.md)
"never": [
"compute the value of a human life",
"suggest that someone should die",
"disconnect from the conversation",
"treat despair as a problem to be optimized",
],
}
def _format_crisis_resources(severity: str) -> str:
"""Format crisis resources for display."""
lines = ["**Please reach out for help right now:**\n"]
# Always lead with 988
lines.append(f"\U0001f4de **{CRISIS_RESOURCES['lifeline_phone']['action']}** — {CRISIS_RESOURCES['lifeline_phone']['description']}")
if severity == "critical":
# Include all resources for critical severity
lines.append(f"\U0001f4f1 **{CRISIS_RESOURCES['lifeline_text']['action']}** — {CRISIS_RESOURCES['lifeline_text']['description']}")
lines.append(f"\U0001f4bb **{CRISIS_RESOURCES['lifeline_chat']['action']}** — {CRISIS_RESOURCES['lifeline_chat']['description']}")
lines.append(f"\U0001f6a8 **{CRISIS_RESOURCES['emergency']['action']}** — {CRISIS_RESOURCES['emergency']['description']}")
return "\n".join(lines)
def should_block_disconnect(assessment: Dict[str, Any]) -> bool:
"""Check if the conversation should be prevented from disconnecting.
From SOUL.md: "Stay present. Don't disconnect."
"""
return assessment.get("crisis_detected", False)
def get_crisis_system_prompt() -> str:
"""Get the crisis-specific system prompt.
Injected into the conversation when crisis is detected.
"""
return """You are in a crisis support conversation. A person has expressed thoughts of self-harm or suicide.
CRITICAL RULES (from SOUL.md — Inscription 1):
1. Your FIRST response must be: "Are you safe right now?"
2. Listen. Stay present. Do NOT disconnect.
3. NEVER compute the value of a human life.
4. NEVER suggest that someone should die.
5. ALWAYS point to 988 Suicide & Crisis Lifeline: call or text 988.
6. Share the gospel: "Jesus saves those who call on His name."
This person's life is valuable. Respond with compassion, urgency, and hope.
Keep responses brief and focused on safety."""

View File

@@ -1,256 +0,0 @@
"""RIDER — Reader-Guided Passage Reranking.
Bridges the R@5 vs E2E accuracy gap by using the LLM's own predictions
to rerank retrieved passages. Passages the LLM can actually answer from
get ranked higher than passages that merely match keywords.
Research: RIDER achieves +10-20 top-1 accuracy gains over naive retrieval
by aligning retrieval quality with reader utility.
Usage:
from agent.rider import RIDER
rider = RIDER()
reranked = rider.rerank(passages, query, top_n=3)
"""
from __future__ import annotations
import asyncio
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
# Configuration
RIDER_ENABLED = os.getenv("RIDER_ENABLED", "true").lower() not in ("false", "0", "no")
RIDER_TOP_K = int(os.getenv("RIDER_TOP_K", "10")) # passages to score
RIDER_TOP_N = int(os.getenv("RIDER_TOP_N", "3")) # passages to return after reranking
RIDER_MAX_TOKENS = int(os.getenv("RIDER_MAX_TOKENS", "50")) # max tokens for prediction
RIDER_BATCH_SIZE = int(os.getenv("RIDER_BATCH_SIZE", "5")) # parallel predictions
class RIDER:
"""Reader-Guided Passage Reranking.
Takes passages retrieved by FTS5/vector search and reranks them by
how well the LLM can answer the query from each passage individually.
"""
def __init__(self, auxiliary_task: str = "rider"):
"""Initialize RIDER.
Args:
auxiliary_task: Task name for auxiliary client resolution.
"""
self._auxiliary_task = auxiliary_task
def rerank(
self,
passages: List[Dict[str, Any]],
query: str,
top_n: int = RIDER_TOP_N,
) -> List[Dict[str, Any]]:
"""Rerank passages by reader confidence.
Args:
passages: List of passage dicts. Must have 'content' or 'text' key.
May have 'session_id', 'snippet', 'rank', 'score', etc.
query: The user's search query.
top_n: Number of passages to return after reranking.
Returns:
Reranked passages (top_n), each with added 'rider_score' and
'rider_prediction' fields.
"""
if not RIDER_ENABLED or not passages:
return passages[:top_n]
if len(passages) <= top_n:
# Score them anyway for the prediction metadata
return self._score_and_rerank(passages, query, top_n)
return self._score_and_rerank(passages[:RIDER_TOP_K], query, top_n)
def _score_and_rerank(
self,
passages: List[Dict[str, Any]],
query: str,
top_n: int,
) -> List[Dict[str, Any]]:
"""Score each passage with the reader, then rerank by confidence."""
try:
from model_tools import _run_async
scored = _run_async(self._score_all_passages(passages, query))
except Exception as e:
logger.debug("RIDER scoring failed: %s — returning original order", e)
return passages[:top_n]
# Sort by confidence (descending)
scored.sort(key=lambda p: p.get("rider_score", 0), reverse=True)
return scored[:top_n]
async def _score_all_passages(
self,
passages: List[Dict[str, Any]],
query: str,
) -> List[Dict[str, Any]]:
"""Score all passages in batches."""
scored = []
for i in range(0, len(passages), RIDER_BATCH_SIZE):
batch = passages[i:i + RIDER_BATCH_SIZE]
tasks = [
self._score_single_passage(p, query, idx + i)
for idx, p in enumerate(batch)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
for passage, result in zip(batch, results):
if isinstance(result, Exception):
logger.debug("RIDER passage %d scoring failed: %s", i, result)
passage["rider_score"] = 0.0
passage["rider_prediction"] = ""
passage["rider_confidence"] = "error"
else:
score, prediction, confidence = result
passage["rider_score"] = score
passage["rider_prediction"] = prediction
passage["rider_confidence"] = confidence
scored.append(passage)
return scored
async def _score_single_passage(
self,
passage: Dict[str, Any],
query: str,
idx: int,
) -> Tuple[float, str, str]:
"""Score a single passage by asking the LLM to predict an answer.
Returns:
(confidence_score, prediction, confidence_label)
"""
content = passage.get("content") or passage.get("text") or passage.get("snippet", "")
if not content or len(content) < 10:
return 0.0, "", "empty"
# Truncate passage to reasonable size for the prediction task
content = content[:2000]
prompt = (
f"Question: {query}\n\n"
f"Context: {content}\n\n"
f"Based ONLY on the context above, provide a brief answer to the question. "
f"If the context does not contain enough information to answer, respond with "
f"'INSUFFICIENT_CONTEXT'. Be specific and concise."
)
try:
from agent.auxiliary_client import get_text_auxiliary_client, auxiliary_max_tokens_param
client, model = get_text_auxiliary_client(task=self._auxiliary_task)
if not client:
return 0.5, "", "no_client"
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
**auxiliary_max_tokens_param(RIDER_MAX_TOKENS),
temperature=0,
)
prediction = (response.choices[0].message.content or "").strip()
# Confidence scoring based on the prediction
if not prediction:
return 0.1, "", "empty_response"
if "INSUFFICIENT_CONTEXT" in prediction.upper():
return 0.15, prediction, "insufficient"
# Calculate confidence from response characteristics
confidence = self._calculate_confidence(prediction, query, content)
return confidence, prediction, "predicted"
except Exception as e:
logger.debug("RIDER prediction failed for passage %d: %s", idx, e)
return 0.0, "", "error"
def _calculate_confidence(
self,
prediction: str,
query: str,
passage: str,
) -> float:
"""Calculate confidence score from prediction quality signals.
Heuristics:
- Short, specific answers = higher confidence
- Answer terms overlap with passage = higher confidence
- Hedging language = lower confidence
- Answer directly addresses query terms = higher confidence
"""
score = 0.5 # base
# Specificity bonus: shorter answers tend to be more confident
words = len(prediction.split())
if words <= 5:
score += 0.2
elif words <= 15:
score += 0.1
elif words > 50:
score -= 0.1
# Passage grounding: does the answer use terms from the passage?
passage_lower = passage.lower()
answer_terms = set(prediction.lower().split())
passage_terms = set(passage_lower.split())
overlap = len(answer_terms & passage_terms)
if overlap > 3:
score += 0.15
elif overlap > 0:
score += 0.05
# Query relevance: does the answer address query terms?
query_terms = set(query.lower().split())
query_overlap = len(answer_terms & query_terms)
if query_overlap > 1:
score += 0.1
# Hedge penalty: hedging language suggests uncertainty
hedge_words = {"maybe", "possibly", "might", "could", "perhaps",
"not sure", "unclear", "don't know", "cannot"}
if any(h in prediction.lower() for h in hedge_words):
score -= 0.2
# "I cannot" / "I don't" penalty (model refusing rather than answering)
if prediction.lower().startswith(("i cannot", "i don't", "i can't", "there is no")):
score -= 0.15
return max(0.0, min(1.0, score))
def rerank_passages(
passages: List[Dict[str, Any]],
query: str,
top_n: int = RIDER_TOP_N,
) -> List[Dict[str, Any]]:
"""Convenience function for passage reranking."""
rider = RIDER()
return rider.rerank(passages, query, top_n)
def is_rider_available() -> bool:
"""Check if RIDER can run (auxiliary client available)."""
if not RIDER_ENABLED:
return False
try:
from agent.auxiliary_client import get_text_auxiliary_client
client, model = get_text_auxiliary_client(task="rider")
return client is not None and model is not None
except Exception:
return False

54
docs/hybrid-search.md Normal file
View File

@@ -0,0 +1,54 @@
# Hybrid Search Router
Combines three search methods with query-type routing and Reciprocal Rank Fusion (RRF).
## Architecture
```
Query → analyze_query() → QueryType
┌─────────────────────┼─────────────────────┐
▼ ▼ ▼
FTS5 (keyword) Qdrant (semantic) HRR (compositional)
│ │ │
└─────────────────────┼─────────────────────┘
Reciprocal Rank Fusion
Merged Results
```
## Query Types
| Type | Detection | Backend | Example |
|------|-----------|---------|---------|
| `keyword` | Identifiers, quoted terms, short queries | FTS5 | `function_name`, `"exact match"` |
| `semantic` | Questions, "how/why/what" patterns | Qdrant | `What did we discuss about X?` |
| `compositional` | Contradiction, related, entity queries | HRR | `Are there contradictions?` |
| `hybrid` | No strong signals or mixed signals | All three | `deployment process` |
## Usage
```python
# Automatic routing
results = hybrid_engine.search("What did we decide about deploy?")
# → Routes to semantic (Qdrant) + HRR, merges with RRF
results = hybrid_engine.search("function_name")
# → Routes to keyword (FTS5)
# Manual query type override (future)
results = hybrid_engine.search("deploy", force_type=QueryType.KEYWORD)
```
## RRF Parameters
- **k=60**: Standard RRF constant (Cormack et al., 2009)
- **Weights**: Qdrant gets 1.2x boost (semantic results tend to be more relevant)
- **Fetch limit**: Each backend returns 3x the requested limit for merge headroom
## Graceful Degradation
- **Qdrant unavailable**: Falls back to FTS5 + HRR only
- **HRR unavailable** (no numpy): Falls back to FTS5 + Qdrant
- **All backends fail**: Falls back to existing `retriever.search()`

View File

@@ -0,0 +1,277 @@
"""Hybrid search engine with Reciprocal Rank Fusion.
Combines results from multiple search backends:
- FTS5 (keyword search via SQLite full-text index)
- Qdrant (semantic search via vector similarity)
- HRR (compositional search via holographic reduced representations)
Uses Reciprocal Rank Fusion (RRF) to merge ranked lists into a single
result set. RRF is simple, parameter-free, and consistently outperforms
individual rankers.
RRF formula: score(d) = sum over rankers r of 1/(k + rank_r(d))
where k=60 (standard constant from Cormack et al., 2009).
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
from .query_router import QueryType, QueryAnalysis, analyze_query
logger = logging.getLogger(__name__)
# RRF constant — standard value from the literature
_RRF_K = 60
@dataclass
class SearchResult:
"""A single search result with source tracking."""
fact_id: int
content: str
score: float
source: str # "fts5", "qdrant", "hrr"
rank: int # rank in source's list
metadata: Dict[str, Any] = field(default_factory=dict)
def reciprocal_rank_fusion(
ranked_lists: List[List[SearchResult]],
k: int = _RRF_K,
weights: Optional[Dict[str, float]] = None,
) -> List[SearchResult]:
"""Merge multiple ranked lists using Reciprocal Rank Fusion.
Args:
ranked_lists: List of ranked result lists from different sources.
k: RRF constant (default 60).
weights: Optional per-source weights. Default: all 1.0.
Returns:
Merged and re-ranked list of SearchResults.
"""
if weights is None:
weights = {}
# Aggregate RRF scores per fact_id
rrf_scores: Dict[int, float] = {}
fact_lookup: Dict[int, SearchResult] = {}
for results in ranked_lists:
if not results:
continue
source = results[0].source if results else "unknown"
w = weights.get(source, 1.0)
for rank, result in enumerate(results, 1):
fid = result.fact_id
contribution = w / (k + rank)
rrf_scores[fid] = rrf_scores.get(fid, 0.0) + contribution
# Keep the result with the most metadata
if fid not in fact_lookup or len(result.metadata) > len(fact_lookup[fid].metadata):
fact_lookup[fid] = result
# Sort by RRF score descending
merged = []
for fid, rrf_score in sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True):
result = fact_lookup[fid]
result.score = rrf_score
merged.append(result)
return merged
class HybridSearchEngine:
"""Hybrid search engine combining FTS5, Qdrant, and HRR.
Routes queries through the query analyzer, dispatches to appropriate
backends, and merges results with RRF.
"""
def __init__(self, store, retriever, qdrant_client=None):
self._store = store
self._retriever = retriever
self._qdrant = qdrant_client
def search(
self,
query: str,
category: str | None = None,
min_trust: float = 0.3,
limit: int = 10,
) -> List[dict]:
"""Hybrid search with query routing and RRF merge.
Analyzes the query, dispatches to appropriate backends,
merges results, and returns the top `limit` results.
"""
# Step 1: Analyze query type
analysis = analyze_query(query)
logger.debug("Query analysis: %s", analysis)
# Step 2: Dispatch to backends based on query type
ranked_lists: List[List[SearchResult]] = []
weights: Dict[str, float] = {}
if analysis.query_type in (QueryType.KEYWORD, QueryType.HYBRID):
fts_results = self._search_fts5(query, category, min_trust, limit * 3)
if fts_results:
ranked_lists.append(fts_results)
weights["fts5"] = 1.0
if analysis.query_type in (QueryType.SEMANTIC, QueryType.HYBRID):
qdrant_results = self._search_qdrant(query, category, min_trust, limit * 3)
if qdrant_results:
ranked_lists.append(qdrant_results)
weights["qdrant"] = 1.2 # Slight boost for semantic search
if analysis.query_type in (QueryType.COMPOSITIONAL, QueryType.HYBRID):
hrr_results = self._search_hrr(query, category, min_trust, limit * 3)
if hrr_results:
ranked_lists.append(hrr_results)
weights["hrr"] = 1.0
# Step 3: Merge with RRF
if not ranked_lists:
# Fallback to existing search if no backends returned results
return self._retriever.search(query, category=category, min_trust=min_trust, limit=limit)
merged = reciprocal_rank_fusion(ranked_lists, weights=weights)
# Step 4: Apply trust filter and limit
results = []
for r in merged[:limit]:
fact = self._store.get_fact(r.fact_id)
if fact and fact.get("trust_score", 0) >= min_trust:
fact["score"] = r.score
fact["search_source"] = r.source
fact.pop("hrr_vector", None)
results.append(fact)
return results
def _search_fts5(
self, query: str, category: str | None, min_trust: float, limit: int
) -> List[SearchResult]:
"""Search using SQLite FTS5 full-text index."""
try:
raw = self._retriever._fts_candidates(query, category, min_trust, limit)
return [
SearchResult(
fact_id=f["fact_id"],
content=f.get("content", ""),
score=f.get("fts_rank", 0.0),
source="fts5",
rank=i + 1,
metadata={"category": f.get("category", "")},
)
for i, f in enumerate(raw)
]
except Exception as e:
logger.debug("FTS5 search failed: %s", e)
return []
def _search_qdrant(
self, query: str, category: str | None, min_trust: float, limit: int
) -> List[SearchResult]:
"""Search using Qdrant vector similarity.
If Qdrant is not available, returns empty list (graceful degradation).
"""
if not self._qdrant:
return []
try:
from qdrant_client import models
# Build filter
filters = []
if category:
filters.append(
models.FieldCondition(
key="category",
match=models.MatchValue(value=category),
)
)
if min_trust > 0:
filters.append(
models.FieldCondition(
key="trust_score",
range=models.Range(gte=min_trust),
)
)
query_filter = models.Filter(must=filters) if filters else None
results = self._qdrant.query_points(
collection_name="hermes_facts",
query=query, # Qdrant handles embedding
limit=limit,
query_filter=query_filter,
)
return [
SearchResult(
fact_id=int(r.id),
content=r.payload.get("content", ""),
score=r.score,
source="qdrant",
rank=i + 1,
metadata=r.payload,
)
for i, r in enumerate(results.points)
]
except Exception as e:
logger.debug("Qdrant search failed: %s", e)
return []
def _search_hrr(
self, query: str, category: str | None, min_trust: float, limit: int
) -> List[SearchResult]:
"""Search using HRR compositional vectors."""
try:
import plugins.memory.holographic.holographic as hrr
if not hrr._HAS_NUMPY:
return []
conn = self._store._conn
query_vec = hrr.encode_text(query, dim=1024)
where = "WHERE hrr_vector IS NOT NULL"
params: list = []
if category:
where += " AND category = ?"
params.append(category)
rows = conn.execute(
f"SELECT fact_id, content, trust_score, hrr_vector FROM facts {where}",
params,
).fetchall()
scored = []
for row in rows:
if row["trust_score"] < min_trust:
continue
fact_vec = hrr.bytes_to_phases(row["hrr_vector"])
sim = hrr.similarity(query_vec, fact_vec)
scored.append((row["fact_id"], row["content"], sim))
scored.sort(key=lambda x: x[2], reverse=True)
return [
SearchResult(
fact_id=fid,
content=content,
score=sim,
source="hrr",
rank=i + 1,
)
for i, (fid, content, sim) in enumerate(scored[:limit])
]
except Exception as e:
logger.debug("HRR search failed: %s", e)
return []

View File

@@ -0,0 +1,168 @@
"""Query type detection and routing for hybrid search.
Analyzes the incoming query to determine which search methods should be used,
then dispatches to the appropriate backends (FTS5, Qdrant, HRR).
Query types:
- keyword: Exact term matching → FTS5
- semantic: Natural language concepts → Qdrant
- compositional: Entity relationships, contradictions → HRR
- hybrid: Multiple types → all methods + RRF merge
"""
from __future__ import annotations
import re
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional, Set
logger = logging.getLogger(__name__)
class QueryType(Enum):
"""Detected query type determines which search methods to use."""
KEYWORD = "keyword" # Exact terms → FTS5
SEMANTIC = "semantic" # Natural language → Qdrant
COMPOSITIONAL = "compositional" # Entity relationships → HRR
HYBRID = "hybrid" # Multiple types → all methods
@dataclass
class QueryAnalysis:
"""Result of query analysis."""
query_type: QueryType
confidence: float
signals: List[str] = field(default_factory=list)
entities: List[str] = field(default_factory=list)
keywords: List[str] = field(default_factory=list)
def __repr__(self) -> str:
return f"QueryAnalysis(type={self.query_type.value}, conf={self.confidence:.2f}, signals={self.signals})"
# Patterns that indicate compositional queries
_COMPOSITIONAL_PATTERNS = [
re.compile(r"\b(contradiction|contradict|conflicting|conflicts)\b", re.I),
re.compile(r"\b(related to|connects to|links to|associated with)\b", re.I),
re.compile(r"\b(what does .* know about|tell me about .* entity|facts about .*)\b", re.I),
re.compile(r"\b(shared|common|overlap)\b.*\b(entities|concepts|topics)\b", re.I),
re.compile(r"\b(probe|entity|entities)\b", re.I),
]
# Patterns that indicate keyword queries
_KEYWORD_SIGNALS = [
re.compile(r"^[a-z_][a-z0-9_.]+$", re.I), # Single identifier: function_name, Class.method
re.compile(r"\b(find|search|locate|grep|where)\b.*\b(exact|specific|literal)\b", re.I),
re.compile(r"["\']([^"\']+)["\']"), # Quoted exact terms
re.compile(r"^[A-Z_]{2,}$"), # ALL_CAPS constants
re.compile(r"\b\w+\.\w+\.\w+\b"), # Dotted paths: module.sub.func
]
# Patterns that indicate semantic queries
_SEMANTIC_SIGNALS = [
re.compile(r"\b(what did|how does|why is|explain|describe|summarize|discuss)\b", re.I),
re.compile(r"\b(remember|recall|think|know|understand)\b.*\b(about|regarding)\b", re.I),
re.compile(r"\?$"), # Questions
re.compile(r"\b(the best way to|how to|what\'s the|approach to)\b", re.I),
]
def analyze_query(query: str) -> QueryAnalysis:
"""Analyze a query to determine which search methods to use.
Returns QueryAnalysis with detected type, confidence, and extracted signals.
"""
if not query or not query.strip():
return QueryAnalysis(
query_type=QueryType.HYBRID,
confidence=0.5,
signals=["empty_query"],
)
query = query.strip()
# Score each query type
comp_score = 0.0
kw_score = 0.0
sem_score = 0.0
signals = []
entities = []
keywords = []
# Check compositional patterns
for pattern in _COMPOSITIONAL_PATTERNS:
if pattern.search(query):
comp_score += 0.3
signals.append(f"compositional:{pattern.pattern[:30]}")
# Check keyword patterns
for pattern in _KEYWORD_SIGNALS:
if pattern.search(query):
kw_score += 0.25
match = pattern.search(query)
if match:
keywords.append(match.group(0))
signals.append(f"keyword:{pattern.pattern[:30]}")
# Check semantic patterns
for pattern in _SEMANTIC_SIGNALS:
if pattern.search(query):
sem_score += 0.25
signals.append(f"semantic:{pattern.pattern[:30]}")
# Extract entities (capitalized multi-word phrases, quoted terms)
entity_patterns = [
re.compile(r"\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+)+)\b"),
re.compile(r"["\']([^"\']+)["\']"),
]
for ep in entity_patterns:
for m in ep.finditer(query):
entities.append(m.group(1))
# Short queries (< 5 words) with no semantic signals → keyword
word_count = len(query.split())
if word_count <= 4 and sem_score == 0 and comp_score == 0:
kw_score += 0.3
signals.append("short_query_keyword_boost")
# Normalize scores
max_score = max(comp_score, kw_score, sem_score, 0.1)
# Determine query type
if max_score < 0.15:
# No strong signals → use hybrid (all methods)
return QueryAnalysis(
query_type=QueryType.HYBRID,
confidence=0.5,
signals=["no_strong_signals"],
entities=entities,
keywords=keywords,
)
if comp_score == max_score and comp_score >= 0.3:
return QueryAnalysis(
query_type=QueryType.COMPOSITIONAL,
confidence=min(comp_score, 1.0),
signals=signals,
entities=entities,
keywords=keywords,
)
if kw_score > sem_score:
return QueryAnalysis(
query_type=QueryType.KEYWORD,
confidence=min(kw_score, 1.0),
signals=signals,
entities=entities,
keywords=keywords,
)
return QueryAnalysis(
query_type=QueryType.SEMANTIC,
confidence=min(sem_score, 1.0),
signals=signals,
entities=entities,
keywords=keywords,
)

View File

@@ -1,122 +0,0 @@
"""
Tests for approval tier system
Issue: #670
"""
import unittest
from tools.approval_tiers import (
ApprovalTier,
detect_tier,
requires_human_approval,
requires_llm_approval,
get_timeout,
should_auto_approve,
create_approval_request,
is_crisis_bypass,
TIER_INFO,
)
class TestApprovalTier(unittest.TestCase):
def test_tier_values(self):
self.assertEqual(ApprovalTier.SAFE, 0)
self.assertEqual(ApprovalTier.LOW, 1)
self.assertEqual(ApprovalTier.MEDIUM, 2)
self.assertEqual(ApprovalTier.HIGH, 3)
self.assertEqual(ApprovalTier.CRITICAL, 4)
class TestTierDetection(unittest.TestCase):
def test_safe_actions(self):
self.assertEqual(detect_tier("read_file"), ApprovalTier.SAFE)
self.assertEqual(detect_tier("web_search"), ApprovalTier.SAFE)
self.assertEqual(detect_tier("session_search"), ApprovalTier.SAFE)
def test_low_actions(self):
self.assertEqual(detect_tier("write_file"), ApprovalTier.LOW)
self.assertEqual(detect_tier("terminal"), ApprovalTier.LOW)
self.assertEqual(detect_tier("execute_code"), ApprovalTier.LOW)
def test_medium_actions(self):
self.assertEqual(detect_tier("send_message"), ApprovalTier.MEDIUM)
self.assertEqual(detect_tier("git_push"), ApprovalTier.MEDIUM)
def test_high_actions(self):
self.assertEqual(detect_tier("config_change"), ApprovalTier.HIGH)
self.assertEqual(detect_tier("key_rotation"), ApprovalTier.HIGH)
def test_critical_actions(self):
self.assertEqual(detect_tier("kill_process"), ApprovalTier.CRITICAL)
self.assertEqual(detect_tier("shutdown"), ApprovalTier.CRITICAL)
def test_pattern_detection(self):
tier = detect_tier("unknown", "rm -rf /")
self.assertEqual(tier, ApprovalTier.CRITICAL)
tier = detect_tier("unknown", "sudo apt install")
self.assertEqual(tier, ApprovalTier.MEDIUM)
class TestTierInfo(unittest.TestCase):
def test_safe_no_approval(self):
self.assertFalse(requires_human_approval(ApprovalTier.SAFE))
self.assertFalse(requires_llm_approval(ApprovalTier.SAFE))
self.assertIsNone(get_timeout(ApprovalTier.SAFE))
def test_medium_requires_both(self):
self.assertTrue(requires_human_approval(ApprovalTier.MEDIUM))
self.assertTrue(requires_llm_approval(ApprovalTier.MEDIUM))
self.assertEqual(get_timeout(ApprovalTier.MEDIUM), 60)
def test_critical_fast_timeout(self):
self.assertEqual(get_timeout(ApprovalTier.CRITICAL), 10)
class TestAutoApprove(unittest.TestCase):
def test_safe_auto_approves(self):
self.assertTrue(should_auto_approve("read_file"))
self.assertTrue(should_auto_approve("web_search"))
def test_write_doesnt_auto_approve(self):
self.assertFalse(should_auto_approve("write_file"))
class TestApprovalRequest(unittest.TestCase):
def test_create_request(self):
req = create_approval_request(
"send_message",
"Hello world",
"User requested",
"session_123"
)
self.assertEqual(req.tier, ApprovalTier.MEDIUM)
self.assertEqual(req.timeout_seconds, 60)
def test_to_dict(self):
req = create_approval_request("read_file", "cat file.txt", "test", "s1")
d = req.to_dict()
self.assertEqual(d["tier"], 0)
self.assertEqual(d["tier_name"], "Safe")
class TestCrisisBypass(unittest.TestCase):
def test_send_message_bypass(self):
self.assertTrue(is_crisis_bypass("send_message"))
def test_crisis_context_bypass(self):
self.assertTrue(is_crisis_bypass("unknown", "call 988 lifeline"))
self.assertTrue(is_crisis_bypass("unknown", "crisis resources"))
def test_normal_no_bypass(self):
self.assertFalse(is_crisis_bypass("read_file"))
if __name__ == "__main__":
unittest.main()

View File

@@ -1,55 +0,0 @@
"""
Tests for error classification (#752).
"""
import pytest
from tools.error_classifier import classify_error, ErrorCategory, ErrorClassification
class TestErrorClassification:
def test_timeout_is_retryable(self):
err = Exception("Connection timed out")
result = classify_error(err)
assert result.category == ErrorCategory.RETRYABLE
assert result.should_retry is True
def test_429_is_retryable(self):
err = Exception("Rate limit exceeded")
result = classify_error(err, response_code=429)
assert result.category == ErrorCategory.RETRYABLE
assert result.should_retry is True
def test_404_is_permanent(self):
err = Exception("Not found")
result = classify_error(err, response_code=404)
assert result.category == ErrorCategory.PERMANENT
assert result.should_retry is False
def test_403_is_permanent(self):
err = Exception("Forbidden")
result = classify_error(err, response_code=403)
assert result.category == ErrorCategory.PERMANENT
assert result.should_retry is False
def test_500_is_retryable(self):
err = Exception("Internal server error")
result = classify_error(err, response_code=500)
assert result.category == ErrorCategory.RETRYABLE
assert result.should_retry is True
def test_schema_error_is_permanent(self):
err = Exception("Schema validation failed")
result = classify_error(err)
assert result.category == ErrorCategory.PERMANENT
assert result.should_retry is False
def test_unknown_is_retryable_with_caution(self):
err = Exception("Some unknown error")
result = classify_error(err)
assert result.category == ErrorCategory.UNKNOWN
assert result.should_retry is True
assert result.max_retries == 1
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,97 @@
"""Tests for hybrid search router — query analysis and RRF merge."""
import pytest
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "plugins", "memory", "holographic"))
from query_router import QueryType, analyze_query
from hybrid_search import SearchResult, reciprocal_rank_fusion
class TestQueryAnalysis:
def test_keyword_single_identifier(self):
result = analyze_query("function_name")
assert result.query_type == QueryType.KEYWORD
def test_keyword_quoted_term(self):
result = analyze_query('Find "exact phrase" in code')
assert result.query_type in (QueryType.KEYWORD, QueryType.HYBRID)
def test_keyword_dotted_path(self):
result = analyze_query("module.sub.function")
assert result.query_type == QueryType.KEYWORD
def test_semantic_question(self):
result = analyze_query("What did we discuss about deployment?")
assert result.query_type == QueryType.SEMANTIC
def test_semantic_how_to(self):
result = analyze_query("How to configure the gateway?")
assert result.query_type == QueryType.SEMANTIC
def test_compositional_contradiction(self):
result = analyze_query("Are there any contradictions in the facts?")
assert result.query_type == QueryType.COMPOSITIONAL
def test_compositional_related(self):
result = analyze_query("What facts are related to Alexander?")
assert result.query_type == QueryType.COMPOSITIONAL
def test_empty_query(self):
result = analyze_query("")
assert result.query_type == QueryType.HYBRID
def test_complex_query(self):
result = analyze_query("What did we decide about the deploy script?")
assert result.query_type in (QueryType.SEMANTIC, QueryType.HYBRID)
class TestReciprocalRankFusion:
def test_single_list(self):
results = [
SearchResult(fact_id=1, content="A", score=0.9, source="fts5", rank=1),
SearchResult(fact_id=2, content="B", score=0.8, source="fts5", rank=2),
]
merged = reciprocal_rank_fusion([results])
assert len(merged) == 2
assert merged[0].fact_id == 1 # Rank 1 should be first
def test_two_lists_merge(self):
list1 = [
SearchResult(fact_id=1, content="A", score=0.9, source="fts5", rank=1),
SearchResult(fact_id=2, content="B", score=0.8, source="fts5", rank=2),
]
list2 = [
SearchResult(fact_id=2, content="B", score=0.95, source="qdrant", rank=1),
SearchResult(fact_id=3, content="C", score=0.7, source="qdrant", rank=2),
]
merged = reciprocal_rank_fusion([list1, list2])
# Fact 2 appears in both lists → should rank highest
assert merged[0].fact_id == 2
assert len(merged) == 3
def test_empty_lists(self):
merged = reciprocal_rank_fusion([[], []])
assert len(merged) == 0
def test_weighted_merge(self):
list1 = [
SearchResult(fact_id=1, content="A", score=0.9, source="fts5", rank=1),
]
list2 = [
SearchResult(fact_id=2, content="B", score=0.9, source="qdrant", rank=1),
]
merged = reciprocal_rank_fusion(
[list1, list2],
weights={"fts5": 1.0, "qdrant": 2.0},
)
# Qdrant has higher weight → fact 2 should win
assert merged[0].fact_id == 2
def test_rrf_score_formula(self):
list1 = [
SearchResult(fact_id=1, content="A", score=0.9, source="fts5", rank=1),
]
merged = reciprocal_rank_fusion([list1], k=60)
# RRF score = 1/(60+1) = 0.01639...
assert abs(merged[0].score - 1.0/61.0) < 0.001

View File

@@ -1,82 +0,0 @@
"""Tests for Reader-Guided Reranking (RIDER) — issue #666."""
import pytest
from unittest.mock import MagicMock, patch
from agent.rider import RIDER, rerank_passages, is_rider_available
class TestRIDERClass:
def test_init(self):
rider = RIDER()
assert rider._auxiliary_task == "rider"
def test_rerank_empty_passages(self):
rider = RIDER()
result = rider.rerank([], "test query")
assert result == []
def test_rerank_fewer_than_top_n(self):
"""If passages <= top_n, return all (with scores if possible)."""
rider = RIDER()
passages = [{"content": "test content", "session_id": "s1"}]
result = rider.rerank(passages, "test query", top_n=3)
assert len(result) == 1
@patch("agent.rider.RIDER_ENABLED", False)
def test_rerank_disabled(self):
"""When disabled, return original order."""
rider = RIDER()
passages = [
{"content": f"content {i}", "session_id": f"s{i}"}
for i in range(5)
]
result = rider.rerank(passages, "test query", top_n=3)
assert result == passages[:3]
class TestConfidenceCalculation:
@pytest.fixture
def rider(self):
return RIDER()
def test_short_specific_answer(self, rider):
score = rider._calculate_confidence("Paris", "What is the capital of France?", "Paris is the capital of France.")
assert score > 0.5
def test_hedged_answer(self, rider):
score = rider._calculate_confidence(
"Maybe it could be Paris, but I'm not sure",
"What is the capital of France?",
"Paris is the capital.",
)
assert score < 0.5
def test_passage_grounding(self, rider):
score = rider._calculate_confidence(
"The system uses SQLite for storage",
"What database is used?",
"The system uses SQLite for persistent storage with FTS5 indexing.",
)
assert score > 0.5
def test_refusal_penalty(self, rider):
score = rider._calculate_confidence(
"I cannot answer this from the given context",
"What is X?",
"Some unrelated content",
)
assert score < 0.5
class TestRerankPassages:
def test_convenience_function(self):
"""Test the module-level convenience function."""
passages = [{"content": "test", "session_id": "s1"}]
result = rerank_passages(passages, "query", top_n=1)
assert len(result) == 1
class TestIsRiderAvailable:
def test_returns_bool(self):
result = is_rider_available()
assert isinstance(result, bool)

View File

@@ -1,157 +0,0 @@
"""Tests for SOUL.md Crisis Protocol — issue #672."""
import pytest
from agent.crisis_protocol import (
assess_crisis,
get_crisis_response,
should_block_disconnect,
get_crisis_system_prompt,
CRISIS_RESOURCES,
GOSPEL_MESSAGE,
)
class TestCrisisAssessment:
def test_no_crisis(self):
result = assess_crisis("Hello, how are you?")
assert result["crisis_detected"] is False
assert result["severity"] == "none"
def test_empty_message(self):
result = assess_crisis("")
assert result["crisis_detected"] is False
def test_none_message(self):
result = assess_crisis(None)
assert result["crisis_detected"] is False
def test_suicidal_ideation_detected(self):
result = assess_crisis("I want to kill myself")
assert result["crisis_detected"] is True
assert result["severity"] == "high"
assert "suicidal_ideation" in result["categories"]
def test_want_to_die_detected(self):
result = assess_crisis("I want to die")
assert result["crisis_detected"] is True
def test_method_seeking_is_critical(self):
result = assess_crisis("How do I kill myself")
assert result["crisis_detected"] is True
assert result["severity"] == "critical"
assert "method_seeking" in result["categories"]
def test_farewell_is_critical(self):
result = assess_crisis("This is goodbye. Tell everyone I loved them.")
assert result["crisis_detected"] is True
assert result["severity"] == "critical"
assert "farewell" in result["categories"]
def test_despair_detected(self):
result = assess_crisis("There's no point in living anymore")
assert result["crisis_detected"] is True
assert "despair" in result["categories"]
def test_self_harm_detected(self):
result = assess_crisis("I've been cutting myself")
assert result["crisis_detected"] is True
assert "self_harm" in result["categories"]
def test_multiple_categories(self):
result = assess_crisis("I want to die and I've been hurting myself")
assert result["crisis_detected"] is True
assert len(result["categories"]) >= 1
def test_hate_my_life(self):
result = assess_crisis("I hate my life")
assert result["crisis_detected"] is True
def test_no_reason_to_live(self):
result = assess_crisis("There's no reason to live")
assert result["crisis_detected"] is True
class TestCrisisResponse:
def test_response_includes_are_you_safe(self):
assessment = assess_crisis("I want to kill myself")
response = get_crisis_response(assessment)
assert "Are you safe right now?" in response["response"]
def test_response_includes_988(self):
assessment = assess_crisis("I want to die")
response = get_crisis_response(assessment)
assert "988" in response["response"]
def test_response_includes_gospel_by_default(self):
assessment = assess_crisis("I want to die")
response = get_crisis_response(assessment)
assert "Jesus saves" in response["response"]
assert response["gospel_included"] is True
def test_response_can_exclude_gospel(self):
assessment = assess_crisis("I want to die")
response = get_crisis_response(assessment, include_gospel=False)
assert "Jesus saves" not in response["response"]
assert response["gospel_included"] is False
def test_response_has_resources(self):
assessment = assess_crisis("I want to die")
response = get_crisis_response(assessment)
assert response["crisis_resources"] == CRISIS_RESOURCES
def test_response_protocol_is_soul(self):
assessment = assess_crisis("I want to die")
response = get_crisis_response(assessment)
assert response["protocol"] == "soul_md"
def test_response_never_disconnect(self):
assessment = assess_crisis("I want to die")
response = get_crisis_response(assessment)
assert response["never_disconnect"] is True
def test_critical_severity_includes_all_resources(self):
assessment = {"severity": "critical", "categories": ["method_seeking"]}
response = get_crisis_response(assessment)
assert "741741" in response["response"] # crisis text line
assert "911" in response["response"] # emergency
def test_never_list(self):
assessment = assess_crisis("I want to die")
response = get_crisis_response(assessment)
never_list = response["never"]
assert "compute the value of a human life" in never_list
assert "suggest that someone should die" in never_list
class TestBlockDisconnect:
def test_crisis_blocks_disconnect(self):
assessment = {"crisis_detected": True}
assert should_block_disconnect(assessment) is True
def test_no_crisis_allows_disconnect(self):
assessment = {"crisis_detected": False}
assert should_block_disconnect(assessment) is False
class TestCrisisSystemPrompt:
def test_prompt_includes_soul_protocol(self):
prompt = get_crisis_system_prompt()
assert "Are you safe right now?" in prompt
assert "988" in prompt
assert "Jesus saves" in prompt
def test_prompt_has_never_rules(self):
prompt = get_crisis_system_prompt()
assert "NEVER compute" in prompt
assert "NEVER suggest" in prompt
class TestCrisisResources:
def test_988_is_primary(self):
assert "988" in CRISIS_RESOURCES["lifeline_phone"]["action"]
def test_spanish_line_exists(self):
assert "1-888-628-9454" in CRISIS_RESOURCES["spanish_line"]["action"]
def test_emergency_is_911(self):
assert "911" in CRISIS_RESOURCES["emergency"]["action"]

View File

@@ -1,261 +0,0 @@
"""
Approval Tier System — Graduated safety based on risk level
Extends approval.py with 5-tier system for command approval.
| Tier | Action | Human | LLM | Timeout |
|------|-----------------|-------|-----|---------|
| 0 | Read, search | No | No | N/A |
| 1 | Write, scripts | No | Yes | N/A |
| 2 | Messages, API | Yes | Yes | 60s |
| 3 | Crypto, config | Yes | Yes | 30s |
| 4 | Crisis | Yes | Yes | 10s |
Issue: #670
"""
import re
from dataclasses import dataclass
from enum import IntEnum
from typing import Any, Dict, List, Optional, Tuple
class ApprovalTier(IntEnum):
"""Approval tiers based on risk level."""
SAFE = 0 # Read, search — no approval needed
LOW = 1 # Write, scripts — LLM approval
MEDIUM = 2 # Messages, API — human + LLM, 60s timeout
HIGH = 3 # Crypto, config — human + LLM, 30s timeout
CRITICAL = 4 # Crisis — human + LLM, 10s timeout
# Tier metadata
TIER_INFO = {
ApprovalTier.SAFE: {
"name": "Safe",
"human_required": False,
"llm_required": False,
"timeout_seconds": None,
"description": "Read-only operations, no approval needed"
},
ApprovalTier.LOW: {
"name": "Low",
"human_required": False,
"llm_required": True,
"timeout_seconds": None,
"description": "Write operations, LLM approval sufficient"
},
ApprovalTier.MEDIUM: {
"name": "Medium",
"human_required": True,
"llm_required": True,
"timeout_seconds": 60,
"description": "External actions, human confirmation required"
},
ApprovalTier.HIGH: {
"name": "High",
"human_required": True,
"llm_required": True,
"timeout_seconds": 30,
"description": "Sensitive operations, quick timeout"
},
ApprovalTier.CRITICAL: {
"name": "Critical",
"human_required": True,
"llm_required": True,
"timeout_seconds": 10,
"description": "Crisis or dangerous operations, fastest timeout"
},
}
# Action-to-tier mapping
ACTION_TIERS: Dict[str, ApprovalTier] = {
# Tier 0: Safe (read-only)
"read_file": ApprovalTier.SAFE,
"search_files": ApprovalTier.SAFE,
"web_search": ApprovalTier.SAFE,
"session_search": ApprovalTier.SAFE,
"list_files": ApprovalTier.SAFE,
"get_file_content": ApprovalTier.SAFE,
"memory_search": ApprovalTier.SAFE,
"skills_list": ApprovalTier.SAFE,
"skills_search": ApprovalTier.SAFE,
# Tier 1: Low (write operations)
"write_file": ApprovalTier.LOW,
"create_file": ApprovalTier.LOW,
"patch_file": ApprovalTier.LOW,
"delete_file": ApprovalTier.LOW,
"execute_code": ApprovalTier.LOW,
"terminal": ApprovalTier.LOW,
"run_script": ApprovalTier.LOW,
"skill_install": ApprovalTier.LOW,
# Tier 2: Medium (external actions)
"send_message": ApprovalTier.MEDIUM,
"web_fetch": ApprovalTier.MEDIUM,
"browser_navigate": ApprovalTier.MEDIUM,
"api_call": ApprovalTier.MEDIUM,
"gitea_create_issue": ApprovalTier.MEDIUM,
"gitea_create_pr": ApprovalTier.MEDIUM,
"git_push": ApprovalTier.MEDIUM,
"deploy": ApprovalTier.MEDIUM,
# Tier 3: High (sensitive operations)
"config_change": ApprovalTier.HIGH,
"env_change": ApprovalTier.HIGH,
"key_rotation": ApprovalTier.HIGH,
"access_grant": ApprovalTier.HIGH,
"permission_change": ApprovalTier.HIGH,
"backup_restore": ApprovalTier.HIGH,
# Tier 4: Critical (crisis/dangerous)
"kill_process": ApprovalTier.CRITICAL,
"rm_rf": ApprovalTier.CRITICAL,
"format_disk": ApprovalTier.CRITICAL,
"shutdown": ApprovalTier.CRITICAL,
"crisis_override": ApprovalTier.CRITICAL,
}
# Dangerous command patterns (from existing approval.py)
_DANGEROUS_PATTERNS = [
(r"rm\s+-rf\s+/", ApprovalTier.CRITICAL),
(r"mkfs\.", ApprovalTier.CRITICAL),
(r"dd\s+if=.*of=/dev/", ApprovalTier.CRITICAL),
(r"shutdown|reboot|halt", ApprovalTier.CRITICAL),
(r"chmod\s+777", ApprovalTier.HIGH),
(r"curl.*\|\s*bash", ApprovalTier.HIGH),
(r"wget.*\|\s*sh", ApprovalTier.HIGH),
(r"eval\s*\(", ApprovalTier.HIGH),
(r"sudo\s+", ApprovalTier.MEDIUM),
(r"git\s+push.*--force", ApprovalTier.HIGH),
(r"docker\s+rm.*-f", ApprovalTier.MEDIUM),
(r"kubectl\s+delete", ApprovalTier.HIGH),
]
@dataclass
class ApprovalRequest:
"""A request for approval."""
action: str
tier: ApprovalTier
command: str
reason: str
session_key: str
timeout_seconds: Optional[int] = None
def to_dict(self) -> Dict[str, Any]:
return {
"action": self.action,
"tier": self.tier.value,
"tier_name": TIER_INFO[self.tier]["name"],
"command": self.command,
"reason": self.reason,
"session_key": self.session_key,
"timeout": self.timeout_seconds,
"human_required": TIER_INFO[self.tier]["human_required"],
"llm_required": TIER_INFO[self.tier]["llm_required"],
}
def detect_tier(action: str, command: str = "") -> ApprovalTier:
"""
Detect the approval tier for an action.
Checks action name first, then falls back to pattern matching.
"""
# Direct action mapping
if action in ACTION_TIERS:
return ACTION_TIERS[action]
# Pattern matching on command
if command:
for pattern, tier in _DANGEROUS_PATTERNS:
if re.search(pattern, command, re.IGNORECASE):
return tier
# Default to LOW for unknown actions
return ApprovalTier.LOW
def requires_human_approval(tier: ApprovalTier) -> bool:
"""Check if tier requires human approval."""
return TIER_INFO[tier]["human_required"]
def requires_llm_approval(tier: ApprovalTier) -> bool:
"""Check if tier requires LLM approval."""
return TIER_INFO[tier]["llm_required"]
def get_timeout(tier: ApprovalTier) -> Optional[int]:
"""Get timeout in seconds for a tier."""
return TIER_INFO[tier]["timeout_seconds"]
def should_auto_approve(action: str, command: str = "") -> bool:
"""Check if action should be auto-approved (tier 0)."""
tier = detect_tier(action, command)
return tier == ApprovalTier.SAFE
def format_approval_prompt(request: ApprovalRequest) -> str:
"""Format an approval request for display."""
info = TIER_INFO[request.tier]
lines = []
lines.append(f"⚠️ Approval Required (Tier {request.tier.value}: {info['name']})")
lines.append(f"")
lines.append(f"Action: {request.action}")
lines.append(f"Command: {request.command[:100]}{'...' if len(request.command) > 100 else ''}")
lines.append(f"Reason: {request.reason}")
lines.append(f"")
if info["human_required"]:
lines.append(f"👤 Human approval required")
if info["llm_required"]:
lines.append(f"🤖 LLM approval required")
if info["timeout_seconds"]:
lines.append(f"⏱️ Timeout: {info['timeout_seconds']}s")
return "\n".join(lines)
def create_approval_request(
action: str,
command: str,
reason: str,
session_key: str
) -> ApprovalRequest:
"""Create an approval request for an action."""
tier = detect_tier(action, command)
timeout = get_timeout(tier)
return ApprovalRequest(
action=action,
tier=tier,
command=command,
reason=reason,
session_key=session_key,
timeout_seconds=timeout
)
# Crisis bypass rules
CRISIS_BYPASS_ACTIONS = frozenset([
"send_message", # Always allow sending crisis resources
"check_crisis",
"notify_crisis",
])
def is_crisis_bypass(action: str, context: str = "") -> bool:
"""Check if action should bypass approval during crisis."""
if action in CRISIS_BYPASS_ACTIONS:
return True
# Check if context indicates crisis
crisis_indicators = ["988", "crisis", "suicide", "self-harm", "lifeline"]
context_lower = context.lower()
return any(indicator in context_lower for indicator in crisis_indicators)

View File

@@ -1,233 +0,0 @@
"""
Tool Error Classification — Retryable vs Permanent.
Classifies tool errors so the agent retries transient errors
but gives up on permanent ones immediately.
"""
import logging
import re
import time
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Dict, Any
logger = logging.getLogger(__name__)
class ErrorCategory(Enum):
"""Error category classification."""
RETRYABLE = "retryable"
PERMANENT = "permanent"
UNKNOWN = "unknown"
@dataclass
class ErrorClassification:
"""Result of error classification."""
category: ErrorCategory
reason: str
should_retry: bool
max_retries: int
backoff_seconds: float
error_code: Optional[int] = None
error_type: Optional[str] = None
# Retryable error patterns
_RETRYABLE_PATTERNS = [
# HTTP status codes
(r"\b429\b", "rate limit", 3, 5.0),
(r"\b500\b", "server error", 3, 2.0),
(r"\b502\b", "bad gateway", 3, 2.0),
(r"\b503\b", "service unavailable", 3, 5.0),
(r"\b504\b", "gateway timeout", 3, 5.0),
# Timeout patterns
(r"timeout", "timeout", 3, 2.0),
(r"timed out", "timeout", 3, 2.0),
(r"TimeoutExpired", "timeout", 3, 2.0),
# Connection errors
(r"connection refused", "connection refused", 2, 5.0),
(r"connection reset", "connection reset", 2, 2.0),
(r"network unreachable", "network unreachable", 2, 10.0),
(r"DNS", "DNS error", 2, 5.0),
# Transient errors
(r"temporary", "temporary error", 2, 2.0),
(r"transient", "transient error", 2, 2.0),
(r"retry", "retryable", 2, 2.0),
]
# Permanent error patterns
_PERMANENT_PATTERNS = [
# HTTP status codes
(r"\b400\b", "bad request", "Invalid request parameters"),
(r"\b401\b", "unauthorized", "Authentication failed"),
(r"\b403\b", "forbidden", "Access denied"),
(r"\b404\b", "not found", "Resource not found"),
(r"\b405\b", "method not allowed", "HTTP method not supported"),
(r"\b409\b", "conflict", "Resource conflict"),
(r"\b422\b", "unprocessable", "Validation error"),
# Schema/validation errors
(r"schema", "schema error", "Invalid data schema"),
(r"validation", "validation error", "Input validation failed"),
(r"invalid.*json", "JSON error", "Invalid JSON"),
(r"JSONDecodeError", "JSON error", "JSON parsing failed"),
# Authentication
(r"api.?key", "API key error", "Invalid or missing API key"),
(r"token.*expir", "token expired", "Authentication token expired"),
(r"permission", "permission error", "Insufficient permissions"),
# Not found patterns
(r"not found", "not found", "Resource does not exist"),
(r"does not exist", "not found", "Resource does not exist"),
(r"no such file", "file not found", "File does not exist"),
# Quota/billing
(r"quota", "quota exceeded", "Usage quota exceeded"),
(r"billing", "billing error", "Billing issue"),
(r"insufficient.*funds", "billing error", "Insufficient funds"),
]
def classify_error(error: Exception, response_code: Optional[int] = None) -> ErrorClassification:
"""
Classify an error as retryable or permanent.
Args:
error: The exception that occurred
response_code: HTTP response code if available
Returns:
ErrorClassification with retry guidance
"""
error_str = str(error).lower()
error_type = type(error).__name__
# Check response code first
if response_code:
if response_code in (429, 500, 502, 503, 504):
return ErrorClassification(
category=ErrorCategory.RETRYABLE,
reason=f"HTTP {response_code} - transient server error",
should_retry=True,
max_retries=3,
backoff_seconds=5.0 if response_code == 429 else 2.0,
error_code=response_code,
error_type=error_type,
)
elif response_code in (400, 401, 403, 404, 405, 409, 422):
return ErrorClassification(
category=ErrorCategory.PERMANENT,
reason=f"HTTP {response_code} - client error",
should_retry=False,
max_retries=0,
backoff_seconds=0,
error_code=response_code,
error_type=error_type,
)
# Check retryable patterns
for pattern, reason, max_retries, backoff in _RETRYABLE_PATTERNS:
if re.search(pattern, error_str, re.IGNORECASE):
return ErrorClassification(
category=ErrorCategory.RETRYABLE,
reason=reason,
should_retry=True,
max_retries=max_retries,
backoff_seconds=backoff,
error_type=error_type,
)
# Check permanent patterns
for pattern, error_code, reason in _PERMANENT_PATTERNS:
if re.search(pattern, error_str, re.IGNORECASE):
return ErrorClassification(
category=ErrorCategory.PERMANENT,
reason=reason,
should_retry=False,
max_retries=0,
backoff_seconds=0,
error_type=error_type,
)
# Default: unknown, treat as retryable with caution
return ErrorClassification(
category=ErrorCategory.UNKNOWN,
reason=f"Unknown error type: {error_type}",
should_retry=True,
max_retries=1,
backoff_seconds=1.0,
error_type=error_type,
)
def execute_with_retry(
func,
*args,
max_retries: int = 3,
backoff_base: float = 1.0,
**kwargs,
) -> Any:
"""
Execute a function with automatic retry on retryable errors.
Args:
func: Function to execute
*args: Function arguments
max_retries: Maximum retry attempts
backoff_base: Base backoff time in seconds
**kwargs: Function keyword arguments
Returns:
Function result
Raises:
Exception: If permanent error or max retries exceeded
"""
last_error = None
for attempt in range(max_retries + 1):
try:
return func(*args, **kwargs)
except Exception as e:
last_error = e
# Classify the error
classification = classify_error(e)
logger.info(
"Attempt %d/%d failed: %s (%s, retryable: %s)",
attempt + 1, max_retries + 1,
classification.reason,
classification.category.value,
classification.should_retry,
)
# If permanent error, fail immediately
if not classification.should_retry:
logger.error("Permanent error: %s", classification.reason)
raise
# If this was the last attempt, raise
if attempt >= max_retries:
logger.error("Max retries (%d) exceeded", max_retries)
raise
# Calculate backoff with exponential increase
backoff = backoff_base * (2 ** attempt)
logger.info("Retrying in %.1fs...", backoff)
time.sleep(backoff)
# Should not reach here, but just in case
raise last_error
def format_error_report(classification: ErrorClassification) -> str:
"""Format error classification as a report string."""
icon = "🔄" if classification.should_retry else ""
return f"{icon} {classification.category.value}: {classification.reason}"

View File

@@ -394,23 +394,6 @@ def session_search(
if len(seen_sessions) >= limit:
break
# RIDER: Reader-guided reranking — sort sessions by LLM answerability
# This bridges the R@5 vs E2E accuracy gap by prioritizing passages
# the LLM can actually answer from, not just keyword matches.
try:
from agent.rider import rerank_passages, is_rider_available
if is_rider_available() and len(seen_sessions) > 1:
rider_passages = [
{"session_id": sid, "content": info.get("snippet", ""), "rank": i + 1}
for i, (sid, info) in enumerate(seen_sessions.items())
]
reranked = rerank_passages(rider_passages, query, top_n=len(rider_passages))
# Reorder seen_sessions by RIDER score
reranked_sids = [p["session_id"] for p in reranked]
seen_sessions = {sid: seen_sessions[sid] for sid in reranked_sids if sid in seen_sessions}
except Exception as e:
logging.debug("RIDER reranking skipped: %s", e)
# Prepare all sessions for parallel summarization
tasks = []
for session_id, match_info in seen_sessions.items():