[loop-cycle-50] refactor: replace bare sqlite3.connect() with context managers batch 2 (#157) (#180)

This commit is contained in:
2026-03-15 11:58:43 -04:00
parent bea2749158
commit bcd6d7e321
16 changed files with 512 additions and 510 deletions

View File

@@ -209,14 +209,11 @@ async def api_swarm_status():
pending_tasks = 0
try:
db = _get_db()
try:
with _get_db() as db:
row = db.execute(
"SELECT COUNT(*) as cnt FROM tasks WHERE status IN ('pending_approval','approved')"
).fetchone()
pending_tasks = row["cnt"] if row else 0
finally:
db.close()
except Exception:
pass

View File

@@ -3,7 +3,8 @@
import logging
import sqlite3
import uuid
from contextlib import closing
from collections.abc import Generator
from contextlib import closing, contextmanager
from datetime import datetime
from pathlib import Path
@@ -36,26 +37,27 @@ VALID_STATUSES = {
VALID_PRIORITIES = {"low", "normal", "high", "urgent"}
def _get_db() -> sqlite3.Connection:
@contextmanager
def _get_db() -> Generator[sqlite3.Connection, None, None]:
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
conn.execute("""
CREATE TABLE IF NOT EXISTS tasks (
id TEXT PRIMARY KEY,
title TEXT NOT NULL,
description TEXT DEFAULT '',
status TEXT DEFAULT 'pending_approval',
priority TEXT DEFAULT 'normal',
assigned_to TEXT DEFAULT '',
created_by TEXT DEFAULT 'operator',
result TEXT DEFAULT '',
created_at TEXT DEFAULT (datetime('now')),
completed_at TEXT
)
""")
conn.commit()
return conn
with closing(sqlite3.connect(str(DB_PATH))) as conn:
conn.row_factory = sqlite3.Row
conn.execute("""
CREATE TABLE IF NOT EXISTS tasks (
id TEXT PRIMARY KEY,
title TEXT NOT NULL,
description TEXT DEFAULT '',
status TEXT DEFAULT 'pending_approval',
priority TEXT DEFAULT 'normal',
assigned_to TEXT DEFAULT '',
created_by TEXT DEFAULT 'operator',
result TEXT DEFAULT '',
created_at TEXT DEFAULT (datetime('now')),
completed_at TEXT
)
""")
conn.commit()
yield conn
def _row_to_dict(row: sqlite3.Row) -> dict:
@@ -102,7 +104,7 @@ class _TaskView:
@router.get("/tasks", response_class=HTMLResponse)
async def tasks_page(request: Request):
"""Render the main task queue page with 3-column layout."""
with closing(_get_db()) as db:
with _get_db() as db:
pending = [
_TaskView(_row_to_dict(r))
for r in db.execute(
@@ -143,7 +145,7 @@ async def tasks_page(request: Request):
@router.get("/tasks/pending", response_class=HTMLResponse)
async def tasks_pending(request: Request):
with closing(_get_db()) as db:
with _get_db() as db:
rows = db.execute(
"SELECT * FROM tasks WHERE status='pending_approval' ORDER BY created_at DESC"
).fetchall()
@@ -162,7 +164,7 @@ async def tasks_pending(request: Request):
@router.get("/tasks/active", response_class=HTMLResponse)
async def tasks_active(request: Request):
with closing(_get_db()) as db:
with _get_db() as db:
rows = db.execute(
"SELECT * FROM tasks WHERE status IN ('approved','running','paused') ORDER BY created_at DESC"
).fetchall()
@@ -181,7 +183,7 @@ async def tasks_active(request: Request):
@router.get("/tasks/completed", response_class=HTMLResponse)
async def tasks_completed(request: Request):
with closing(_get_db()) as db:
with _get_db() as db:
rows = db.execute(
"SELECT * FROM tasks WHERE status IN ('completed','vetoed','failed') ORDER BY completed_at DESC LIMIT 50"
).fetchall()
@@ -220,7 +222,7 @@ async def create_task_form(
now = datetime.utcnow().isoformat()
priority = priority if priority in VALID_PRIORITIES else "normal"
with closing(_get_db()) as db:
with _get_db() as db:
db.execute(
"INSERT INTO tasks (id, title, description, priority, assigned_to, created_at) VALUES (?, ?, ?, ?, ?, ?)",
(task_id, title, description, priority, assigned_to, now),
@@ -269,7 +271,7 @@ async def modify_task(
title: str = Form(...),
description: str = Form(""),
):
with closing(_get_db()) as db:
with _get_db() as db:
db.execute(
"UPDATE tasks SET title=?, description=? WHERE id=?",
(title, description, task_id),
@@ -287,7 +289,7 @@ async def _set_status(request: Request, task_id: str, new_status: str):
completed_at = (
datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None
)
with closing(_get_db()) as db:
with _get_db() as db:
db.execute(
"UPDATE tasks SET status=?, completed_at=COALESCE(?, completed_at) WHERE id=?",
(new_status, completed_at, task_id),
@@ -319,7 +321,7 @@ async def api_create_task(request: Request):
if priority not in VALID_PRIORITIES:
priority = "normal"
with closing(_get_db()) as db:
with _get_db() as db:
db.execute(
"INSERT INTO tasks (id, title, description, priority, assigned_to, created_by, created_at) "
"VALUES (?, ?, ?, ?, ?, ?, ?)",
@@ -342,7 +344,7 @@ async def api_create_task(request: Request):
@router.get("/api/tasks", response_class=JSONResponse)
async def api_list_tasks():
"""List all tasks as JSON."""
with closing(_get_db()) as db:
with _get_db() as db:
rows = db.execute("SELECT * FROM tasks ORDER BY created_at DESC").fetchall()
return JSONResponse([_row_to_dict(r) for r in rows])
@@ -358,7 +360,7 @@ async def api_update_status(task_id: str, request: Request):
completed_at = (
datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None
)
with closing(_get_db()) as db:
with _get_db() as db:
db.execute(
"UPDATE tasks SET status=?, completed_at=COALESCE(?, completed_at) WHERE id=?",
(new_status, completed_at, task_id),
@@ -373,7 +375,7 @@ async def api_update_status(task_id: str, request: Request):
@router.delete("/api/tasks/{task_id}", response_class=JSONResponse)
async def api_delete_task(task_id: str):
"""Delete a task."""
with closing(_get_db()) as db:
with _get_db() as db:
cursor = db.execute("DELETE FROM tasks WHERE id=?", (task_id,))
db.commit()
if cursor.rowcount == 0:
@@ -389,7 +391,7 @@ async def api_delete_task(task_id: str):
@router.get("/api/queue/status", response_class=JSONResponse)
async def queue_status(assigned_to: str = "default"):
"""Return queue status for the chat panel's agent status indicator."""
with closing(_get_db()) as db:
with _get_db() as db:
running = db.execute(
"SELECT * FROM tasks WHERE status='running' AND assigned_to=? LIMIT 1",
(assigned_to,),

View File

@@ -3,7 +3,8 @@
import logging
import sqlite3
import uuid
from contextlib import closing
from collections.abc import Generator
from contextlib import closing, contextmanager
from datetime import datetime
from pathlib import Path
@@ -24,28 +25,29 @@ CATEGORIES = ["bug", "feature", "suggestion", "maintenance", "security"]
VALID_STATUSES = {"submitted", "triaged", "approved", "in_progress", "completed", "rejected"}
def _get_db() -> sqlite3.Connection:
@contextmanager
def _get_db() -> Generator[sqlite3.Connection, None, None]:
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
conn.execute("""
CREATE TABLE IF NOT EXISTS work_orders (
id TEXT PRIMARY KEY,
title TEXT NOT NULL,
description TEXT DEFAULT '',
priority TEXT DEFAULT 'medium',
category TEXT DEFAULT 'suggestion',
submitter TEXT DEFAULT 'dashboard',
related_files TEXT DEFAULT '',
status TEXT DEFAULT 'submitted',
result TEXT DEFAULT '',
rejection_reason TEXT DEFAULT '',
created_at TEXT DEFAULT (datetime('now')),
completed_at TEXT
)
""")
conn.commit()
return conn
with closing(sqlite3.connect(str(DB_PATH))) as conn:
conn.row_factory = sqlite3.Row
conn.execute("""
CREATE TABLE IF NOT EXISTS work_orders (
id TEXT PRIMARY KEY,
title TEXT NOT NULL,
description TEXT DEFAULT '',
priority TEXT DEFAULT 'medium',
category TEXT DEFAULT 'suggestion',
submitter TEXT DEFAULT 'dashboard',
related_files TEXT DEFAULT '',
status TEXT DEFAULT 'submitted',
result TEXT DEFAULT '',
rejection_reason TEXT DEFAULT '',
created_at TEXT DEFAULT (datetime('now')),
completed_at TEXT
)
""")
conn.commit()
yield conn
class _EnumLike:
@@ -105,7 +107,7 @@ def _query_wos(db, statuses):
@router.get("/work-orders/queue", response_class=HTMLResponse)
async def work_orders_page(request: Request):
with closing(_get_db()) as db:
with _get_db() as db:
pending = _query_wos(db, ["submitted", "triaged"])
active = _query_wos(db, ["approved", "in_progress"])
completed = _query_wos(db, ["completed"])
@@ -146,7 +148,7 @@ async def submit_work_order(
priority = priority if priority in PRIORITIES else "medium"
category = category if category in CATEGORIES else "suggestion"
with closing(_get_db()) as db:
with _get_db() as db:
db.execute(
"INSERT INTO work_orders (id, title, description, priority, category, submitter, related_files, created_at) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
@@ -166,7 +168,7 @@ async def submit_work_order(
@router.get("/work-orders/queue/pending", response_class=HTMLResponse)
async def pending_partial(request: Request):
with closing(_get_db()) as db:
with _get_db() as db:
wos = _query_wos(db, ["submitted", "triaged"])
if not wos:
return HTMLResponse(
@@ -185,7 +187,7 @@ async def pending_partial(request: Request):
@router.get("/work-orders/queue/active", response_class=HTMLResponse)
async def active_partial(request: Request):
with closing(_get_db()) as db:
with _get_db() as db:
wos = _query_wos(db, ["approved", "in_progress"])
if not wos:
return HTMLResponse(
@@ -211,7 +213,7 @@ async def _update_status(request: Request, wo_id: str, new_status: str, **extra)
completed_at = (
datetime.utcnow().isoformat() if new_status in ("completed", "rejected") else None
)
with closing(_get_db()) as db:
with _get_db() as db:
sets = ["status=?", "completed_at=COALESCE(?, completed_at)"]
vals = [new_status, completed_at]
for col, val in extra.items():

View File

@@ -9,6 +9,8 @@ A configurable retention policy (default 500 messages) keeps the DB lean.
import sqlite3
import threading
from collections.abc import Generator
from contextlib import closing, contextmanager
from dataclasses import dataclass
from pathlib import Path
@@ -28,24 +30,25 @@ class Message:
source: str = "browser" # "browser" | "api" | "telegram" | "discord" | "system"
def _get_conn(db_path: Path | None = None) -> sqlite3.Connection:
@contextmanager
def _get_conn(db_path: Path | None = None) -> Generator[sqlite3.Connection, None, None]:
"""Open (or create) the chat database and ensure schema exists."""
path = db_path or DB_PATH
path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(path), check_same_thread=False)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("""
CREATE TABLE IF NOT EXISTS chat_messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
role TEXT NOT NULL,
content TEXT NOT NULL,
timestamp TEXT NOT NULL,
source TEXT NOT NULL DEFAULT 'browser'
)
""")
conn.commit()
return conn
with closing(sqlite3.connect(str(path), check_same_thread=False)) as conn:
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("""
CREATE TABLE IF NOT EXISTS chat_messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
role TEXT NOT NULL,
content TEXT NOT NULL,
timestamp TEXT NOT NULL,
source TEXT NOT NULL DEFAULT 'browser'
)
""")
conn.commit()
yield conn
class MessageLog:
@@ -59,7 +62,23 @@ class MessageLog:
# Lazy connection — opened on first use, not at import time.
def _ensure_conn(self) -> sqlite3.Connection:
if self._conn is None:
self._conn = _get_conn(self._db_path)
# Open a persistent connection for the class instance
path = self._db_path or DB_PATH
path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(path), check_same_thread=False)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("""
CREATE TABLE IF NOT EXISTS chat_messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
role TEXT NOT NULL,
content TEXT NOT NULL,
timestamp TEXT NOT NULL,
source TEXT NOT NULL DEFAULT 'browser'
)
""")
conn.commit()
self._conn = conn
return self._conn
def append(self, role: str, content: str, timestamp: str, source: str = "browser") -> None:

View File

@@ -9,8 +9,8 @@ import asyncio
import json
import logging
import sqlite3
from collections.abc import Callable, Coroutine
from contextlib import closing
from collections.abc import Callable, Coroutine, Generator
from contextlib import closing, contextmanager
from dataclasses import dataclass, field
from datetime import UTC, datetime
from pathlib import Path
@@ -106,22 +106,23 @@ class EventBus:
conn.executescript(_EVENTS_SCHEMA)
conn.commit()
def _get_persistence_conn(self) -> sqlite3.Connection | None:
@contextmanager
def _get_persistence_conn(self) -> Generator[sqlite3.Connection | None, None, None]:
"""Get a connection to the persistence database."""
if self._persistence_db_path is None:
return None
conn = sqlite3.connect(str(self._persistence_db_path))
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA busy_timeout=5000")
return conn
yield None
return
with closing(sqlite3.connect(str(self._persistence_db_path))) as conn:
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA busy_timeout=5000")
yield conn
def _persist_event(self, event: Event) -> None:
"""Write an event to the persistence database."""
conn = self._get_persistence_conn()
if conn is None:
return
try:
with closing(conn):
with self._get_persistence_conn() as conn:
if conn is None:
return
try:
task_id = event.data.get("task_id", "")
agent_id = event.data.get("agent_id", "")
conn.execute(
@@ -139,8 +140,8 @@ class EventBus:
),
)
conn.commit()
except Exception as exc:
logger.debug("Failed to persist event: %s", exc)
except Exception as exc:
logger.debug("Failed to persist event: %s", exc)
# ── Replay ───────────────────────────────────────────────────────────
@@ -162,12 +163,11 @@ class EventBus:
Returns:
List of Event objects from persistent storage.
"""
conn = self._get_persistence_conn()
if conn is None:
return []
with self._get_persistence_conn() as conn:
if conn is None:
return []
try:
with closing(conn):
try:
conditions = []
params: list = []
@@ -197,9 +197,9 @@ class EventBus:
)
for row in rows
]
except Exception as exc:
logger.debug("Failed to replay events: %s", exc)
return []
except Exception as exc:
logger.debug("Failed to replay events: %s", exc)
return []
# ── Subscribe / Publish ──────────────────────────────────────────────

View File

@@ -11,7 +11,8 @@ model roles (student, teacher, judge/PRM) run on dedicated resources.
import logging
import sqlite3
import threading
from contextlib import closing
from collections.abc import Generator
from contextlib import closing, contextmanager
from dataclasses import dataclass
from datetime import UTC, datetime
from enum import StrEnum
@@ -61,36 +62,37 @@ class CustomModel:
self.registered_at = datetime.now(UTC).isoformat()
def _get_conn() -> sqlite3.Connection:
@contextmanager
def _get_conn() -> Generator[sqlite3.Connection, None, None]:
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
conn.execute("""
CREATE TABLE IF NOT EXISTS custom_models (
name TEXT PRIMARY KEY,
format TEXT NOT NULL,
path TEXT NOT NULL,
role TEXT NOT NULL DEFAULT 'general',
context_window INTEGER NOT NULL DEFAULT 4096,
description TEXT NOT NULL DEFAULT '',
registered_at TEXT NOT NULL,
active INTEGER NOT NULL DEFAULT 1,
default_temperature REAL NOT NULL DEFAULT 0.7,
max_tokens INTEGER NOT NULL DEFAULT 2048
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS agent_model_assignments (
agent_id TEXT PRIMARY KEY,
model_name TEXT NOT NULL,
assigned_at TEXT NOT NULL,
FOREIGN KEY (model_name) REFERENCES custom_models(name)
)
""")
conn.commit()
return conn
with closing(sqlite3.connect(str(DB_PATH))) as conn:
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
conn.execute("""
CREATE TABLE IF NOT EXISTS custom_models (
name TEXT PRIMARY KEY,
format TEXT NOT NULL,
path TEXT NOT NULL,
role TEXT NOT NULL DEFAULT 'general',
context_window INTEGER NOT NULL DEFAULT 4096,
description TEXT NOT NULL DEFAULT '',
registered_at TEXT NOT NULL,
active INTEGER NOT NULL DEFAULT 1,
default_temperature REAL NOT NULL DEFAULT 0.7,
max_tokens INTEGER NOT NULL DEFAULT 2048
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS agent_model_assignments (
agent_id TEXT PRIMARY KEY,
model_name TEXT NOT NULL,
assigned_at TEXT NOT NULL,
FOREIGN KEY (model_name) REFERENCES custom_models(name)
)
""")
conn.commit()
yield conn
class ModelRegistry:
@@ -106,7 +108,7 @@ class ModelRegistry:
def _load_from_db(self) -> None:
"""Bootstrap cache from SQLite."""
try:
with closing(_get_conn()) as conn:
with _get_conn() as conn:
for row in conn.execute("SELECT * FROM custom_models WHERE active = 1").fetchall():
self._models[row["name"]] = CustomModel(
name=row["name"],
@@ -130,7 +132,7 @@ class ModelRegistry:
def register(self, model: CustomModel) -> CustomModel:
"""Register a new custom model."""
with self._lock:
with closing(_get_conn()) as conn:
with _get_conn() as conn:
conn.execute(
"""
INSERT OR REPLACE INTO custom_models
@@ -161,7 +163,7 @@ class ModelRegistry:
with self._lock:
if name not in self._models:
return False
with closing(_get_conn()) as conn:
with _get_conn() as conn:
conn.execute("DELETE FROM custom_models WHERE name = ?", (name,))
conn.execute("DELETE FROM agent_model_assignments WHERE model_name = ?", (name,))
conn.commit()
@@ -191,7 +193,7 @@ class ModelRegistry:
return False
with self._lock:
model.active = active
with closing(_get_conn()) as conn:
with _get_conn() as conn:
conn.execute(
"UPDATE custom_models SET active = ? WHERE name = ?",
(int(active), name),
@@ -207,7 +209,7 @@ class ModelRegistry:
return False
with self._lock:
now = datetime.now(UTC).isoformat()
with closing(_get_conn()) as conn:
with _get_conn() as conn:
conn.execute(
"""
INSERT OR REPLACE INTO agent_model_assignments
@@ -226,7 +228,7 @@ class ModelRegistry:
with self._lock:
if agent_id not in self._agent_assignments:
return False
with closing(_get_conn()) as conn:
with _get_conn() as conn:
conn.execute(
"DELETE FROM agent_model_assignments WHERE agent_id = ?",
(agent_id,),

View File

@@ -16,6 +16,8 @@ import json
import logging
import sqlite3
import uuid
from collections.abc import Generator
from contextlib import closing, contextmanager
from dataclasses import dataclass
from datetime import UTC, datetime
from pathlib import Path
@@ -39,28 +41,31 @@ class Prediction:
evaluated_at: str | None
def _get_conn() -> sqlite3.Connection:
@contextmanager
def _get_conn() -> Generator[sqlite3.Connection, None, None]:
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
conn.execute("""
CREATE TABLE IF NOT EXISTS spark_predictions (
id TEXT PRIMARY KEY,
task_id TEXT NOT NULL,
prediction_type TEXT NOT NULL,
predicted_value TEXT NOT NULL,
actual_value TEXT,
accuracy REAL,
created_at TEXT NOT NULL,
evaluated_at TEXT
with closing(sqlite3.connect(str(DB_PATH))) as conn:
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
conn.execute("""
CREATE TABLE IF NOT EXISTS spark_predictions (
id TEXT PRIMARY KEY,
task_id TEXT NOT NULL,
prediction_type TEXT NOT NULL,
predicted_value TEXT NOT NULL,
actual_value TEXT,
accuracy REAL,
created_at TEXT NOT NULL,
evaluated_at TEXT
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_task ON spark_predictions(task_id)")
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_pred_type ON spark_predictions(prediction_type)"
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_task ON spark_predictions(task_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_type ON spark_predictions(prediction_type)")
conn.commit()
return conn
conn.commit()
yield conn
# ── Prediction phase ────────────────────────────────────────────────────────
@@ -119,17 +124,16 @@ def predict_task_outcome(
# Store prediction
pred_id = str(uuid.uuid4())
now = datetime.now(UTC).isoformat()
conn = _get_conn()
conn.execute(
"""
INSERT INTO spark_predictions
(id, task_id, prediction_type, predicted_value, created_at)
VALUES (?, ?, ?, ?, ?)
""",
(pred_id, task_id, "outcome", json.dumps(prediction), now),
)
conn.commit()
conn.close()
with _get_conn() as conn:
conn.execute(
"""
INSERT INTO spark_predictions
(id, task_id, prediction_type, predicted_value, created_at)
VALUES (?, ?, ?, ?, ?)
""",
(pred_id, task_id, "outcome", json.dumps(prediction), now),
)
conn.commit()
prediction["prediction_id"] = pred_id
return prediction
@@ -148,41 +152,39 @@ def evaluate_prediction(
Returns the evaluation result or None if no prediction exists.
"""
conn = _get_conn()
row = conn.execute(
"""
SELECT * FROM spark_predictions
WHERE task_id = ? AND prediction_type = 'outcome' AND evaluated_at IS NULL
ORDER BY created_at DESC LIMIT 1
""",
(task_id,),
).fetchone()
with _get_conn() as conn:
row = conn.execute(
"""
SELECT * FROM spark_predictions
WHERE task_id = ? AND prediction_type = 'outcome' AND evaluated_at IS NULL
ORDER BY created_at DESC LIMIT 1
""",
(task_id,),
).fetchone()
if not row:
conn.close()
return None
if not row:
return None
predicted = json.loads(row["predicted_value"])
actual = {
"winner": actual_winner,
"succeeded": task_succeeded,
"winning_bid": winning_bid,
}
predicted = json.loads(row["predicted_value"])
actual = {
"winner": actual_winner,
"succeeded": task_succeeded,
"winning_bid": winning_bid,
}
# Calculate accuracy
accuracy = _compute_accuracy(predicted, actual)
now = datetime.now(UTC).isoformat()
# Calculate accuracy
accuracy = _compute_accuracy(predicted, actual)
now = datetime.now(UTC).isoformat()
conn.execute(
"""
UPDATE spark_predictions
SET actual_value = ?, accuracy = ?, evaluated_at = ?
WHERE id = ?
""",
(json.dumps(actual), accuracy, now, row["id"]),
)
conn.commit()
conn.close()
conn.execute(
"""
UPDATE spark_predictions
SET actual_value = ?, accuracy = ?, evaluated_at = ?
WHERE id = ?
""",
(json.dumps(actual), accuracy, now, row["id"]),
)
conn.commit()
return {
"prediction_id": row["id"],
@@ -243,7 +245,6 @@ def get_predictions(
limit: int = 50,
) -> list[Prediction]:
"""Query stored predictions."""
conn = _get_conn()
query = "SELECT * FROM spark_predictions WHERE 1=1"
params: list = []
@@ -256,8 +257,8 @@ def get_predictions(
query += " ORDER BY created_at DESC LIMIT ?"
params.append(limit)
rows = conn.execute(query, params).fetchall()
conn.close()
with _get_conn() as conn:
rows = conn.execute(query, params).fetchall()
return [
Prediction(
id=r["id"],
@@ -275,17 +276,16 @@ def get_predictions(
def get_accuracy_stats() -> dict:
"""Return aggregate accuracy statistics for the EIDOS loop."""
conn = _get_conn()
row = conn.execute("""
SELECT
COUNT(*) AS total_predictions,
COUNT(evaluated_at) AS evaluated,
AVG(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS avg_accuracy,
MIN(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS min_accuracy,
MAX(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS max_accuracy
FROM spark_predictions
""").fetchone()
conn.close()
with _get_conn() as conn:
row = conn.execute("""
SELECT
COUNT(*) AS total_predictions,
COUNT(evaluated_at) AS evaluated,
AVG(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS avg_accuracy,
MIN(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS min_accuracy,
MAX(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS max_accuracy
FROM spark_predictions
""").fetchone()
return {
"total_predictions": row["total_predictions"] or 0,

View File

@@ -13,6 +13,8 @@ spark_memories — consolidated insights extracted from event patterns
import logging
import sqlite3
import uuid
from collections.abc import Generator
from contextlib import closing, contextmanager
from dataclasses import dataclass
from datetime import UTC, datetime
from pathlib import Path
@@ -55,42 +57,43 @@ class SparkMemory:
expires_at: str | None
def _get_conn() -> sqlite3.Connection:
@contextmanager
def _get_conn() -> Generator[sqlite3.Connection, None, None]:
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
conn.execute("""
CREATE TABLE IF NOT EXISTS spark_events (
id TEXT PRIMARY KEY,
event_type TEXT NOT NULL,
agent_id TEXT,
task_id TEXT,
description TEXT NOT NULL DEFAULT '',
data TEXT NOT NULL DEFAULT '{}',
importance REAL NOT NULL DEFAULT 0.5,
created_at TEXT NOT NULL
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS spark_memories (
id TEXT PRIMARY KEY,
memory_type TEXT NOT NULL,
subject TEXT NOT NULL DEFAULT 'system',
content TEXT NOT NULL,
confidence REAL NOT NULL DEFAULT 0.5,
source_events INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL,
expires_at TEXT
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_type ON spark_events(event_type)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_agent ON spark_events(agent_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_task ON spark_events(task_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_subject ON spark_memories(subject)")
conn.commit()
return conn
with closing(sqlite3.connect(str(DB_PATH))) as conn:
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
conn.execute("""
CREATE TABLE IF NOT EXISTS spark_events (
id TEXT PRIMARY KEY,
event_type TEXT NOT NULL,
agent_id TEXT,
task_id TEXT,
description TEXT NOT NULL DEFAULT '',
data TEXT NOT NULL DEFAULT '{}',
importance REAL NOT NULL DEFAULT 0.5,
created_at TEXT NOT NULL
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS spark_memories (
id TEXT PRIMARY KEY,
memory_type TEXT NOT NULL,
subject TEXT NOT NULL DEFAULT 'system',
content TEXT NOT NULL,
confidence REAL NOT NULL DEFAULT 0.5,
source_events INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL,
expires_at TEXT
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_type ON spark_events(event_type)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_agent ON spark_events(agent_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_task ON spark_events(task_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_subject ON spark_memories(subject)")
conn.commit()
yield conn
# ── Importance scoring ──────────────────────────────────────────────────────
@@ -149,17 +152,16 @@ def record_event(
parsed = {}
importance = score_importance(event_type, parsed)
conn = _get_conn()
conn.execute(
"""
INSERT INTO spark_events
(id, event_type, agent_id, task_id, description, data, importance, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(event_id, event_type, agent_id, task_id, description, data, importance, now),
)
conn.commit()
conn.close()
with _get_conn() as conn:
conn.execute(
"""
INSERT INTO spark_events
(id, event_type, agent_id, task_id, description, data, importance, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(event_id, event_type, agent_id, task_id, description, data, importance, now),
)
conn.commit()
# Bridge to unified event log so all events are queryable from one place
try:
@@ -188,7 +190,6 @@ def get_events(
min_importance: float = 0.0,
) -> list[SparkEvent]:
"""Query events with optional filters."""
conn = _get_conn()
query = "SELECT * FROM spark_events WHERE importance >= ?"
params: list = [min_importance]
@@ -205,8 +206,8 @@ def get_events(
query += " ORDER BY created_at DESC LIMIT ?"
params.append(limit)
rows = conn.execute(query, params).fetchall()
conn.close()
with _get_conn() as conn:
rows = conn.execute(query, params).fetchall()
return [
SparkEvent(
id=r["id"],
@@ -224,15 +225,14 @@ def get_events(
def count_events(event_type: str | None = None) -> int:
"""Count events, optionally filtered by type."""
conn = _get_conn()
if event_type:
row = conn.execute(
"SELECT COUNT(*) FROM spark_events WHERE event_type = ?",
(event_type,),
).fetchone()
else:
row = conn.execute("SELECT COUNT(*) FROM spark_events").fetchone()
conn.close()
with _get_conn() as conn:
if event_type:
row = conn.execute(
"SELECT COUNT(*) FROM spark_events WHERE event_type = ?",
(event_type,),
).fetchone()
else:
row = conn.execute("SELECT COUNT(*) FROM spark_events").fetchone()
return row[0]
@@ -250,17 +250,16 @@ def store_memory(
"""Store a consolidated memory. Returns the memory id."""
mem_id = str(uuid.uuid4())
now = datetime.now(UTC).isoformat()
conn = _get_conn()
conn.execute(
"""
INSERT INTO spark_memories
(id, memory_type, subject, content, confidence, source_events, created_at, expires_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(mem_id, memory_type, subject, content, confidence, source_events, now, expires_at),
)
conn.commit()
conn.close()
with _get_conn() as conn:
conn.execute(
"""
INSERT INTO spark_memories
(id, memory_type, subject, content, confidence, source_events, created_at, expires_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(mem_id, memory_type, subject, content, confidence, source_events, now, expires_at),
)
conn.commit()
return mem_id
@@ -271,7 +270,6 @@ def get_memories(
limit: int = 50,
) -> list[SparkMemory]:
"""Query memories with optional filters."""
conn = _get_conn()
query = "SELECT * FROM spark_memories WHERE confidence >= ?"
params: list = [min_confidence]
@@ -285,8 +283,8 @@ def get_memories(
query += " ORDER BY created_at DESC LIMIT ?"
params.append(limit)
rows = conn.execute(query, params).fetchall()
conn.close()
with _get_conn() as conn:
rows = conn.execute(query, params).fetchall()
return [
SparkMemory(
id=r["id"],
@@ -304,13 +302,12 @@ def get_memories(
def count_memories(memory_type: str | None = None) -> int:
"""Count memories, optionally filtered by type."""
conn = _get_conn()
if memory_type:
row = conn.execute(
"SELECT COUNT(*) FROM spark_memories WHERE memory_type = ?",
(memory_type,),
).fetchone()
else:
row = conn.execute("SELECT COUNT(*) FROM spark_memories").fetchone()
conn.close()
with _get_conn() as conn:
if memory_type:
row = conn.execute(
"SELECT COUNT(*) FROM spark_memories WHERE memory_type = ?",
(memory_type,),
).fetchone()
else:
row = conn.execute("SELECT COUNT(*) FROM spark_memories").fetchone()
return row[0]

View File

@@ -13,7 +13,8 @@ Default is always True. The owner changes this intentionally.
import sqlite3
import uuid
from contextlib import closing
from collections.abc import Generator
from contextlib import closing, contextmanager
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from pathlib import Path
@@ -44,23 +45,24 @@ class ApprovalItem:
status: str # "pending" | "approved" | "rejected"
def _get_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection:
@contextmanager
def _get_conn(db_path: Path = _DEFAULT_DB) -> Generator[sqlite3.Connection, None, None]:
db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
conn.execute("""
CREATE TABLE IF NOT EXISTS approval_items (
id TEXT PRIMARY KEY,
title TEXT NOT NULL,
description TEXT NOT NULL,
proposed_action TEXT NOT NULL,
impact TEXT NOT NULL DEFAULT 'low',
created_at TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending'
)
""")
conn.commit()
return conn
with closing(sqlite3.connect(str(db_path))) as conn:
conn.row_factory = sqlite3.Row
conn.execute("""
CREATE TABLE IF NOT EXISTS approval_items (
id TEXT PRIMARY KEY,
title TEXT NOT NULL,
description TEXT NOT NULL,
proposed_action TEXT NOT NULL,
impact TEXT NOT NULL DEFAULT 'low',
created_at TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending'
)
""")
conn.commit()
yield conn
def _row_to_item(row: sqlite3.Row) -> ApprovalItem:
@@ -97,7 +99,7 @@ def create_item(
created_at=datetime.now(UTC),
status="pending",
)
with closing(_get_conn(db_path)) as conn:
with _get_conn(db_path) as conn:
conn.execute(
"""
INSERT INTO approval_items
@@ -120,7 +122,7 @@ def create_item(
def list_pending(db_path: Path = _DEFAULT_DB) -> list[ApprovalItem]:
"""Return all pending approval items, newest first."""
with closing(_get_conn(db_path)) as conn:
with _get_conn(db_path) as conn:
rows = conn.execute(
"SELECT * FROM approval_items WHERE status = 'pending' ORDER BY created_at DESC"
).fetchall()
@@ -129,20 +131,20 @@ def list_pending(db_path: Path = _DEFAULT_DB) -> list[ApprovalItem]:
def list_all(db_path: Path = _DEFAULT_DB) -> list[ApprovalItem]:
"""Return all approval items regardless of status, newest first."""
with closing(_get_conn(db_path)) as conn:
with _get_conn(db_path) as conn:
rows = conn.execute("SELECT * FROM approval_items ORDER BY created_at DESC").fetchall()
return [_row_to_item(r) for r in rows]
def get_item(item_id: str, db_path: Path = _DEFAULT_DB) -> ApprovalItem | None:
with closing(_get_conn(db_path)) as conn:
with _get_conn(db_path) as conn:
row = conn.execute("SELECT * FROM approval_items WHERE id = ?", (item_id,)).fetchone()
return _row_to_item(row) if row else None
def approve(item_id: str, db_path: Path = _DEFAULT_DB) -> ApprovalItem | None:
"""Mark an approval item as approved."""
with closing(_get_conn(db_path)) as conn:
with _get_conn(db_path) as conn:
conn.execute("UPDATE approval_items SET status = 'approved' WHERE id = ?", (item_id,))
conn.commit()
return get_item(item_id, db_path)
@@ -150,7 +152,7 @@ def approve(item_id: str, db_path: Path = _DEFAULT_DB) -> ApprovalItem | None:
def reject(item_id: str, db_path: Path = _DEFAULT_DB) -> ApprovalItem | None:
"""Mark an approval item as rejected."""
with closing(_get_conn(db_path)) as conn:
with _get_conn(db_path) as conn:
conn.execute("UPDATE approval_items SET status = 'rejected' WHERE id = ?", (item_id,))
conn.commit()
return get_item(item_id, db_path)
@@ -159,7 +161,7 @@ def reject(item_id: str, db_path: Path = _DEFAULT_DB) -> ApprovalItem | None:
def expire_old(db_path: Path = _DEFAULT_DB) -> int:
"""Auto-expire pending items older than EXPIRY_DAYS. Returns count removed."""
cutoff = (datetime.now(UTC) - timedelta(days=_EXPIRY_DAYS)).isoformat()
with closing(_get_conn(db_path)) as conn:
with _get_conn(db_path) as conn:
cursor = conn.execute(
"DELETE FROM approval_items WHERE status = 'pending' AND created_at < ?",
(cutoff,),

View File

@@ -10,7 +10,8 @@ regenerates the briefing every 6 hours.
import logging
import sqlite3
from contextlib import closing
from collections.abc import Generator
from contextlib import closing, contextmanager
from dataclasses import dataclass, field
from datetime import UTC, datetime, timedelta
from pathlib import Path
@@ -57,25 +58,26 @@ class Briefing:
# ---------------------------------------------------------------------------
def _get_cache_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection:
@contextmanager
def _get_cache_conn(db_path: Path = _DEFAULT_DB) -> Generator[sqlite3.Connection, None, None]:
db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
conn.execute("""
CREATE TABLE IF NOT EXISTS briefings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
generated_at TEXT NOT NULL,
period_start TEXT NOT NULL,
period_end TEXT NOT NULL,
summary TEXT NOT NULL
)
""")
conn.commit()
return conn
with closing(sqlite3.connect(str(db_path))) as conn:
conn.row_factory = sqlite3.Row
conn.execute("""
CREATE TABLE IF NOT EXISTS briefings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
generated_at TEXT NOT NULL,
period_start TEXT NOT NULL,
period_end TEXT NOT NULL,
summary TEXT NOT NULL
)
""")
conn.commit()
yield conn
def _save_briefing(briefing: Briefing, db_path: Path = _DEFAULT_DB) -> None:
with closing(_get_cache_conn(db_path)) as conn:
with _get_cache_conn(db_path) as conn:
conn.execute(
"""
INSERT INTO briefings (generated_at, period_start, period_end, summary)
@@ -93,7 +95,7 @@ def _save_briefing(briefing: Briefing, db_path: Path = _DEFAULT_DB) -> None:
def _load_latest(db_path: Path = _DEFAULT_DB) -> Briefing | None:
"""Load the most-recently cached briefing, or None if there is none."""
with closing(_get_cache_conn(db_path)) as conn:
with _get_cache_conn(db_path) as conn:
row = conn.execute("SELECT * FROM briefings ORDER BY generated_at DESC LIMIT 1").fetchone()
if row is None:
return None

View File

@@ -11,6 +11,8 @@ All three tables live in ``data/memory.db``. Existing APIs in
import logging
import sqlite3
from collections.abc import Generator
from contextlib import closing, contextmanager
from pathlib import Path
logger = logging.getLogger(__name__)
@@ -18,15 +20,16 @@ logger = logging.getLogger(__name__)
DB_PATH = Path(__file__).parent.parent.parent.parent / "data" / "memory.db"
def get_connection() -> sqlite3.Connection:
@contextmanager
def get_connection() -> Generator[sqlite3.Connection, None, None]:
"""Open (and lazily create) the unified memory database."""
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
_ensure_schema(conn)
return conn
with closing(sqlite3.connect(str(DB_PATH))) as conn:
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
_ensure_schema(conn)
yield conn
def _ensure_schema(conn: sqlite3.Connection) -> None:

View File

@@ -8,6 +8,8 @@ import json
import logging
import sqlite3
import uuid
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass, field
from datetime import UTC, datetime
@@ -54,11 +56,13 @@ class MemoryEntry:
relevance_score: float | None = None # Set during search
def _get_conn() -> sqlite3.Connection:
@contextmanager
def _get_conn() -> Generator[sqlite3.Connection, None, None]:
"""Get database connection to unified memory.db."""
from timmy.memory.unified import get_connection
return get_connection()
with get_connection() as conn:
yield conn
def store_memory(
@@ -101,29 +105,28 @@ def store_memory(
embedding=embedding,
)
conn = _get_conn()
conn.execute(
"""
INSERT INTO episodes
(id, content, source, context_type, agent_id, task_id, session_id,
metadata, embedding, timestamp)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
entry.id,
entry.content,
entry.source,
entry.context_type,
entry.agent_id,
entry.task_id,
entry.session_id,
json.dumps(metadata) if metadata else None,
json.dumps(embedding) if embedding else None,
entry.timestamp,
),
)
conn.commit()
conn.close()
with _get_conn() as conn:
conn.execute(
"""
INSERT INTO episodes
(id, content, source, context_type, agent_id, task_id, session_id,
metadata, embedding, timestamp)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
entry.id,
entry.content,
entry.source,
entry.context_type,
entry.agent_id,
entry.task_id,
entry.session_id,
json.dumps(metadata) if metadata else None,
json.dumps(embedding) if embedding else None,
entry.timestamp,
),
)
conn.commit()
return entry
@@ -151,8 +154,6 @@ def search_memories(
"""
query_embedding = _compute_embedding(query)
conn = _get_conn()
# Build query with filters
conditions = []
params = []
@@ -179,8 +180,8 @@ def search_memories(
"""
params.append(limit * 3) # Get more candidates for ranking
rows = conn.execute(query_sql, params).fetchall()
conn.close()
with _get_conn() as conn:
rows = conn.execute(query_sql, params).fetchall()
# Compute similarity scores
results = []
@@ -275,58 +276,54 @@ def recall_personal_facts(agent_id: str | None = None) -> list[str]:
Returns:
List of fact strings
"""
conn = _get_conn()
with _get_conn() as conn:
if agent_id:
rows = conn.execute(
"""
SELECT content FROM episodes
WHERE context_type = 'fact' AND agent_id = ?
ORDER BY timestamp DESC
LIMIT 100
""",
(agent_id,),
).fetchall()
else:
rows = conn.execute(
"""
SELECT content FROM episodes
WHERE context_type = 'fact'
ORDER BY timestamp DESC
LIMIT 100
""",
).fetchall()
if agent_id:
rows = conn.execute(
"""
SELECT content FROM episodes
WHERE context_type = 'fact' AND agent_id = ?
ORDER BY timestamp DESC
LIMIT 100
""",
(agent_id,),
).fetchall()
else:
rows = conn.execute(
"""
SELECT content FROM episodes
WHERE context_type = 'fact'
ORDER BY timestamp DESC
LIMIT 100
""",
).fetchall()
conn.close()
return [r["content"] for r in rows]
def recall_personal_facts_with_ids(agent_id: str | None = None) -> list[dict]:
"""Recall personal facts with their IDs for edit/delete operations."""
conn = _get_conn()
if agent_id:
rows = conn.execute(
"SELECT id, content FROM episodes WHERE context_type = 'fact' AND agent_id = ? ORDER BY timestamp DESC LIMIT 100",
(agent_id,),
).fetchall()
else:
rows = conn.execute(
"SELECT id, content FROM episodes WHERE context_type = 'fact' ORDER BY timestamp DESC LIMIT 100",
).fetchall()
conn.close()
with _get_conn() as conn:
if agent_id:
rows = conn.execute(
"SELECT id, content FROM episodes WHERE context_type = 'fact' AND agent_id = ? ORDER BY timestamp DESC LIMIT 100",
(agent_id,),
).fetchall()
else:
rows = conn.execute(
"SELECT id, content FROM episodes WHERE context_type = 'fact' ORDER BY timestamp DESC LIMIT 100",
).fetchall()
return [{"id": r["id"], "content": r["content"]} for r in rows]
def update_personal_fact(memory_id: str, new_content: str) -> bool:
"""Update a personal fact's content."""
conn = _get_conn()
cursor = conn.execute(
"UPDATE episodes SET content = ? WHERE id = ? AND context_type = 'fact'",
(new_content, memory_id),
)
conn.commit()
updated = cursor.rowcount > 0
conn.close()
with _get_conn() as conn:
cursor = conn.execute(
"UPDATE episodes SET content = ? WHERE id = ? AND context_type = 'fact'",
(new_content, memory_id),
)
conn.commit()
updated = cursor.rowcount > 0
return updated
@@ -355,14 +352,13 @@ def delete_memory(memory_id: str) -> bool:
Returns:
True if deleted, False if not found
"""
conn = _get_conn()
cursor = conn.execute(
"DELETE FROM episodes WHERE id = ?",
(memory_id,),
)
conn.commit()
deleted = cursor.rowcount > 0
conn.close()
with _get_conn() as conn:
cursor = conn.execute(
"DELETE FROM episodes WHERE id = ?",
(memory_id,),
)
conn.commit()
deleted = cursor.rowcount > 0
return deleted
@@ -372,22 +368,19 @@ def get_memory_stats() -> dict:
Returns:
Dict with counts by type, total entries, etc.
"""
conn = _get_conn()
with _get_conn() as conn:
total = conn.execute("SELECT COUNT(*) as count FROM episodes").fetchone()["count"]
total = conn.execute("SELECT COUNT(*) as count FROM episodes").fetchone()["count"]
by_type = {}
rows = conn.execute(
"SELECT context_type, COUNT(*) as count FROM episodes GROUP BY context_type"
).fetchall()
for row in rows:
by_type[row["context_type"]] = row["count"]
by_type = {}
rows = conn.execute(
"SELECT context_type, COUNT(*) as count FROM episodes GROUP BY context_type"
).fetchall()
for row in rows:
by_type[row["context_type"]] = row["count"]
with_embeddings = conn.execute(
"SELECT COUNT(*) as count FROM episodes WHERE embedding IS NOT NULL"
).fetchone()["count"]
conn.close()
with_embeddings = conn.execute(
"SELECT COUNT(*) as count FROM episodes WHERE embedding IS NOT NULL"
).fetchone()["count"]
return {
"total_entries": total,
@@ -411,24 +404,22 @@ def prune_memories(older_than_days: int = 90, keep_facts: bool = True) -> int:
cutoff = (datetime.now(UTC) - timedelta(days=older_than_days)).isoformat()
conn = _get_conn()
with _get_conn() as conn:
if keep_facts:
cursor = conn.execute(
"""
DELETE FROM episodes
WHERE timestamp < ? AND context_type != 'fact'
""",
(cutoff,),
)
else:
cursor = conn.execute(
"DELETE FROM episodes WHERE timestamp < ?",
(cutoff,),
)
if keep_facts:
cursor = conn.execute(
"""
DELETE FROM episodes
WHERE timestamp < ? AND context_type != 'fact'
""",
(cutoff,),
)
else:
cursor = conn.execute(
"DELETE FROM episodes WHERE timestamp < ?",
(cutoff,),
)
deleted = cursor.rowcount
conn.commit()
conn.close()
deleted = cursor.rowcount
conn.commit()
return deleted

View File

@@ -21,7 +21,8 @@ import logging
import random
import sqlite3
import uuid
from contextlib import closing
from collections.abc import Generator
from contextlib import closing, contextmanager
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from difflib import SequenceMatcher
@@ -169,23 +170,24 @@ class Thought:
created_at: str
def _get_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection:
@contextmanager
def _get_conn(db_path: Path = _DEFAULT_DB) -> Generator[sqlite3.Connection, None, None]:
"""Get a SQLite connection with the thoughts table created."""
db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
conn.execute("""
CREATE TABLE IF NOT EXISTS thoughts (
id TEXT PRIMARY KEY,
content TEXT NOT NULL,
seed_type TEXT NOT NULL,
parent_id TEXT,
created_at TEXT NOT NULL
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_thoughts_time ON thoughts(created_at)")
conn.commit()
return conn
with closing(sqlite3.connect(str(db_path))) as conn:
conn.row_factory = sqlite3.Row
conn.execute("""
CREATE TABLE IF NOT EXISTS thoughts (
id TEXT PRIMARY KEY,
content TEXT NOT NULL,
seed_type TEXT NOT NULL,
parent_id TEXT,
created_at TEXT NOT NULL
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_thoughts_time ON thoughts(created_at)")
conn.commit()
yield conn
def _row_to_thought(row: sqlite3.Row) -> Thought:
@@ -321,7 +323,7 @@ class ThinkingEngine:
def get_recent_thoughts(self, limit: int = 20) -> list[Thought]:
"""Retrieve the most recent thoughts."""
with closing(_get_conn(self._db_path)) as conn:
with _get_conn(self._db_path) as conn:
rows = conn.execute(
"SELECT * FROM thoughts ORDER BY created_at DESC LIMIT ?",
(limit,),
@@ -330,7 +332,7 @@ class ThinkingEngine:
def get_thought(self, thought_id: str) -> Thought | None:
"""Retrieve a single thought by ID."""
with closing(_get_conn(self._db_path)) as conn:
with _get_conn(self._db_path) as conn:
row = conn.execute("SELECT * FROM thoughts WHERE id = ?", (thought_id,)).fetchone()
return _row_to_thought(row) if row else None
@@ -342,7 +344,7 @@ class ThinkingEngine:
chain = []
current_id: str | None = thought_id
with closing(_get_conn(self._db_path)) as conn:
with _get_conn(self._db_path) as conn:
for _ in range(max_depth):
if not current_id:
break
@@ -357,7 +359,7 @@ class ThinkingEngine:
def count_thoughts(self) -> int:
"""Return total number of stored thoughts."""
with closing(_get_conn(self._db_path)) as conn:
with _get_conn(self._db_path) as conn:
count = conn.execute("SELECT COUNT(*) as c FROM thoughts").fetchone()["c"]
return count
@@ -366,7 +368,7 @@ class ThinkingEngine:
Returns the number of deleted rows.
"""
with closing(_get_conn(self._db_path)) as conn:
with _get_conn(self._db_path) as conn:
try:
total = conn.execute("SELECT COUNT(*) as c FROM thoughts").fetchone()["c"]
if total <= keep_min:
@@ -603,7 +605,7 @@ class ThinkingEngine:
# Thought count today (cheap DB query)
try:
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
with closing(_get_conn(self._db_path)) as conn:
with _get_conn(self._db_path) as conn:
count = conn.execute(
"SELECT COUNT(*) as c FROM thoughts WHERE created_at >= ?",
(today_start.isoformat(),),
@@ -960,7 +962,7 @@ class ThinkingEngine:
created_at=datetime.now(UTC).isoformat(),
)
with closing(_get_conn(self._db_path)) as conn:
with _get_conn(self._db_path) as conn:
conn.execute(
"""
INSERT INTO thoughts (id, content, seed_type, parent_id, created_at)

View File

@@ -216,21 +216,15 @@ class TestWALMode:
with patch("infrastructure.models.registry.DB_PATH", db):
from infrastructure.models.registry import _get_conn
conn = _get_conn()
try:
with _get_conn() as conn:
mode = conn.execute("PRAGMA journal_mode").fetchone()[0]
assert mode == "wal"
finally:
conn.close()
def test_registry_db_busy_timeout(self, tmp_path):
db = tmp_path / "wal_test.db"
with patch("infrastructure.models.registry.DB_PATH", db):
from infrastructure.models.registry import _get_conn
conn = _get_conn()
try:
with _get_conn() as conn:
timeout = conn.execute("PRAGMA busy_timeout").fetchone()[0]
assert timeout == 5000
finally:
conn.close()

View File

@@ -481,29 +481,20 @@ class TestWALMode:
def test_spark_memory_uses_wal(self):
from spark.memory import _get_conn
conn = _get_conn()
try:
with _get_conn() as conn:
mode = conn.execute("PRAGMA journal_mode").fetchone()[0]
assert mode == "wal", f"Expected WAL mode, got {mode}"
finally:
conn.close()
def test_spark_eidos_uses_wal(self):
from spark.eidos import _get_conn
conn = _get_conn()
try:
with _get_conn() as conn:
mode = conn.execute("PRAGMA journal_mode").fetchone()[0]
assert mode == "wal", f"Expected WAL mode, got {mode}"
finally:
conn.close()
def test_spark_memory_busy_timeout(self):
from spark.memory import _get_conn
conn = _get_conn()
try:
with _get_conn() as conn:
timeout = conn.execute("PRAGMA busy_timeout").fetchone()[0]
assert timeout == 5000
finally:
conn.close()

View File

@@ -142,14 +142,13 @@ class TestExpireOld:
# Create item and manually backdate it
item = create_item("Old", "D", "A", db_path=db_path)
conn = _get_conn(db_path)
old_date = (datetime.now(UTC) - timedelta(days=30)).isoformat()
conn.execute(
"UPDATE approval_items SET created_at = ? WHERE id = ?",
(old_date, item.id),
)
conn.commit()
conn.close()
with _get_conn(db_path) as conn:
conn.execute(
"UPDATE approval_items SET created_at = ? WHERE id = ?",
(old_date, item.id),
)
conn.commit()
count = expire_old(db_path)
assert count == 1
@@ -169,14 +168,13 @@ class TestExpireOld:
approve(item.id, db_path)
# Backdate it
conn = _get_conn(db_path)
old_date = (datetime.now(UTC) - timedelta(days=30)).isoformat()
conn.execute(
"UPDATE approval_items SET created_at = ? WHERE id = ?",
(old_date, item.id),
)
conn.commit()
conn.close()
with _get_conn(db_path) as conn:
conn.execute(
"UPDATE approval_items SET created_at = ? WHERE id = ?",
(old_date, item.id),
)
conn.commit()
count = expire_old(db_path)
assert count == 0 # approved items not expired