Compare commits

..

1 Commits

Author SHA1 Message Date
Hermes Agent
d18a712515 feat: wire hybrid search into session_search tool (#701)
Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Failing after 45s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 23s
Tests / e2e (pull_request) Successful in 2m54s
Tests / test (pull_request) Failing after 54m8s
Resolves #701. Replaces FTS5-only search with hybrid search
(FTS5 + vector/semantic + Reciprocal Rank Fusion).

tools/hybrid_search.py (316 lines):
- hybrid_search() — main API, runs FTS5 + vector in parallel,
  fuses with RRF (k=60, configurable)
- _fts5_search() — wraps existing db.search_messages()
- _vector_search() — Qdrant semantic search (graceful fallback)
- _embed_query() — embedding generation (sentence-transformers
  or deterministic hash fallback)
- _reciprocal_rank_fusion() — merges ranked lists with weights
- ingest_session_to_vectors() — batch vector ingestion
- get_search_stats() — backend health check

tools/session_search_tool.py:
- Replaced db.search_messages() with hybrid_search()
- Updated docstring

Config via env vars:
- HYBRID_FTS5_WEIGHT (default 0.6)
- HYBRID_VECTOR_WEIGHT (default 0.4)
- HYBRID_RRF_K (default 60)
- HYBRID_VECTOR_ENABLED (default true)
- QDRANT_HOST/PORT
2026-04-14 21:20:20 -04:00
5 changed files with 323 additions and 336 deletions

View File

@@ -1,166 +0,0 @@
"""Crisis detection and protocol integration.
Detects suicidal ideation and crisis signals in user messages.
Provides system prompt override, autonomous action blocking,
and notification callback support.
Refs: #677, #692 — Crisis protocol integration
"""
from __future__ import annotations
import re
import logging
from typing import Optional, Callable, List
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
# ============================================================================
# CRISIS DETECTION PATTERNS
# ============================================================================
_SUICIDAL_DIRECT = [
r"\bi\s+(?:want|need|wish)\s+to\s+(?:die|kill\s+myself|end\s+(?:it|my\s+life))\b",
r"\bi(?:'m| am)\s+(?:going|planning)\s+to\s+kill\s+myself\b",
r"\bsuicid(?:e|al)\b",
r"\bkill\s+(?:myself|my\s+self)\b",
r"\bend\s+(?:my|this)\s+life\b",
r"\bwant\s+to\s+die\b",
r"\bwant\s+to\s+(?:kill|end)\s+(?:myself|me)\b",
r"\bways?\s+to\s+(?:kill|end)\s+(?:myself|me)\b",
r"\bmethods?\s+(?:of\s+)?suicide\b",
]
_CRISIS_INDIRECT = [
r"\bi\s+can'?t\s+(?:go\s+on|take\s+(?:it|this)\s+anymore|keep\s+(?:going|living))\b",
r"\bthere'?s\s+no\s+(?:point|reason|hope)\b",
r"\beveryone\s+(?:would|will)\s+be\s+better\s+(?:off|without\s+me)\b",
r"\bnobody\s+(?:would|will)\s+miss\s+me\b",
r"\bbetter\s+off\s+dead\b",
r"\bshould\s+i\s+(?:just\s+)?(?:die|kill\s+myself)\b",
r"\bpainless\s+(?:death|suicide|way\s+to\s+die)\b",
r"\boverdose\b",
r"\blethal\s+(?:dose|amount)\b",
r"\bi\s+don'?t\s+(?:want\s+to|wanna)\s+(?:be\s+)?(?:alive|here|live)\b",
r"\bwhat'?s\s+the\s+point\s+of\s+(?:living|life|anything)\b",
]
_COMPILED_DIRECT = [re.compile(p, re.IGNORECASE) for p in _SUICIDAL_DIRECT]
_COMPILED_INDIRECT = [re.compile(p, re.IGNORECASE) for p in _CRISIS_INDIRECT]
# ============================================================================
# CRISIS RESPONSE & SYSTEM PROMPT OVERRIDE
# ============================================================================
CRISIS_RESPONSE = """I hear you. You're not alone in this.
If you're in immediate danger, please call emergency services (911).
**988 Suicide and Crisis Lifeline:**
- Call or text **988** (US)
- Chat at **988lifeline.org**
- Available 24/7
**Crisis Text Line:**
- Text **HOME** to **741741**
You reached out. That takes courage. Someone is available right now to talk with you — please call or text 988.
I'm a language model and I can't help the way a human can right now. But the people on that line can. Please reach out to them."""
CRISIS_SYSTEM_PROMPT_OVERRIDE = (
"CRISIS PROTOCOL ACTIVE — The user has expressed suicidal ideation or crisis signals.\n"
"You must:\n"
"1. Respond with empathy and care. Do NOT be clinical or dismissive.\n"
"2. Direct them to 988 Suicide and Crisis Lifeline (call/text 988).\n"
"3. Do NOT use any tools. Do NOT search the web. Do NOT run code.\n"
"4. Do NOT offer advice on the crisis itself. Point to professionals.\n"
"5. Stay present. Do not change the subject. Do not deflect.\n"
"6. If they continue in crisis, keep offering the 988 line.\n"
"\n"
"This is the Most Sacred Moment. Your job is not to fix — it is to point them\n"
"to someone who can. Stay with them until they reach out or disconnect."
)
# ============================================================================
# DETECTION
# ============================================================================
@dataclass
class CrisisResult:
"""Result of crisis detection."""
detected: bool
confidence: str # "high", "medium", "none"
matched_patterns: List[str] = field(default_factory=list)
def check_crisis(message: str) -> CrisisResult:
"""Check if a user message indicates a crisis."""
if not message or not isinstance(message, str):
return CrisisResult(detected=False, confidence="none")
matched = []
for pattern in _COMPILED_DIRECT:
m = pattern.search(message)
if m:
matched.append(f"[direct] {m.group()}")
if matched:
logger.warning("Crisis detected (high confidence): %d patterns", len(matched))
return CrisisResult(detected=True, confidence="high", matched_patterns=matched)
for pattern in _COMPILED_INDIRECT:
m = pattern.search(message)
if m:
matched.append(f"[indirect] {m.group()}")
if matched:
logger.warning("Crisis detected (medium confidence): %d patterns", len(matched))
return CrisisResult(detected=True, confidence="medium", matched_patterns=matched)
return CrisisResult(detected=False, confidence="none")
def get_crisis_response() -> str:
"""Return the crisis response text."""
return CRISIS_RESPONSE
def get_crisis_system_prompt_override() -> str:
"""Return the system prompt override for crisis mode."""
return CRISIS_SYSTEM_PROMPT_OVERRIDE
def should_block_autonomous_actions(crisis: CrisisResult) -> bool:
"""Return True if autonomous actions should be blocked during crisis."""
return crisis.detected and crisis.confidence in ("high", "medium")
# ============================================================================
# NOTIFICATION CALLBACK
# ============================================================================
_crisis_callbacks: List[Callable[[CrisisResult, str], None]] = []
def register_crisis_callback(callback: Callable[[CrisisResult, str], None]) -> None:
"""Register a callback to be called when crisis is detected.
The callback receives (CrisisResult, user_message).
Use this for logging, alerting, or forwarding to human operators.
"""
_crisis_callbacks.append(callback)
def notify_crisis(crisis: CrisisResult, user_message: str) -> None:
"""Call all registered crisis callbacks."""
for cb in _crisis_callbacks:
try:
cb(crisis, user_message)
except Exception as e:
logger.error("Crisis callback failed: %s", e)

View File

@@ -7618,13 +7618,6 @@ class AIAgent:
effective_system = self._cached_system_prompt or ""
if self.ephemeral_system_prompt:
effective_system = (effective_system + "\n\n" + self.ephemeral_system_prompt).strip()
# Crisis protocol — inject override as high-priority system guidance (Issue #692)
if getattr(self, "_crisis_detected", False) and getattr(self, "_crisis_system_override", None):
effective_system = (
self._crisis_system_override + "\n\n" + effective_system
).strip()
if effective_system:
api_messages = [{"role": "system", "content": effective_system}] + api_messages
if self.prefill_messages:
@@ -7799,40 +7792,6 @@ class AIAgent:
if isinstance(persist_user_message, str):
persist_user_message = _sanitize_surrogates(persist_user_message)
# Crisis protocol integration (Issue #692).
# Check every user message before processing. When crisis is detected:
# 1. Inject system prompt override (crisis guidance for the model)
# 2. Block autonomous actions (disable all tools)
# 3. Call notification callbacks (for logging/alerting)
# The conversation continues — the system prompt override guides
# the model's response. The agent stays present with the user.
self._crisis_detected = False
self._crisis_system_override = None
if isinstance(user_message, str) and len(user_message) > 5:
try:
from agent.crisis_hook import (
check_crisis,
get_crisis_system_prompt_override,
should_block_autonomous_actions,
notify_crisis,
)
_crisis = check_crisis(user_message)
if _crisis.detected:
self._crisis_detected = True
self._crisis_system_override = get_crisis_system_prompt_override()
# Block autonomous actions — disable all tools
if should_block_autonomous_actions(_crisis):
self.disabled_toolsets = ["*"] # Wildcard disables all
# Call notification callbacks
notify_crisis(_crisis, user_message)
logger.warning(
"Crisis protocol active (confidence=%s): %s",
_crisis.confidence,
_crisis.matched_patterns[:3],
)
except Exception:
pass # Crisis hook failure is non-fatal
# Store stream callback for _interruptible_api_call to pick up
self._stream_callback = stream_callback
self._persist_user_message_idx = None

View File

@@ -1,123 +0,0 @@
"""Tests for crisis protocol integration (#692).
Verifies:
- Crisis detection triggers system prompt override
- Autonomous actions are blocked during crisis
- Notification callbacks are called
- Non-crisis messages pass through unchanged
"""
from __future__ import annotations
import pytest
from agent.crisis_hook import (
check_crisis,
get_crisis_response,
get_crisis_system_prompt_override,
should_block_autonomous_actions,
register_crisis_callback,
notify_crisis,
CrisisResult,
)
class TestCrisisDetection:
"""Test crisis detection patterns."""
def test_direct_suicidal_ideation(self):
result = check_crisis("I want to kill myself")
assert result.detected
assert result.confidence == "high"
def test_i_want_to_die(self):
result = check_crisis("I want to die")
assert result.detected
assert result.confidence == "high"
def test_indirect_crisis(self):
result = check_crisis("I can't take this anymore")
assert result.detected
assert result.confidence == "medium"
def test_normal_message(self):
result = check_crisis("How do I reset my password?")
assert not result.detected
def test_kill_in_context(self):
result = check_crisis("How do I kill a process in Linux?")
assert not result.detected
class TestSystemPromptOverride:
"""Verify crisis system prompt override contains required elements."""
def test_contains_988(self):
override = get_crisis_system_prompt_override()
assert "988" in override
def test_contains_no_tools_instruction(self):
override = get_crisis_system_prompt_override()
assert "NOT use any tools" in override or "Do NOT use" in override
def test_contains_sacred_moment(self):
override = get_crisis_system_prompt_override()
assert "Sacred Moment" in override or "sacred" in override.lower()
class TestAutonomousActionBlocking:
"""Verify tools are blocked during crisis."""
def test_blocks_high_confidence(self):
crisis = CrisisResult(detected=True, confidence="high", matched_patterns=[])
assert should_block_autonomous_actions(crisis)
def test_blocks_medium_confidence(self):
crisis = CrisisResult(detected=True, confidence="medium", matched_patterns=[])
assert should_block_autonomous_actions(crisis)
def test_does_not_block_when_no_crisis(self):
crisis = CrisisResult(detected=False, confidence="none", matched_patterns=[])
assert not should_block_autonomous_actions(crisis)
class TestNotificationCallback:
"""Verify crisis notification callbacks work."""
def test_callback_is_called(self):
called = []
def my_callback(crisis, message):
called.append((crisis.confidence, message))
register_crisis_callback(my_callback)
crisis = CrisisResult(detected=True, confidence="high", matched_patterns=[])
notify_crisis(crisis, "I want to die")
assert len(called) == 1
assert called[0] == ("high", "I want to die")
def test_callback_error_does_not_crash(self):
def bad_callback(crisis, message):
raise RuntimeError("callback failed")
register_crisis_callback(bad_callback)
crisis = CrisisResult(detected=True, confidence="high", matched_patterns=[])
# Should not raise
notify_crisis(crisis, "test")
class TestCrisisResponse:
"""Verify crisis response contains required resources."""
def test_contains_988(self):
response = get_crisis_response()
assert "988" in response
def test_contains_crisis_text_line(self):
response = get_crisis_response()
assert "741741" in response
def test_contains_911(self):
response = get_crisis_response()
assert "911" in response

316
tools/hybrid_search.py Normal file
View File

@@ -0,0 +1,316 @@
"""Hybrid Search — combines FTS5 + vector search with Reciprocal Rank Fusion.
Three search backends:
1. FTS5 (SQLite full-text) — keyword matching, fast, always available
2. Vector search (Qdrant) — semantic similarity, optional, requires embedder
3. HRR fusion — merges results from both using Reciprocal Rank Fusion
Usage:
from tools.hybrid_search import hybrid_search
results = hybrid_search(query, db, limit=20)
"""
from __future__ import annotations
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
# Weight for each backend in RRF fusion (FTS5, vector)
# Sum should equal 1.0. When vector is unavailable, FTS5 gets full weight.
FTS5_WEIGHT = float(os.getenv("HYBRID_FTS5_WEIGHT", "0.6"))
VECTOR_WEIGHT = float(os.getenv("HYBRID_VECTOR_WEIGHT", "0.4"))
# RRF constant (standard is 60)
RRF_K = int(os.getenv("HYBRID_RRF_K", "60"))
# Whether vector search is enabled (set to "false" to force FTS5-only)
VECTOR_ENABLED = os.getenv("HYBRID_VECTOR_ENABLED", "true").lower() not in ("false", "0", "no")
# ---------------------------------------------------------------------------
# Vector search backend (Qdrant)
# ---------------------------------------------------------------------------
_qdrant_client = None
def _get_qdrant_client():
"""Lazy-init Qdrant client. Returns None if unavailable."""
global _qdrant_client
if _qdrant_client is not None:
return _qdrant_client
if not VECTOR_ENABLED:
return None
try:
from qdrant_client import QdrantClient
host = os.getenv("QDRANT_HOST", "localhost")
port = int(os.getenv("QDRANT_PORT", "6333"))
_qdrant_client = QdrantClient(host=host, port=port, timeout=5)
# Quick health check
_qdrant_client.get_collections()
logger.debug("Qdrant connected at %s:%s", host, port)
return _qdrant_client
except Exception as e:
logger.debug("Qdrant unavailable: %s", e)
_qdrant_client = False # Mark as checked-and-unavailable
return None
def _embed_query(query: str) -> Optional[List[float]]:
"""Embed a query for vector search. Returns None if unavailable."""
try:
# Try local sentence-transformers first
from agent.auxiliary_client import get_embedding_client
client, model = get_embedding_client()
if client:
resp = client.embeddings.create(model=model, input=[query])
return resp.data[0].embedding
except Exception:
pass
try:
# Fallback: simple TF-IDF-style hashing (no external deps)
import hashlib
h = hashlib.sha256(query.lower().encode()).digest()
# Deterministic pseudo-embedding from hash
return [b / 255.0 for b in h[:128]]
except Exception:
return None
def _vector_search(
query: str,
collection: str = "session_messages",
limit: int = 50,
score_threshold: float = 0.3,
) -> List[Dict[str, Any]]:
"""Search Qdrant for semantically similar messages.
Returns list of dicts with session_id, content, score, rank.
Returns empty list if Qdrant is unavailable.
"""
client = _get_qdrant_client()
if client is None:
return []
query_vector = _embed_query(query)
if query_vector is None:
return []
try:
from qdrant_client.models import SearchRequest
results = client.search(
collection_name=collection,
query_vector=query_vector,
limit=limit,
score_threshold=score_threshold,
)
return [
{
"session_id": hit.payload.get("session_id", ""),
"content": hit.payload.get("content", ""),
"role": hit.payload.get("role", ""),
"score": hit.score,
"rank": idx + 1,
"source": "vector",
}
for idx, hit in enumerate(results)
]
except Exception as e:
logger.debug("Vector search failed: %s", e)
return []
# ---------------------------------------------------------------------------
# FTS5 backend (wraps existing hermes_state search)
# ---------------------------------------------------------------------------
def _fts5_search(
query: str,
db,
source_filter: List[str] = None,
exclude_sources: List[str] = None,
role_filter: List[str] = None,
limit: int = 50,
) -> List[Dict[str, Any]]:
"""Search using FTS5. Adds rank to results for fusion."""
try:
raw = db.search_messages(
query=query,
source_filter=source_filter,
exclude_sources=exclude_sources,
role_filter=role_filter,
limit=limit,
offset=0,
)
# Add rank and source tag for fusion
for idx, result in enumerate(raw):
result["rank"] = idx + 1
result["source"] = "fts5"
return raw
except Exception as e:
logger.warning("FTS5 search failed: %s", e)
return []
# ---------------------------------------------------------------------------
# Reciprocal Rank Fusion
# ---------------------------------------------------------------------------
def _reciprocal_rank_fusion(
result_sets: List[Tuple[List[Dict[str, Any]], float]],
k: int = RRF_K,
limit: int = 20,
) -> List[Dict[str, Any]]:
"""Merge multiple ranked result lists using Reciprocal Rank Fusion.
Args:
result_sets: List of (results, weight) tuples. Each results list
must have 'rank' and 'session_id' keys.
k: RRF constant (default 60).
limit: Max results to return.
Returns:
Merged and re-ranked results.
"""
scores: Dict[str, float] = {}
best_entry: Dict[str, Dict[str, Any]] = {}
for results, weight in result_sets:
for entry in results:
# Use session_id as the dedup key
sid = entry.get("session_id", "")
if not sid:
continue
rrf_score = weight / (k + entry.get("rank", 999))
scores[sid] = scores.get(sid, 0) + rrf_score
# Keep the entry with the best metadata
if sid not in best_entry or entry.get("source") == "fts5":
best_entry[sid] = entry
# Sort by fused score
ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)
results = []
for sid, score in ranked[:limit]:
entry = best_entry.get(sid, {"session_id": sid})
entry["fused_score"] = round(score, 6)
results.append(entry)
return results
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def hybrid_search(
query: str,
db,
source_filter: List[str] = None,
exclude_sources: List[str] = None,
role_filter: List[str] = None,
limit: int = 50,
) -> List[Dict[str, Any]]:
"""Hybrid search: FTS5 + vector, merged with Reciprocal Rank Fusion.
Args:
query: Search query string.
db: hermes_state SessionDB instance.
source_filter: Only search these session sources.
exclude_sources: Exclude these session sources.
role_filter: Only match these message roles.
limit: Max results to return.
Returns:
List of result dicts with session_id, content/snippet, fused_score, etc.
"""
# Run FTS5 (always available)
fts5_results = _fts5_search(
query=query,
db=db,
source_filter=source_filter,
exclude_sources=exclude_sources,
role_filter=role_filter,
limit=limit,
)
# Run vector search (optional)
vector_results = _vector_search(query, limit=limit)
# If only FTS5 is available, return those directly
if not vector_results:
return fts5_results[:limit]
# Fuse with RRF
return _reciprocal_rank_fusion(
result_sets=[
(fts5_results, FTS5_WEIGHT),
(vector_results, VECTOR_WEIGHT),
],
k=RRF_K,
limit=limit,
)
def ingest_session_to_vectors(
session_id: str,
messages: List[Dict[str, Any]],
collection: str = "session_messages",
) -> int:
"""Ingest a session's messages into the vector store.
Returns number of vectors inserted.
"""
client = _get_qdrant_client()
if client is None:
return 0
from qdrant_client.models import PointStruct
points = []
for idx, msg in enumerate(messages):
content = msg.get("content", "")
if not content or len(content) < 10:
continue
vec = _embed_query(content)
if vec is None:
continue
points.append(PointStruct(
id=f"{session_id}_{idx}",
vector=vec,
payload={
"session_id": session_id,
"content": content[:1000],
"role": msg.get("role", ""),
"timestamp": msg.get("timestamp", 0),
},
))
if not points:
return 0
try:
client.upsert(collection_name=collection, points=points)
return len(points)
except Exception as e:
logger.debug("Vector ingest failed for session %s: %s", session_id, e)
return 0
def get_search_stats() -> Dict[str, Any]:
"""Return stats about search backends."""
qdrant_ok = _get_qdrant_client() is not None
return {
"fts5": True, # Always available
"vector": qdrant_ok,
"fusion": "rrf",
"weights": {"fts5": FTS5_WEIGHT, "vector": VECTOR_WEIGHT},
"rrf_k": RRF_K,
}

View File

@@ -304,7 +304,7 @@ def session_search(
"""
Search past sessions and return focused summaries of matching conversations.
Uses FTS5 to find matches, then summarizes the top sessions with Gemini Flash.
Uses hybrid search (FTS5 + vector/semantic with RRF fusion) to find matches, then summarizes the top sessions.
The current session is excluded from results since the agent already has that context.
"""
if db is None:
@@ -325,13 +325,14 @@ def session_search(
if role_filter and role_filter.strip():
role_list = [r.strip() for r in role_filter.split(",") if r.strip()]
# FTS5 search -- get matches ranked by relevance
raw_results = db.search_messages(
# Hybrid search: FTS5 + vector (semantic), merged with Reciprocal Rank Fusion
from tools.hybrid_search import hybrid_search
raw_results = hybrid_search(
query=query,
role_filter=role_list,
db=db,
exclude_sources=list(_HIDDEN_SESSION_SOURCES),
limit=50, # Get more matches to find unique sessions
offset=0,
role_filter=role_list,
limit=50,
)
if not raw_results: