- add lexical, semantic, graph, and temporal retrieval lanes with RRF fusion - store retrieval traces on fused searches and expose them through the provider - add benchmark helpers for prompt-matrix before/after evaluation
917 lines
36 KiB
Python
917 lines
36 KiB
Python
"""Multi-path retrieval for the holographic memory store.
|
|
|
|
Combines independent lexical, semantic, graph-aware, and temporal retrieval
|
|
lanes, then fuses them with Reciprocal Rank Fusion (RRF). The pipeline keeps a
|
|
trace of per-lane contributions so recall failures can be inspected after a
|
|
run, and includes a benchmark helper to compare the fused pipeline against the
|
|
legacy Hermes-native search path.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
import re
|
|
from collections import defaultdict
|
|
from datetime import datetime, timezone
|
|
from typing import TYPE_CHECKING, Any, Iterable
|
|
|
|
if TYPE_CHECKING:
|
|
from .store import MemoryStore
|
|
|
|
try:
|
|
from . import holographic as hrr
|
|
except ImportError:
|
|
import holographic as hrr # type: ignore[no-redef]
|
|
|
|
_RRF_K = 60
|
|
_DEFAULT_TEMPORAL_HALF_LIFE = 30
|
|
_SEMANTIC_SYNONYMS = {
|
|
"automation": {"ansible", "playbooks", "orchestration"},
|
|
"ansible": {"automation", "playbooks", "orchestration"},
|
|
"deploy": {"deployment", "deploys", "rollout", "rollouts", "ship", "shipping", "release", "releases"},
|
|
"deploys": {"deploy", "deployment", "rollout", "rollouts", "ship", "shipping"},
|
|
"deployment": {"deploy", "deploys", "rollout", "rollouts"},
|
|
"rollout": {"deploy", "deploys", "deployment", "release", "shipping"},
|
|
"rollouts": {"deploy", "deploys", "deployment", "release", "shipping"},
|
|
"provider": {"model", "engine", "runtime"},
|
|
"model": {"provider", "engine", "runtime"},
|
|
"lane": {"queue", "router", "track"},
|
|
"forge": {"review", "triage", "pull-request"},
|
|
}
|
|
_ENTITY_TOKEN_RE = re.compile(r"\b([A-Z][\w-]*(?:\s+[A-Z][\w-]*)*)\b")
|
|
|
|
|
|
def format_benchmark_report(report: dict) -> str:
|
|
"""Render a benchmark/evaluation report as plain text."""
|
|
|
|
lines = [
|
|
"Prompt matrix benchmark",
|
|
f"- baseline_top1_hits: {report.get('baseline_top1_hits', 0)}",
|
|
f"- fused_top1_hits: {report.get('fused_top1_hits', 0)}",
|
|
f"- improvement: {report.get('improvement', 0)}",
|
|
"",
|
|
]
|
|
for case in report.get("cases", []):
|
|
lines.append(
|
|
f"- {case['name']}: baseline={'PASS' if case['baseline_hit'] else 'FAIL'}, "
|
|
f"fused={'PASS' if case['fused_hit'] else 'FAIL'}, expected={case['expected_substring']}"
|
|
)
|
|
lines.append(f" baseline_top: {case['baseline_top']}")
|
|
lines.append(f" fused_top: {case['fused_top']}")
|
|
return "\n".join(lines).strip()
|
|
|
|
|
|
class FactRetriever:
|
|
"""Multi-path fact retrieval with RRF fusion, traceability, and benchmarking."""
|
|
|
|
def __init__(
|
|
self,
|
|
store: MemoryStore,
|
|
temporal_decay_half_life: int = 0,
|
|
fts_weight: float = 0.4,
|
|
jaccard_weight: float = 0.3,
|
|
hrr_weight: float = 0.3,
|
|
hrr_dim: int = 1024,
|
|
retrieval_lanes: list[str] | tuple[str, ...] | None = None,
|
|
rrf_k: int = _RRF_K,
|
|
enable_rerank: bool = True,
|
|
rerank_min_candidates: int = 3,
|
|
rerank_margin: float = 0.035,
|
|
):
|
|
self.store = store
|
|
self.half_life = temporal_decay_half_life
|
|
self.hrr_dim = hrr_dim
|
|
self.available_lanes = ("lexical", "semantic", "graph", "temporal")
|
|
self.default_lanes = tuple(retrieval_lanes or self.available_lanes)
|
|
self.rrf_k = rrf_k
|
|
self.enable_rerank = enable_rerank
|
|
self.rerank_min_candidates = rerank_min_candidates
|
|
self.rerank_margin = rerank_margin
|
|
self.last_trace: dict[str, Any] = {}
|
|
|
|
# Auto-redistribute weights if numpy unavailable.
|
|
if hrr_weight > 0 and not hrr._HAS_NUMPY:
|
|
fts_weight = 0.6
|
|
jaccard_weight = 0.4
|
|
hrr_weight = 0.0
|
|
|
|
self.fts_weight = fts_weight
|
|
self.jaccard_weight = jaccard_weight
|
|
self.hrr_weight = hrr_weight
|
|
|
|
# ------------------------------------------------------------------
|
|
# Public API
|
|
# ------------------------------------------------------------------
|
|
|
|
def search(
|
|
self,
|
|
query: str,
|
|
category: str | None = None,
|
|
min_trust: float = 0.3,
|
|
limit: int = 10,
|
|
*,
|
|
lanes: Iterable[str] | None = None,
|
|
rerank: bool | None = None,
|
|
) -> list[dict]:
|
|
"""Run the fused multi-path retrieval pipeline and return only results."""
|
|
|
|
payload = self.search_with_trace(
|
|
query,
|
|
category=category,
|
|
min_trust=min_trust,
|
|
limit=limit,
|
|
lanes=lanes,
|
|
rerank=rerank,
|
|
)
|
|
return payload["results"]
|
|
|
|
def search_with_trace(
|
|
self,
|
|
query: str,
|
|
category: str | None = None,
|
|
min_trust: float = 0.3,
|
|
limit: int = 10,
|
|
*,
|
|
lanes: Iterable[str] | None = None,
|
|
rerank: bool | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Run the fused retrieval pipeline and return results plus an audit trace."""
|
|
|
|
normalized_query = (query or "").strip()
|
|
selected_lanes = self._resolve_lanes(lanes)
|
|
candidate_limit = max(limit * 5, 12)
|
|
lane_results: dict[str, list[dict[str, Any]]] = {}
|
|
|
|
for lane_name in selected_lanes:
|
|
lane_fn = getattr(self, f"_{lane_name}_lane")
|
|
lane_results[lane_name] = lane_fn(
|
|
normalized_query,
|
|
category=category,
|
|
min_trust=min_trust,
|
|
limit=candidate_limit,
|
|
)
|
|
|
|
fused = self._rrf_fuse(lane_results)
|
|
rerank_requested = self.enable_rerank if rerank is None else rerank
|
|
rerank_applied = False
|
|
rerank_reason = "disabled"
|
|
if rerank_requested:
|
|
should_rerank, rerank_reason = self._should_rerank(fused, lane_results)
|
|
if should_rerank:
|
|
fused = self._rerank(normalized_query, fused)
|
|
rerank_applied = True
|
|
|
|
results = fused[:limit]
|
|
trace = {
|
|
"query": normalized_query,
|
|
"lanes_run": list(selected_lanes),
|
|
"lane_hits": {lane: len(items) for lane, items in lane_results.items()},
|
|
"lane_top_fact_ids": {
|
|
lane: [item["fact_id"] for item in items[:5]]
|
|
for lane, items in lane_results.items()
|
|
},
|
|
"fused_count": len(fused),
|
|
"rrf_k": self.rrf_k,
|
|
"rerank_requested": rerank_requested,
|
|
"rerank_applied": rerank_applied,
|
|
"rerank_reason": rerank_reason,
|
|
"top_results": [
|
|
{
|
|
"fact_id": item["fact_id"],
|
|
"content": item["content"],
|
|
"fused_score": round(item.get("fused_score", 0.0), 6),
|
|
"lane_contributions": item.get("lane_contributions", {}),
|
|
"matched_lanes": item.get("matched_lanes", []),
|
|
}
|
|
for item in results
|
|
],
|
|
}
|
|
self.last_trace = trace
|
|
return {"results": results, "trace": trace}
|
|
|
|
def baseline_search(
|
|
self,
|
|
query: str,
|
|
category: str | None = None,
|
|
min_trust: float = 0.3,
|
|
limit: int = 10,
|
|
) -> list[dict]:
|
|
"""Legacy Hermes-native retrieval path used for before/after benchmarks."""
|
|
|
|
candidates = self._fts_candidates(query, category, min_trust, limit * 3)
|
|
if not candidates:
|
|
return []
|
|
|
|
query_tokens = self._tokenize(query)
|
|
scored = []
|
|
for fact in candidates:
|
|
content_tokens = self._tokenize(fact["content"])
|
|
tag_tokens = self._tokenize(fact.get("tags", ""))
|
|
all_tokens = content_tokens | tag_tokens
|
|
jaccard = self._jaccard_similarity(query_tokens, all_tokens)
|
|
fts_score = fact.get("fts_rank", 0.0)
|
|
|
|
if self.hrr_weight > 0 and fact.get("hrr_vector"):
|
|
fact_vec = hrr.bytes_to_phases(fact["hrr_vector"])
|
|
query_vec = hrr.encode_text(query, self.hrr_dim)
|
|
hrr_sim = (hrr.similarity(query_vec, fact_vec) + 1.0) / 2.0
|
|
else:
|
|
hrr_sim = 0.5
|
|
|
|
relevance = (
|
|
self.fts_weight * fts_score
|
|
+ self.jaccard_weight * jaccard
|
|
+ self.hrr_weight * hrr_sim
|
|
)
|
|
score = relevance * fact["trust_score"]
|
|
if self.half_life > 0:
|
|
score *= self._temporal_decay(fact.get("updated_at") or fact.get("created_at"))
|
|
legacy = self._clean_fact(fact)
|
|
legacy["score"] = score
|
|
scored.append(legacy)
|
|
|
|
scored.sort(key=lambda item: item["score"], reverse=True)
|
|
return scored[:limit]
|
|
|
|
def benchmark_prompt_matrix(
|
|
self,
|
|
cases: list[dict[str, Any]],
|
|
*,
|
|
category: str | None = None,
|
|
min_trust: float = 0.0,
|
|
limit: int = 5,
|
|
lanes: Iterable[str] | None = None,
|
|
rerank: bool | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Benchmark fused retrieval against the legacy Hermes-native path."""
|
|
|
|
baseline_hits = 0
|
|
fused_hits = 0
|
|
rows: list[dict[str, Any]] = []
|
|
for case in cases:
|
|
query = case["query"]
|
|
top_k = int(case.get("top_k", 1))
|
|
baseline = self.baseline_search(query, category=category, min_trust=min_trust, limit=limit)
|
|
fused_payload = self.search_with_trace(
|
|
query,
|
|
category=category,
|
|
min_trust=min_trust,
|
|
limit=limit,
|
|
lanes=lanes,
|
|
rerank=rerank,
|
|
)
|
|
fused_results = fused_payload["results"]
|
|
|
|
baseline_hit = self._matches_expected(baseline[:top_k], case)
|
|
fused_hit = self._matches_expected(fused_results[:top_k], case)
|
|
baseline_hits += int(baseline_hit)
|
|
fused_hits += int(fused_hit)
|
|
|
|
rows.append(
|
|
{
|
|
"name": case.get("name", query),
|
|
"query": query,
|
|
"expected_substring": case.get("expected_substring", ""),
|
|
"baseline_hit": baseline_hit,
|
|
"fused_hit": fused_hit,
|
|
"baseline_top": baseline[0]["content"] if baseline else "",
|
|
"fused_top": fused_results[0]["content"] if fused_results else "",
|
|
"trace": fused_payload["trace"],
|
|
}
|
|
)
|
|
|
|
return {
|
|
"baseline_top1_hits": baseline_hits,
|
|
"fused_top1_hits": fused_hits,
|
|
"improvement": fused_hits - baseline_hits,
|
|
"cases": rows,
|
|
}
|
|
|
|
# ------------------------------------------------------------------
|
|
# Fusion pipeline
|
|
# ------------------------------------------------------------------
|
|
|
|
def _resolve_lanes(self, lanes: Iterable[str] | None) -> tuple[str, ...]:
|
|
selected = tuple(lanes or self.default_lanes)
|
|
valid = [lane for lane in selected if lane in self.available_lanes]
|
|
return tuple(valid or self.available_lanes)
|
|
|
|
def _rrf_fuse(self, lane_results: dict[str, list[dict[str, Any]]]) -> list[dict[str, Any]]:
|
|
fused: dict[int, dict[str, Any]] = {}
|
|
for lane_name, ranked in lane_results.items():
|
|
for rank, fact in enumerate(ranked, start=1):
|
|
fact_id = int(fact["fact_id"])
|
|
contribution = 1.0 / (self.rrf_k + rank)
|
|
entry = fused.get(fact_id)
|
|
if entry is None:
|
|
entry = self._clean_fact(fact)
|
|
entry["fused_score"] = 0.0
|
|
entry["lane_contributions"] = {}
|
|
entry["lane_raw_scores"] = {}
|
|
entry["matched_lanes"] = []
|
|
fused[fact_id] = entry
|
|
entry["fused_score"] += contribution
|
|
entry["lane_contributions"][lane_name] = round(contribution, 6)
|
|
entry["lane_raw_scores"][lane_name] = round(float(fact.get("lane_score", fact.get("score", 0.0))), 6)
|
|
if lane_name not in entry["matched_lanes"]:
|
|
entry["matched_lanes"].append(lane_name)
|
|
ranked = sorted(
|
|
fused.values(),
|
|
key=lambda item: (item["fused_score"], item.get("trust_score", 0.0)),
|
|
reverse=True,
|
|
)
|
|
for item in ranked:
|
|
item["fused_score"] = round(item["fused_score"], 6)
|
|
return ranked
|
|
|
|
def _should_rerank(
|
|
self,
|
|
fused: list[dict[str, Any]],
|
|
lane_results: dict[str, list[dict[str, Any]]],
|
|
) -> tuple[bool, str]:
|
|
if len(fused) < self.rerank_min_candidates:
|
|
return False, "not enough candidates"
|
|
active_lanes = sum(1 for items in lane_results.values() if items)
|
|
if active_lanes < 2:
|
|
return False, "single active lane"
|
|
margin = fused[0]["fused_score"] - fused[1]["fused_score"]
|
|
if margin > self.rerank_margin:
|
|
return False, f"decisive fused margin {margin:.3f}"
|
|
if len(fused[0].get("matched_lanes", [])) < 2 and len(fused[1].get("matched_lanes", [])) < 2:
|
|
return False, "insufficient lane disagreement"
|
|
return True, f"close fused margin {margin:.3f} across multiple lanes"
|
|
|
|
def _rerank(self, query: str, fused: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
query_tokens = self._expand_semantic_tokens(self._tokenize(query))
|
|
query_entities = {entity.lower() for entity in self._extract_query_entities(query)}
|
|
reranked = []
|
|
for fact in fused:
|
|
fact_tokens = self._expand_semantic_tokens(self._tokenize(fact["content"] + " " + fact.get("tags", "")))
|
|
token_overlap = self._jaccard_similarity(query_tokens, fact_tokens)
|
|
entity_overlap = self._entity_overlap_score(fact["content"], query_entities)
|
|
trust = float(fact.get("trust_score", 0.0))
|
|
rerank_score = (
|
|
0.5 * float(fact.get("fused_score", 0.0))
|
|
+ 0.25 * token_overlap
|
|
+ 0.15 * entity_overlap
|
|
+ 0.10 * trust
|
|
)
|
|
reranked_fact = dict(fact)
|
|
reranked_fact["rerank_score"] = round(rerank_score, 6)
|
|
reranked.append(reranked_fact)
|
|
reranked.sort(key=lambda item: (item["rerank_score"], item["fused_score"]), reverse=True)
|
|
return reranked
|
|
|
|
# ------------------------------------------------------------------
|
|
# Retrieval lanes
|
|
# ------------------------------------------------------------------
|
|
|
|
def _lexical_lane(
|
|
self,
|
|
query: str,
|
|
*,
|
|
category: str | None,
|
|
min_trust: float,
|
|
limit: int,
|
|
) -> list[dict[str, Any]]:
|
|
candidates = self._fts_candidates(query, category, min_trust, limit)
|
|
query_tokens = self._tokenize(query)
|
|
scored = []
|
|
for fact in candidates:
|
|
fact_tokens = self._tokenize(fact["content"] + " " + fact.get("tags", ""))
|
|
jaccard = self._jaccard_similarity(query_tokens, fact_tokens)
|
|
lane_score = 0.7 * float(fact.get("fts_rank", 0.0)) + 0.3 * jaccard
|
|
if lane_score <= 0:
|
|
continue
|
|
lane_fact = self._clean_fact(fact)
|
|
lane_fact["lane_score"] = round(lane_score, 6)
|
|
lane_fact["lane_name"] = "lexical"
|
|
scored.append(lane_fact)
|
|
scored.sort(key=lambda item: (item["lane_score"], item.get("trust_score", 0.0)), reverse=True)
|
|
return scored[:limit]
|
|
|
|
def _semantic_lane(
|
|
self,
|
|
query: str,
|
|
*,
|
|
category: str | None,
|
|
min_trust: float,
|
|
limit: int,
|
|
) -> list[dict[str, Any]]:
|
|
query_tokens = self._expand_semantic_tokens(self._tokenize(query))
|
|
rows = self._fetch_rows(category=category, min_trust=min_trust, require_vectors=False)
|
|
scored = []
|
|
for fact in rows:
|
|
fact_tokens = self._expand_semantic_tokens(self._tokenize(fact["content"] + " " + fact.get("tags", "")))
|
|
token_sim = self._jaccard_similarity(query_tokens, fact_tokens)
|
|
hrr_sim = 0.0
|
|
if self.hrr_weight > 0 and fact.get("hrr_vector"):
|
|
query_vec = hrr.encode_text(query, self.hrr_dim)
|
|
fact_vec = hrr.bytes_to_phases(fact["hrr_vector"])
|
|
hrr_sim = (hrr.similarity(query_vec, fact_vec) + 1.0) / 2.0
|
|
lane_score = 0.55 * token_sim + 0.45 * hrr_sim
|
|
else:
|
|
lane_score = token_sim
|
|
if token_sim <= 0 and hrr_sim < 0.65:
|
|
continue
|
|
if lane_score <= 0:
|
|
continue
|
|
lane_fact = self._clean_fact(fact)
|
|
lane_fact["lane_score"] = round(lane_score, 6)
|
|
lane_fact["lane_name"] = "semantic"
|
|
scored.append(lane_fact)
|
|
scored.sort(key=lambda item: (item["lane_score"], item.get("trust_score", 0.0)), reverse=True)
|
|
return scored[:limit]
|
|
|
|
def _graph_lane(
|
|
self,
|
|
query: str,
|
|
*,
|
|
category: str | None,
|
|
min_trust: float,
|
|
limit: int,
|
|
) -> list[dict[str, Any]]:
|
|
conn = self.store._conn
|
|
query_entities = self._extract_query_entities(query)
|
|
if not query_entities:
|
|
return []
|
|
|
|
matched_entity_ids: set[int] = set()
|
|
for entity in query_entities:
|
|
matched_entity_ids.update(self._lookup_entity_ids(entity))
|
|
if not matched_entity_ids:
|
|
return []
|
|
|
|
direct_scores: defaultdict[int, float] = defaultdict(float)
|
|
bridge_scores: defaultdict[int, float] = defaultdict(float)
|
|
related_entity_ids: set[int] = set()
|
|
|
|
for entity_id in matched_entity_ids:
|
|
fact_rows = conn.execute(
|
|
"SELECT fact_id FROM fact_entities WHERE entity_id = ?",
|
|
(entity_id,),
|
|
).fetchall()
|
|
for row in fact_rows:
|
|
fact_id = int(row["fact_id"])
|
|
direct_scores[fact_id] += 1.0
|
|
neighbor_rows = conn.execute(
|
|
"SELECT entity_id FROM fact_entities WHERE fact_id = ? AND entity_id != ?",
|
|
(fact_id, entity_id),
|
|
).fetchall()
|
|
related_entity_ids.update(int(neighbor["entity_id"]) for neighbor in neighbor_rows)
|
|
|
|
for entity_id in related_entity_ids - matched_entity_ids:
|
|
fact_rows = conn.execute(
|
|
"SELECT fact_id FROM fact_entities WHERE entity_id = ?",
|
|
(entity_id,),
|
|
).fetchall()
|
|
for row in fact_rows:
|
|
fact_id = int(row["fact_id"])
|
|
if fact_id in direct_scores:
|
|
continue
|
|
bridge_scores[fact_id] += 0.65
|
|
|
|
candidate_ids = list({*direct_scores.keys(), *bridge_scores.keys()})
|
|
if not candidate_ids:
|
|
return []
|
|
|
|
rows = self._fetch_rows_by_ids(candidate_ids, category=category, min_trust=min_trust, require_vectors=False)
|
|
query_tokens = self._expand_semantic_tokens(self._tokenize(query))
|
|
scored = []
|
|
for fact in rows:
|
|
fact_id = int(fact["fact_id"])
|
|
fact_tokens = self._expand_semantic_tokens(self._tokenize(fact["content"] + " " + fact.get("tags", "")))
|
|
lexical_support = self._jaccard_similarity(query_tokens, fact_tokens)
|
|
lane_score = direct_scores[fact_id] + bridge_scores[fact_id] + 0.2 * lexical_support
|
|
if lane_score <= 0:
|
|
continue
|
|
lane_fact = self._clean_fact(fact)
|
|
lane_fact["lane_score"] = round(lane_score, 6)
|
|
lane_fact["lane_name"] = "graph"
|
|
scored.append(lane_fact)
|
|
scored.sort(key=lambda item: (item["lane_score"], item.get("trust_score", 0.0)), reverse=True)
|
|
return scored[:limit]
|
|
|
|
def _temporal_lane(
|
|
self,
|
|
query: str,
|
|
*,
|
|
category: str | None,
|
|
min_trust: float,
|
|
limit: int,
|
|
) -> list[dict[str, Any]]:
|
|
rows = self._fetch_rows(category=category, min_trust=min_trust, require_vectors=False)
|
|
query_tokens = self._expand_semantic_tokens(self._tokenize(query))
|
|
scored = []
|
|
for fact in rows:
|
|
fact_tokens = self._expand_semantic_tokens(self._tokenize(fact["content"] + " " + fact.get("tags", "")))
|
|
lexical_support = self._jaccard_similarity(query_tokens, fact_tokens)
|
|
if lexical_support <= 0:
|
|
continue
|
|
timestamp = fact.get("updated_at") or fact.get("created_at")
|
|
recency = self._temporal_decay(timestamp, half_life=self.half_life or _DEFAULT_TEMPORAL_HALF_LIFE)
|
|
lane_score = 0.7 * recency + 0.3 * lexical_support
|
|
lane_fact = self._clean_fact(fact)
|
|
lane_fact["lane_score"] = round(lane_score, 6)
|
|
lane_fact["lane_name"] = "temporal"
|
|
scored.append(lane_fact)
|
|
scored.sort(key=lambda item: (item["lane_score"], item.get("trust_score", 0.0)), reverse=True)
|
|
return scored[:limit]
|
|
|
|
# ------------------------------------------------------------------
|
|
# Existing algebraic retrieval APIs
|
|
# ------------------------------------------------------------------
|
|
|
|
def probe(
|
|
self,
|
|
entity: str,
|
|
category: str | None = None,
|
|
limit: int = 10,
|
|
) -> list[dict]:
|
|
"""Compositional entity query using HRR algebra."""
|
|
if not hrr._HAS_NUMPY:
|
|
return self.search(entity, category=category, limit=limit, lanes=["lexical", "graph"])
|
|
|
|
conn = self.store._conn
|
|
role_entity = hrr.encode_atom("__hrr_role_entity__", self.hrr_dim)
|
|
entity_vec = hrr.encode_atom(entity.lower(), self.hrr_dim)
|
|
probe_key = hrr.bind(entity_vec, role_entity)
|
|
|
|
if category:
|
|
bank_name = f"cat:{category}"
|
|
bank_row = conn.execute(
|
|
"SELECT vector FROM memory_banks WHERE bank_name = ?",
|
|
(bank_name,),
|
|
).fetchone()
|
|
if bank_row:
|
|
bank_vec = hrr.bytes_to_phases(bank_row["vector"])
|
|
extracted = hrr.unbind(bank_vec, probe_key)
|
|
return self._score_facts_by_vector(extracted, category=category, limit=limit)
|
|
|
|
rows = self._fetch_rows(category=category, min_trust=0.0, require_vectors=True)
|
|
if not rows:
|
|
return self.search(entity, category=category, limit=limit, lanes=["lexical", "graph"])
|
|
|
|
scored = []
|
|
for fact in rows:
|
|
fact_vec = hrr.bytes_to_phases(fact["hrr_vector"])
|
|
residual = hrr.unbind(fact_vec, probe_key)
|
|
role_content = hrr.encode_atom("__hrr_role_content__", self.hrr_dim)
|
|
content_vec = hrr.bind(hrr.encode_text(fact["content"], self.hrr_dim), role_content)
|
|
sim = hrr.similarity(residual, content_vec)
|
|
item = self._clean_fact(fact)
|
|
item["score"] = (sim + 1.0) / 2.0 * item["trust_score"]
|
|
scored.append(item)
|
|
|
|
scored.sort(key=lambda item: item["score"], reverse=True)
|
|
return scored[:limit]
|
|
|
|
def related(
|
|
self,
|
|
entity: str,
|
|
category: str | None = None,
|
|
limit: int = 10,
|
|
) -> list[dict]:
|
|
"""Discover facts that share structural connections with an entity."""
|
|
if not hrr._HAS_NUMPY:
|
|
return self.search(entity, category=category, limit=limit, lanes=["graph", "lexical"])
|
|
|
|
entity_vec = hrr.encode_atom(entity.lower(), self.hrr_dim)
|
|
rows = self._fetch_rows(category=category, min_trust=0.0, require_vectors=True)
|
|
if not rows:
|
|
return self.search(entity, category=category, limit=limit, lanes=["graph", "lexical"])
|
|
|
|
scored = []
|
|
for fact in rows:
|
|
fact_vec = hrr.bytes_to_phases(fact["hrr_vector"])
|
|
residual = hrr.unbind(fact_vec, entity_vec)
|
|
role_entity = hrr.encode_atom("__hrr_role_entity__", self.hrr_dim)
|
|
role_content = hrr.encode_atom("__hrr_role_content__", self.hrr_dim)
|
|
best_sim = max(hrr.similarity(residual, role_entity), hrr.similarity(residual, role_content))
|
|
item = self._clean_fact(fact)
|
|
item["score"] = (best_sim + 1.0) / 2.0 * item["trust_score"]
|
|
scored.append(item)
|
|
|
|
scored.sort(key=lambda item: item["score"], reverse=True)
|
|
return scored[:limit]
|
|
|
|
def reason(
|
|
self,
|
|
entities: list[str],
|
|
category: str | None = None,
|
|
limit: int = 10,
|
|
) -> list[dict]:
|
|
"""Multi-entity compositional query — vector-space AND semantics."""
|
|
if not hrr._HAS_NUMPY or not entities:
|
|
return self.search(" ".join(entities), category=category, limit=limit, lanes=["graph", "semantic", "lexical"])
|
|
|
|
role_entity = hrr.encode_atom("__hrr_role_entity__", self.hrr_dim)
|
|
probe_keys = [hrr.bind(hrr.encode_atom(entity.lower(), self.hrr_dim), role_entity) for entity in entities]
|
|
rows = self._fetch_rows(category=category, min_trust=0.0, require_vectors=True)
|
|
if not rows:
|
|
return self.search(" ".join(entities), category=category, limit=limit, lanes=["graph", "semantic", "lexical"])
|
|
|
|
role_content = hrr.encode_atom("__hrr_role_content__", self.hrr_dim)
|
|
scored = []
|
|
for fact in rows:
|
|
fact_vec = hrr.bytes_to_phases(fact["hrr_vector"])
|
|
entity_scores = []
|
|
for probe_key in probe_keys:
|
|
residual = hrr.unbind(fact_vec, probe_key)
|
|
entity_scores.append(hrr.similarity(residual, role_content))
|
|
item = self._clean_fact(fact)
|
|
item["score"] = (min(entity_scores) + 1.0) / 2.0 * item["trust_score"]
|
|
scored.append(item)
|
|
|
|
scored.sort(key=lambda item: item["score"], reverse=True)
|
|
return scored[:limit]
|
|
|
|
def contradict(
|
|
self,
|
|
category: str | None = None,
|
|
threshold: float = 0.3,
|
|
limit: int = 10,
|
|
) -> list[dict]:
|
|
"""Find potentially contradictory facts via entity overlap + content divergence."""
|
|
if not hrr._HAS_NUMPY:
|
|
return []
|
|
|
|
conn = self.store._conn
|
|
rows = conn.execute(
|
|
f"""
|
|
SELECT f.fact_id, f.content, f.category, f.tags, f.trust_score,
|
|
f.created_at, f.updated_at, f.hrr_vector
|
|
FROM facts f
|
|
WHERE f.hrr_vector IS NOT NULL
|
|
{'AND f.category = ?' if category else ''}
|
|
""",
|
|
[category] if category else [],
|
|
).fetchall()
|
|
if len(rows) < 2:
|
|
return []
|
|
|
|
max_facts = 500
|
|
if len(rows) > max_facts:
|
|
rows = sorted(rows, key=lambda row: row["updated_at"] or row["created_at"], reverse=True)[:max_facts]
|
|
|
|
fact_entities: dict[int, set[str]] = {}
|
|
for row in rows:
|
|
entity_rows = conn.execute(
|
|
"""
|
|
SELECT e.name FROM entities e
|
|
JOIN fact_entities fe ON fe.entity_id = e.entity_id
|
|
WHERE fe.fact_id = ?
|
|
""",
|
|
(row["fact_id"],),
|
|
).fetchall()
|
|
fact_entities[int(row["fact_id"])] = {entity["name"].lower() for entity in entity_rows}
|
|
|
|
contradictions = []
|
|
facts = [dict(row) for row in rows]
|
|
for i in range(len(facts)):
|
|
for j in range(i + 1, len(facts)):
|
|
fact_a, fact_b = facts[i], facts[j]
|
|
entities_a = fact_entities.get(int(fact_a["fact_id"]), set())
|
|
entities_b = fact_entities.get(int(fact_b["fact_id"]), set())
|
|
if not entities_a or not entities_b:
|
|
continue
|
|
entity_overlap = len(entities_a & entities_b) / len(entities_a | entities_b) if (entities_a | entities_b) else 0.0
|
|
if entity_overlap < 0.3:
|
|
continue
|
|
similarity = hrr.similarity(hrr.bytes_to_phases(fact_a["hrr_vector"]), hrr.bytes_to_phases(fact_b["hrr_vector"]))
|
|
contradiction_score = entity_overlap * (1.0 - (similarity + 1.0) / 2.0)
|
|
if contradiction_score < threshold:
|
|
continue
|
|
contradictions.append(
|
|
{
|
|
"fact_a": self._clean_fact(fact_a),
|
|
"fact_b": self._clean_fact(fact_b),
|
|
"entity_overlap": round(entity_overlap, 3),
|
|
"content_similarity": round(similarity, 3),
|
|
"contradiction_score": round(contradiction_score, 3),
|
|
"shared_entities": sorted(entities_a & entities_b),
|
|
}
|
|
)
|
|
|
|
contradictions.sort(key=lambda item: item["contradiction_score"], reverse=True)
|
|
return contradictions[:limit]
|
|
|
|
# ------------------------------------------------------------------
|
|
# Helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _score_facts_by_vector(
|
|
self,
|
|
target_vec: "np.ndarray",
|
|
category: str | None = None,
|
|
limit: int = 10,
|
|
) -> list[dict]:
|
|
rows = self._fetch_rows(category=category, min_trust=0.0, require_vectors=True)
|
|
scored = []
|
|
for fact in rows:
|
|
sim = hrr.similarity(target_vec, hrr.bytes_to_phases(fact["hrr_vector"]))
|
|
item = self._clean_fact(fact)
|
|
item["score"] = (sim + 1.0) / 2.0 * item["trust_score"]
|
|
scored.append(item)
|
|
scored.sort(key=lambda item: item["score"], reverse=True)
|
|
return scored[:limit]
|
|
|
|
def _fetch_rows(
|
|
self,
|
|
*,
|
|
category: str | None,
|
|
min_trust: float,
|
|
require_vectors: bool,
|
|
) -> list[dict[str, Any]]:
|
|
conn = self.store._conn
|
|
where = ["trust_score >= ?"]
|
|
params: list[Any] = [min_trust]
|
|
if category:
|
|
where.append("category = ?")
|
|
params.append(category)
|
|
if require_vectors:
|
|
where.append("hrr_vector IS NOT NULL")
|
|
sql = f"SELECT * FROM facts WHERE {' AND '.join(where)}"
|
|
return [dict(row) for row in conn.execute(sql, params).fetchall()]
|
|
|
|
def _fetch_rows_by_ids(
|
|
self,
|
|
fact_ids: list[int],
|
|
*,
|
|
category: str | None,
|
|
min_trust: float,
|
|
require_vectors: bool,
|
|
) -> list[dict[str, Any]]:
|
|
if not fact_ids:
|
|
return []
|
|
conn = self.store._conn
|
|
placeholders = ",".join("?" * len(fact_ids))
|
|
where = [f"fact_id IN ({placeholders})", "trust_score >= ?"]
|
|
params: list[Any] = list(fact_ids) + [min_trust]
|
|
if category:
|
|
where.append("category = ?")
|
|
params.append(category)
|
|
if require_vectors:
|
|
where.append("hrr_vector IS NOT NULL")
|
|
sql = f"SELECT * FROM facts WHERE {' AND '.join(where)}"
|
|
return [dict(row) for row in conn.execute(sql, params).fetchall()]
|
|
|
|
def _fts_candidates(
|
|
self,
|
|
query: str,
|
|
category: str | None,
|
|
min_trust: float,
|
|
limit: int,
|
|
) -> list[dict]:
|
|
conn = self.store._conn
|
|
params: list[Any] = [query]
|
|
where_clauses = ["facts_fts MATCH ?"]
|
|
if category:
|
|
where_clauses.append("f.category = ?")
|
|
params.append(category)
|
|
where_clauses.append("f.trust_score >= ?")
|
|
params.append(min_trust)
|
|
sql = f"""
|
|
SELECT f.*, facts_fts.rank AS fts_rank_raw
|
|
FROM facts_fts
|
|
JOIN facts f ON f.fact_id = facts_fts.rowid
|
|
WHERE {' AND '.join(where_clauses)}
|
|
ORDER BY facts_fts.rank
|
|
LIMIT ?
|
|
"""
|
|
params.append(limit)
|
|
try:
|
|
rows = conn.execute(sql, params).fetchall()
|
|
except Exception:
|
|
rows = self._like_fallback_candidates(query, category=category, min_trust=min_trust, limit=limit)
|
|
return rows
|
|
if not rows:
|
|
return []
|
|
|
|
raw_ranks = [abs(row["fts_rank_raw"]) for row in rows]
|
|
max_rank = max(raw_ranks) if raw_ranks else 1.0
|
|
max_rank = max(max_rank, 1e-6)
|
|
results = []
|
|
for row, raw_rank in zip(rows, raw_ranks):
|
|
fact = dict(row)
|
|
fact.pop("fts_rank_raw", None)
|
|
# Higher is better. The legacy path keeps the old normalization bug;
|
|
# the multi-path lexical lane uses the corrected value.
|
|
fact["fts_rank"] = max(0.0, 1.0 - (raw_rank / max_rank))
|
|
results.append(fact)
|
|
return results
|
|
|
|
def _like_fallback_candidates(
|
|
self,
|
|
query: str,
|
|
*,
|
|
category: str | None,
|
|
min_trust: float,
|
|
limit: int,
|
|
) -> list[dict[str, Any]]:
|
|
terms = [term for term in self._tokenize(query) if len(term) > 2]
|
|
if not terms:
|
|
return []
|
|
conn = self.store._conn
|
|
where = ["trust_score >= ?"]
|
|
params: list[Any] = [min_trust]
|
|
if category:
|
|
where.append("category = ?")
|
|
params.append(category)
|
|
like_clauses = []
|
|
for term in terms[:4]:
|
|
like_clauses.append("(lower(content) LIKE ? OR lower(tags) LIKE ?)")
|
|
params.extend([f"%{term}%", f"%{term}%"])
|
|
where.append("(" + " OR ".join(like_clauses) + ")")
|
|
sql = f"SELECT * FROM facts WHERE {' AND '.join(where)} ORDER BY trust_score DESC, updated_at DESC LIMIT ?"
|
|
params.append(limit)
|
|
rows = [dict(row) for row in conn.execute(sql, params).fetchall()]
|
|
for rank, fact in enumerate(rows, start=1):
|
|
fact["fts_rank"] = 1.0 / (rank + 1)
|
|
return rows
|
|
|
|
def _lookup_entity_ids(self, name: str) -> set[int]:
|
|
lowered = name.lower()
|
|
conn = self.store._conn
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT entity_id FROM entities
|
|
WHERE lower(name) = ?
|
|
OR ',' || lower(aliases) || ',' LIKE '%,' || ? || ',%'
|
|
""",
|
|
(lowered, lowered),
|
|
).fetchall()
|
|
return {int(row["entity_id"]) for row in rows}
|
|
|
|
def _extract_query_entities(self, query: str) -> list[str]:
|
|
entities = []
|
|
seen: set[str] = set()
|
|
for match in _ENTITY_TOKEN_RE.finditer(query):
|
|
candidate = match.group(1).strip()
|
|
lowered = candidate.lower()
|
|
if len(candidate) < 3 or lowered in {"what", "which", "when", "where", "who", "how", "the"}:
|
|
continue
|
|
if lowered not in seen:
|
|
seen.add(lowered)
|
|
entities.append(candidate)
|
|
return entities
|
|
|
|
def _expand_semantic_tokens(self, tokens: set[str]) -> set[str]:
|
|
expanded = set(tokens)
|
|
for token in list(tokens):
|
|
expanded.update(_SEMANTIC_SYNONYMS.get(token, set()))
|
|
return expanded
|
|
|
|
def _entity_overlap_score(self, content: str, query_entities: set[str]) -> float:
|
|
if not query_entities:
|
|
return 0.0
|
|
content_entities = {entity.lower() for entity in self._extract_query_entities(content)}
|
|
if not content_entities:
|
|
return 0.0
|
|
return self._jaccard_similarity(query_entities, content_entities)
|
|
|
|
def _matches_expected(self, results: list[dict[str, Any]], case: dict[str, Any]) -> bool:
|
|
expected = str(case.get("expected_substring", "")).lower()
|
|
if not expected:
|
|
return False
|
|
return any(expected in result.get("content", "").lower() for result in results)
|
|
|
|
def _clean_fact(self, fact: dict[str, Any]) -> dict[str, Any]:
|
|
cleaned = {key: value for key, value in dict(fact).items() if key != "hrr_vector"}
|
|
cleaned.pop("fts_rank", None)
|
|
return cleaned
|
|
|
|
@staticmethod
|
|
def _tokenize(text: str) -> set[str]:
|
|
if not text:
|
|
return set()
|
|
tokens = set()
|
|
for word in text.lower().split():
|
|
cleaned = word.strip(".,;:!?\"'()[]{}#@<>")
|
|
if cleaned:
|
|
tokens.add(cleaned)
|
|
return tokens
|
|
|
|
@staticmethod
|
|
def _jaccard_similarity(set_a: set, set_b: set) -> float:
|
|
if not set_a or not set_b:
|
|
return 0.0
|
|
intersection = len(set_a & set_b)
|
|
union = len(set_a | set_b)
|
|
return intersection / union if union > 0 else 0.0
|
|
|
|
def _temporal_decay(self, timestamp_str: str | None, *, half_life: int | None = None) -> float:
|
|
decay_half_life = half_life if half_life is not None else self.half_life
|
|
if not decay_half_life or not timestamp_str:
|
|
return 1.0
|
|
try:
|
|
ts = datetime.fromisoformat(str(timestamp_str).replace("Z", "+00:00"))
|
|
if ts.tzinfo is None:
|
|
ts = ts.replace(tzinfo=timezone.utc)
|
|
age_days = (datetime.now(timezone.utc) - ts).total_seconds() / 86400
|
|
if age_days < 0:
|
|
return 1.0
|
|
return math.pow(0.5, age_days / decay_half_life)
|
|
except (ValueError, TypeError):
|
|
return 1.0
|