diff --git a/src/timmy/memory_system.py b/src/timmy/memory_system.py index d9be52a..fe01d5e 100644 --- a/src/timmy/memory_system.py +++ b/src/timmy/memory_system.py @@ -303,6 +303,85 @@ def store_memory( return entry +def _build_search_filters( + context_type: str | None, + agent_id: str | None, + session_id: str | None, +) -> tuple[str, list]: + """Build SQL WHERE clause and params from search filters.""" + conditions: list[str] = [] + params: list = [] + + if context_type: + conditions.append("memory_type = ?") + params.append(context_type) + if agent_id: + conditions.append("agent_id = ?") + params.append(agent_id) + if session_id: + conditions.append("session_id = ?") + params.append(session_id) + + where_clause = "WHERE " + " AND ".join(conditions) if conditions else "" + return where_clause, params + + +def _fetch_memory_candidates( + where_clause: str, params: list, candidate_limit: int +) -> list[sqlite3.Row]: + """Fetch candidate memory rows from the database.""" + query_sql = f""" + SELECT * FROM memories + {where_clause} + ORDER BY created_at DESC + LIMIT ? + """ + params.append(candidate_limit) + + with get_connection() as conn: + return conn.execute(query_sql, params).fetchall() + + +def _row_to_entry(row: sqlite3.Row) -> MemoryEntry: + """Convert a database row to a MemoryEntry.""" + return MemoryEntry( + id=row["id"], + content=row["content"], + source=row["source"], + context_type=row["memory_type"], # DB column -> API field + agent_id=row["agent_id"], + task_id=row["task_id"], + session_id=row["session_id"], + metadata=json.loads(row["metadata"]) if row["metadata"] else None, + embedding=json.loads(row["embedding"]) if row["embedding"] else None, + timestamp=row["created_at"], + ) + + +def _score_and_filter( + rows: list[sqlite3.Row], + query: str, + query_embedding: list[float], + min_relevance: float, +) -> list[MemoryEntry]: + """Score candidate rows by similarity and filter by min_relevance.""" + results = [] + for row in rows: + entry = _row_to_entry(row) + + if entry.embedding: + score = cosine_similarity(query_embedding, entry.embedding) + else: + score = _keyword_overlap(query, entry.content) + + entry.relevance_score = score + if score >= min_relevance: + results.append(entry) + + results.sort(key=lambda x: x.relevance_score or 0, reverse=True) + return results + + def search_memories( query: str, limit: int = 10, @@ -325,65 +404,9 @@ def search_memories( List of MemoryEntry objects sorted by relevance """ query_embedding = embed_text(query) - - # Build query with filters - conditions = [] - params = [] - - if context_type: - conditions.append("memory_type = ?") - params.append(context_type) - if agent_id: - conditions.append("agent_id = ?") - params.append(agent_id) - if session_id: - conditions.append("session_id = ?") - params.append(session_id) - - where_clause = "WHERE " + " AND ".join(conditions) if conditions else "" - - # Fetch candidates (we'll do in-memory similarity for now) - query_sql = f""" - SELECT * FROM memories - {where_clause} - ORDER BY created_at DESC - LIMIT ? - """ - params.append(limit * 3) # Get more candidates for ranking - - with get_connection() as conn: - rows = conn.execute(query_sql, params).fetchall() - - # Compute similarity scores - results = [] - for row in rows: - entry = MemoryEntry( - id=row["id"], - content=row["content"], - source=row["source"], - context_type=row["memory_type"], # DB column -> API field - agent_id=row["agent_id"], - task_id=row["task_id"], - session_id=row["session_id"], - metadata=json.loads(row["metadata"]) if row["metadata"] else None, - embedding=json.loads(row["embedding"]) if row["embedding"] else None, - timestamp=row["created_at"], - ) - - if entry.embedding: - score = cosine_similarity(query_embedding, entry.embedding) - entry.relevance_score = score - if score >= min_relevance: - results.append(entry) - else: - # Fallback: check for keyword overlap - score = _keyword_overlap(query, entry.content) - entry.relevance_score = score - if score >= min_relevance: - results.append(entry) - - # Sort by relevance and return top results - results.sort(key=lambda x: x.relevance_score or 0, reverse=True) + where_clause, params = _build_search_filters(context_type, agent_id, session_id) + rows = _fetch_memory_candidates(where_clause, params, limit * 3) + results = _score_and_filter(rows, query, query_embedding, min_relevance) return results[:limit]