Compare commits

...

4 Commits

Author SHA1 Message Date
kimi
ddb9c7d8ca refactor: break up search_memories() into focused helpers
Extract _build_memory_filter(), _fetch_memory_candidates(),
_row_to_entry(), and _score_and_rank() from the 82-line
search_memories() function for better readability and testability.

Fixes #554

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-19 21:17:49 -04:00
f361893fdd [loop-cycle-951] refactor: break up _migrate_schema() (#552) (#558) 2026-03-19 21:11:02 -04:00
7ad0ee17b6 refactor: break up shell.py::run() into helpers (#551)
Co-authored-by: Kimi Agent <kimi@timmy.local>
Co-committed-by: Kimi Agent <kimi@timmy.local>
2026-03-19 21:04:10 -04:00
29220b6bdd refactor: break up api_chat() into helpers (#547)
Co-authored-by: Kimi Agent <kimi@timmy.local>
Co-committed-by: Kimi Agent <kimi@timmy.local>
2026-03-19 21:02:04 -04:00
2 changed files with 218 additions and 176 deletions

View File

@@ -144,46 +144,25 @@ class ShellHand:
return None
async def run(
self,
command: str,
timeout: int | None = None,
working_dir: str | None = None,
env: dict | None = None,
) -> ShellResult:
"""Execute a shell command in a sandboxed environment.
Args:
command: The shell command to execute.
timeout: Override default timeout (seconds).
working_dir: Override default working directory.
env: Additional environment variables to set.
Returns:
ShellResult with stdout/stderr or error details.
"""
start = time.time()
# Validate
validation_error = self._validate_command(command)
if validation_error:
return ShellResult(
command=command,
success=False,
error=validation_error,
latency_ms=(time.time() - start) * 1000,
)
effective_timeout = timeout or self._default_timeout
cwd = working_dir or self._working_dir
try:
@staticmethod
def _build_run_env(env: dict | None) -> dict:
"""Merge *env* overrides into a copy of the current environment."""
import os
run_env = os.environ.copy()
if env:
run_env.update(env)
return run_env
async def _execute_subprocess(
self,
command: str,
effective_timeout: int,
cwd: str | None,
run_env: dict,
start: float,
) -> ShellResult:
"""Run *command* as a subprocess with timeout enforcement."""
proc = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
@@ -224,6 +203,41 @@ class ShellHand:
latency_ms=latency,
)
async def run(
self,
command: str,
timeout: int | None = None,
working_dir: str | None = None,
env: dict | None = None,
) -> ShellResult:
"""Execute a shell command in a sandboxed environment.
Args:
command: The shell command to execute.
timeout: Override default timeout (seconds).
working_dir: Override default working directory.
env: Additional environment variables to set.
Returns:
ShellResult with stdout/stderr or error details.
"""
start = time.time()
validation_error = self._validate_command(command)
if validation_error:
return ShellResult(
command=command,
success=False,
error=validation_error,
latency_ms=(time.time() - start) * 1000,
)
effective_timeout = timeout or self._default_timeout
cwd = working_dir or self._working_dir
try:
run_env = self._build_run_env(env)
return await self._execute_subprocess(command, effective_timeout, cwd, run_env, start)
except Exception as exc:
latency = (time.time() - start) * 1000
logger.warning("Shell command failed: %s%s", command, exc)

View File

@@ -98,29 +98,8 @@ def _get_table_columns(conn: sqlite3.Connection, table_name: str) -> set[str]:
return {row[1] for row in cursor.fetchall()}
def _migrate_schema(conn: sqlite3.Connection) -> None:
"""Migrate from old three-table schema to unified memories table.
Migration paths:
- episodes table -> memories (context_type -> memory_type)
- chunks table -> memories with memory_type='vault_chunk'
- facts table -> dropped (unused, 0 rows expected)
"""
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = {row[0] for row in cursor.fetchall()}
has_memories = "memories" in tables
has_episodes = "episodes" in tables
has_chunks = "chunks" in tables
has_facts = "facts" in tables
# Check if we need to migrate (old schema exists)
if not has_memories and (has_episodes or has_chunks or has_facts):
logger.info("Migration: Creating unified memories table")
# Schema will be created by _ensure_schema above
# Migrate episodes -> memories
if has_episodes and has_memories:
def _migrate_episodes(conn: sqlite3.Connection) -> None:
"""Migrate episodes table rows into the unified memories table."""
logger.info("Migration: Converting episodes table to memories")
try:
cols = _get_table_columns(conn, "episodes")
@@ -146,8 +125,9 @@ def _migrate_schema(conn: sqlite3.Connection) -> None:
except sqlite3.Error as exc:
logger.warning("Migration: Failed to migrate episodes: %s", exc)
# Migrate chunks -> memories as vault_chunk
if has_chunks and has_memories:
def _migrate_chunks(conn: sqlite3.Connection) -> None:
"""Migrate chunks table rows into the unified memories table."""
logger.info("Migration: Converting chunks table to memories")
try:
cols = _get_table_columns(conn, "chunks")
@@ -175,13 +155,38 @@ def _migrate_schema(conn: sqlite3.Connection) -> None:
except sqlite3.Error as exc:
logger.warning("Migration: Failed to migrate chunks: %s", exc)
# Drop old tables
if has_facts:
def _drop_legacy_table(conn: sqlite3.Connection, table: str) -> None:
"""Drop a legacy table if it exists."""
try:
conn.execute("DROP TABLE facts")
logger.info("Migration: Dropped old facts table")
conn.execute(f"DROP TABLE {table}") # noqa: S608
logger.info("Migration: Dropped old %s table", table)
except sqlite3.Error as exc:
logger.warning("Migration: Failed to drop facts: %s", exc)
logger.warning("Migration: Failed to drop %s: %s", table, exc)
def _migrate_schema(conn: sqlite3.Connection) -> None:
"""Migrate from old three-table schema to unified memories table.
Migration paths:
- episodes table -> memories (context_type -> memory_type)
- chunks table -> memories with memory_type='vault_chunk'
- facts table -> dropped (unused, 0 rows expected)
"""
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = {row[0] for row in cursor.fetchall()}
has_memories = "memories" in tables
if not has_memories and (tables & {"episodes", "chunks", "facts"}):
logger.info("Migration: Creating unified memories table")
if "episodes" in tables and has_memories:
_migrate_episodes(conn)
if "chunks" in tables and has_memories:
_migrate_chunks(conn)
if "facts" in tables:
_drop_legacy_table(conn, "facts")
conn.commit()
@@ -298,6 +303,86 @@ def store_memory(
return entry
def _build_memory_filter(
context_type: str | None,
agent_id: str | None,
session_id: str | None,
) -> tuple[str, list]:
"""Build WHERE clause and params for memory queries."""
conditions: list[str] = []
params: list = []
if context_type:
conditions.append("memory_type = ?")
params.append(context_type)
if agent_id:
conditions.append("agent_id = ?")
params.append(agent_id)
if session_id:
conditions.append("session_id = ?")
params.append(session_id)
where_clause = "WHERE " + " AND ".join(conditions) if conditions else ""
return where_clause, params
def _fetch_memory_candidates(
where_clause: str, params: list, candidate_limit: int
) -> list[sqlite3.Row]:
"""Fetch candidate memory rows from the database."""
query_sql = f"""
SELECT * FROM memories
{where_clause}
ORDER BY created_at DESC
LIMIT ?
"""
params.append(candidate_limit)
with get_connection() as conn:
return conn.execute(query_sql, params).fetchall()
def _row_to_entry(row: sqlite3.Row) -> MemoryEntry:
"""Convert a database row to a MemoryEntry."""
return MemoryEntry(
id=row["id"],
content=row["content"],
source=row["source"],
context_type=row["memory_type"], # DB column -> API field
agent_id=row["agent_id"],
task_id=row["task_id"],
session_id=row["session_id"],
metadata=json.loads(row["metadata"]) if row["metadata"] else None,
embedding=json.loads(row["embedding"]) if row["embedding"] else None,
timestamp=row["created_at"],
)
def _score_and_rank(
rows: list[sqlite3.Row],
query: str,
query_embedding: list[float],
min_relevance: float,
limit: int,
) -> list[MemoryEntry]:
"""Score candidates by similarity and return top results."""
results = []
for row in rows:
entry = _row_to_entry(row)
if entry.embedding:
score = cosine_similarity(query_embedding, entry.embedding)
else:
score = _keyword_overlap(query, entry.content)
entry.relevance_score = score
if score >= min_relevance:
results.append(entry)
results.sort(key=lambda x: x.relevance_score or 0, reverse=True)
return results[:limit]
def search_memories(
query: str,
limit: int = 10,
@@ -320,66 +405,9 @@ def search_memories(
List of MemoryEntry objects sorted by relevance
"""
query_embedding = embed_text(query)
# Build query with filters
conditions = []
params = []
if context_type:
conditions.append("memory_type = ?")
params.append(context_type)
if agent_id:
conditions.append("agent_id = ?")
params.append(agent_id)
if session_id:
conditions.append("session_id = ?")
params.append(session_id)
where_clause = "WHERE " + " AND ".join(conditions) if conditions else ""
# Fetch candidates (we'll do in-memory similarity for now)
query_sql = f"""
SELECT * FROM memories
{where_clause}
ORDER BY created_at DESC
LIMIT ?
"""
params.append(limit * 3) # Get more candidates for ranking
with get_connection() as conn:
rows = conn.execute(query_sql, params).fetchall()
# Compute similarity scores
results = []
for row in rows:
entry = MemoryEntry(
id=row["id"],
content=row["content"],
source=row["source"],
context_type=row["memory_type"], # DB column -> API field
agent_id=row["agent_id"],
task_id=row["task_id"],
session_id=row["session_id"],
metadata=json.loads(row["metadata"]) if row["metadata"] else None,
embedding=json.loads(row["embedding"]) if row["embedding"] else None,
timestamp=row["created_at"],
)
if entry.embedding:
score = cosine_similarity(query_embedding, entry.embedding)
entry.relevance_score = score
if score >= min_relevance:
results.append(entry)
else:
# Fallback: check for keyword overlap
score = _keyword_overlap(query, entry.content)
entry.relevance_score = score
if score >= min_relevance:
results.append(entry)
# Sort by relevance and return top results
results.sort(key=lambda x: x.relevance_score or 0, reverse=True)
return results[:limit]
where_clause, params = _build_memory_filter(context_type, agent_id, session_id)
rows = _fetch_memory_candidates(where_clause, params, limit * 3)
return _score_and_rank(rows, query, query_embedding, min_relevance, limit)
def delete_memory(memory_id: str) -> bool: