Merge pull request '[loop-cycle-5] Persist chat history in SQLite (#46)' (#63) from fix/issue-46-chat-persistence into main

This commit is contained in:
rockachopa
2026-03-14 16:10:55 -04:00
3 changed files with 254 additions and 15 deletions

View File

@@ -1,4 +1,23 @@
"""Persistent chat message store backed by SQLite.
Provides the same API as the original in-memory MessageLog so all callers
(dashboard routes, chat_api, thinking, briefing) work without changes.
Data lives in ``data/chat.db`` — survives server restarts.
A configurable retention policy (default 500 messages) keeps the DB lean.
"""
import sqlite3
import threading
from dataclasses import dataclass
from pathlib import Path
# ── Data dir — resolved relative to repo root (two levels up from this file) ──
_REPO_ROOT = Path(__file__).resolve().parents[2]
DB_PATH: Path = _REPO_ROOT / "data" / "chat.db"
# Maximum messages to retain (oldest pruned on append)
MAX_MESSAGES: int = 500
@dataclass
@@ -9,25 +28,106 @@ class Message:
source: str = "browser" # "browser" | "api" | "telegram" | "discord" | "system"
class MessageLog:
"""In-memory chat history for the lifetime of the server process."""
def _get_conn(db_path: Path | None = None) -> sqlite3.Connection:
"""Open (or create) the chat database and ensure schema exists."""
path = db_path or DB_PATH
path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(path), check_same_thread=False)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("""
CREATE TABLE IF NOT EXISTS chat_messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
role TEXT NOT NULL,
content TEXT NOT NULL,
timestamp TEXT NOT NULL,
source TEXT NOT NULL DEFAULT 'browser'
)
""")
conn.commit()
return conn
def __init__(self) -> None:
self._entries: list[Message] = []
class MessageLog:
"""SQLite-backed chat history — drop-in replacement for the old in-memory list."""
def __init__(self, db_path: Path | None = None) -> None:
self._db_path = db_path or DB_PATH
self._lock = threading.Lock()
self._conn: sqlite3.Connection | None = None
# Lazy connection — opened on first use, not at import time.
def _ensure_conn(self) -> sqlite3.Connection:
if self._conn is None:
self._conn = _get_conn(self._db_path)
return self._conn
def append(self, role: str, content: str, timestamp: str, source: str = "browser") -> None:
self._entries.append(
Message(role=role, content=content, timestamp=timestamp, source=source)
)
with self._lock:
conn = self._ensure_conn()
conn.execute(
"INSERT INTO chat_messages (role, content, timestamp, source) VALUES (?, ?, ?, ?)",
(role, content, timestamp, source),
)
conn.commit()
self._prune(conn)
def all(self) -> list[Message]:
return list(self._entries)
with self._lock:
conn = self._ensure_conn()
rows = conn.execute(
"SELECT role, content, timestamp, source FROM chat_messages ORDER BY id"
).fetchall()
return [
Message(
role=r["role"], content=r["content"], timestamp=r["timestamp"], source=r["source"]
)
for r in rows
]
def recent(self, limit: int = 50) -> list[Message]:
"""Return the *limit* most recent messages (oldest-first)."""
with self._lock:
conn = self._ensure_conn()
rows = conn.execute(
"SELECT role, content, timestamp, source FROM chat_messages "
"ORDER BY id DESC LIMIT ?",
(limit,),
).fetchall()
return [
Message(
role=r["role"], content=r["content"], timestamp=r["timestamp"], source=r["source"]
)
for r in reversed(rows)
]
def clear(self) -> None:
self._entries.clear()
with self._lock:
conn = self._ensure_conn()
conn.execute("DELETE FROM chat_messages")
conn.commit()
def _prune(self, conn: sqlite3.Connection) -> None:
"""Keep at most MAX_MESSAGES rows, deleting the oldest."""
count = conn.execute("SELECT COUNT(*) FROM chat_messages").fetchone()[0]
if count > MAX_MESSAGES:
excess = count - MAX_MESSAGES
conn.execute(
"DELETE FROM chat_messages WHERE id IN "
"(SELECT id FROM chat_messages ORDER BY id LIMIT ?)",
(excess,),
)
conn.commit()
def close(self) -> None:
if self._conn is not None:
self._conn.close()
self._conn = None
def __len__(self) -> int:
return len(self._entries)
with self._lock:
conn = self._ensure_conn()
return conn.execute("SELECT COUNT(*) FROM chat_messages").fetchone()[0]
# Module-level singleton shared across the app

View File

