forked from Rockachopa/Timmy-time-dashboard
[loop-cycle-50] refactor: replace bare sqlite3.connect() with context managers batch 2 (#157) (#180)
This commit is contained in:
@@ -209,14 +209,11 @@ async def api_swarm_status():
|
||||
|
||||
pending_tasks = 0
|
||||
try:
|
||||
db = _get_db()
|
||||
try:
|
||||
with _get_db() as db:
|
||||
row = db.execute(
|
||||
"SELECT COUNT(*) as cnt FROM tasks WHERE status IN ('pending_approval','approved')"
|
||||
).fetchone()
|
||||
pending_tasks = row["cnt"] if row else 0
|
||||
finally:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
import logging
|
||||
import sqlite3
|
||||
import uuid
|
||||
from contextlib import closing
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing, contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
@@ -36,26 +37,27 @@ VALID_STATUSES = {
|
||||
VALID_PRIORITIES = {"low", "normal", "high", "urgent"}
|
||||
|
||||
|
||||
def _get_db() -> sqlite3.Connection:
|
||||
@contextmanager
|
||||
def _get_db() -> Generator[sqlite3.Connection, None, None]:
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
id TEXT PRIMARY KEY,
|
||||
title TEXT NOT NULL,
|
||||
description TEXT DEFAULT '',
|
||||
status TEXT DEFAULT 'pending_approval',
|
||||
priority TEXT DEFAULT 'normal',
|
||||
assigned_to TEXT DEFAULT '',
|
||||
created_by TEXT DEFAULT 'operator',
|
||||
result TEXT DEFAULT '',
|
||||
created_at TEXT DEFAULT (datetime('now')),
|
||||
completed_at TEXT
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
return conn
|
||||
with closing(sqlite3.connect(str(DB_PATH))) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
id TEXT PRIMARY KEY,
|
||||
title TEXT NOT NULL,
|
||||
description TEXT DEFAULT '',
|
||||
status TEXT DEFAULT 'pending_approval',
|
||||
priority TEXT DEFAULT 'normal',
|
||||
assigned_to TEXT DEFAULT '',
|
||||
created_by TEXT DEFAULT 'operator',
|
||||
result TEXT DEFAULT '',
|
||||
created_at TEXT DEFAULT (datetime('now')),
|
||||
completed_at TEXT
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
yield conn
|
||||
|
||||
|
||||
def _row_to_dict(row: sqlite3.Row) -> dict:
|
||||
@@ -102,7 +104,7 @@ class _TaskView:
|
||||
@router.get("/tasks", response_class=HTMLResponse)
|
||||
async def tasks_page(request: Request):
|
||||
"""Render the main task queue page with 3-column layout."""
|
||||
with closing(_get_db()) as db:
|
||||
with _get_db() as db:
|
||||
pending = [
|
||||
_TaskView(_row_to_dict(r))
|
||||
for r in db.execute(
|
||||
@@ -143,7 +145,7 @@ async def tasks_page(request: Request):
|
||||
|
||||
@router.get("/tasks/pending", response_class=HTMLResponse)
|
||||
async def tasks_pending(request: Request):
|
||||
with closing(_get_db()) as db:
|
||||
with _get_db() as db:
|
||||
rows = db.execute(
|
||||
"SELECT * FROM tasks WHERE status='pending_approval' ORDER BY created_at DESC"
|
||||
).fetchall()
|
||||
@@ -162,7 +164,7 @@ async def tasks_pending(request: Request):
|
||||
|
||||
@router.get("/tasks/active", response_class=HTMLResponse)
|
||||
async def tasks_active(request: Request):
|
||||
with closing(_get_db()) as db:
|
||||
with _get_db() as db:
|
||||
rows = db.execute(
|
||||
"SELECT * FROM tasks WHERE status IN ('approved','running','paused') ORDER BY created_at DESC"
|
||||
).fetchall()
|
||||
@@ -181,7 +183,7 @@ async def tasks_active(request: Request):
|
||||
|
||||
@router.get("/tasks/completed", response_class=HTMLResponse)
|
||||
async def tasks_completed(request: Request):
|
||||
with closing(_get_db()) as db:
|
||||
with _get_db() as db:
|
||||
rows = db.execute(
|
||||
"SELECT * FROM tasks WHERE status IN ('completed','vetoed','failed') ORDER BY completed_at DESC LIMIT 50"
|
||||
).fetchall()
|
||||
@@ -220,7 +222,7 @@ async def create_task_form(
|
||||
now = datetime.utcnow().isoformat()
|
||||
priority = priority if priority in VALID_PRIORITIES else "normal"
|
||||
|
||||
with closing(_get_db()) as db:
|
||||
with _get_db() as db:
|
||||
db.execute(
|
||||
"INSERT INTO tasks (id, title, description, priority, assigned_to, created_at) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(task_id, title, description, priority, assigned_to, now),
|
||||
@@ -269,7 +271,7 @@ async def modify_task(
|
||||
title: str = Form(...),
|
||||
description: str = Form(""),
|
||||
):
|
||||
with closing(_get_db()) as db:
|
||||
with _get_db() as db:
|
||||
db.execute(
|
||||
"UPDATE tasks SET title=?, description=? WHERE id=?",
|
||||
(title, description, task_id),
|
||||
@@ -287,7 +289,7 @@ async def _set_status(request: Request, task_id: str, new_status: str):
|
||||
completed_at = (
|
||||
datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None
|
||||
)
|
||||
with closing(_get_db()) as db:
|
||||
with _get_db() as db:
|
||||
db.execute(
|
||||
"UPDATE tasks SET status=?, completed_at=COALESCE(?, completed_at) WHERE id=?",
|
||||
(new_status, completed_at, task_id),
|
||||
@@ -319,7 +321,7 @@ async def api_create_task(request: Request):
|
||||
if priority not in VALID_PRIORITIES:
|
||||
priority = "normal"
|
||||
|
||||
with closing(_get_db()) as db:
|
||||
with _get_db() as db:
|
||||
db.execute(
|
||||
"INSERT INTO tasks (id, title, description, priority, assigned_to, created_by, created_at) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||
@@ -342,7 +344,7 @@ async def api_create_task(request: Request):
|
||||
@router.get("/api/tasks", response_class=JSONResponse)
|
||||
async def api_list_tasks():
|
||||
"""List all tasks as JSON."""
|
||||
with closing(_get_db()) as db:
|
||||
with _get_db() as db:
|
||||
rows = db.execute("SELECT * FROM tasks ORDER BY created_at DESC").fetchall()
|
||||
return JSONResponse([_row_to_dict(r) for r in rows])
|
||||
|
||||
@@ -358,7 +360,7 @@ async def api_update_status(task_id: str, request: Request):
|
||||
completed_at = (
|
||||
datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None
|
||||
)
|
||||
with closing(_get_db()) as db:
|
||||
with _get_db() as db:
|
||||
db.execute(
|
||||
"UPDATE tasks SET status=?, completed_at=COALESCE(?, completed_at) WHERE id=?",
|
||||
(new_status, completed_at, task_id),
|
||||
@@ -373,7 +375,7 @@ async def api_update_status(task_id: str, request: Request):
|
||||
@router.delete("/api/tasks/{task_id}", response_class=JSONResponse)
|
||||
async def api_delete_task(task_id: str):
|
||||
"""Delete a task."""
|
||||
with closing(_get_db()) as db:
|
||||
with _get_db() as db:
|
||||
cursor = db.execute("DELETE FROM tasks WHERE id=?", (task_id,))
|
||||
db.commit()
|
||||
if cursor.rowcount == 0:
|
||||
@@ -389,7 +391,7 @@ async def api_delete_task(task_id: str):
|
||||
@router.get("/api/queue/status", response_class=JSONResponse)
|
||||
async def queue_status(assigned_to: str = "default"):
|
||||
"""Return queue status for the chat panel's agent status indicator."""
|
||||
with closing(_get_db()) as db:
|
||||
with _get_db() as db:
|
||||
running = db.execute(
|
||||
"SELECT * FROM tasks WHERE status='running' AND assigned_to=? LIMIT 1",
|
||||
(assigned_to,),
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
import logging
|
||||
import sqlite3
|
||||
import uuid
|
||||
from contextlib import closing
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing, contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
@@ -24,28 +25,29 @@ CATEGORIES = ["bug", "feature", "suggestion", "maintenance", "security"]
|
||||
VALID_STATUSES = {"submitted", "triaged", "approved", "in_progress", "completed", "rejected"}
|
||||
|
||||
|
||||
def _get_db() -> sqlite3.Connection:
|
||||
@contextmanager
|
||||
def _get_db() -> Generator[sqlite3.Connection, None, None]:
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS work_orders (
|
||||
id TEXT PRIMARY KEY,
|
||||
title TEXT NOT NULL,
|
||||
description TEXT DEFAULT '',
|
||||
priority TEXT DEFAULT 'medium',
|
||||
category TEXT DEFAULT 'suggestion',
|
||||
submitter TEXT DEFAULT 'dashboard',
|
||||
related_files TEXT DEFAULT '',
|
||||
status TEXT DEFAULT 'submitted',
|
||||
result TEXT DEFAULT '',
|
||||
rejection_reason TEXT DEFAULT '',
|
||||
created_at TEXT DEFAULT (datetime('now')),
|
||||
completed_at TEXT
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
return conn
|
||||
with closing(sqlite3.connect(str(DB_PATH))) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS work_orders (
|
||||
id TEXT PRIMARY KEY,
|
||||
title TEXT NOT NULL,
|
||||
description TEXT DEFAULT '',
|
||||
priority TEXT DEFAULT 'medium',
|
||||
category TEXT DEFAULT 'suggestion',
|
||||
submitter TEXT DEFAULT 'dashboard',
|
||||
related_files TEXT DEFAULT '',
|
||||
status TEXT DEFAULT 'submitted',
|
||||
result TEXT DEFAULT '',
|
||||
rejection_reason TEXT DEFAULT '',
|
||||
created_at TEXT DEFAULT (datetime('now')),
|
||||
completed_at TEXT
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
yield conn
|
||||
|
||||
|
||||
class _EnumLike:
|
||||
@@ -105,7 +107,7 @@ def _query_wos(db, statuses):
|
||||
|
||||
@router.get("/work-orders/queue", response_class=HTMLResponse)
|
||||
async def work_orders_page(request: Request):
|
||||
with closing(_get_db()) as db:
|
||||
with _get_db() as db:
|
||||
pending = _query_wos(db, ["submitted", "triaged"])
|
||||
active = _query_wos(db, ["approved", "in_progress"])
|
||||
completed = _query_wos(db, ["completed"])
|
||||
@@ -146,7 +148,7 @@ async def submit_work_order(
|
||||
priority = priority if priority in PRIORITIES else "medium"
|
||||
category = category if category in CATEGORIES else "suggestion"
|
||||
|
||||
with closing(_get_db()) as db:
|
||||
with _get_db() as db:
|
||||
db.execute(
|
||||
"INSERT INTO work_orders (id, title, description, priority, category, submitter, related_files, created_at) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
@@ -166,7 +168,7 @@ async def submit_work_order(
|
||||
|
||||
@router.get("/work-orders/queue/pending", response_class=HTMLResponse)
|
||||
async def pending_partial(request: Request):
|
||||
with closing(_get_db()) as db:
|
||||
with _get_db() as db:
|
||||
wos = _query_wos(db, ["submitted", "triaged"])
|
||||
if not wos:
|
||||
return HTMLResponse(
|
||||
@@ -185,7 +187,7 @@ async def pending_partial(request: Request):
|
||||
|
||||
@router.get("/work-orders/queue/active", response_class=HTMLResponse)
|
||||
async def active_partial(request: Request):
|
||||
with closing(_get_db()) as db:
|
||||
with _get_db() as db:
|
||||
wos = _query_wos(db, ["approved", "in_progress"])
|
||||
if not wos:
|
||||
return HTMLResponse(
|
||||
@@ -211,7 +213,7 @@ async def _update_status(request: Request, wo_id: str, new_status: str, **extra)
|
||||
completed_at = (
|
||||
datetime.utcnow().isoformat() if new_status in ("completed", "rejected") else None
|
||||
)
|
||||
with closing(_get_db()) as db:
|
||||
with _get_db() as db:
|
||||
sets = ["status=?", "completed_at=COALESCE(?, completed_at)"]
|
||||
vals = [new_status, completed_at]
|
||||
for col, val in extra.items():
|
||||
|
||||
@@ -9,6 +9,8 @@ A configurable retention policy (default 500 messages) keeps the DB lean.
|
||||
|
||||
import sqlite3
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing, contextmanager
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
@@ -28,24 +30,25 @@ class Message:
|
||||
source: str = "browser" # "browser" | "api" | "telegram" | "discord" | "system"
|
||||
|
||||
|
||||
def _get_conn(db_path: Path | None = None) -> sqlite3.Connection:
|
||||
@contextmanager
|
||||
def _get_conn(db_path: Path | None = None) -> Generator[sqlite3.Connection, None, None]:
|
||||
"""Open (or create) the chat database and ensure schema exists."""
|
||||
path = db_path or DB_PATH
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(path), check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS chat_messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
source TEXT NOT NULL DEFAULT 'browser'
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
return conn
|
||||
with closing(sqlite3.connect(str(path), check_same_thread=False)) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS chat_messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
source TEXT NOT NULL DEFAULT 'browser'
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
yield conn
|
||||
|
||||
|
||||
class MessageLog:
|
||||
@@ -59,7 +62,23 @@ class MessageLog:
|
||||
# Lazy connection — opened on first use, not at import time.
|
||||
def _ensure_conn(self) -> sqlite3.Connection:
|
||||
if self._conn is None:
|
||||
self._conn = _get_conn(self._db_path)
|
||||
# Open a persistent connection for the class instance
|
||||
path = self._db_path or DB_PATH
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(path), check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS chat_messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
source TEXT NOT NULL DEFAULT 'browser'
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
self._conn = conn
|
||||
return self._conn
|
||||
|
||||
def append(self, role: str, content: str, timestamp: str, source: str = "browser") -> None:
|
||||
|
||||
@@ -9,8 +9,8 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from collections.abc import Callable, Coroutine
|
||||
from contextlib import closing
|
||||
from collections.abc import Callable, Coroutine, Generator
|
||||
from contextlib import closing, contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
@@ -106,22 +106,23 @@ class EventBus:
|
||||
conn.executescript(_EVENTS_SCHEMA)
|
||||
conn.commit()
|
||||
|
||||
def _get_persistence_conn(self) -> sqlite3.Connection | None:
|
||||
@contextmanager
|
||||
def _get_persistence_conn(self) -> Generator[sqlite3.Connection | None, None, None]:
|
||||
"""Get a connection to the persistence database."""
|
||||
if self._persistence_db_path is None:
|
||||
return None
|
||||
conn = sqlite3.connect(str(self._persistence_db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
return conn
|
||||
yield None
|
||||
return
|
||||
with closing(sqlite3.connect(str(self._persistence_db_path))) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
yield conn
|
||||
|
||||
def _persist_event(self, event: Event) -> None:
|
||||
"""Write an event to the persistence database."""
|
||||
conn = self._get_persistence_conn()
|
||||
if conn is None:
|
||||
return
|
||||
try:
|
||||
with closing(conn):
|
||||
with self._get_persistence_conn() as conn:
|
||||
if conn is None:
|
||||
return
|
||||
try:
|
||||
task_id = event.data.get("task_id", "")
|
||||
agent_id = event.data.get("agent_id", "")
|
||||
conn.execute(
|
||||
@@ -139,8 +140,8 @@ class EventBus:
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to persist event: %s", exc)
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to persist event: %s", exc)
|
||||
|
||||
# ── Replay ───────────────────────────────────────────────────────────
|
||||
|
||||
@@ -162,12 +163,11 @@ class EventBus:
|
||||
Returns:
|
||||
List of Event objects from persistent storage.
|
||||
"""
|
||||
conn = self._get_persistence_conn()
|
||||
if conn is None:
|
||||
return []
|
||||
with self._get_persistence_conn() as conn:
|
||||
if conn is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
with closing(conn):
|
||||
try:
|
||||
conditions = []
|
||||
params: list = []
|
||||
|
||||
@@ -197,9 +197,9 @@ class EventBus:
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to replay events: %s", exc)
|
||||
return []
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to replay events: %s", exc)
|
||||
return []
|
||||
|
||||
# ── Subscribe / Publish ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -11,7 +11,8 @@ model roles (student, teacher, judge/PRM) run on dedicated resources.
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
from contextlib import closing
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing, contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
@@ -61,36 +62,37 @@ class CustomModel:
|
||||
self.registered_at = datetime.now(UTC).isoformat()
|
||||
|
||||
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
@contextmanager
|
||||
def _get_conn() -> Generator[sqlite3.Connection, None, None]:
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS custom_models (
|
||||
name TEXT PRIMARY KEY,
|
||||
format TEXT NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
role TEXT NOT NULL DEFAULT 'general',
|
||||
context_window INTEGER NOT NULL DEFAULT 4096,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
registered_at TEXT NOT NULL,
|
||||
active INTEGER NOT NULL DEFAULT 1,
|
||||
default_temperature REAL NOT NULL DEFAULT 0.7,
|
||||
max_tokens INTEGER NOT NULL DEFAULT 2048
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS agent_model_assignments (
|
||||
agent_id TEXT PRIMARY KEY,
|
||||
model_name TEXT NOT NULL,
|
||||
assigned_at TEXT NOT NULL,
|
||||
FOREIGN KEY (model_name) REFERENCES custom_models(name)
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
return conn
|
||||
with closing(sqlite3.connect(str(DB_PATH))) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS custom_models (
|
||||
name TEXT PRIMARY KEY,
|
||||
format TEXT NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
role TEXT NOT NULL DEFAULT 'general',
|
||||
context_window INTEGER NOT NULL DEFAULT 4096,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
registered_at TEXT NOT NULL,
|
||||
active INTEGER NOT NULL DEFAULT 1,
|
||||
default_temperature REAL NOT NULL DEFAULT 0.7,
|
||||
max_tokens INTEGER NOT NULL DEFAULT 2048
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS agent_model_assignments (
|
||||
agent_id TEXT PRIMARY KEY,
|
||||
model_name TEXT NOT NULL,
|
||||
assigned_at TEXT NOT NULL,
|
||||
FOREIGN KEY (model_name) REFERENCES custom_models(name)
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
yield conn
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
@@ -106,7 +108,7 @@ class ModelRegistry:
|
||||
def _load_from_db(self) -> None:
|
||||
"""Bootstrap cache from SQLite."""
|
||||
try:
|
||||
with closing(_get_conn()) as conn:
|
||||
with _get_conn() as conn:
|
||||
for row in conn.execute("SELECT * FROM custom_models WHERE active = 1").fetchall():
|
||||
self._models[row["name"]] = CustomModel(
|
||||
name=row["name"],
|
||||
@@ -130,7 +132,7 @@ class ModelRegistry:
|
||||
def register(self, model: CustomModel) -> CustomModel:
|
||||
"""Register a new custom model."""
|
||||
with self._lock:
|
||||
with closing(_get_conn()) as conn:
|
||||
with _get_conn() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO custom_models
|
||||
@@ -161,7 +163,7 @@ class ModelRegistry:
|
||||
with self._lock:
|
||||
if name not in self._models:
|
||||
return False
|
||||
with closing(_get_conn()) as conn:
|
||||
with _get_conn() as conn:
|
||||
conn.execute("DELETE FROM custom_models WHERE name = ?", (name,))
|
||||
conn.execute("DELETE FROM agent_model_assignments WHERE model_name = ?", (name,))
|
||||
conn.commit()
|
||||
@@ -191,7 +193,7 @@ class ModelRegistry:
|
||||
return False
|
||||
with self._lock:
|
||||
model.active = active
|
||||
with closing(_get_conn()) as conn:
|
||||
with _get_conn() as conn:
|
||||
conn.execute(
|
||||
"UPDATE custom_models SET active = ? WHERE name = ?",
|
||||
(int(active), name),
|
||||
@@ -207,7 +209,7 @@ class ModelRegistry:
|
||||
return False
|
||||
with self._lock:
|
||||
now = datetime.now(UTC).isoformat()
|
||||
with closing(_get_conn()) as conn:
|
||||
with _get_conn() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO agent_model_assignments
|
||||
@@ -226,7 +228,7 @@ class ModelRegistry:
|
||||
with self._lock:
|
||||
if agent_id not in self._agent_assignments:
|
||||
return False
|
||||
with closing(_get_conn()) as conn:
|
||||
with _get_conn() as conn:
|
||||
conn.execute(
|
||||
"DELETE FROM agent_model_assignments WHERE agent_id = ?",
|
||||
(agent_id,),
|
||||
|
||||
@@ -16,6 +16,8 @@ import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing, contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
@@ -39,28 +41,31 @@ class Prediction:
|
||||
evaluated_at: str | None
|
||||
|
||||
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
@contextmanager
|
||||
def _get_conn() -> Generator[sqlite3.Connection, None, None]:
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS spark_predictions (
|
||||
id TEXT PRIMARY KEY,
|
||||
task_id TEXT NOT NULL,
|
||||
prediction_type TEXT NOT NULL,
|
||||
predicted_value TEXT NOT NULL,
|
||||
actual_value TEXT,
|
||||
accuracy REAL,
|
||||
created_at TEXT NOT NULL,
|
||||
evaluated_at TEXT
|
||||
with closing(sqlite3.connect(str(DB_PATH))) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS spark_predictions (
|
||||
id TEXT PRIMARY KEY,
|
||||
task_id TEXT NOT NULL,
|
||||
prediction_type TEXT NOT NULL,
|
||||
predicted_value TEXT NOT NULL,
|
||||
actual_value TEXT,
|
||||
accuracy REAL,
|
||||
created_at TEXT NOT NULL,
|
||||
evaluated_at TEXT
|
||||
)
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_task ON spark_predictions(task_id)")
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_pred_type ON spark_predictions(prediction_type)"
|
||||
)
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_task ON spark_predictions(task_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_type ON spark_predictions(prediction_type)")
|
||||
conn.commit()
|
||||
return conn
|
||||
conn.commit()
|
||||
yield conn
|
||||
|
||||
|
||||
# ── Prediction phase ────────────────────────────────────────────────────────
|
||||
@@ -119,17 +124,16 @@ def predict_task_outcome(
|
||||
# Store prediction
|
||||
pred_id = str(uuid.uuid4())
|
||||
now = datetime.now(UTC).isoformat()
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO spark_predictions
|
||||
(id, task_id, prediction_type, predicted_value, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(pred_id, task_id, "outcome", json.dumps(prediction), now),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
with _get_conn() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO spark_predictions
|
||||
(id, task_id, prediction_type, predicted_value, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(pred_id, task_id, "outcome", json.dumps(prediction), now),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
prediction["prediction_id"] = pred_id
|
||||
return prediction
|
||||
@@ -148,41 +152,39 @@ def evaluate_prediction(
|
||||
|
||||
Returns the evaluation result or None if no prediction exists.
|
||||
"""
|
||||
conn = _get_conn()
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT * FROM spark_predictions
|
||||
WHERE task_id = ? AND prediction_type = 'outcome' AND evaluated_at IS NULL
|
||||
ORDER BY created_at DESC LIMIT 1
|
||||
""",
|
||||
(task_id,),
|
||||
).fetchone()
|
||||
with _get_conn() as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT * FROM spark_predictions
|
||||
WHERE task_id = ? AND prediction_type = 'outcome' AND evaluated_at IS NULL
|
||||
ORDER BY created_at DESC LIMIT 1
|
||||
""",
|
||||
(task_id,),
|
||||
).fetchone()
|
||||
|
||||
if not row:
|
||||
conn.close()
|
||||
return None
|
||||
if not row:
|
||||
return None
|
||||
|
||||
predicted = json.loads(row["predicted_value"])
|
||||
actual = {
|
||||
"winner": actual_winner,
|
||||
"succeeded": task_succeeded,
|
||||
"winning_bid": winning_bid,
|
||||
}
|
||||
predicted = json.loads(row["predicted_value"])
|
||||
actual = {
|
||||
"winner": actual_winner,
|
||||
"succeeded": task_succeeded,
|
||||
"winning_bid": winning_bid,
|
||||
}
|
||||
|
||||
# Calculate accuracy
|
||||
accuracy = _compute_accuracy(predicted, actual)
|
||||
now = datetime.now(UTC).isoformat()
|
||||
# Calculate accuracy
|
||||
accuracy = _compute_accuracy(predicted, actual)
|
||||
now = datetime.now(UTC).isoformat()
|
||||
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE spark_predictions
|
||||
SET actual_value = ?, accuracy = ?, evaluated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(json.dumps(actual), accuracy, now, row["id"]),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE spark_predictions
|
||||
SET actual_value = ?, accuracy = ?, evaluated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(json.dumps(actual), accuracy, now, row["id"]),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
return {
|
||||
"prediction_id": row["id"],
|
||||
@@ -243,7 +245,6 @@ def get_predictions(
|
||||
limit: int = 50,
|
||||
) -> list[Prediction]:
|
||||
"""Query stored predictions."""
|
||||
conn = _get_conn()
|
||||
query = "SELECT * FROM spark_predictions WHERE 1=1"
|
||||
params: list = []
|
||||
|
||||
@@ -256,8 +257,8 @@ def get_predictions(
|
||||
query += " ORDER BY created_at DESC LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
conn.close()
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return [
|
||||
Prediction(
|
||||
id=r["id"],
|
||||
@@ -275,17 +276,16 @@ def get_predictions(
|
||||
|
||||
def get_accuracy_stats() -> dict:
|
||||
"""Return aggregate accuracy statistics for the EIDOS loop."""
|
||||
conn = _get_conn()
|
||||
row = conn.execute("""
|
||||
SELECT
|
||||
COUNT(*) AS total_predictions,
|
||||
COUNT(evaluated_at) AS evaluated,
|
||||
AVG(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS avg_accuracy,
|
||||
MIN(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS min_accuracy,
|
||||
MAX(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS max_accuracy
|
||||
FROM spark_predictions
|
||||
""").fetchone()
|
||||
conn.close()
|
||||
with _get_conn() as conn:
|
||||
row = conn.execute("""
|
||||
SELECT
|
||||
COUNT(*) AS total_predictions,
|
||||
COUNT(evaluated_at) AS evaluated,
|
||||
AVG(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS avg_accuracy,
|
||||
MIN(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS min_accuracy,
|
||||
MAX(CASE WHEN accuracy IS NOT NULL THEN accuracy END) AS max_accuracy
|
||||
FROM spark_predictions
|
||||
""").fetchone()
|
||||
|
||||
return {
|
||||
"total_predictions": row["total_predictions"] or 0,
|
||||
|
||||
@@ -13,6 +13,8 @@ spark_memories — consolidated insights extracted from event patterns
|
||||
import logging
|
||||
import sqlite3
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing, contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
@@ -55,42 +57,43 @@ class SparkMemory:
|
||||
expires_at: str | None
|
||||
|
||||
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
@contextmanager
|
||||
def _get_conn() -> Generator[sqlite3.Connection, None, None]:
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS spark_events (
|
||||
id TEXT PRIMARY KEY,
|
||||
event_type TEXT NOT NULL,
|
||||
agent_id TEXT,
|
||||
task_id TEXT,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
data TEXT NOT NULL DEFAULT '{}',
|
||||
importance REAL NOT NULL DEFAULT 0.5,
|
||||
created_at TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS spark_memories (
|
||||
id TEXT PRIMARY KEY,
|
||||
memory_type TEXT NOT NULL,
|
||||
subject TEXT NOT NULL DEFAULT 'system',
|
||||
content TEXT NOT NULL,
|
||||
confidence REAL NOT NULL DEFAULT 0.5,
|
||||
source_events INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL,
|
||||
expires_at TEXT
|
||||
)
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_type ON spark_events(event_type)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_agent ON spark_events(agent_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_task ON spark_events(task_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_subject ON spark_memories(subject)")
|
||||
conn.commit()
|
||||
return conn
|
||||
with closing(sqlite3.connect(str(DB_PATH))) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS spark_events (
|
||||
id TEXT PRIMARY KEY,
|
||||
event_type TEXT NOT NULL,
|
||||
agent_id TEXT,
|
||||
task_id TEXT,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
data TEXT NOT NULL DEFAULT '{}',
|
||||
importance REAL NOT NULL DEFAULT 0.5,
|
||||
created_at TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS spark_memories (
|
||||
id TEXT PRIMARY KEY,
|
||||
memory_type TEXT NOT NULL,
|
||||
subject TEXT NOT NULL DEFAULT 'system',
|
||||
content TEXT NOT NULL,
|
||||
confidence REAL NOT NULL DEFAULT 0.5,
|
||||
source_events INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL,
|
||||
expires_at TEXT
|
||||
)
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_type ON spark_events(event_type)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_agent ON spark_events(agent_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_task ON spark_events(task_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_subject ON spark_memories(subject)")
|
||||
conn.commit()
|
||||
yield conn
|
||||
|
||||
|
||||
# ── Importance scoring ──────────────────────────────────────────────────────
|
||||
@@ -149,17 +152,16 @@ def record_event(
|
||||
parsed = {}
|
||||
importance = score_importance(event_type, parsed)
|
||||
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO spark_events
|
||||
(id, event_type, agent_id, task_id, description, data, importance, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(event_id, event_type, agent_id, task_id, description, data, importance, now),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
with _get_conn() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO spark_events
|
||||
(id, event_type, agent_id, task_id, description, data, importance, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(event_id, event_type, agent_id, task_id, description, data, importance, now),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# Bridge to unified event log so all events are queryable from one place
|
||||
try:
|
||||
@@ -188,7 +190,6 @@ def get_events(
|
||||
min_importance: float = 0.0,
|
||||
) -> list[SparkEvent]:
|
||||
"""Query events with optional filters."""
|
||||
conn = _get_conn()
|
||||
query = "SELECT * FROM spark_events WHERE importance >= ?"
|
||||
params: list = [min_importance]
|
||||
|
||||
@@ -205,8 +206,8 @@ def get_events(
|
||||
query += " ORDER BY created_at DESC LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
conn.close()
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return [
|
||||
SparkEvent(
|
||||
id=r["id"],
|
||||
@@ -224,15 +225,14 @@ def get_events(
|
||||
|
||||
def count_events(event_type: str | None = None) -> int:
|
||||
"""Count events, optionally filtered by type."""
|
||||
conn = _get_conn()
|
||||
if event_type:
|
||||
row = conn.execute(
|
||||
"SELECT COUNT(*) FROM spark_events WHERE event_type = ?",
|
||||
(event_type,),
|
||||
).fetchone()
|
||||
else:
|
||||
row = conn.execute("SELECT COUNT(*) FROM spark_events").fetchone()
|
||||
conn.close()
|
||||
with _get_conn() as conn:
|
||||
if event_type:
|
||||
row = conn.execute(
|
||||
"SELECT COUNT(*) FROM spark_events WHERE event_type = ?",
|
||||
(event_type,),
|
||||
).fetchone()
|
||||
else:
|
||||
row = conn.execute("SELECT COUNT(*) FROM spark_events").fetchone()
|
||||
return row[0]
|
||||
|
||||
|
||||
@@ -250,17 +250,16 @@ def store_memory(
|
||||
"""Store a consolidated memory. Returns the memory id."""
|
||||
mem_id = str(uuid.uuid4())
|
||||
now = datetime.now(UTC).isoformat()
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO spark_memories
|
||||
(id, memory_type, subject, content, confidence, source_events, created_at, expires_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(mem_id, memory_type, subject, content, confidence, source_events, now, expires_at),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
with _get_conn() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO spark_memories
|
||||
(id, memory_type, subject, content, confidence, source_events, created_at, expires_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(mem_id, memory_type, subject, content, confidence, source_events, now, expires_at),
|
||||
)
|
||||
conn.commit()
|
||||
return mem_id
|
||||
|
||||
|
||||
@@ -271,7 +270,6 @@ def get_memories(
|
||||
limit: int = 50,
|
||||
) -> list[SparkMemory]:
|
||||
"""Query memories with optional filters."""
|
||||
conn = _get_conn()
|
||||
query = "SELECT * FROM spark_memories WHERE confidence >= ?"
|
||||
params: list = [min_confidence]
|
||||
|
||||
@@ -285,8 +283,8 @@ def get_memories(
|
||||
query += " ORDER BY created_at DESC LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
conn.close()
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return [
|
||||
SparkMemory(
|
||||
id=r["id"],
|
||||
@@ -304,13 +302,12 @@ def get_memories(
|
||||
|
||||
def count_memories(memory_type: str | None = None) -> int:
|
||||
"""Count memories, optionally filtered by type."""
|
||||
conn = _get_conn()
|
||||
if memory_type:
|
||||
row = conn.execute(
|
||||
"SELECT COUNT(*) FROM spark_memories WHERE memory_type = ?",
|
||||
(memory_type,),
|
||||
).fetchone()
|
||||
else:
|
||||
row = conn.execute("SELECT COUNT(*) FROM spark_memories").fetchone()
|
||||
conn.close()
|
||||
with _get_conn() as conn:
|
||||
if memory_type:
|
||||
row = conn.execute(
|
||||
"SELECT COUNT(*) FROM spark_memories WHERE memory_type = ?",
|
||||
(memory_type,),
|
||||
).fetchone()
|
||||
else:
|
||||
row = conn.execute("SELECT COUNT(*) FROM spark_memories").fetchone()
|
||||
return row[0]
|
||||
|
||||
@@ -13,7 +13,8 @@ Default is always True. The owner changes this intentionally.
|
||||
|
||||
import sqlite3
|
||||
import uuid
|
||||
from contextlib import closing
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing, contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
@@ -44,23 +45,24 @@ class ApprovalItem:
|
||||
status: str # "pending" | "approved" | "rejected"
|
||||
|
||||
|
||||
def _get_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection:
|
||||
@contextmanager
|
||||
def _get_conn(db_path: Path = _DEFAULT_DB) -> Generator[sqlite3.Connection, None, None]:
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS approval_items (
|
||||
id TEXT PRIMARY KEY,
|
||||
title TEXT NOT NULL,
|
||||
description TEXT NOT NULL,
|
||||
proposed_action TEXT NOT NULL,
|
||||
impact TEXT NOT NULL DEFAULT 'low',
|
||||
created_at TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending'
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
return conn
|
||||
with closing(sqlite3.connect(str(db_path))) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS approval_items (
|
||||
id TEXT PRIMARY KEY,
|
||||
title TEXT NOT NULL,
|
||||
description TEXT NOT NULL,
|
||||
proposed_action TEXT NOT NULL,
|
||||
impact TEXT NOT NULL DEFAULT 'low',
|
||||
created_at TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending'
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
yield conn
|
||||
|
||||
|
||||
def _row_to_item(row: sqlite3.Row) -> ApprovalItem:
|
||||
@@ -97,7 +99,7 @@ def create_item(
|
||||
created_at=datetime.now(UTC),
|
||||
status="pending",
|
||||
)
|
||||
with closing(_get_conn(db_path)) as conn:
|
||||
with _get_conn(db_path) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO approval_items
|
||||
@@ -120,7 +122,7 @@ def create_item(
|
||||
|
||||
def list_pending(db_path: Path = _DEFAULT_DB) -> list[ApprovalItem]:
|
||||
"""Return all pending approval items, newest first."""
|
||||
with closing(_get_conn(db_path)) as conn:
|
||||
with _get_conn(db_path) as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM approval_items WHERE status = 'pending' ORDER BY created_at DESC"
|
||||
).fetchall()
|
||||
@@ -129,20 +131,20 @@ def list_pending(db_path: Path = _DEFAULT_DB) -> list[ApprovalItem]:
|
||||
|
||||
def list_all(db_path: Path = _DEFAULT_DB) -> list[ApprovalItem]:
|
||||
"""Return all approval items regardless of status, newest first."""
|
||||
with closing(_get_conn(db_path)) as conn:
|
||||
with _get_conn(db_path) as conn:
|
||||
rows = conn.execute("SELECT * FROM approval_items ORDER BY created_at DESC").fetchall()
|
||||
return [_row_to_item(r) for r in rows]
|
||||
|
||||
|
||||
def get_item(item_id: str, db_path: Path = _DEFAULT_DB) -> ApprovalItem | None:
|
||||
with closing(_get_conn(db_path)) as conn:
|
||||
with _get_conn(db_path) as conn:
|
||||
row = conn.execute("SELECT * FROM approval_items WHERE id = ?", (item_id,)).fetchone()
|
||||
return _row_to_item(row) if row else None
|
||||
|
||||
|
||||
def approve(item_id: str, db_path: Path = _DEFAULT_DB) -> ApprovalItem | None:
|
||||
"""Mark an approval item as approved."""
|
||||
with closing(_get_conn(db_path)) as conn:
|
||||
with _get_conn(db_path) as conn:
|
||||
conn.execute("UPDATE approval_items SET status = 'approved' WHERE id = ?", (item_id,))
|
||||
conn.commit()
|
||||
return get_item(item_id, db_path)
|
||||
@@ -150,7 +152,7 @@ def approve(item_id: str, db_path: Path = _DEFAULT_DB) -> ApprovalItem | None:
|
||||
|
||||
def reject(item_id: str, db_path: Path = _DEFAULT_DB) -> ApprovalItem | None:
|
||||
"""Mark an approval item as rejected."""
|
||||
with closing(_get_conn(db_path)) as conn:
|
||||
with _get_conn(db_path) as conn:
|
||||
conn.execute("UPDATE approval_items SET status = 'rejected' WHERE id = ?", (item_id,))
|
||||
conn.commit()
|
||||
return get_item(item_id, db_path)
|
||||
@@ -159,7 +161,7 @@ def reject(item_id: str, db_path: Path = _DEFAULT_DB) -> ApprovalItem | None:
|
||||
def expire_old(db_path: Path = _DEFAULT_DB) -> int:
|
||||
"""Auto-expire pending items older than EXPIRY_DAYS. Returns count removed."""
|
||||
cutoff = (datetime.now(UTC) - timedelta(days=_EXPIRY_DAYS)).isoformat()
|
||||
with closing(_get_conn(db_path)) as conn:
|
||||
with _get_conn(db_path) as conn:
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM approval_items WHERE status = 'pending' AND created_at < ?",
|
||||
(cutoff,),
|
||||
|
||||
@@ -10,7 +10,8 @@ regenerates the briefing every 6 hours.
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
from contextlib import closing
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing, contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
@@ -57,25 +58,26 @@ class Briefing:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_cache_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection:
|
||||
@contextmanager
|
||||
def _get_cache_conn(db_path: Path = _DEFAULT_DB) -> Generator[sqlite3.Connection, None, None]:
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS briefings (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
generated_at TEXT NOT NULL,
|
||||
period_start TEXT NOT NULL,
|
||||
period_end TEXT NOT NULL,
|
||||
summary TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
return conn
|
||||
with closing(sqlite3.connect(str(db_path))) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS briefings (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
generated_at TEXT NOT NULL,
|
||||
period_start TEXT NOT NULL,
|
||||
period_end TEXT NOT NULL,
|
||||
summary TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
yield conn
|
||||
|
||||
|
||||
def _save_briefing(briefing: Briefing, db_path: Path = _DEFAULT_DB) -> None:
|
||||
with closing(_get_cache_conn(db_path)) as conn:
|
||||
with _get_cache_conn(db_path) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO briefings (generated_at, period_start, period_end, summary)
|
||||
@@ -93,7 +95,7 @@ def _save_briefing(briefing: Briefing, db_path: Path = _DEFAULT_DB) -> None:
|
||||
|
||||
def _load_latest(db_path: Path = _DEFAULT_DB) -> Briefing | None:
|
||||
"""Load the most-recently cached briefing, or None if there is none."""
|
||||
with closing(_get_cache_conn(db_path)) as conn:
|
||||
with _get_cache_conn(db_path) as conn:
|
||||
row = conn.execute("SELECT * FROM briefings ORDER BY generated_at DESC LIMIT 1").fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
@@ -11,6 +11,8 @@ All three tables live in ``data/memory.db``. Existing APIs in
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing, contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -18,15 +20,16 @@ logger = logging.getLogger(__name__)
|
||||
DB_PATH = Path(__file__).parent.parent.parent.parent / "data" / "memory.db"
|
||||
|
||||
|
||||
def get_connection() -> sqlite3.Connection:
|
||||
@contextmanager
|
||||
def get_connection() -> Generator[sqlite3.Connection, None, None]:
|
||||
"""Open (and lazily create) the unified memory database."""
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
_ensure_schema(conn)
|
||||
return conn
|
||||
with closing(sqlite3.connect(str(DB_PATH))) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
_ensure_schema(conn)
|
||||
yield conn
|
||||
|
||||
|
||||
def _ensure_schema(conn: sqlite3.Connection) -> None:
|
||||
|
||||
@@ -8,6 +8,8 @@ import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
|
||||
@@ -54,11 +56,13 @@ class MemoryEntry:
|
||||
relevance_score: float | None = None # Set during search
|
||||
|
||||
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
@contextmanager
|
||||
def _get_conn() -> Generator[sqlite3.Connection, None, None]:
|
||||
"""Get database connection to unified memory.db."""
|
||||
from timmy.memory.unified import get_connection
|
||||
|
||||
return get_connection()
|
||||
with get_connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
def store_memory(
|
||||
@@ -101,29 +105,28 @@ def store_memory(
|
||||
embedding=embedding,
|
||||
)
|
||||
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO episodes
|
||||
(id, content, source, context_type, agent_id, task_id, session_id,
|
||||
metadata, embedding, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
entry.id,
|
||||
entry.content,
|
||||
entry.source,
|
||||
entry.context_type,
|
||||
entry.agent_id,
|
||||
entry.task_id,
|
||||
entry.session_id,
|
||||
json.dumps(metadata) if metadata else None,
|
||||
json.dumps(embedding) if embedding else None,
|
||||
entry.timestamp,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
with _get_conn() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO episodes
|
||||
(id, content, source, context_type, agent_id, task_id, session_id,
|
||||
metadata, embedding, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
entry.id,
|
||||
entry.content,
|
||||
entry.source,
|
||||
entry.context_type,
|
||||
entry.agent_id,
|
||||
entry.task_id,
|
||||
entry.session_id,
|
||||
json.dumps(metadata) if metadata else None,
|
||||
json.dumps(embedding) if embedding else None,
|
||||
entry.timestamp,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
return entry
|
||||
|
||||
@@ -151,8 +154,6 @@ def search_memories(
|
||||
"""
|
||||
query_embedding = _compute_embedding(query)
|
||||
|
||||
conn = _get_conn()
|
||||
|
||||
# Build query with filters
|
||||
conditions = []
|
||||
params = []
|
||||
@@ -179,8 +180,8 @@ def search_memories(
|
||||
"""
|
||||
params.append(limit * 3) # Get more candidates for ranking
|
||||
|
||||
rows = conn.execute(query_sql, params).fetchall()
|
||||
conn.close()
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute(query_sql, params).fetchall()
|
||||
|
||||
# Compute similarity scores
|
||||
results = []
|
||||
@@ -275,58 +276,54 @@ def recall_personal_facts(agent_id: str | None = None) -> list[str]:
|
||||
Returns:
|
||||
List of fact strings
|
||||
"""
|
||||
conn = _get_conn()
|
||||
with _get_conn() as conn:
|
||||
if agent_id:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT content FROM episodes
|
||||
WHERE context_type = 'fact' AND agent_id = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 100
|
||||
""",
|
||||
(agent_id,),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT content FROM episodes
|
||||
WHERE context_type = 'fact'
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 100
|
||||
""",
|
||||
).fetchall()
|
||||
|
||||
if agent_id:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT content FROM episodes
|
||||
WHERE context_type = 'fact' AND agent_id = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 100
|
||||
""",
|
||||
(agent_id,),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT content FROM episodes
|
||||
WHERE context_type = 'fact'
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 100
|
||||
""",
|
||||
).fetchall()
|
||||
|
||||
conn.close()
|
||||
return [r["content"] for r in rows]
|
||||
|
||||
|
||||
def recall_personal_facts_with_ids(agent_id: str | None = None) -> list[dict]:
|
||||
"""Recall personal facts with their IDs for edit/delete operations."""
|
||||
conn = _get_conn()
|
||||
if agent_id:
|
||||
rows = conn.execute(
|
||||
"SELECT id, content FROM episodes WHERE context_type = 'fact' AND agent_id = ? ORDER BY timestamp DESC LIMIT 100",
|
||||
(agent_id,),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"SELECT id, content FROM episodes WHERE context_type = 'fact' ORDER BY timestamp DESC LIMIT 100",
|
||||
).fetchall()
|
||||
conn.close()
|
||||
with _get_conn() as conn:
|
||||
if agent_id:
|
||||
rows = conn.execute(
|
||||
"SELECT id, content FROM episodes WHERE context_type = 'fact' AND agent_id = ? ORDER BY timestamp DESC LIMIT 100",
|
||||
(agent_id,),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"SELECT id, content FROM episodes WHERE context_type = 'fact' ORDER BY timestamp DESC LIMIT 100",
|
||||
).fetchall()
|
||||
return [{"id": r["id"], "content": r["content"]} for r in rows]
|
||||
|
||||
|
||||
def update_personal_fact(memory_id: str, new_content: str) -> bool:
|
||||
"""Update a personal fact's content."""
|
||||
conn = _get_conn()
|
||||
cursor = conn.execute(
|
||||
"UPDATE episodes SET content = ? WHERE id = ? AND context_type = 'fact'",
|
||||
(new_content, memory_id),
|
||||
)
|
||||
conn.commit()
|
||||
updated = cursor.rowcount > 0
|
||||
conn.close()
|
||||
with _get_conn() as conn:
|
||||
cursor = conn.execute(
|
||||
"UPDATE episodes SET content = ? WHERE id = ? AND context_type = 'fact'",
|
||||
(new_content, memory_id),
|
||||
)
|
||||
conn.commit()
|
||||
updated = cursor.rowcount > 0
|
||||
return updated
|
||||
|
||||
|
||||
@@ -355,14 +352,13 @@ def delete_memory(memory_id: str) -> bool:
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
conn = _get_conn()
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM episodes WHERE id = ?",
|
||||
(memory_id,),
|
||||
)
|
||||
conn.commit()
|
||||
deleted = cursor.rowcount > 0
|
||||
conn.close()
|
||||
with _get_conn() as conn:
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM episodes WHERE id = ?",
|
||||
(memory_id,),
|
||||
)
|
||||
conn.commit()
|
||||
deleted = cursor.rowcount > 0
|
||||
return deleted
|
||||
|
||||
|
||||
@@ -372,22 +368,19 @@ def get_memory_stats() -> dict:
|
||||
Returns:
|
||||
Dict with counts by type, total entries, etc.
|
||||
"""
|
||||
conn = _get_conn()
|
||||
with _get_conn() as conn:
|
||||
total = conn.execute("SELECT COUNT(*) as count FROM episodes").fetchone()["count"]
|
||||
|
||||
total = conn.execute("SELECT COUNT(*) as count FROM episodes").fetchone()["count"]
|
||||
by_type = {}
|
||||
rows = conn.execute(
|
||||
"SELECT context_type, COUNT(*) as count FROM episodes GROUP BY context_type"
|
||||
).fetchall()
|
||||
for row in rows:
|
||||
by_type[row["context_type"]] = row["count"]
|
||||
|
||||
by_type = {}
|
||||
rows = conn.execute(
|
||||
"SELECT context_type, COUNT(*) as count FROM episodes GROUP BY context_type"
|
||||
).fetchall()
|
||||
for row in rows:
|
||||
by_type[row["context_type"]] = row["count"]
|
||||
|
||||
with_embeddings = conn.execute(
|
||||
"SELECT COUNT(*) as count FROM episodes WHERE embedding IS NOT NULL"
|
||||
).fetchone()["count"]
|
||||
|
||||
conn.close()
|
||||
with_embeddings = conn.execute(
|
||||
"SELECT COUNT(*) as count FROM episodes WHERE embedding IS NOT NULL"
|
||||
).fetchone()["count"]
|
||||
|
||||
return {
|
||||
"total_entries": total,
|
||||
@@ -411,24 +404,22 @@ def prune_memories(older_than_days: int = 90, keep_facts: bool = True) -> int:
|
||||
|
||||
cutoff = (datetime.now(UTC) - timedelta(days=older_than_days)).isoformat()
|
||||
|
||||
conn = _get_conn()
|
||||
with _get_conn() as conn:
|
||||
if keep_facts:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
DELETE FROM episodes
|
||||
WHERE timestamp < ? AND context_type != 'fact'
|
||||
""",
|
||||
(cutoff,),
|
||||
)
|
||||
else:
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM episodes WHERE timestamp < ?",
|
||||
(cutoff,),
|
||||
)
|
||||
|
||||
if keep_facts:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
DELETE FROM episodes
|
||||
WHERE timestamp < ? AND context_type != 'fact'
|
||||
""",
|
||||
(cutoff,),
|
||||
)
|
||||
else:
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM episodes WHERE timestamp < ?",
|
||||
(cutoff,),
|
||||
)
|
||||
|
||||
deleted = cursor.rowcount
|
||||
conn.commit()
|
||||
conn.close()
|
||||
deleted = cursor.rowcount
|
||||
conn.commit()
|
||||
|
||||
return deleted
|
||||
|
||||
@@ -21,7 +21,8 @@ import logging
|
||||
import random
|
||||
import sqlite3
|
||||
import uuid
|
||||
from contextlib import closing
|
||||
from collections.abc import Generator
|
||||
from contextlib import closing, contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from difflib import SequenceMatcher
|
||||
@@ -169,23 +170,24 @@ class Thought:
|
||||
created_at: str
|
||||
|
||||
|
||||
def _get_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection:
|
||||
@contextmanager
|
||||
def _get_conn(db_path: Path = _DEFAULT_DB) -> Generator[sqlite3.Connection, None, None]:
|
||||
"""Get a SQLite connection with the thoughts table created."""
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS thoughts (
|
||||
id TEXT PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
seed_type TEXT NOT NULL,
|
||||
parent_id TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_thoughts_time ON thoughts(created_at)")
|
||||
conn.commit()
|
||||
return conn
|
||||
with closing(sqlite3.connect(str(db_path))) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS thoughts (
|
||||
id TEXT PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
seed_type TEXT NOT NULL,
|
||||
parent_id TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_thoughts_time ON thoughts(created_at)")
|
||||
conn.commit()
|
||||
yield conn
|
||||
|
||||
|
||||
def _row_to_thought(row: sqlite3.Row) -> Thought:
|
||||
@@ -321,7 +323,7 @@ class ThinkingEngine:
|
||||
|
||||
def get_recent_thoughts(self, limit: int = 20) -> list[Thought]:
|
||||
"""Retrieve the most recent thoughts."""
|
||||
with closing(_get_conn(self._db_path)) as conn:
|
||||
with _get_conn(self._db_path) as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM thoughts ORDER BY created_at DESC LIMIT ?",
|
||||
(limit,),
|
||||
@@ -330,7 +332,7 @@ class ThinkingEngine:
|
||||
|
||||
def get_thought(self, thought_id: str) -> Thought | None:
|
||||
"""Retrieve a single thought by ID."""
|
||||
with closing(_get_conn(self._db_path)) as conn:
|
||||
with _get_conn(self._db_path) as conn:
|
||||
row = conn.execute("SELECT * FROM thoughts WHERE id = ?", (thought_id,)).fetchone()
|
||||
return _row_to_thought(row) if row else None
|
||||
|
||||
@@ -342,7 +344,7 @@ class ThinkingEngine:
|
||||
chain = []
|
||||
current_id: str | None = thought_id
|
||||
|
||||
with closing(_get_conn(self._db_path)) as conn:
|
||||
with _get_conn(self._db_path) as conn:
|
||||
for _ in range(max_depth):
|
||||
if not current_id:
|
||||
break
|
||||
@@ -357,7 +359,7 @@ class ThinkingEngine:
|
||||
|
||||
def count_thoughts(self) -> int:
|
||||
"""Return total number of stored thoughts."""
|
||||
with closing(_get_conn(self._db_path)) as conn:
|
||||
with _get_conn(self._db_path) as conn:
|
||||
count = conn.execute("SELECT COUNT(*) as c FROM thoughts").fetchone()["c"]
|
||||
return count
|
||||
|
||||
@@ -366,7 +368,7 @@ class ThinkingEngine:
|
||||
|
||||
Returns the number of deleted rows.
|
||||
"""
|
||||
with closing(_get_conn(self._db_path)) as conn:
|
||||
with _get_conn(self._db_path) as conn:
|
||||
try:
|
||||
total = conn.execute("SELECT COUNT(*) as c FROM thoughts").fetchone()["c"]
|
||||
if total <= keep_min:
|
||||
@@ -603,7 +605,7 @@ class ThinkingEngine:
|
||||
# Thought count today (cheap DB query)
|
||||
try:
|
||||
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
with closing(_get_conn(self._db_path)) as conn:
|
||||
with _get_conn(self._db_path) as conn:
|
||||
count = conn.execute(
|
||||
"SELECT COUNT(*) as c FROM thoughts WHERE created_at >= ?",
|
||||
(today_start.isoformat(),),
|
||||
@@ -960,7 +962,7 @@ class ThinkingEngine:
|
||||
created_at=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
|
||||
with closing(_get_conn(self._db_path)) as conn:
|
||||
with _get_conn(self._db_path) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO thoughts (id, content, seed_type, parent_id, created_at)
|
||||
|
||||
@@ -216,21 +216,15 @@ class TestWALMode:
|
||||
with patch("infrastructure.models.registry.DB_PATH", db):
|
||||
from infrastructure.models.registry import _get_conn
|
||||
|
||||
conn = _get_conn()
|
||||
try:
|
||||
with _get_conn() as conn:
|
||||
mode = conn.execute("PRAGMA journal_mode").fetchone()[0]
|
||||
assert mode == "wal"
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def test_registry_db_busy_timeout(self, tmp_path):
|
||||
db = tmp_path / "wal_test.db"
|
||||
with patch("infrastructure.models.registry.DB_PATH", db):
|
||||
from infrastructure.models.registry import _get_conn
|
||||
|
||||
conn = _get_conn()
|
||||
try:
|
||||
with _get_conn() as conn:
|
||||
timeout = conn.execute("PRAGMA busy_timeout").fetchone()[0]
|
||||
assert timeout == 5000
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
@@ -481,29 +481,20 @@ class TestWALMode:
|
||||
def test_spark_memory_uses_wal(self):
|
||||
from spark.memory import _get_conn
|
||||
|
||||
conn = _get_conn()
|
||||
try:
|
||||
with _get_conn() as conn:
|
||||
mode = conn.execute("PRAGMA journal_mode").fetchone()[0]
|
||||
assert mode == "wal", f"Expected WAL mode, got {mode}"
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def test_spark_eidos_uses_wal(self):
|
||||
from spark.eidos import _get_conn
|
||||
|
||||
conn = _get_conn()
|
||||
try:
|
||||
with _get_conn() as conn:
|
||||
mode = conn.execute("PRAGMA journal_mode").fetchone()[0]
|
||||
assert mode == "wal", f"Expected WAL mode, got {mode}"
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def test_spark_memory_busy_timeout(self):
|
||||
from spark.memory import _get_conn
|
||||
|
||||
conn = _get_conn()
|
||||
try:
|
||||
with _get_conn() as conn:
|
||||
timeout = conn.execute("PRAGMA busy_timeout").fetchone()[0]
|
||||
assert timeout == 5000
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
@@ -142,14 +142,13 @@ class TestExpireOld:
|
||||
# Create item and manually backdate it
|
||||
item = create_item("Old", "D", "A", db_path=db_path)
|
||||
|
||||
conn = _get_conn(db_path)
|
||||
old_date = (datetime.now(UTC) - timedelta(days=30)).isoformat()
|
||||
conn.execute(
|
||||
"UPDATE approval_items SET created_at = ? WHERE id = ?",
|
||||
(old_date, item.id),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
with _get_conn(db_path) as conn:
|
||||
conn.execute(
|
||||
"UPDATE approval_items SET created_at = ? WHERE id = ?",
|
||||
(old_date, item.id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
count = expire_old(db_path)
|
||||
assert count == 1
|
||||
@@ -169,14 +168,13 @@ class TestExpireOld:
|
||||
approve(item.id, db_path)
|
||||
|
||||
# Backdate it
|
||||
conn = _get_conn(db_path)
|
||||
old_date = (datetime.now(UTC) - timedelta(days=30)).isoformat()
|
||||
conn.execute(
|
||||
"UPDATE approval_items SET created_at = ? WHERE id = ?",
|
||||
(old_date, item.id),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
with _get_conn(db_path) as conn:
|
||||
conn.execute(
|
||||
"UPDATE approval_items SET created_at = ? WHERE id = ?",
|
||||
(old_date, item.id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
count = expire_old(db_path)
|
||||
assert count == 0 # approved items not expired
|
||||
|
||||
Reference in New Issue
Block a user