forked from Rockachopa/Timmy-time-dashboard
[loop-cycle-50] refactor: replace bare sqlite3.connect() with context managers batch 2 (#157) (#180)
This commit is contained in:
@@ -11,6 +11,8 @@ All three tables live in ``data/memory.db``. Existing APIs in
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing, contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -18,15 +20,16 @@ logger = logging.getLogger(__name__)
|
||||
DB_PATH = Path(__file__).parent.parent.parent.parent / "data" / "memory.db"
|
||||
|
||||
|
||||
def get_connection() -> sqlite3.Connection:
|
||||
@contextmanager
|
||||
def get_connection() -> Generator[sqlite3.Connection, None, None]:
|
||||
"""Open (and lazily create) the unified memory database."""
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
_ensure_schema(conn)
|
||||
return conn
|
||||
with closing(sqlite3.connect(str(DB_PATH))) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
_ensure_schema(conn)
|
||||
yield conn
|
||||
|
||||
|
||||
def _ensure_schema(conn: sqlite3.Connection) -> None:
|
||||
|
||||
@@ -8,6 +8,8 @@ import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
|
||||
@@ -54,11 +56,13 @@ class MemoryEntry:
|
||||
relevance_score: float | None = None # Set during search
|
||||
|
||||
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
@contextmanager
|
||||
def _get_conn() -> Generator[sqlite3.Connection, None, None]:
|
||||
"""Get database connection to unified memory.db."""
|
||||
from timmy.memory.unified import get_connection
|
||||
|
||||
return get_connection()
|
||||
with get_connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
def store_memory(
|
||||
@@ -101,29 +105,28 @@ def store_memory(
|
||||
embedding=embedding,
|
||||
)
|
||||
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO episodes
|
||||
(id, content, source, context_type, agent_id, task_id, session_id,
|
||||
metadata, embedding, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
entry.id,
|
||||
entry.content,
|
||||
entry.source,
|
||||
entry.context_type,
|
||||
entry.agent_id,
|
||||
entry.task_id,
|
||||
entry.session_id,
|
||||
json.dumps(metadata) if metadata else None,
|
||||
json.dumps(embedding) if embedding else None,
|
||||
entry.timestamp,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
with _get_conn() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO episodes
|
||||
(id, content, source, context_type, agent_id, task_id, session_id,
|
||||
metadata, embedding, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
entry.id,
|
||||
entry.content,
|
||||
entry.source,
|
||||
entry.context_type,
|
||||
entry.agent_id,
|
||||
entry.task_id,
|
||||
entry.session_id,
|
||||
json.dumps(metadata) if metadata else None,
|
||||
json.dumps(embedding) if embedding else None,
|
||||
entry.timestamp,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
return entry
|
||||
|
||||
@@ -151,8 +154,6 @@ def search_memories(
|
||||
"""
|
||||
query_embedding = _compute_embedding(query)
|
||||
|
||||
conn = _get_conn()
|
||||
|
||||
# Build query with filters
|
||||
conditions = []
|
||||
params = []
|
||||
@@ -179,8 +180,8 @@ def search_memories(
|
||||
"""
|
||||
params.append(limit * 3) # Get more candidates for ranking
|
||||
|
||||
rows = conn.execute(query_sql, params).fetchall()
|
||||
conn.close()
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute(query_sql, params).fetchall()
|
||||
|
||||
# Compute similarity scores
|
||||
results = []
|
||||
@@ -275,58 +276,54 @@ def recall_personal_facts(agent_id: str | None = None) -> list[str]:
|
||||
Returns:
|
||||
List of fact strings
|
||||
"""
|
||||
conn = _get_conn()
|
||||
with _get_conn() as conn:
|
||||
if agent_id:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT content FROM episodes
|
||||
WHERE context_type = 'fact' AND agent_id = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 100
|
||||
""",
|
||||
(agent_id,),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT content FROM episodes
|
||||
WHERE context_type = 'fact'
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 100
|
||||
""",
|
||||
).fetchall()
|
||||
|
||||
if agent_id:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT content FROM episodes
|
||||
WHERE context_type = 'fact' AND agent_id = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 100
|
||||
""",
|
||||
(agent_id,),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT content FROM episodes
|
||||
WHERE context_type = 'fact'
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 100
|
||||
""",
|
||||
).fetchall()
|
||||
|
||||
conn.close()
|
||||
return [r["content"] for r in rows]
|
||||
|
||||
|
||||
def recall_personal_facts_with_ids(agent_id: str | None = None) -> list[dict]:
|
||||
"""Recall personal facts with their IDs for edit/delete operations."""
|
||||
conn = _get_conn()
|
||||
if agent_id:
|
||||
rows = conn.execute(
|
||||
"SELECT id, content FROM episodes WHERE context_type = 'fact' AND agent_id = ? ORDER BY timestamp DESC LIMIT 100",
|
||||
(agent_id,),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"SELECT id, content FROM episodes WHERE context_type = 'fact' ORDER BY timestamp DESC LIMIT 100",
|
||||
).fetchall()
|
||||
conn.close()
|
||||
with _get_conn() as conn:
|
||||
if agent_id:
|
||||
rows = conn.execute(
|
||||
"SELECT id, content FROM episodes WHERE context_type = 'fact' AND agent_id = ? ORDER BY timestamp DESC LIMIT 100",
|
||||
(agent_id,),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"SELECT id, content FROM episodes WHERE context_type = 'fact' ORDER BY timestamp DESC LIMIT 100",
|
||||
).fetchall()
|
||||
return [{"id": r["id"], "content": r["content"]} for r in rows]
|
||||
|
||||
|
||||
def update_personal_fact(memory_id: str, new_content: str) -> bool:
|
||||
"""Update a personal fact's content."""
|
||||
conn = _get_conn()
|
||||
cursor = conn.execute(
|
||||
"UPDATE episodes SET content = ? WHERE id = ? AND context_type = 'fact'",
|
||||
(new_content, memory_id),
|
||||
)
|
||||
conn.commit()
|
||||
updated = cursor.rowcount > 0
|
||||
conn.close()
|
||||
with _get_conn() as conn:
|
||||
cursor = conn.execute(
|
||||
"UPDATE episodes SET content = ? WHERE id = ? AND context_type = 'fact'",
|
||||
(new_content, memory_id),
|
||||
)
|
||||
conn.commit()
|
||||
updated = cursor.rowcount > 0
|
||||
return updated
|
||||
|
||||
|
||||
@@ -355,14 +352,13 @@ def delete_memory(memory_id: str) -> bool:
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
conn = _get_conn()
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM episodes WHERE id = ?",
|
||||
(memory_id,),
|
||||
)
|
||||
conn.commit()
|
||||
deleted = cursor.rowcount > 0
|
||||
conn.close()
|
||||
with _get_conn() as conn:
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM episodes WHERE id = ?",
|
||||
(memory_id,),
|
||||
)
|
||||
conn.commit()
|
||||
deleted = cursor.rowcount > 0
|
||||
return deleted
|
||||
|
||||
|
||||
@@ -372,22 +368,19 @@ def get_memory_stats() -> dict:
|
||||
Returns:
|
||||
Dict with counts by type, total entries, etc.
|
||||
"""
|
||||
conn = _get_conn()
|
||||
with _get_conn() as conn:
|
||||
total = conn.execute("SELECT COUNT(*) as count FROM episodes").fetchone()["count"]
|
||||
|
||||
total = conn.execute("SELECT COUNT(*) as count FROM episodes").fetchone()["count"]
|
||||
by_type = {}
|
||||
rows = conn.execute(
|
||||
"SELECT context_type, COUNT(*) as count FROM episodes GROUP BY context_type"
|
||||
).fetchall()
|
||||
for row in rows:
|
||||
by_type[row["context_type"]] = row["count"]
|
||||
|
||||
by_type = {}
|
||||
rows = conn.execute(
|
||||
"SELECT context_type, COUNT(*) as count FROM episodes GROUP BY context_type"
|
||||
).fetchall()
|
||||
for row in rows:
|
||||
by_type[row["context_type"]] = row["count"]
|
||||
|
||||
with_embeddings = conn.execute(
|
||||
"SELECT COUNT(*) as count FROM episodes WHERE embedding IS NOT NULL"
|
||||
).fetchone()["count"]
|
||||
|
||||
conn.close()
|
||||
with_embeddings = conn.execute(
|
||||
"SELECT COUNT(*) as count FROM episodes WHERE embedding IS NOT NULL"
|
||||
).fetchone()["count"]
|
||||
|
||||
return {
|
||||
"total_entries": total,
|
||||
@@ -411,24 +404,22 @@ def prune_memories(older_than_days: int = 90, keep_facts: bool = True) -> int:
|
||||
|
||||
cutoff = (datetime.now(UTC) - timedelta(days=older_than_days)).isoformat()
|
||||
|
||||
conn = _get_conn()
|
||||
with _get_conn() as conn:
|
||||
if keep_facts:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
DELETE FROM episodes
|
||||
WHERE timestamp < ? AND context_type != 'fact'
|
||||
""",
|
||||
(cutoff,),
|
||||
)
|
||||
else:
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM episodes WHERE timestamp < ?",
|
||||
(cutoff,),
|
||||
)
|
||||
|
||||
if keep_facts:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
DELETE FROM episodes
|
||||
WHERE timestamp < ? AND context_type != 'fact'
|
||||
""",
|
||||
(cutoff,),
|
||||
)
|
||||
else:
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM episodes WHERE timestamp < ?",
|
||||
(cutoff,),
|
||||
)
|
||||
|
||||
deleted = cursor.rowcount
|
||||
conn.commit()
|
||||
conn.close()
|
||||
deleted = cursor.rowcount
|
||||
conn.commit()
|
||||
|
||||
return deleted
|
||||
|
||||
Reference in New Issue
Block a user