Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bafec50be |
238
agent/hybrid_search.py
Normal file
238
agent/hybrid_search.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""
|
||||
hybrid_search.py — Hybrid search combining FTS5, vector, and HRR.
|
||||
|
||||
Three-backend search router:
|
||||
1. FTS5 (SQLite full-text) — fast keyword matching, always available
|
||||
2. Vector search (Qdrant/ChromaDB) — semantic similarity, optional
|
||||
3. HRR (Holographic Reduced Representations) — compositional recall, optional
|
||||
|
||||
Graceful degradation: if vector or HRR backends are unavailable,
|
||||
falls back to FTS5-only.
|
||||
|
||||
Usage:
|
||||
from agent.hybrid_search import hybrid_search
|
||||
|
||||
results = hybrid_search(query, db=session_db, limit=10)
|
||||
# Returns merged, deduplicated, ranked results from all available backends
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""Single search result from any backend."""
|
||||
session_id: str
|
||||
message_content: str
|
||||
score: float
|
||||
source: str # "fts5", "vector", "hrr"
|
||||
role: str = ""
|
||||
timestamp: str = ""
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridSearchConfig:
|
||||
"""Configuration for hybrid search."""
|
||||
fts5_enabled: bool = True
|
||||
vector_enabled: bool = False
|
||||
hrr_enabled: bool = False
|
||||
vector_weight: float = 0.4
|
||||
fts5_weight: float = 0.4
|
||||
hrr_weight: float = 0.2
|
||||
dedup_threshold: float = 0.9 # similarity threshold for dedup
|
||||
|
||||
|
||||
def search_fts5(query: str, db, limit: int = 50, role_filter: list = None) -> List[SearchResult]:
|
||||
"""Search using FTS5 full-text search."""
|
||||
try:
|
||||
raw = db.search_messages(
|
||||
query=query,
|
||||
role_filter=role_filter,
|
||||
limit=limit,
|
||||
offset=0,
|
||||
)
|
||||
results = []
|
||||
for r in raw:
|
||||
results.append(SearchResult(
|
||||
session_id=r.get("session_id", ""),
|
||||
message_content=r.get("content", ""),
|
||||
score=r.get("rank", 0.0),
|
||||
source="fts5",
|
||||
role=r.get("role", ""),
|
||||
timestamp=str(r.get("timestamp", "")),
|
||||
))
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"FTS5 search failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def search_vector(query: str, limit: int = 50) -> List[SearchResult]:
|
||||
"""Search using vector similarity (Qdrant/ChromaDB).
|
||||
|
||||
Returns empty list if vector backend unavailable.
|
||||
"""
|
||||
try:
|
||||
# Try ChromaDB first
|
||||
import chromadb
|
||||
client = chromadb.PersistentClient(path="~/.hermes/memory/chroma")
|
||||
collection = client.get_or_create_collection("sessions")
|
||||
results = collection.query(
|
||||
query_texts=[query],
|
||||
n_results=limit,
|
||||
)
|
||||
search_results = []
|
||||
for i, doc in enumerate(results.get("documents", [[]])[0]):
|
||||
metadata = results.get("metadatas", [[]])[0]
|
||||
meta = metadata[i] if i < len(metadata) else {}
|
||||
distance = results.get("distances", [[]])[0]
|
||||
score = 1.0 - (distance[i] if i < len(distance) else 1.0)
|
||||
search_results.append(SearchResult(
|
||||
session_id=meta.get("session_id", ""),
|
||||
message_content=doc,
|
||||
score=score,
|
||||
source="vector",
|
||||
role=meta.get("role", ""),
|
||||
timestamp=meta.get("timestamp", ""),
|
||||
))
|
||||
return search_results
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Try Qdrant
|
||||
from qdrant_client import QdrantClient
|
||||
client = QdrantClient(host="localhost", port=6333)
|
||||
results = client.query_points(
|
||||
collection_name="sessions",
|
||||
query_text=query,
|
||||
limit=limit,
|
||||
)
|
||||
return [
|
||||
SearchResult(
|
||||
session_id=pt.payload.get("session_id", ""),
|
||||
message_content=pt.payload.get("content", ""),
|
||||
score=pt.score,
|
||||
source="vector",
|
||||
role=pt.payload.get("role", ""),
|
||||
)
|
||||
for pt in results.points
|
||||
]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def search_hrr(query: str, limit: int = 50) -> List[SearchResult]:
|
||||
"""Search using Holographic Reduced Representations.
|
||||
|
||||
Returns empty list if HRR backend unavailable.
|
||||
"""
|
||||
try:
|
||||
from agent.holographic_memory import holographic_recall
|
||||
results = holographic_recall(query, limit=limit)
|
||||
return [
|
||||
SearchResult(
|
||||
session_id=r.get("session_id", ""),
|
||||
message_content=r.get("content", ""),
|
||||
score=r.get("binding_score", 0.0),
|
||||
source="hrr",
|
||||
role=r.get("role", ""),
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
|
||||
|
||||
def merge_results(
|
||||
fts5_results: List[SearchResult],
|
||||
vector_results: List[SearchResult],
|
||||
hrr_results: List[SearchResult],
|
||||
config: HybridSearchConfig,
|
||||
limit: int = 10,
|
||||
) -> List[SearchResult]:
|
||||
"""Merge results from multiple backends with weighted scoring."""
|
||||
all_results = []
|
||||
|
||||
# Apply weights
|
||||
for r in fts5_results:
|
||||
r.score *= config.fts5_weight
|
||||
all_results.append(r)
|
||||
for r in vector_results:
|
||||
r.score *= config.vector_weight
|
||||
all_results.append(r)
|
||||
for r in hrr_results:
|
||||
r.score *= config.hrr_weight
|
||||
all_results.append(r)
|
||||
|
||||
# Sort by weighted score
|
||||
all_results.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
# Deduplicate by session_id + content similarity
|
||||
seen = set()
|
||||
deduped = []
|
||||
for r in all_results:
|
||||
key = f"{r.session_id}:{r.message_content[:100]}"
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
deduped.append(r)
|
||||
|
||||
return deduped[:limit]
|
||||
|
||||
|
||||
def hybrid_search(
|
||||
query: str,
|
||||
db=None,
|
||||
limit: int = 10,
|
||||
role_filter: list = None,
|
||||
config: HybridSearchConfig = None,
|
||||
) -> List[SearchResult]:
|
||||
"""Hybrid search across FTS5, vector, and HRR backends.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
db: Session database (for FTS5)
|
||||
limit: Max results
|
||||
role_filter: Filter by message role
|
||||
config: Hybrid search configuration
|
||||
|
||||
Returns:
|
||||
List of SearchResult, ranked by weighted score
|
||||
"""
|
||||
if config is None:
|
||||
config = HybridSearchConfig()
|
||||
|
||||
fts5_results = []
|
||||
vector_results = []
|
||||
hrr_results = []
|
||||
|
||||
# FTS5 (always available if db provided)
|
||||
if config.fts5_enabled and db:
|
||||
fts5_results = search_fts5(query, db, limit=50, role_filter=role_filter)
|
||||
logger.debug(f"FTS5: {len(fts5_results)} results")
|
||||
|
||||
# Vector search (optional)
|
||||
if config.vector_enabled:
|
||||
vector_results = search_vector(query, limit=50)
|
||||
logger.debug(f"Vector: {len(vector_results)} results")
|
||||
|
||||
# HRR (optional)
|
||||
if config.hrr_enabled:
|
||||
hrr_results = search_hrr(query, limit=50)
|
||||
logger.debug(f"HRR: {len(hrr_results)} results")
|
||||
|
||||
# If only FTS5 available, just return those
|
||||
if not vector_results and not hrr_results:
|
||||
return fts5_results[:limit]
|
||||
|
||||
# Merge and rank
|
||||
return merge_results(fts5_results, vector_results, hrr_results, config, limit)
|
||||
58
tests/agent/test_hybrid_search.py
Normal file
58
tests/agent/test_hybrid_search.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Tests for hybrid search router."""
|
||||
|
||||
import pytest
|
||||
from agent.hybrid_search import (
|
||||
SearchResult,
|
||||
HybridSearchConfig,
|
||||
merge_results,
|
||||
hybrid_search,
|
||||
search_fts5,
|
||||
)
|
||||
|
||||
|
||||
class TestSearchResult:
|
||||
def test_creation(self):
|
||||
r = SearchResult(session_id="s1", message_content="hello", score=0.9, source="fts5")
|
||||
assert r.session_id == "s1"
|
||||
assert r.source == "fts5"
|
||||
|
||||
|
||||
class TestMergeResults:
|
||||
def test_merges_and_ranks(self):
|
||||
fts5 = [SearchResult("s1", "alpha content", 1.0, "fts5")]
|
||||
vec = [SearchResult("s2", "beta content", 0.9, "vector")]
|
||||
hrr = [SearchResult("s3", "gamma content", 0.5, "hrr")]
|
||||
config = HybridSearchConfig(fts5_weight=0.4, vector_weight=0.4, hrr_weight=0.2)
|
||||
results = merge_results(fts5, vec, hrr, config, limit=10)
|
||||
assert len(results) == 3
|
||||
# s1: 1.0*0.4=0.4, s2: 0.9*0.4=0.36, s3: 0.5*0.2=0.1
|
||||
assert results[0].session_id == "s1"
|
||||
|
||||
def test_deduplicates(self):
|
||||
fts5 = [SearchResult("s1", "same content", 1.0, "fts5")]
|
||||
vec = [SearchResult("s1", "same content", 0.8, "vector")]
|
||||
config = HybridSearchConfig()
|
||||
results = merge_results(fts5, vec, [], config, limit=10)
|
||||
assert len(results) == 1
|
||||
|
||||
def test_respects_limit(self):
|
||||
fts5 = [SearchResult(f"s{i}", f"content {i}", 1.0/i, "fts5") for i in range(1, 20)]
|
||||
results = merge_results(fts5, [], [], HybridSearchConfig(), limit=5)
|
||||
assert len(results) == 5
|
||||
|
||||
def test_empty_inputs(self):
|
||||
results = merge_results([], [], [], HybridSearchConfig())
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
class TestHybridSearchFallback:
|
||||
def test_falls_back_to_fts5_only(self):
|
||||
"""When vector and HRR unavailable, returns FTS5 results."""
|
||||
# Mock db
|
||||
class MockDB:
|
||||
def search_messages(self, **kwargs):
|
||||
return [{"session_id": "s1", "content": "test", "rank": 1.0, "role": "user"}]
|
||||
|
||||
results = hybrid_search("test", db=MockDB(), config=HybridSearchConfig(vector_enabled=False, hrr_enabled=False))
|
||||
assert len(results) == 1
|
||||
assert results[0].source == "fts5"
|
||||
Reference in New Issue
Block a user