@@ -55,13 +55,27 @@ os.environ["TIMMY_SKIP_EMBEDDINGS"] = "1"
@pytest.fixture(autouse=True)
def reset_message_log():
"""Clear the in-memory chat log before and after every test."""
from dashboard.store import message_log
def reset_message_log(tmp_path):
"""Redirect chat DB to temp dir and clear before/after every test."""
import dashboard.store as _store_mod
message_log.clear()
original_db_path = _store_mod.DB_PATH
tmp_chat_db = tmp_path / "chat.db"
_store_mod.DB_PATH = tmp_chat_db
# Close existing singleton connection and point it at tmp DB
_store_mod.message_log.close()
_store_mod.message_log._db_path = tmp_chat_db
_store_mod.message_log._conn = None
_store_mod.message_log.clear()
yield
message_log.clear()
_store_mod.message_log.clear()
_store_mod.message_log.close()
_store_mod.DB_PATH = original_db_path
_store_mod.message_log._db_path = original_db_path
_store_mod.message_log._conn = None
@pytest.fixture(autouse=True)

View File

@@ -0,0 +1,125 @@
"""Tests for SQLite-backed chat persistence (issue #46)."""
from dashboard.store import Message, MessageLog
def test_persistence_across_instances(tmp_path):
"""Messages survive creating a new MessageLog pointing at the same DB."""
db = tmp_path / "chat.db"
log1 = MessageLog(db_path=db)
log1.append(role="user", content="hello", timestamp="10:00:00", source="browser")
log1.append(role="agent", content="hi back", timestamp="10:00:01", source="browser")
log1.close()
# New instance — simulates server restart
log2 = MessageLog(db_path=db)
msgs = log2.all()
assert len(msgs) == 2
assert msgs[0].role == "user"
assert msgs[0].content == "hello"
assert msgs[1].role == "agent"
assert msgs[1].content == "hi back"
log2.close()
def test_retention_policy(tmp_path):
"""Oldest messages are pruned when count exceeds MAX_MESSAGES."""
import dashboard.store as store_mod
original_max = store_mod.MAX_MESSAGES
store_mod.MAX_MESSAGES = 5 # Small limit for testing
try:
db = tmp_path / "chat.db"
log = MessageLog(db_path=db)
for i in range(8):
log.append(role="user", content=f"msg-{i}", timestamp=f"10:00:{i:02d}")
assert len(log) == 5
msgs = log.all()
# Oldest 3 should have been pruned
assert msgs[0].content == "msg-3"
assert msgs[-1].content == "msg-7"
log.close()
finally:
store_mod.MAX_MESSAGES = original_max
def test_clear_removes_all(tmp_path):
db = tmp_path / "chat.db"
log = MessageLog(db_path=db)
log.append(role="user", content="data", timestamp="12:00:00")
assert len(log) == 1
log.clear()
assert len(log) == 0
assert log.all() == []
log.close()
def test_recent_returns_limited_newest(tmp_path):
db = tmp_path / "chat.db"
log = MessageLog(db_path=db)
for i in range(10):
log.append(role="user", content=f"msg-{i}", timestamp=f"10:00:{i:02d}")
recent = log.recent(limit=3)
assert len(recent) == 3
# Should be oldest-first within the window
assert recent[0].content == "msg-7"
assert recent[1].content == "msg-8"
assert recent[2].content == "msg-9"
log.close()
def test_source_field_persisted(tmp_path):
db = tmp_path / "chat.db"
log = MessageLog(db_path=db)
log.append(role="user", content="from api", timestamp="10:00:00", source="api")
log.append(role="user", content="from tg", timestamp="10:00:01", source="telegram")
log.close()
log2 = MessageLog(db_path=db)
msgs = log2.all()
assert msgs[0].source == "api"
assert msgs[1].source == "telegram"
log2.close()
def test_message_dataclass_defaults():
m = Message(role="user", content="hi", timestamp="12:00:00")
assert m.source == "browser"
def test_empty_db_returns_empty(tmp_path):
db = tmp_path / "chat.db"
log = MessageLog(db_path=db)
assert log.all() == []
assert len(log) == 0
assert log.recent() == []
log.close()
def test_concurrent_appends(tmp_path):
"""Multiple threads can append without corrupting data."""
import threading
db = tmp_path / "chat.db"
log = MessageLog(db_path=db)
errors = []
def writer(thread_id):
try:
for i in range(20):
log.append(role="user", content=f"t{thread_id}-{i}", timestamp="10:00:00")
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=writer, args=(t,)) for t in range(4)]
for t in threads:
t.start()
for t in threads:
t.join()
assert not errors
assert len(log) == 80
log.close()