forked from Rockachopa/Timmy-time-dashboard
301 lines
10 KiB
Python
301 lines
10 KiB
Python
"""SemanticMemory and MemorySearcher — vector-based search over vault content.
|
|
|
|
SemanticMemory: indexes markdown files into chunks with embeddings, supports search.
|
|
MemorySearcher: high-level multi-tier search interface.
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import sqlite3
|
|
from collections.abc import Generator
|
|
from contextlib import closing, contextmanager
|
|
from datetime import UTC, datetime
|
|
from pathlib import Path
|
|
|
|
from config import settings
|
|
from timmy.memory.db import DB_PATH, VAULT_PATH, get_connection
|
|
from timmy.memory.embeddings import (
|
|
EMBEDDING_DIM,
|
|
_get_embedding_model,
|
|
cosine_similarity,
|
|
embed_text,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SemanticMemory:
|
|
"""Vector-based semantic search over vault content."""
|
|
|
|
def __init__(self) -> None:
|
|
self.db_path = DB_PATH
|
|
self.vault_path = VAULT_PATH
|
|
|
|
@contextmanager
|
|
def _get_conn(self) -> Generator[sqlite3.Connection, None, None]:
|
|
"""Get connection to the instance's db_path (backward compatibility).
|
|
|
|
Uses self.db_path if set differently from global DB_PATH,
|
|
otherwise uses the global get_connection().
|
|
"""
|
|
if self.db_path == DB_PATH:
|
|
# Use global connection (normal production path)
|
|
with get_connection() as conn:
|
|
yield conn
|
|
else:
|
|
# Use instance-specific db_path (test path)
|
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
with closing(sqlite3.connect(str(self.db_path))) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
conn.execute("PRAGMA journal_mode=WAL")
|
|
conn.execute(f"PRAGMA busy_timeout={settings.db_busy_timeout_ms}")
|
|
# Ensure schema exists
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS memories (
|
|
id TEXT PRIMARY KEY,
|
|
content TEXT NOT NULL,
|
|
memory_type TEXT NOT NULL DEFAULT 'fact',
|
|
source TEXT NOT NULL DEFAULT 'agent',
|
|
embedding TEXT,
|
|
metadata TEXT,
|
|
source_hash TEXT,
|
|
agent_id TEXT,
|
|
task_id TEXT,
|
|
session_id TEXT,
|
|
confidence REAL NOT NULL DEFAULT 0.8,
|
|
tags TEXT NOT NULL DEFAULT '[]',
|
|
created_at TEXT NOT NULL,
|
|
last_accessed TEXT,
|
|
access_count INTEGER NOT NULL DEFAULT 0
|
|
)
|
|
""")
|
|
conn.execute(
|
|
"CREATE INDEX IF NOT EXISTS idx_memories_type ON memories(memory_type)"
|
|
)
|
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_time ON memories(created_at)")
|
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_source ON memories(source)")
|
|
conn.commit()
|
|
yield conn
|
|
|
|
def _init_db(self) -> None:
|
|
"""Initialize database at self.db_path (backward compatibility).
|
|
|
|
This method is kept for backward compatibility with existing code and tests.
|
|
Schema creation is handled by _get_conn.
|
|
"""
|
|
# Trigger schema creation via _get_conn
|
|
with self._get_conn():
|
|
pass
|
|
|
|
def index_file(self, filepath: Path) -> int:
|
|
"""Index a single file into semantic memory."""
|
|
if not filepath.exists():
|
|
return 0
|
|
|
|
content = filepath.read_text()
|
|
file_hash = hashlib.md5(content.encode()).hexdigest()
|
|
|
|
with self._get_conn() as conn:
|
|
# Check if already indexed with same hash
|
|
cursor = conn.execute(
|
|
"SELECT metadata FROM memories WHERE source = ? AND memory_type = 'vault_chunk' LIMIT 1",
|
|
(str(filepath),),
|
|
)
|
|
existing = cursor.fetchone()
|
|
if existing and existing[0]:
|
|
try:
|
|
meta = json.loads(existing[0])
|
|
if meta.get("source_hash") == file_hash:
|
|
return 0 # Already indexed
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Delete old chunks for this file
|
|
conn.execute(
|
|
"DELETE FROM memories WHERE source = ? AND memory_type = 'vault_chunk'",
|
|
(str(filepath),),
|
|
)
|
|
|
|
# Split into chunks (paragraphs)
|
|
chunks = self._split_into_chunks(content)
|
|
|
|
# Index each chunk
|
|
now = datetime.now(UTC).isoformat()
|
|
for i, chunk_text in enumerate(chunks):
|
|
if len(chunk_text.strip()) < 20: # Skip tiny chunks
|
|
continue
|
|
|
|
chunk_id = f"{filepath.stem}_{i}"
|
|
chunk_embedding = embed_text(chunk_text)
|
|
|
|
conn.execute(
|
|
"""INSERT INTO memories
|
|
(id, content, memory_type, source, metadata, embedding, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
|
(
|
|
chunk_id,
|
|
chunk_text,
|
|
"vault_chunk",
|
|
str(filepath),
|
|
json.dumps({"source_hash": file_hash, "chunk_index": i}),
|
|
json.dumps(chunk_embedding),
|
|
now,
|
|
),
|
|
)
|
|
|
|
conn.commit()
|
|
|
|
logger.info("SemanticMemory: Indexed %s (%d chunks)", filepath.name, len(chunks))
|
|
return len(chunks)
|
|
|
|
def _split_into_chunks(self, text: str, max_chunk_size: int = 500) -> list[str]:
|
|
"""Split text into semantic chunks."""
|
|
# Split by paragraphs first
|
|
paragraphs = text.split("\n\n")
|
|
chunks = []
|
|
|
|
for para in paragraphs:
|
|
para = para.strip()
|
|
if not para:
|
|
continue
|
|
|
|
# If paragraph is small enough, keep as one chunk
|
|
if len(para) <= max_chunk_size:
|
|
chunks.append(para)
|
|
else:
|
|
# Split long paragraphs by sentences
|
|
sentences = para.replace(". ", ".\n").split("\n")
|
|
current_chunk = ""
|
|
|
|
for sent in sentences:
|
|
if len(current_chunk) + len(sent) < max_chunk_size:
|
|
current_chunk += " " + sent if current_chunk else sent
|
|
else:
|
|
if current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
current_chunk = sent
|
|
|
|
if current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
|
|
return chunks
|
|
|
|
def index_vault(self) -> int:
|
|
"""Index entire vault directory."""
|
|
total_chunks = 0
|
|
|
|
for md_file in self.vault_path.rglob("*.md"):
|
|
# Skip handoff file (handled separately)
|
|
if "last-session-handoff" in md_file.name:
|
|
continue
|
|
total_chunks += self.index_file(md_file)
|
|
|
|
logger.info("SemanticMemory: Indexed vault (%d total chunks)", total_chunks)
|
|
return total_chunks
|
|
|
|
def search(self, query: str, top_k: int = 5) -> list[tuple[str, float]]:
|
|
"""Search for relevant memory chunks."""
|
|
query_embedding = embed_text(query)
|
|
|
|
with self._get_conn() as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
|
|
# Get all vault chunks
|
|
rows = conn.execute(
|
|
"SELECT source, content, embedding FROM memories WHERE memory_type = 'vault_chunk'"
|
|
).fetchall()
|
|
|
|
# Calculate similarities
|
|
scored = []
|
|
for row in rows:
|
|
embedding = json.loads(row["embedding"])
|
|
score = cosine_similarity(query_embedding, embedding)
|
|
scored.append((row["source"], row["content"], score))
|
|
|
|
# Sort by score descending
|
|
scored.sort(key=lambda x: x[2], reverse=True)
|
|
|
|
# Return top_k
|
|
return [(content, score) for _, content, score in scored[:top_k]]
|
|
|
|
def get_relevant_context(self, query: str, max_chars: int = 2000) -> str:
|
|
"""Get formatted context string for a query."""
|
|
results = self.search(query, top_k=3)
|
|
|
|
if not results:
|
|
return ""
|
|
|
|
parts = []
|
|
total_chars = 0
|
|
|
|
for content, score in results:
|
|
if score < 0.3: # Similarity threshold
|
|
continue
|
|
|
|
chunk = f"[Relevant memory - score {score:.2f}]: {content[:400]}..."
|
|
if total_chars + len(chunk) > max_chars:
|
|
break
|
|
|
|
parts.append(chunk)
|
|
total_chars += len(chunk)
|
|
|
|
return "\n\n".join(parts) if parts else ""
|
|
|
|
def stats(self) -> dict:
|
|
"""Get indexing statistics."""
|
|
with self._get_conn() as conn:
|
|
cursor = conn.execute(
|
|
"SELECT COUNT(*), COUNT(DISTINCT source) FROM memories WHERE memory_type = 'vault_chunk'"
|
|
)
|
|
total_chunks, total_files = cursor.fetchone()
|
|
|
|
return {
|
|
"total_chunks": total_chunks,
|
|
"total_files": total_files,
|
|
"embedding_dim": EMBEDDING_DIM if _get_embedding_model() else 128,
|
|
}
|
|
|
|
|
|
class MemorySearcher:
|
|
"""High-level interface for memory search."""
|
|
|
|
def __init__(self) -> None:
|
|
self.semantic = SemanticMemory()
|
|
|
|
def search(self, query: str, tiers: list[str] = None) -> dict:
|
|
"""Search across memory tiers.
|
|
|
|
Args:
|
|
query: Search query
|
|
tiers: List of tiers to search ["hot", "vault", "semantic"]
|
|
|
|
Returns:
|
|
Dict with results from each tier
|
|
"""
|
|
tiers = tiers or ["semantic"] # Default to semantic only
|
|
results = {}
|
|
|
|
if "semantic" in tiers:
|
|
semantic_results = self.semantic.search(query, top_k=5)
|
|
results["semantic"] = [
|
|
{"content": content, "score": score} for content, score in semantic_results
|
|
]
|
|
|
|
return results
|
|
|
|
def get_context_for_query(self, query: str) -> str:
|
|
"""Get comprehensive context for a user query."""
|
|
# Get semantic context
|
|
semantic_context = self.semantic.get_relevant_context(query)
|
|
|
|
if semantic_context:
|
|
return f"## Relevant Past Context\n\n{semantic_context}"
|
|
|
|
return ""
|
|
|
|
|
|
# Module-level singletons
|
|
semantic_memory = SemanticMemory()
|
|
memory_searcher = MemorySearcher()
|