Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bafec50be |
238
agent/hybrid_search.py
Normal file
238
agent/hybrid_search.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
"""
|
||||||
|
hybrid_search.py — Hybrid search combining FTS5, vector, and HRR.
|
||||||
|
|
||||||
|
Three-backend search router:
|
||||||
|
1. FTS5 (SQLite full-text) — fast keyword matching, always available
|
||||||
|
2. Vector search (Qdrant/ChromaDB) — semantic similarity, optional
|
||||||
|
3. HRR (Holographic Reduced Representations) — compositional recall, optional
|
||||||
|
|
||||||
|
Graceful degradation: if vector or HRR backends are unavailable,
|
||||||
|
falls back to FTS5-only.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from agent.hybrid_search import hybrid_search
|
||||||
|
|
||||||
|
results = hybrid_search(query, db=session_db, limit=10)
|
||||||
|
# Returns merged, deduplicated, ranked results from all available backends
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Dict, Any, Optional, Callable
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SearchResult:
|
||||||
|
"""Single search result from any backend."""
|
||||||
|
session_id: str
|
||||||
|
message_content: str
|
||||||
|
score: float
|
||||||
|
source: str # "fts5", "vector", "hrr"
|
||||||
|
role: str = ""
|
||||||
|
timestamp: str = ""
|
||||||
|
metadata: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HybridSearchConfig:
|
||||||
|
"""Configuration for hybrid search."""
|
||||||
|
fts5_enabled: bool = True
|
||||||
|
vector_enabled: bool = False
|
||||||
|
hrr_enabled: bool = False
|
||||||
|
vector_weight: float = 0.4
|
||||||
|
fts5_weight: float = 0.4
|
||||||
|
hrr_weight: float = 0.2
|
||||||
|
dedup_threshold: float = 0.9 # similarity threshold for dedup
|
||||||
|
|
||||||
|
|
||||||
|
def search_fts5(query: str, db, limit: int = 50, role_filter: list = None) -> List[SearchResult]:
|
||||||
|
"""Search using FTS5 full-text search."""
|
||||||
|
try:
|
||||||
|
raw = db.search_messages(
|
||||||
|
query=query,
|
||||||
|
role_filter=role_filter,
|
||||||
|
limit=limit,
|
||||||
|
offset=0,
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
for r in raw:
|
||||||
|
results.append(SearchResult(
|
||||||
|
session_id=r.get("session_id", ""),
|
||||||
|
message_content=r.get("content", ""),
|
||||||
|
score=r.get("rank", 0.0),
|
||||||
|
source="fts5",
|
||||||
|
role=r.get("role", ""),
|
||||||
|
timestamp=str(r.get("timestamp", "")),
|
||||||
|
))
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"FTS5 search failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def search_vector(query: str, limit: int = 50) -> List[SearchResult]:
|
||||||
|
"""Search using vector similarity (Qdrant/ChromaDB).
|
||||||
|
|
||||||
|
Returns empty list if vector backend unavailable.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try ChromaDB first
|
||||||
|
import chromadb
|
||||||
|
client = chromadb.PersistentClient(path="~/.hermes/memory/chroma")
|
||||||
|
collection = client.get_or_create_collection("sessions")
|
||||||
|
results = collection.query(
|
||||||
|
query_texts=[query],
|
||||||
|
n_results=limit,
|
||||||
|
)
|
||||||
|
search_results = []
|
||||||
|
for i, doc in enumerate(results.get("documents", [[]])[0]):
|
||||||
|
metadata = results.get("metadatas", [[]])[0]
|
||||||
|
meta = metadata[i] if i < len(metadata) else {}
|
||||||
|
distance = results.get("distances", [[]])[0]
|
||||||
|
score = 1.0 - (distance[i] if i < len(distance) else 1.0)
|
||||||
|
search_results.append(SearchResult(
|
||||||
|
session_id=meta.get("session_id", ""),
|
||||||
|
message_content=doc,
|
||||||
|
score=score,
|
||||||
|
source="vector",
|
||||||
|
role=meta.get("role", ""),
|
||||||
|
timestamp=meta.get("timestamp", ""),
|
||||||
|
))
|
||||||
|
return search_results
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try Qdrant
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
client = QdrantClient(host="localhost", port=6333)
|
||||||
|
results = client.query_points(
|
||||||
|
collection_name="sessions",
|
||||||
|
query_text=query,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
SearchResult(
|
||||||
|
session_id=pt.payload.get("session_id", ""),
|
||||||
|
message_content=pt.payload.get("content", ""),
|
||||||
|
score=pt.score,
|
||||||
|
source="vector",
|
||||||
|
role=pt.payload.get("role", ""),
|
||||||
|
)
|
||||||
|
for pt in results.points
|
||||||
|
]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def search_hrr(query: str, limit: int = 50) -> List[SearchResult]:
|
||||||
|
"""Search using Holographic Reduced Representations.
|
||||||
|
|
||||||
|
Returns empty list if HRR backend unavailable.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from agent.holographic_memory import holographic_recall
|
||||||
|
results = holographic_recall(query, limit=limit)
|
||||||
|
return [
|
||||||
|
SearchResult(
|
||||||
|
session_id=r.get("session_id", ""),
|
||||||
|
message_content=r.get("content", ""),
|
||||||
|
score=r.get("binding_score", 0.0),
|
||||||
|
source="hrr",
|
||||||
|
role=r.get("role", ""),
|
||||||
|
)
|
||||||
|
for r in results
|
||||||
|
]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def merge_results(
|
||||||
|
fts5_results: List[SearchResult],
|
||||||
|
vector_results: List[SearchResult],
|
||||||
|
hrr_results: List[SearchResult],
|
||||||
|
config: HybridSearchConfig,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""Merge results from multiple backends with weighted scoring."""
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
# Apply weights
|
||||||
|
for r in fts5_results:
|
||||||
|
r.score *= config.fts5_weight
|
||||||
|
all_results.append(r)
|
||||||
|
for r in vector_results:
|
||||||
|
r.score *= config.vector_weight
|
||||||
|
all_results.append(r)
|
||||||
|
for r in hrr_results:
|
||||||
|
r.score *= config.hrr_weight
|
||||||
|
all_results.append(r)
|
||||||
|
|
||||||
|
# Sort by weighted score
|
||||||
|
all_results.sort(key=lambda r: r.score, reverse=True)
|
||||||
|
|
||||||
|
# Deduplicate by session_id + content similarity
|
||||||
|
seen = set()
|
||||||
|
deduped = []
|
||||||
|
for r in all_results:
|
||||||
|
key = f"{r.session_id}:{r.message_content[:100]}"
|
||||||
|
if key not in seen:
|
||||||
|
seen.add(key)
|
||||||
|
deduped.append(r)
|
||||||
|
|
||||||
|
return deduped[:limit]
|
||||||
|
|
||||||
|
|
||||||
|
def hybrid_search(
|
||||||
|
query: str,
|
||||||
|
db=None,
|
||||||
|
limit: int = 10,
|
||||||
|
role_filter: list = None,
|
||||||
|
config: HybridSearchConfig = None,
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""Hybrid search across FTS5, vector, and HRR backends.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query
|
||||||
|
db: Session database (for FTS5)
|
||||||
|
limit: Max results
|
||||||
|
role_filter: Filter by message role
|
||||||
|
config: Hybrid search configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SearchResult, ranked by weighted score
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = HybridSearchConfig()
|
||||||
|
|
||||||
|
fts5_results = []
|
||||||
|
vector_results = []
|
||||||
|
hrr_results = []
|
||||||
|
|
||||||
|
# FTS5 (always available if db provided)
|
||||||
|
if config.fts5_enabled and db:
|
||||||
|
fts5_results = search_fts5(query, db, limit=50, role_filter=role_filter)
|
||||||
|
logger.debug(f"FTS5: {len(fts5_results)} results")
|
||||||
|
|
||||||
|
# Vector search (optional)
|
||||||
|
if config.vector_enabled:
|
||||||
|
vector_results = search_vector(query, limit=50)
|
||||||
|
logger.debug(f"Vector: {len(vector_results)} results")
|
||||||
|
|
||||||
|
# HRR (optional)
|
||||||
|
if config.hrr_enabled:
|
||||||
|
hrr_results = search_hrr(query, limit=50)
|
||||||
|
logger.debug(f"HRR: {len(hrr_results)} results")
|
||||||
|
|
||||||
|
# If only FTS5 available, just return those
|
||||||
|
if not vector_results and not hrr_results:
|
||||||
|
return fts5_results[:limit]
|
||||||
|
|
||||||
|
# Merge and rank
|
||||||
|
return merge_results(fts5_results, vector_results, hrr_results, config, limit)
|
||||||
58
tests/agent/test_hybrid_search.py
Normal file
58
tests/agent/test_hybrid_search.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""Tests for hybrid search router."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from agent.hybrid_search import (
|
||||||
|
SearchResult,
|
||||||
|
HybridSearchConfig,
|
||||||
|
merge_results,
|
||||||
|
hybrid_search,
|
||||||
|
search_fts5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchResult:
|
||||||
|
def test_creation(self):
|
||||||
|
r = SearchResult(session_id="s1", message_content="hello", score=0.9, source="fts5")
|
||||||
|
assert r.session_id == "s1"
|
||||||
|
assert r.source == "fts5"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMergeResults:
|
||||||
|
def test_merges_and_ranks(self):
|
||||||
|
fts5 = [SearchResult("s1", "alpha content", 1.0, "fts5")]
|
||||||
|
vec = [SearchResult("s2", "beta content", 0.9, "vector")]
|
||||||
|
hrr = [SearchResult("s3", "gamma content", 0.5, "hrr")]
|
||||||
|
config = HybridSearchConfig(fts5_weight=0.4, vector_weight=0.4, hrr_weight=0.2)
|
||||||
|
results = merge_results(fts5, vec, hrr, config, limit=10)
|
||||||
|
assert len(results) == 3
|
||||||
|
# s1: 1.0*0.4=0.4, s2: 0.9*0.4=0.36, s3: 0.5*0.2=0.1
|
||||||
|
assert results[0].session_id == "s1"
|
||||||
|
|
||||||
|
def test_deduplicates(self):
|
||||||
|
fts5 = [SearchResult("s1", "same content", 1.0, "fts5")]
|
||||||
|
vec = [SearchResult("s1", "same content", 0.8, "vector")]
|
||||||
|
config = HybridSearchConfig()
|
||||||
|
results = merge_results(fts5, vec, [], config, limit=10)
|
||||||
|
assert len(results) == 1
|
||||||
|
|
||||||
|
def test_respects_limit(self):
|
||||||
|
fts5 = [SearchResult(f"s{i}", f"content {i}", 1.0/i, "fts5") for i in range(1, 20)]
|
||||||
|
results = merge_results(fts5, [], [], HybridSearchConfig(), limit=5)
|
||||||
|
assert len(results) == 5
|
||||||
|
|
||||||
|
def test_empty_inputs(self):
|
||||||
|
results = merge_results([], [], [], HybridSearchConfig())
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestHybridSearchFallback:
|
||||||
|
def test_falls_back_to_fts5_only(self):
|
||||||
|
"""When vector and HRR unavailable, returns FTS5 results."""
|
||||||
|
# Mock db
|
||||||
|
class MockDB:
|
||||||
|
def search_messages(self, **kwargs):
|
||||||
|
return [{"session_id": "s1", "content": "test", "rank": 1.0, "role": "user"}]
|
||||||
|
|
||||||
|
results = hybrid_search("test", db=MockDB(), config=HybridSearchConfig(vector_enabled=False, hrr_enabled=False))
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].source == "fts5"
|
||||||
Reference in New Issue
Block a user