159 lines
5.6 KiB
Python
159 lines
5.6 KiB
Python
"""Persistent chat message store backed by SQLite.
|
|
|
|
Provides the same API as the original in-memory MessageLog so all callers
|
|
(dashboard routes, chat_api, thinking, briefing) work without changes.
|
|
|
|
Data lives in ``data/chat.db`` — survives server restarts.
|
|
A configurable retention policy (default 500 messages) keeps the DB lean.
|
|
"""
|
|
|
|
import sqlite3
|
|
import threading
|
|
from collections.abc import Generator
|
|
from contextlib import closing, contextmanager
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
# ── Data dir — resolved relative to repo root (three levels up from this file) ──
|
|
_REPO_ROOT = Path(__file__).resolve().parents[3]
|
|
DB_PATH: Path = _REPO_ROOT / "data" / "chat.db"
|
|
|
|
# Maximum messages to retain (oldest pruned on append)
|
|
MAX_MESSAGES: int = 500
|
|
|
|
|
|
@dataclass
|
|
class Message:
|
|
role: str # "user" | "agent" | "error"
|
|
content: str
|
|
timestamp: str
|
|
source: str = "browser" # "browser" | "api" | "telegram" | "discord" | "system"
|
|
|
|
|
|
@contextmanager
|
|
def _get_conn(db_path: Path | None = None) -> Generator[sqlite3.Connection, None, None]:
|
|
"""Open (or create) the chat database and ensure schema exists."""
|
|
path = db_path or DB_PATH
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
with closing(sqlite3.connect(str(path), check_same_thread=False)) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
conn.execute("PRAGMA journal_mode=WAL")
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS chat_messages (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
role TEXT NOT NULL,
|
|
content TEXT NOT NULL,
|
|
timestamp TEXT NOT NULL,
|
|
source TEXT NOT NULL DEFAULT 'browser'
|
|
)
|
|
""")
|
|
conn.commit()
|
|
yield conn
|
|
|
|
|
|
class MessageLog:
|
|
"""SQLite-backed chat history — drop-in replacement for the old in-memory list."""
|
|
|
|
def __init__(self, db_path: Path | None = None) -> None:
|
|
self._db_path = db_path or DB_PATH
|
|
self._lock = threading.Lock()
|
|
self._conn: sqlite3.Connection | None = None
|
|
|
|
# Lazy connection — opened on first use, not at import time.
|
|
@contextmanager
|
|
def _get_conn(self) -> Generator[sqlite3.Connection, None, None]:
|
|
path = self._db_path or DB_PATH
|
|
with closing(sqlite3.connect(str(path), check_same_thread=False)) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
yield conn
|
|
if self._conn is None:
|
|
# Open a persistent connection for the class instance
|
|
path = self._db_path or DB_PATH
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
conn = sqlite3.connect(str(path), check_same_thread=False)
|
|
conn.row_factory = sqlite3.Row
|
|
conn.execute("PRAGMA journal_mode=WAL")
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS chat_messages (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
role TEXT NOT NULL,
|
|
content TEXT NOT NULL,
|
|
timestamp TEXT NOT NULL,
|
|
source TEXT NOT NULL DEFAULT 'browser'
|
|
)
|
|
""")
|
|
conn.commit()
|
|
self._conn = conn
|
|
return self._conn
|
|
|
|
def append(self, role: str, content: str, timestamp: str, source: str = "browser") -> None:
|
|
with self._lock:
|
|
conn = self._ensure_conn()
|
|
conn.execute(
|
|
"INSERT INTO chat_messages (role, content, timestamp, source) VALUES (?, ?, ?, ?)",
|
|
(role, content, timestamp, source),
|
|
)
|
|
conn.commit()
|
|
self._prune(conn)
|
|
|
|
def all(self) -> list[Message]:
|
|
with self._lock:
|
|
conn = self._ensure_conn()
|
|
rows = conn.execute(
|
|
"SELECT role, content, timestamp, source FROM chat_messages ORDER BY id"
|
|
).fetchall()
|
|
return [
|
|
Message(
|
|
role=r["role"], content=r["content"], timestamp=r["timestamp"], source=r["source"]
|
|
)
|
|
for r in rows
|
|
]
|
|
|
|
def recent(self, limit: int = 50) -> list[Message]:
|
|
"""Return the *limit* most recent messages (oldest-first)."""
|
|
with self._lock:
|
|
conn = self._ensure_conn()
|
|
rows = conn.execute(
|
|
"SELECT role, content, timestamp, source FROM chat_messages "
|
|
"ORDER BY id DESC LIMIT ?",
|
|
(limit,),
|
|
).fetchall()
|
|
return [
|
|
Message(
|
|
role=r["role"], content=r["content"], timestamp=r["timestamp"], source=r["source"]
|
|
)
|
|
for r in reversed(rows)
|
|
]
|
|
|
|
def clear(self) -> None:
|
|
with self._lock:
|
|
conn = self._ensure_conn()
|
|
conn.execute("DELETE FROM chat_messages")
|
|
conn.commit()
|
|
|
|
def _prune(self, conn: sqlite3.Connection) -> None:
|
|
"""Keep at most MAX_MESSAGES rows, deleting the oldest."""
|
|
count = conn.execute("SELECT COUNT(*) FROM chat_messages").fetchone()[0]
|
|
if count > MAX_MESSAGES:
|
|
excess = count - MAX_MESSAGES
|
|
conn.execute(
|
|
"DELETE FROM chat_messages WHERE id IN "
|
|
"(SELECT id FROM chat_messages ORDER BY id LIMIT ?)",
|
|
(excess,),
|
|
)
|
|
conn.commit()
|
|
|
|
def close(self) -> None:
|
|
if self._conn is not None:
|
|
self._conn.close()
|
|
self._conn = None
|
|
|
|
def __len__(self) -> int:
|
|
with self._lock:
|
|
conn = self._ensure_conn()
|
|
return conn.execute("SELECT COUNT(*) FROM chat_messages").fetchone()[0]
|
|
|
|
|
|
# Module-level singleton shared across the app
|
|
message_log = MessageLog()
|