diff --git a/src/dashboard/routes/system.py b/src/dashboard/routes/system.py index 73352590..bbb9480b 100644 --- a/src/dashboard/routes/system.py +++ b/src/dashboard/routes/system.py @@ -209,14 +209,11 @@ async def api_swarm_status(): pending_tasks = 0 try: - db = _get_db() - try: + with _get_db() as db: row = db.execute( "SELECT COUNT(*) as cnt FROM tasks WHERE status IN ('pending_approval','approved')" ).fetchone() pending_tasks = row["cnt"] if row else 0 - finally: - db.close() except Exception: pass diff --git a/src/dashboard/routes/tasks.py b/src/dashboard/routes/tasks.py index cec14779..027d0490 100644 --- a/src/dashboard/routes/tasks.py +++ b/src/dashboard/routes/tasks.py @@ -3,7 +3,8 @@ import logging import sqlite3 import uuid -from contextlib import closing +from collections.abc import Generator +from contextlib import closing, contextmanager from datetime import datetime from pathlib import Path @@ -36,26 +37,27 @@ VALID_STATUSES = { VALID_PRIORITIES = {"low", "normal", "high", "urgent"} -def _get_db() -> sqlite3.Connection: +@contextmanager +def _get_db() -> Generator[sqlite3.Connection, None, None]: DB_PATH.parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(str(DB_PATH)) - conn.row_factory = sqlite3.Row - conn.execute(""" - CREATE TABLE IF NOT EXISTS tasks ( - id TEXT PRIMARY KEY, - title TEXT NOT NULL, - description TEXT DEFAULT '', - status TEXT DEFAULT 'pending_approval', - priority TEXT DEFAULT 'normal', - assigned_to TEXT DEFAULT '', - created_by TEXT DEFAULT 'operator', - result TEXT DEFAULT '', - created_at TEXT DEFAULT (datetime('now')), - completed_at TEXT - ) - """) - conn.commit() - return conn + with closing(sqlite3.connect(str(DB_PATH))) as conn: + conn.row_factory = sqlite3.Row + conn.execute(""" + CREATE TABLE IF NOT EXISTS tasks ( + id TEXT PRIMARY KEY, + title TEXT NOT NULL, + description TEXT DEFAULT '', + status TEXT DEFAULT 'pending_approval', + priority TEXT DEFAULT 'normal', + assigned_to TEXT DEFAULT '', + created_by TEXT DEFAULT 'operator', + result TEXT DEFAULT '', + created_at TEXT DEFAULT (datetime('now')), + completed_at TEXT + ) + """) + conn.commit() + yield conn def _row_to_dict(row: sqlite3.Row) -> dict: @@ -102,7 +104,7 @@ class _TaskView: @router.get("/tasks", response_class=HTMLResponse) async def tasks_page(request: Request): """Render the main task queue page with 3-column layout.""" - with closing(_get_db()) as db: + with _get_db() as db: pending = [ _TaskView(_row_to_dict(r)) for r in db.execute( @@ -143,7 +145,7 @@ async def tasks_page(request: Request): @router.get("/tasks/pending", response_class=HTMLResponse) async def tasks_pending(request: Request): - with closing(_get_db()) as db: + with _get_db() as db: rows = db.execute( "SELECT * FROM tasks WHERE status='pending_approval' ORDER BY created_at DESC" ).fetchall() @@ -162,7 +164,7 @@ async def tasks_pending(request: Request): @router.get("/tasks/active", response_class=HTMLResponse) async def tasks_active(request: Request): - with closing(_get_db()) as db: + with _get_db() as db: rows = db.execute( "SELECT * FROM tasks WHERE status IN ('approved','running','paused') ORDER BY created_at DESC" ).fetchall() @@ -181,7 +183,7 @@ async def tasks_active(request: Request): @router.get("/tasks/completed", response_class=HTMLResponse) async def tasks_completed(request: Request): - with closing(_get_db()) as db: + with _get_db() as db: rows = db.execute( "SELECT * FROM tasks WHERE status IN ('completed','vetoed','failed') ORDER BY completed_at DESC LIMIT 50" ).fetchall() @@ -220,7 +222,7 @@ async def create_task_form( now = datetime.utcnow().isoformat() priority = priority if priority in VALID_PRIORITIES else "normal" - with closing(_get_db()) as db: + with _get_db() as db: db.execute( "INSERT INTO tasks (id, title, description, priority, assigned_to, created_at) VALUES (?, ?, ?, ?, ?, ?)", (task_id, title, description, priority, assigned_to, now), @@ -269,7 +271,7 @@ async def modify_task( title: str = Form(...), description: str = Form(""), ): - with closing(_get_db()) as db: + with _get_db() as db: db.execute( "UPDATE tasks SET title=?, description=? WHERE id=?", (title, description, task_id), @@ -287,7 +289,7 @@ async def _set_status(request: Request, task_id: str, new_status: str): completed_at = ( datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None ) - with closing(_get_db()) as db: + with _get_db() as db: db.execute( "UPDATE tasks SET status=?, completed_at=COALESCE(?, completed_at) WHERE id=?", (new_status, completed_at, task_id), @@ -319,7 +321,7 @@ async def api_create_task(request: Request): if priority not in VALID_PRIORITIES: priority = "normal" - with closing(_get_db()) as db: + with _get_db() as db: db.execute( "INSERT INTO tasks (id, title, description, priority, assigned_to, created_by, created_at) " "VALUES (?, ?, ?, ?, ?, ?, ?)", @@ -342,7 +344,7 @@ async def api_create_task(request: Request): @router.get("/api/tasks", response_class=JSONResponse) async def api_list_tasks(): """List all tasks as JSON.""" - with closing(_get_db()) as db: + with _get_db() as db: rows = db.execute("SELECT * FROM tasks ORDER BY created_at DESC").fetchall() return JSONResponse([_row_to_dict(r) for r in rows]) @@ -358,7 +360,7 @@ async def api_update_status(task_id: str, request: Request): completed_at = ( datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None ) - with closing(_get_db()) as db: + with _get_db() as db: db.execute( "UPDATE tasks SET status=?, completed_at=COALESCE(?, completed_at) WHERE id=?", (new_status, completed_at, task_id), @@ -373,7 +375,7 @@ async def api_update_status(task_id: str, request: Request): @router.delete("/api/tasks/{task_id}", response_class=JSONResponse) async def api_delete_task(task_id: str): """Delete a task.""" - with closing(_get_db()) as db: + with _get_db() as db: cursor = db.execute("DELETE FROM tasks WHERE id=?", (task_id,)) db.commit() if cursor.rowcount == 0: @@ -389,7 +391,7 @@ async def api_delete_task(task_id: str): @router.get("/api/queue/status", response_class=JSONResponse) async def queue_status(assigned_to: str = "default"): """Return queue status for the chat panel's agent status indicator.""" - with closing(_get_db()) as db: + with _get_db() as db: running = db.execute( "SELECT * FROM tasks WHERE status='running' AND assigned_to=? LIMIT 1", (assigned_to,), diff --git a/src/dashboard/routes/work_orders.py b/src/dashboard/routes/work_orders.py index c2cd2a0b..c400acda 100644 --- a/src/dashboard/routes/work_orders.py +++ b/src/dashboard/routes/work_orders.py @@ -3,7 +3,8 @@ import logging import sqlite3 import uuid -from contextlib import closing +from collections.abc import Generator +from contextlib import closing, contextmanager from datetime import datetime from pathlib import Path @@ -24,28 +25,29 @@ CATEGORIES = ["bug", "feature", "suggestion", "maintenance", "security"] VALID_STATUSES = {"submitted", "triaged", "approved", "in_progress", "completed", "rejected"} -def _get_db() -> sqlite3.Connection: +@contextmanager +def _get_db() -> Generator[sqlite3.Connection, None, None]: DB_PATH.parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(str(DB_PATH)) - conn.row_factory = sqlite3.Row - conn.execute(""" - CREATE TABLE IF NOT EXISTS work_orders ( - id TEXT PRIMARY KEY, - title TEXT NOT NULL, - description TEXT DEFAULT '', - priority TEXT DEFAULT 'medium', - category TEXT DEFAULT 'suggestion', - submitter TEXT DEFAULT 'dashboard', - related_files TEXT DEFAULT '', - status TEXT DEFAULT 'submitted', - result TEXT DEFAULT '', - rejection_reason TEXT DEFAULT '', - created_at TEXT DEFAULT (datetime('now')), - completed_at TEXT - ) - """) - conn.commit() - return conn + with closing(sqlite3.connect(str(DB_PATH))) as conn: + conn.row_factory = sqlite3.Row + conn.execute(""" + CREATE TABLE IF NOT EXISTS work_orders ( + id TEXT PRIMARY KEY, + title TEXT NOT NULL, + description TEXT DEFAULT '', + priority TEXT DEFAULT 'medium', + category TEXT DEFAULT 'suggestion', + submitter TEXT DEFAULT 'dashboard', + related_files TEXT DEFAULT '', + status TEXT DEFAULT 'submitted', + result TEXT DEFAULT '', + rejection_reason TEXT DEFAULT '', + created_at TEXT DEFAULT (datetime('now')), + completed_at TEXT + ) + """) + conn.commit() + yield conn class _EnumLike: @@ -105,7 +107,7 @@ def _query_wos(db, statuses): @router.get("/work-orders/queue", response_class=HTMLResponse) async def work_orders_page(request: Request): - with closing(_get_db()) as db: + with _get_db() as db: pending = _query_wos(db, ["submitted", "triaged"]) active = _query_wos(db, ["approved", "in_progress"]) completed = _query_wos(db, ["completed"]) @@ -146,7 +148,7 @@ async def submit_work_order( priority = priority if priority in PRIORITIES else "medium" category = category if category in CATEGORIES else "suggestion" - with closing(_get_db()) as db: + with _get_db() as db: db.execute( "INSERT INTO work_orders (id, title, description, priority, category, submitter, related_files, created_at) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?)", @@ -166,7 +168,7 @@ async def submit_work_order( @router.get("/work-orders/queue/pending", response_class=HTMLResponse) async def pending_partial(request: Request): - with closing(_get_db()) as db: + with _get_db() as db: wos = _query_wos(db, ["submitted", "triaged"]) if not wos: return HTMLResponse( @@ -185,7 +187,7 @@ async def pending_partial(request: Request): @router.get("/work-orders/queue/active", response_class=HTMLResponse) async def active_partial(request: Request): - with closing(_get_db()) as db: + with _get_db() as db: wos = _query_wos(db, ["approved", "in_progress"]) if not wos: return HTMLResponse( @@ -211,7 +213,7 @@ async def _update_status(request: Request, wo_id: str, new_status: str, **extra) completed_at = ( datetime.utcnow().isoformat() if new_status in ("completed", "rejected") else None ) - with closing(_get_db()) as db: + with _get_db() as db: sets = ["status=?", "completed_at=COALESCE(?, completed_at)"] vals = [new_status, completed_at] for col, val in extra.items(): diff --git a/src/dashboard/store.py b/src/dashboard/store.py index ffd97e63..7ad51a78 100644 --- a/src/dashboard/store.py +++ b/src/dashboard/store.py @@ -9,6 +9,8 @@ A configurable retention policy (default 500 messages) keeps the DB lean. import sqlite3 import threading +from collections.abc import Generator +from contextlib import closing, contextmanager from dataclasses import dataclass from pathlib import Path @@ -28,24 +30,25 @@ class Message: source: str = "browser" # "browser" | "api" | "telegram" | "discord" | "system" -def _get_conn(db_path: Path | None = None) -> sqlite3.Connection: +@contextmanager +def _get_conn(db_path: Path | None = None) -> Generator[sqlite3.Connection, None, None]: """Open (or create) the chat database and ensure schema exists.""" path = db_path or DB_PATH path.parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(str(path), check_same_thread=False) - conn.row_factory = sqlite3.Row - conn.execute("PRAGMA journal_mode=WAL") - conn.execute(""" - CREATE TABLE IF NOT EXISTS chat_messages ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - role TEXT NOT NULL, - content TEXT NOT NULL, - timestamp TEXT NOT NULL, - source TEXT NOT NULL DEFAULT 'browser' - ) - """) - conn.commit() - return conn + with closing(sqlite3.connect(str(path), check_same_thread=False)) as conn: + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute(""" + CREATE TABLE IF NOT EXISTS chat_messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + role TEXT NOT NULL, + content TEXT NOT NULL, + timestamp TEXT NOT NULL, + source TEXT NOT NULL DEFAULT 'browser' + ) + """) + conn.commit() + yield conn class MessageLog: @@ -59,7 +62,23 @@ class MessageLog: # Lazy connection — opened on first use, not at import time. def _ensure_conn(self) -> sqlite3.Connection: if self._conn is None: - self._conn = _get_conn(self._db_path) + # Open a persistent connection for the class instance + path = self._db_path or DB_PATH + path.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(str(path), check_same_thread=False) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute(""" + CREATE TABLE IF NOT EXISTS chat_messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + role TEXT NOT NULL, + content TEXT NOT NULL, + timestamp TEXT NOT NULL, + source TEXT NOT NULL DEFAULT 'browser' + ) + """) + conn.commit() + self._conn = conn return self._conn def append(self, role: str, content: str, timestamp: str, source: str = "browser") -> None: diff --git a/src/infrastructure/events/bus.py b/src/infrastructure/events/bus.py index fab61f0e..f9e81f2f 100644 --- a/src/infrastructure/events/bus.py +++ b/src/infrastructure/events/bus.py @@ -9,8 +9,8 @@ import asyncio import json import logging import sqlite3 -from collections.abc import Callable, Coroutine -from contextlib import closing +from collections.abc import Callable, Coroutine, Generator +from contextlib import closing, contextmanager from dataclasses import dataclass, field from datetime import UTC, datetime from pathlib import Path @@ -106,22 +106,23 @@ class EventBus: conn.executescript(_EVENTS_SCHEMA) conn.commit() - def _get_persistence_conn(self) -> sqlite3.Connection | None: + @contextmanager + def _get_persistence_conn(self) -> Generator[sqlite3.Connection | None, None, None]: """Get a connection to the persistence database.""" if self._persistence_db_path is None: - return None - conn = sqlite3.connect(str(self._persistence_db_path)) - conn.row_factory = sqlite3.Row - conn.execute("PRAGMA busy_timeout=5000") - return conn + yield None + return + with closing(sqlite3.connect(str(self._persistence_db_path))) as conn: + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA busy_timeout=5000") + yield conn def _persist_event(self, event: Event) -> None: """Write an event to the persistence database.""" - conn = self._get_persistence_conn() - if conn is None: - return - try: - with closing(conn): + with self._get_persistence_conn() as conn: + if conn is None: + return + try: task_id = event.data.get("task_id", "") agent_id = event.data.get("agent_id", "") conn.execute( @@ -139,8 +140,8 @@ class EventBus: ), ) conn.commit() - except Exception as exc: - logger.debug("Failed to persist event: %s", exc) + except Exception as exc: + logger.debug("Failed to persist event: %s", exc) # ── Replay ─────────────────────────────────────────────────────────── @@ -162,12 +163,11 @@ class EventBus: Returns: List of Event objects from persistent storage. """ - conn = self._get_persistence_conn() - if conn is None: - return [] + with self._get_persistence_conn() as conn: + if conn is None: + return [] - try: - with closing(conn): + try: conditions = [] params: list = [] @@ -197,9 +197,9 @@ class EventBus: ) for row in rows ] - except Exception as exc: - logger.debug("Failed to replay events: %s", exc) - return [] + except Exception as exc: + logger.debug("Failed to replay events: %s", exc) + return [] # ── Subscribe / Publish ────────────────────────────────────────────── diff --git a/src/infrastructure/models/registry.py b/src/infrastructure/models/registry.py index f5b97311..0386eaad 100644 --- a/src/infrastructure/models/registry.py +++ b/src/infrastructure/models/registry.py @@ -11,7 +11,8 @@ model roles (student, teacher, judge/PRM) run on dedicated resources. import logging import sqlite3 import threading -from contextlib import closing +from collections.abc import Generator +from contextlib import closing, contextmanager from dataclasses import dataclass from datetime import UTC, datetime from enum import StrEnum @@ -61,36 +62,37 @@ class CustomModel: self.registered_at = datetime.now(UTC).isoformat() -def _get_conn() -> sqlite3.Connection: +@contextmanager +def _get_conn() -> Generator[sqlite3.Connection, None, None]: DB_PATH.parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(str(DB_PATH)) - conn.row_factory = sqlite3.Row - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA busy_timeout=5000") - conn.execute(""" - CREATE TABLE IF NOT EXISTS custom_models ( - name TEXT PRIMARY KEY, - format TEXT NOT NULL, - path TEXT NOT NULL, - role TEXT NOT NULL DEFAULT 'general', - context_window INTEGER NOT NULL DEFAULT 4096, - description TEXT NOT NULL DEFAULT '', - registered_at TEXT NOT NULL, - active INTEGER NOT NULL DEFAULT 1, - default_temperature REAL NOT NULL DEFAULT 0.7, - max_tokens INTEGER NOT NULL DEFAULT 2048 - ) - """) - conn.execute(""" - CREATE TABLE IF NOT EXISTS agent_model_assignments ( - agent_id TEXT PRIMARY KEY, - model_name TEXT NOT NULL, - assigned_at TEXT NOT NULL, - FOREIGN KEY (model_name) REFERENCES custom_models(name) - ) - """) - conn.commit() - return conn + with closing(sqlite3.connect(str(DB_PATH))) as conn: + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA busy_timeout=5000") + conn.execute(""" + CREATE TABLE IF NOT EXISTS custom_models ( + name TEXT PRIMARY KEY, + format TEXT NOT NULL, + path TEXT NOT NULL, + role TEXT NOT NULL DEFAULT 'general', + context_window INTEGER NOT NULL DEFAULT 4096, + description TEXT NOT NULL DEFAULT '', + registered_at TEXT NOT NULL, + active INTEGER NOT NULL DEFAULT 1, + default_temperature REAL NOT NULL DEFAULT 0.7, + max_tokens INTEGER NOT NULL DEFAULT 2048 + ) + """) + conn.execute(""" + CREATE TABLE IF NOT EXISTS agent_model_assignments ( + agent_id TEXT PRIMARY KEY, + model_name TEXT NOT NULL, + assigned_at TEXT NOT NULL, + FOREIGN KEY (model_name) REFERENCES custom_models(name) + ) + """) + conn.commit() + yield conn class ModelRegistry: @@ -106,7 +108,7 @@ class ModelRegistry: def _load_from_db(self) -> None: """Bootstrap cache from SQLite.""" try: - with closing(_get_conn()) as conn: + with _get_conn() as conn: for row in conn.execute("SELECT * FROM custom_models WHERE active = 1").fetchall(): self._models[row["name"]] = CustomModel( name=row["name"], @@ -130,7 +132,7 @@ class ModelRegistry: def register(self, model: CustomModel) -> CustomModel: """Register a new custom model.""" with self._lock: - with closing(_get_conn()) as conn: + with _get_conn() as conn: conn.execute( """ INSERT OR REPLACE INTO custom_models @@ -161,7 +163,7 @@ class ModelRegistry: with self._lock: if name not in self._models: return False - with closing(_get_conn()) as conn: + with _get_conn() as conn: conn.execute("DELETE FROM custom_models WHERE name = ?", (name,)) conn.execute("DELETE FROM agent_model_assignments WHERE model_name = ?", (name,)) conn.commit() @@ -191,7 +193,7 @@ class ModelRegistry: return False with self._lock: model.active = active - with closing(_get_conn()) as conn: + with _get_conn() as conn: conn.execute( "UPDATE custom_models SET active = ? WHERE name = ?", (int(active), name), @@ -207,7 +209,7 @@ class ModelRegistry: return False with self._lock: now = datetime.now(UTC).isoformat() - with closing(_get_conn()) as conn: + with _get_conn() as conn: conn.execute( """ INSERT OR REPLACE INTO agent_model_assignments @@ -226,7 +228,7 @@ class ModelRegistry: with self._lock: if agent_id not in self._agent_assignments: return False - with closing(_get_conn()) as conn: + with _get_conn() as conn: conn.execute( "DELETE FROM agent_model_assignments WHERE agent_id = ?", (agent_id,), diff --git a/src/spark/eidos.py b/src/spark/eidos.py index 28e7a5ce..c9f2a014 100644 --- a/src/spark/eidos.py +++ b/src/spark/eidos.py @@ -16,6 +16,8 @@ import json import logging import sqlite3 import uuid +from collections.abc import Generator +from contextlib import closing, contextmanager from dataclasses import dataclass from datetime import UTC, datetime from pathlib import Path @@ -39,28 +41,31 @@ class Prediction: evaluated_at: str | None -def _get_conn() -> sqlite3.Connection: +@contextmanager +def _get_conn() -> Generator[sqlite3.Connection, None, None]: DB_PATH.parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(str(DB_PATH)) - conn.row_factory = sqlite3.Row - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA busy_timeout=5000") - conn.execute(""" - CREATE TABLE IF NOT EXISTS spark_predictions ( - id TEXT PRIMARY KEY, - task_id TEXT NOT NULL, - prediction_type TEXT NOT NULL, - predicted_value TEXT NOT NULL, - actual_value TEXT, - accuracy REAL, - created_at TEXT NOT NULL, - evaluated_at TEXT + with closing(sqlite3.connect(str(DB_PATH))) as conn: + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA busy_timeout=5000") + conn.execute(""" + CREATE TABLE IF NOT EXISTS spark_predictions ( + id TEXT PRIMARY KEY, + task_id TEXT NOT NULL, + prediction_type TEXT NOT NULL, + predicted_value TEXT NOT NULL, + actual_value TEXT, + accuracy REAL, + created_at TEXT NOT NULL, + evaluated_at TEXT + ) + """) + conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_task ON spark_predictions(task_id)") + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_pred_type ON spark_predictions(prediction_type)" ) - """) - conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_task ON spark_predictions(task_id)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_type ON spark_predictions(prediction_type)") - conn.commit() - return conn + conn.commit() + yield conn # ── Prediction phase ──────────────────────────────────────────────────────── @@ -119,17 +124,16 @@ def predict_task_outcome( # Store prediction pred_id = str(uuid.uuid4()) now = datetime.now(UTC).isoformat() - conn = _get_conn() - conn.execute( - """ - INSERT INTO spark_predictions - (id, task_id, prediction_type, predicted_value, created_at) - VALUES (?, ?, ?, ?, ?) - """, - (pred_id, task_id, "outcome", json.dumps(prediction), now), - ) - conn.commit() - conn.close() + with _get_conn() as conn: + conn.execute( + """ + INSERT INTO spark_predictions + (id, task_id, prediction_type, predicted_value, created_at) + VALUES (?, ?, ?, ?, ?) + """, + (pred_id, task_id, "outcome", json.dumps(prediction), now), + ) + conn.commit() prediction["prediction_id"] = pred_id return prediction @@ -148,41 +152,39 @@ def evaluate_prediction( Returns the evaluation result or None if no prediction exists. """ - conn = _get_conn() - row = conn.execute( - """ - SELECT * FROM spark_predictions - WHERE task_id = ? AND prediction_type = 'outcome' AND evaluated_at IS NULL - ORDER BY created_at DESC LIMIT 1 - """, - (task_id,), - ).fetchone() + with _get_conn() as conn: + row = conn.execute( + """ + SELECT * FROM spark_predictions + WHERE task_id = ? AND prediction_type = 'outcome' AND evaluated_at IS NULL + ORDER BY created_at DESC LIMIT 1 + """, + (task_id,), + ).fetchone() - if not row: - conn.close() - return None + if not row: + return None - predicted = json.loads(row["predicted_value"]) - actual = { - "winner": actual_winner, - "succeeded": task_succeeded, - "winning_bid": winning_bid, - } + predicted = json.loads(row["predicted_value"]) + actual = { + "winner": actual_winner, + "succeeded": task_succeeded, + "winning_bid": winning_bid, + } - # Calculate accuracy - accuracy = _compute_accuracy(predicted, actual) - now = datetime.now(UTC).isoformat() + # Calculate accuracy + accuracy = _compute_accuracy(predicted, actual) + now = datetime.now(UTC).isoformat() - conn.execute( - """ - UPDATE spark_predictions - SET actual_value = ?, accuracy = ?, evaluated_at = ? - WHERE id = ? - """, - (json.dumps(actual), accuracy, now, row["id"]), - ) - conn.commit() - conn.close() + conn.execute( + """ + UPDATE spark_predictions + SET actual_value = ?, accuracy = ?, evaluated_at = ? + WHERE id = ? + """, + (json.dumps(actual), accuracy, now, row["id"]), + ) + conn.commit() return { "prediction_id": row["id"], @@ -243,7 +245,6 @@ def get_predictions( limit: int = 50, ) -> list[Prediction]: """Query stored predictions.""" - conn = _get_conn() query = "SELECT * FROM spark_predictions WHERE 1=1" params: list = [] @@ -256,8 +257,8 @@ def get_predictions( query += " ORDER BY created_at DESC LIMIT ?" params.append(limit) - rows = conn.execute(query, params).fetchall() - conn.close() + with _get_conn() as conn: + rows = conn.execute(query, params).fetchall() return [ Prediction( id=r["id"], @@ -275,17 +276,16 @@ def get_predictions( def get_accuracy_stats() -> dict: """Return aggregate accuracy statistics for the EIDOS loop.""" - conn = _get_conn() - row = conn.execute(""" - SELECT - COUNT(*) AS total_predictions, - COUNT(evaluated_at) AS evaluated, - AVG(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS avg_accuracy, - MIN(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS min_accuracy, - MAX(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS max_accuracy - FROM spark_predictions - """).fetchone() - conn.close() + with _get_conn() as conn: + row = conn.execute(""" + SELECT + COUNT(*) AS total_predictions, + COUNT(evaluated_at) AS evaluated, + AVG(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS avg_accuracy, + MIN(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS min_accuracy, + MAX(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS max_accuracy + FROM spark_predictions + """).fetchone() return { "total_predictions": row["total_predictions"] or 0, diff --git a/src/spark/memory.py b/src/spark/memory.py index 973f816d..032a175f 100644 --- a/src/spark/memory.py +++ b/src/spark/memory.py @@ -13,6 +13,8 @@ spark_memories — consolidated insights extracted from event patterns import logging import sqlite3 import uuid +from collections.abc import Generator +from contextlib import closing, contextmanager from dataclasses import dataclass from datetime import UTC, datetime from pathlib import Path @@ -55,42 +57,43 @@ class SparkMemory: expires_at: str | None -def _get_conn() -> sqlite3.Connection: +@contextmanager +def _get_conn() -> Generator[sqlite3.Connection, None, None]: DB_PATH.parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(str(DB_PATH)) - conn.row_factory = sqlite3.Row - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA busy_timeout=5000") - conn.execute(""" - CREATE TABLE IF NOT EXISTS spark_events ( - id TEXT PRIMARY KEY, - event_type TEXT NOT NULL, - agent_id TEXT, - task_id TEXT, - description TEXT NOT NULL DEFAULT '', - data TEXT NOT NULL DEFAULT '{}', - importance REAL NOT NULL DEFAULT 0.5, - created_at TEXT NOT NULL - ) - """) - conn.execute(""" - CREATE TABLE IF NOT EXISTS spark_memories ( - id TEXT PRIMARY KEY, - memory_type TEXT NOT NULL, - subject TEXT NOT NULL DEFAULT 'system', - content TEXT NOT NULL, - confidence REAL NOT NULL DEFAULT 0.5, - source_events INTEGER NOT NULL DEFAULT 0, - created_at TEXT NOT NULL, - expires_at TEXT - ) - """) - conn.execute("CREATE INDEX IF NOT EXISTS idx_events_type ON spark_events(event_type)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_events_agent ON spark_events(agent_id)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_events_task ON spark_events(task_id)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_subject ON spark_memories(subject)") - conn.commit() - return conn + with closing(sqlite3.connect(str(DB_PATH))) as conn: + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA busy_timeout=5000") + conn.execute(""" + CREATE TABLE IF NOT EXISTS spark_events ( + id TEXT PRIMARY KEY, + event_type TEXT NOT NULL, + agent_id TEXT, + task_id TEXT, + description TEXT NOT NULL DEFAULT '', + data TEXT NOT NULL DEFAULT '{}', + importance REAL NOT NULL DEFAULT 0.5, + created_at TEXT NOT NULL + ) + """) + conn.execute(""" + CREATE TABLE IF NOT EXISTS spark_memories ( + id TEXT PRIMARY KEY, + memory_type TEXT NOT NULL, + subject TEXT NOT NULL DEFAULT 'system', + content TEXT NOT NULL, + confidence REAL NOT NULL DEFAULT 0.5, + source_events INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL, + expires_at TEXT + ) + """) + conn.execute("CREATE INDEX IF NOT EXISTS idx_events_type ON spark_events(event_type)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_events_agent ON spark_events(agent_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_events_task ON spark_events(task_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_subject ON spark_memories(subject)") + conn.commit() + yield conn # ── Importance scoring ────────────────────────────────────────────────────── @@ -149,17 +152,16 @@ def record_event( parsed = {} importance = score_importance(event_type, parsed) - conn = _get_conn() - conn.execute( - """ - INSERT INTO spark_events - (id, event_type, agent_id, task_id, description, data, importance, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - (event_id, event_type, agent_id, task_id, description, data, importance, now), - ) - conn.commit() - conn.close() + with _get_conn() as conn: + conn.execute( + """ + INSERT INTO spark_events + (id, event_type, agent_id, task_id, description, data, importance, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + (event_id, event_type, agent_id, task_id, description, data, importance, now), + ) + conn.commit() # Bridge to unified event log so all events are queryable from one place try: @@ -188,7 +190,6 @@ def get_events( min_importance: float = 0.0, ) -> list[SparkEvent]: """Query events with optional filters.""" - conn = _get_conn() query = "SELECT * FROM spark_events WHERE importance >= ?" params: list = [min_importance] @@ -205,8 +206,8 @@ def get_events( query += " ORDER BY created_at DESC LIMIT ?" params.append(limit) - rows = conn.execute(query, params).fetchall() - conn.close() + with _get_conn() as conn: + rows = conn.execute(query, params).fetchall() return [ SparkEvent( id=r["id"], @@ -224,15 +225,14 @@ def get_events( def count_events(event_type: str | None = None) -> int: """Count events, optionally filtered by type.""" - conn = _get_conn() - if event_type: - row = conn.execute( - "SELECT COUNT(*) FROM spark_events WHERE event_type = ?", - (event_type,), - ).fetchone() - else: - row = conn.execute("SELECT COUNT(*) FROM spark_events").fetchone() - conn.close() + with _get_conn() as conn: + if event_type: + row = conn.execute( + "SELECT COUNT(*) FROM spark_events WHERE event_type = ?", + (event_type,), + ).fetchone() + else: + row = conn.execute("SELECT COUNT(*) FROM spark_events").fetchone() return row[0] @@ -250,17 +250,16 @@ def store_memory( """Store a consolidated memory. Returns the memory id.""" mem_id = str(uuid.uuid4()) now = datetime.now(UTC).isoformat() - conn = _get_conn() - conn.execute( - """ - INSERT INTO spark_memories - (id, memory_type, subject, content, confidence, source_events, created_at, expires_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - (mem_id, memory_type, subject, content, confidence, source_events, now, expires_at), - ) - conn.commit() - conn.close() + with _get_conn() as conn: + conn.execute( + """ + INSERT INTO spark_memories + (id, memory_type, subject, content, confidence, source_events, created_at, expires_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + (mem_id, memory_type, subject, content, confidence, source_events, now, expires_at), + ) + conn.commit() return mem_id @@ -271,7 +270,6 @@ def get_memories( limit: int = 50, ) -> list[SparkMemory]: """Query memories with optional filters.""" - conn = _get_conn() query = "SELECT * FROM spark_memories WHERE confidence >= ?" params: list = [min_confidence] @@ -285,8 +283,8 @@ def get_memories( query += " ORDER BY created_at DESC LIMIT ?" params.append(limit) - rows = conn.execute(query, params).fetchall() - conn.close() + with _get_conn() as conn: + rows = conn.execute(query, params).fetchall() return [ SparkMemory( id=r["id"], @@ -304,13 +302,12 @@ def get_memories( def count_memories(memory_type: str | None = None) -> int: """Count memories, optionally filtered by type.""" - conn = _get_conn() - if memory_type: - row = conn.execute( - "SELECT COUNT(*) FROM spark_memories WHERE memory_type = ?", - (memory_type,), - ).fetchone() - else: - row = conn.execute("SELECT COUNT(*) FROM spark_memories").fetchone() - conn.close() + with _get_conn() as conn: + if memory_type: + row = conn.execute( + "SELECT COUNT(*) FROM spark_memories WHERE memory_type = ?", + (memory_type,), + ).fetchone() + else: + row = conn.execute("SELECT COUNT(*) FROM spark_memories").fetchone() return row[0] diff --git a/src/timmy/approvals.py b/src/timmy/approvals.py index f093dac7..8cca50ff 100644 --- a/src/timmy/approvals.py +++ b/src/timmy/approvals.py @@ -13,7 +13,8 @@ Default is always True. The owner changes this intentionally. import sqlite3 import uuid -from contextlib import closing +from collections.abc import Generator +from contextlib import closing, contextmanager from dataclasses import dataclass from datetime import UTC, datetime, timedelta from pathlib import Path @@ -44,23 +45,24 @@ class ApprovalItem: status: str # "pending" | "approved" | "rejected" -def _get_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection: +@contextmanager +def _get_conn(db_path: Path = _DEFAULT_DB) -> Generator[sqlite3.Connection, None, None]: db_path.parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - conn.execute(""" - CREATE TABLE IF NOT EXISTS approval_items ( - id TEXT PRIMARY KEY, - title TEXT NOT NULL, - description TEXT NOT NULL, - proposed_action TEXT NOT NULL, - impact TEXT NOT NULL DEFAULT 'low', - created_at TEXT NOT NULL, - status TEXT NOT NULL DEFAULT 'pending' - ) - """) - conn.commit() - return conn + with closing(sqlite3.connect(str(db_path))) as conn: + conn.row_factory = sqlite3.Row + conn.execute(""" + CREATE TABLE IF NOT EXISTS approval_items ( + id TEXT PRIMARY KEY, + title TEXT NOT NULL, + description TEXT NOT NULL, + proposed_action TEXT NOT NULL, + impact TEXT NOT NULL DEFAULT 'low', + created_at TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'pending' + ) + """) + conn.commit() + yield conn def _row_to_item(row: sqlite3.Row) -> ApprovalItem: @@ -97,7 +99,7 @@ def create_item( created_at=datetime.now(UTC), status="pending", ) - with closing(_get_conn(db_path)) as conn: + with _get_conn(db_path) as conn: conn.execute( """ INSERT INTO approval_items @@ -120,7 +122,7 @@ def create_item( def list_pending(db_path: Path = _DEFAULT_DB) -> list[ApprovalItem]: """Return all pending approval items, newest first.""" - with closing(_get_conn(db_path)) as conn: + with _get_conn(db_path) as conn: rows = conn.execute( "SELECT * FROM approval_items WHERE status = 'pending' ORDER BY created_at DESC" ).fetchall() @@ -129,20 +131,20 @@ def list_pending(db_path: Path = _DEFAULT_DB) -> list[ApprovalItem]: def list_all(db_path: Path = _DEFAULT_DB) -> list[ApprovalItem]: """Return all approval items regardless of status, newest first.""" - with closing(_get_conn(db_path)) as conn: + with _get_conn(db_path) as conn: rows = conn.execute("SELECT * FROM approval_items ORDER BY created_at DESC").fetchall() return [_row_to_item(r) for r in rows] def get_item(item_id: str, db_path: Path = _DEFAULT_DB) -> ApprovalItem | None: - with closing(_get_conn(db_path)) as conn: + with _get_conn(db_path) as conn: row = conn.execute("SELECT * FROM approval_items WHERE id = ?", (item_id,)).fetchone() return _row_to_item(row) if row else None def approve(item_id: str, db_path: Path = _DEFAULT_DB) -> ApprovalItem | None: """Mark an approval item as approved.""" - with closing(_get_conn(db_path)) as conn: + with _get_conn(db_path) as conn: conn.execute("UPDATE approval_items SET status = 'approved' WHERE id = ?", (item_id,)) conn.commit() return get_item(item_id, db_path) @@ -150,7 +152,7 @@ def approve(item_id: str, db_path: Path = _DEFAULT_DB) -> ApprovalItem | None: def reject(item_id: str, db_path: Path = _DEFAULT_DB) -> ApprovalItem | None: """Mark an approval item as rejected.""" - with closing(_get_conn(db_path)) as conn: + with _get_conn(db_path) as conn: conn.execute("UPDATE approval_items SET status = 'rejected' WHERE id = ?", (item_id,)) conn.commit() return get_item(item_id, db_path) @@ -159,7 +161,7 @@ def reject(item_id: str, db_path: Path = _DEFAULT_DB) -> ApprovalItem | None: def expire_old(db_path: Path = _DEFAULT_DB) -> int: """Auto-expire pending items older than EXPIRY_DAYS. Returns count removed.""" cutoff = (datetime.now(UTC) - timedelta(days=_EXPIRY_DAYS)).isoformat() - with closing(_get_conn(db_path)) as conn: + with _get_conn(db_path) as conn: cursor = conn.execute( "DELETE FROM approval_items WHERE status = 'pending' AND created_at < ?", (cutoff,), diff --git a/src/timmy/briefing.py b/src/timmy/briefing.py index 0527f155..76b08617 100644 --- a/src/timmy/briefing.py +++ b/src/timmy/briefing.py @@ -10,7 +10,8 @@ regenerates the briefing every 6 hours. import logging import sqlite3 -from contextlib import closing +from collections.abc import Generator +from contextlib import closing, contextmanager from dataclasses import dataclass, field from datetime import UTC, datetime, timedelta from pathlib import Path @@ -57,25 +58,26 @@ class Briefing: # --------------------------------------------------------------------------- -def _get_cache_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection: +@contextmanager +def _get_cache_conn(db_path: Path = _DEFAULT_DB) -> Generator[sqlite3.Connection, None, None]: db_path.parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - conn.execute(""" - CREATE TABLE IF NOT EXISTS briefings ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - generated_at TEXT NOT NULL, - period_start TEXT NOT NULL, - period_end TEXT NOT NULL, - summary TEXT NOT NULL - ) - """) - conn.commit() - return conn + with closing(sqlite3.connect(str(db_path))) as conn: + conn.row_factory = sqlite3.Row + conn.execute(""" + CREATE TABLE IF NOT EXISTS briefings ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + generated_at TEXT NOT NULL, + period_start TEXT NOT NULL, + period_end TEXT NOT NULL, + summary TEXT NOT NULL + ) + """) + conn.commit() + yield conn def _save_briefing(briefing: Briefing, db_path: Path = _DEFAULT_DB) -> None: - with closing(_get_cache_conn(db_path)) as conn: + with _get_cache_conn(db_path) as conn: conn.execute( """ INSERT INTO briefings (generated_at, period_start, period_end, summary) @@ -93,7 +95,7 @@ def _save_briefing(briefing: Briefing, db_path: Path = _DEFAULT_DB) -> None: def _load_latest(db_path: Path = _DEFAULT_DB) -> Briefing | None: """Load the most-recently cached briefing, or None if there is none.""" - with closing(_get_cache_conn(db_path)) as conn: + with _get_cache_conn(db_path) as conn: row = conn.execute("SELECT * FROM briefings ORDER BY generated_at DESC LIMIT 1").fetchone() if row is None: return None diff --git a/src/timmy/memory/unified.py b/src/timmy/memory/unified.py index 1023cf33..bd8633fa 100644 --- a/src/timmy/memory/unified.py +++ b/src/timmy/memory/unified.py @@ -11,6 +11,8 @@ All three tables live in ``data/memory.db``. Existing APIs in import logging import sqlite3 +from collections.abc import Generator +from contextlib import closing, contextmanager from pathlib import Path logger = logging.getLogger(__name__) @@ -18,15 +20,16 @@ logger = logging.getLogger(__name__) DB_PATH = Path(__file__).parent.parent.parent.parent / "data" / "memory.db" -def get_connection() -> sqlite3.Connection: +@contextmanager +def get_connection() -> Generator[sqlite3.Connection, None, None]: """Open (and lazily create) the unified memory database.""" DB_PATH.parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(str(DB_PATH)) - conn.row_factory = sqlite3.Row - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA busy_timeout=5000") - _ensure_schema(conn) - return conn + with closing(sqlite3.connect(str(DB_PATH))) as conn: + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA busy_timeout=5000") + _ensure_schema(conn) + yield conn def _ensure_schema(conn: sqlite3.Connection) -> None: diff --git a/src/timmy/memory/vector_store.py b/src/timmy/memory/vector_store.py index 2c802909..b828c05f 100644 --- a/src/timmy/memory/vector_store.py +++ b/src/timmy/memory/vector_store.py @@ -8,6 +8,8 @@ import json import logging import sqlite3 import uuid +from collections.abc import Generator +from contextlib import contextmanager from dataclasses import dataclass, field from datetime import UTC, datetime @@ -54,11 +56,13 @@ class MemoryEntry: relevance_score: float | None = None # Set during search -def _get_conn() -> sqlite3.Connection: +@contextmanager +def _get_conn() -> Generator[sqlite3.Connection, None, None]: """Get database connection to unified memory.db.""" from timmy.memory.unified import get_connection - return get_connection() + with get_connection() as conn: + yield conn def store_memory( @@ -101,29 +105,28 @@ def store_memory( embedding=embedding, ) - conn = _get_conn() - conn.execute( - """ - INSERT INTO episodes - (id, content, source, context_type, agent_id, task_id, session_id, - metadata, embedding, timestamp) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - entry.id, - entry.content, - entry.source, - entry.context_type, - entry.agent_id, - entry.task_id, - entry.session_id, - json.dumps(metadata) if metadata else None, - json.dumps(embedding) if embedding else None, - entry.timestamp, - ), - ) - conn.commit() - conn.close() + with _get_conn() as conn: + conn.execute( + """ + INSERT INTO episodes + (id, content, source, context_type, agent_id, task_id, session_id, + metadata, embedding, timestamp) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + entry.id, + entry.content, + entry.source, + entry.context_type, + entry.agent_id, + entry.task_id, + entry.session_id, + json.dumps(metadata) if metadata else None, + json.dumps(embedding) if embedding else None, + entry.timestamp, + ), + ) + conn.commit() return entry @@ -151,8 +154,6 @@ def search_memories( """ query_embedding = _compute_embedding(query) - conn = _get_conn() - # Build query with filters conditions = [] params = [] @@ -179,8 +180,8 @@ def search_memories( """ params.append(limit * 3) # Get more candidates for ranking - rows = conn.execute(query_sql, params).fetchall() - conn.close() + with _get_conn() as conn: + rows = conn.execute(query_sql, params).fetchall() # Compute similarity scores results = [] @@ -275,58 +276,54 @@ def recall_personal_facts(agent_id: str | None = None) -> list[str]: Returns: List of fact strings """ - conn = _get_conn() + with _get_conn() as conn: + if agent_id: + rows = conn.execute( + """ + SELECT content FROM episodes + WHERE context_type = 'fact' AND agent_id = ? + ORDER BY timestamp DESC + LIMIT 100 + """, + (agent_id,), + ).fetchall() + else: + rows = conn.execute( + """ + SELECT content FROM episodes + WHERE context_type = 'fact' + ORDER BY timestamp DESC + LIMIT 100 + """, + ).fetchall() - if agent_id: - rows = conn.execute( - """ - SELECT content FROM episodes - WHERE context_type = 'fact' AND agent_id = ? - ORDER BY timestamp DESC - LIMIT 100 - """, - (agent_id,), - ).fetchall() - else: - rows = conn.execute( - """ - SELECT content FROM episodes - WHERE context_type = 'fact' - ORDER BY timestamp DESC - LIMIT 100 - """, - ).fetchall() - - conn.close() return [r["content"] for r in rows] def recall_personal_facts_with_ids(agent_id: str | None = None) -> list[dict]: """Recall personal facts with their IDs for edit/delete operations.""" - conn = _get_conn() - if agent_id: - rows = conn.execute( - "SELECT id, content FROM episodes WHERE context_type = 'fact' AND agent_id = ? ORDER BY timestamp DESC LIMIT 100", - (agent_id,), - ).fetchall() - else: - rows = conn.execute( - "SELECT id, content FROM episodes WHERE context_type = 'fact' ORDER BY timestamp DESC LIMIT 100", - ).fetchall() - conn.close() + with _get_conn() as conn: + if agent_id: + rows = conn.execute( + "SELECT id, content FROM episodes WHERE context_type = 'fact' AND agent_id = ? ORDER BY timestamp DESC LIMIT 100", + (agent_id,), + ).fetchall() + else: + rows = conn.execute( + "SELECT id, content FROM episodes WHERE context_type = 'fact' ORDER BY timestamp DESC LIMIT 100", + ).fetchall() return [{"id": r["id"], "content": r["content"]} for r in rows] def update_personal_fact(memory_id: str, new_content: str) -> bool: """Update a personal fact's content.""" - conn = _get_conn() - cursor = conn.execute( - "UPDATE episodes SET content = ? WHERE id = ? AND context_type = 'fact'", - (new_content, memory_id), - ) - conn.commit() - updated = cursor.rowcount > 0 - conn.close() + with _get_conn() as conn: + cursor = conn.execute( + "UPDATE episodes SET content = ? WHERE id = ? AND context_type = 'fact'", + (new_content, memory_id), + ) + conn.commit() + updated = cursor.rowcount > 0 return updated @@ -355,14 +352,13 @@ def delete_memory(memory_id: str) -> bool: Returns: True if deleted, False if not found """ - conn = _get_conn() - cursor = conn.execute( - "DELETE FROM episodes WHERE id = ?", - (memory_id,), - ) - conn.commit() - deleted = cursor.rowcount > 0 - conn.close() + with _get_conn() as conn: + cursor = conn.execute( + "DELETE FROM episodes WHERE id = ?", + (memory_id,), + ) + conn.commit() + deleted = cursor.rowcount > 0 return deleted @@ -372,22 +368,19 @@ def get_memory_stats() -> dict: Returns: Dict with counts by type, total entries, etc. """ - conn = _get_conn() + with _get_conn() as conn: + total = conn.execute("SELECT COUNT(*) as count FROM episodes").fetchone()["count"] - total = conn.execute("SELECT COUNT(*) as count FROM episodes").fetchone()["count"] + by_type = {} + rows = conn.execute( + "SELECT context_type, COUNT(*) as count FROM episodes GROUP BY context_type" + ).fetchall() + for row in rows: + by_type[row["context_type"]] = row["count"] - by_type = {} - rows = conn.execute( - "SELECT context_type, COUNT(*) as count FROM episodes GROUP BY context_type" - ).fetchall() - for row in rows: - by_type[row["context_type"]] = row["count"] - - with_embeddings = conn.execute( - "SELECT COUNT(*) as count FROM episodes WHERE embedding IS NOT NULL" - ).fetchone()["count"] - - conn.close() + with_embeddings = conn.execute( + "SELECT COUNT(*) as count FROM episodes WHERE embedding IS NOT NULL" + ).fetchone()["count"] return { "total_entries": total, @@ -411,24 +404,22 @@ def prune_memories(older_than_days: int = 90, keep_facts: bool = True) -> int: cutoff = (datetime.now(UTC) - timedelta(days=older_than_days)).isoformat() - conn = _get_conn() + with _get_conn() as conn: + if keep_facts: + cursor = conn.execute( + """ + DELETE FROM episodes + WHERE timestamp < ? AND context_type != 'fact' + """, + (cutoff,), + ) + else: + cursor = conn.execute( + "DELETE FROM episodes WHERE timestamp < ?", + (cutoff,), + ) - if keep_facts: - cursor = conn.execute( - """ - DELETE FROM episodes - WHERE timestamp < ? AND context_type != 'fact' - """, - (cutoff,), - ) - else: - cursor = conn.execute( - "DELETE FROM episodes WHERE timestamp < ?", - (cutoff,), - ) - - deleted = cursor.rowcount - conn.commit() - conn.close() + deleted = cursor.rowcount + conn.commit() return deleted diff --git a/src/timmy/thinking.py b/src/timmy/thinking.py index 1c987999..7a955cc8 100644 --- a/src/timmy/thinking.py +++ b/src/timmy/thinking.py @@ -21,7 +21,8 @@ import logging import random import sqlite3 import uuid -from contextlib import closing +from collections.abc import Generator +from contextlib import closing, contextmanager from dataclasses import dataclass from datetime import UTC, datetime, timedelta from difflib import SequenceMatcher @@ -169,23 +170,24 @@ class Thought: created_at: str -def _get_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection: +@contextmanager +def _get_conn(db_path: Path = _DEFAULT_DB) -> Generator[sqlite3.Connection, None, None]: """Get a SQLite connection with the thoughts table created.""" db_path.parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - conn.execute(""" - CREATE TABLE IF NOT EXISTS thoughts ( - id TEXT PRIMARY KEY, - content TEXT NOT NULL, - seed_type TEXT NOT NULL, - parent_id TEXT, - created_at TEXT NOT NULL - ) - """) - conn.execute("CREATE INDEX IF NOT EXISTS idx_thoughts_time ON thoughts(created_at)") - conn.commit() - return conn + with closing(sqlite3.connect(str(db_path))) as conn: + conn.row_factory = sqlite3.Row + conn.execute(""" + CREATE TABLE IF NOT EXISTS thoughts ( + id TEXT PRIMARY KEY, + content TEXT NOT NULL, + seed_type TEXT NOT NULL, + parent_id TEXT, + created_at TEXT NOT NULL + ) + """) + conn.execute("CREATE INDEX IF NOT EXISTS idx_thoughts_time ON thoughts(created_at)") + conn.commit() + yield conn def _row_to_thought(row: sqlite3.Row) -> Thought: @@ -321,7 +323,7 @@ class ThinkingEngine: def get_recent_thoughts(self, limit: int = 20) -> list[Thought]: """Retrieve the most recent thoughts.""" - with closing(_get_conn(self._db_path)) as conn: + with _get_conn(self._db_path) as conn: rows = conn.execute( "SELECT * FROM thoughts ORDER BY created_at DESC LIMIT ?", (limit,), @@ -330,7 +332,7 @@ class ThinkingEngine: def get_thought(self, thought_id: str) -> Thought | None: """Retrieve a single thought by ID.""" - with closing(_get_conn(self._db_path)) as conn: + with _get_conn(self._db_path) as conn: row = conn.execute("SELECT * FROM thoughts WHERE id = ?", (thought_id,)).fetchone() return _row_to_thought(row) if row else None @@ -342,7 +344,7 @@ class ThinkingEngine: chain = [] current_id: str | None = thought_id - with closing(_get_conn(self._db_path)) as conn: + with _get_conn(self._db_path) as conn: for _ in range(max_depth): if not current_id: break @@ -357,7 +359,7 @@ class ThinkingEngine: def count_thoughts(self) -> int: """Return total number of stored thoughts.""" - with closing(_get_conn(self._db_path)) as conn: + with _get_conn(self._db_path) as conn: count = conn.execute("SELECT COUNT(*) as c FROM thoughts").fetchone()["c"] return count @@ -366,7 +368,7 @@ class ThinkingEngine: Returns the number of deleted rows. """ - with closing(_get_conn(self._db_path)) as conn: + with _get_conn(self._db_path) as conn: try: total = conn.execute("SELECT COUNT(*) as c FROM thoughts").fetchone()["c"] if total <= keep_min: @@ -603,7 +605,7 @@ class ThinkingEngine: # Thought count today (cheap DB query) try: today_start = now.replace(hour=0, minute=0, second=0, microsecond=0) - with closing(_get_conn(self._db_path)) as conn: + with _get_conn(self._db_path) as conn: count = conn.execute( "SELECT COUNT(*) as c FROM thoughts WHERE created_at >= ?", (today_start.isoformat(),), @@ -960,7 +962,7 @@ class ThinkingEngine: created_at=datetime.now(UTC).isoformat(), ) - with closing(_get_conn(self._db_path)) as conn: + with _get_conn(self._db_path) as conn: conn.execute( """ INSERT INTO thoughts (id, content, seed_type, parent_id, created_at) diff --git a/tests/infrastructure/test_model_registry.py b/tests/infrastructure/test_model_registry.py index a93933a4..68972180 100644 --- a/tests/infrastructure/test_model_registry.py +++ b/tests/infrastructure/test_model_registry.py @@ -216,21 +216,15 @@ class TestWALMode: with patch("infrastructure.models.registry.DB_PATH", db): from infrastructure.models.registry import _get_conn - conn = _get_conn() - try: + with _get_conn() as conn: mode = conn.execute("PRAGMA journal_mode").fetchone()[0] assert mode == "wal" - finally: - conn.close() def test_registry_db_busy_timeout(self, tmp_path): db = tmp_path / "wal_test.db" with patch("infrastructure.models.registry.DB_PATH", db): from infrastructure.models.registry import _get_conn - conn = _get_conn() - try: + with _get_conn() as conn: timeout = conn.execute("PRAGMA busy_timeout").fetchone()[0] assert timeout == 5000 - finally: - conn.close() diff --git a/tests/spark/test_spark.py b/tests/spark/test_spark.py index 1345fdbf..313e64e4 100644 --- a/tests/spark/test_spark.py +++ b/tests/spark/test_spark.py @@ -481,29 +481,20 @@ class TestWALMode: def test_spark_memory_uses_wal(self): from spark.memory import _get_conn - conn = _get_conn() - try: + with _get_conn() as conn: mode = conn.execute("PRAGMA journal_mode").fetchone()[0] assert mode == "wal", f"Expected WAL mode, got {mode}" - finally: - conn.close() def test_spark_eidos_uses_wal(self): from spark.eidos import _get_conn - conn = _get_conn() - try: + with _get_conn() as conn: mode = conn.execute("PRAGMA journal_mode").fetchone()[0] assert mode == "wal", f"Expected WAL mode, got {mode}" - finally: - conn.close() def test_spark_memory_busy_timeout(self): from spark.memory import _get_conn - conn = _get_conn() - try: + with _get_conn() as conn: timeout = conn.execute("PRAGMA busy_timeout").fetchone()[0] assert timeout == 5000 - finally: - conn.close() diff --git a/tests/timmy/test_approvals.py b/tests/timmy/test_approvals.py index 237b6663..3b0f6875 100644 --- a/tests/timmy/test_approvals.py +++ b/tests/timmy/test_approvals.py @@ -142,14 +142,13 @@ class TestExpireOld: # Create item and manually backdate it item = create_item("Old", "D", "A", db_path=db_path) - conn = _get_conn(db_path) old_date = (datetime.now(UTC) - timedelta(days=30)).isoformat() - conn.execute( - "UPDATE approval_items SET created_at = ? WHERE id = ?", - (old_date, item.id), - ) - conn.commit() - conn.close() + with _get_conn(db_path) as conn: + conn.execute( + "UPDATE approval_items SET created_at = ? WHERE id = ?", + (old_date, item.id), + ) + conn.commit() count = expire_old(db_path) assert count == 1 @@ -169,14 +168,13 @@ class TestExpireOld: approve(item.id, db_path) # Backdate it - conn = _get_conn(db_path) old_date = (datetime.now(UTC) - timedelta(days=30)).isoformat() - conn.execute( - "UPDATE approval_items SET created_at = ? WHERE id = ?", - (old_date, item.id), - ) - conn.commit() - conn.close() + with _get_conn(db_path) as conn: + conn.execute( + "UPDATE approval_items SET created_at = ? WHERE id = ?", + (old_date, item.id), + ) + conn.commit() count = expire_old(db_path) assert count == 0 # approved items not expired