diff --git a/src/dashboard/store.py b/src/dashboard/store.py index 69fc565..ffd97e6 100644 --- a/src/dashboard/store.py +++ b/src/dashboard/store.py @@ -1,4 +1,23 @@ +"""Persistent chat message store backed by SQLite. + +Provides the same API as the original in-memory MessageLog so all callers +(dashboard routes, chat_api, thinking, briefing) work without changes. + +Data lives in ``data/chat.db`` — survives server restarts. +A configurable retention policy (default 500 messages) keeps the DB lean. +""" + +import sqlite3 +import threading from dataclasses import dataclass +from pathlib import Path + +# ── Data dir — resolved relative to repo root (two levels up from this file) ── +_REPO_ROOT = Path(__file__).resolve().parents[2] +DB_PATH: Path = _REPO_ROOT / "data" / "chat.db" + +# Maximum messages to retain (oldest pruned on append) +MAX_MESSAGES: int = 500 @dataclass @@ -9,25 +28,106 @@ class Message: source: str = "browser" # "browser" | "api" | "telegram" | "discord" | "system" -class MessageLog: - """In-memory chat history for the lifetime of the server process.""" +def _get_conn(db_path: Path | None = None) -> sqlite3.Connection: + """Open (or create) the chat database and ensure schema exists.""" + path = db_path or DB_PATH + path.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(str(path), check_same_thread=False) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute(""" + CREATE TABLE IF NOT EXISTS chat_messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + role TEXT NOT NULL, + content TEXT NOT NULL, + timestamp TEXT NOT NULL, + source TEXT NOT NULL DEFAULT 'browser' + ) + """) + conn.commit() + return conn - def __init__(self) -> None: - self._entries: list[Message] = [] + +class MessageLog: + """SQLite-backed chat history — drop-in replacement for the old in-memory list.""" + + def __init__(self, db_path: Path | None = None) -> None: + self._db_path = db_path or DB_PATH + self._lock = threading.Lock() + self._conn: sqlite3.Connection | None = None + + # Lazy connection — opened on first use, not at import time. + def _ensure_conn(self) -> sqlite3.Connection: + if self._conn is None: + self._conn = _get_conn(self._db_path) + return self._conn def append(self, role: str, content: str, timestamp: str, source: str = "browser") -> None: - self._entries.append( - Message(role=role, content=content, timestamp=timestamp, source=source) - ) + with self._lock: + conn = self._ensure_conn() + conn.execute( + "INSERT INTO chat_messages (role, content, timestamp, source) VALUES (?, ?, ?, ?)", + (role, content, timestamp, source), + ) + conn.commit() + self._prune(conn) def all(self) -> list[Message]: - return list(self._entries) + with self._lock: + conn = self._ensure_conn() + rows = conn.execute( + "SELECT role, content, timestamp, source FROM chat_messages ORDER BY id" + ).fetchall() + return [ + Message( + role=r["role"], content=r["content"], timestamp=r["timestamp"], source=r["source"] + ) + for r in rows + ] + + def recent(self, limit: int = 50) -> list[Message]: + """Return the *limit* most recent messages (oldest-first).""" + with self._lock: + conn = self._ensure_conn() + rows = conn.execute( + "SELECT role, content, timestamp, source FROM chat_messages " + "ORDER BY id DESC LIMIT ?", + (limit,), + ).fetchall() + return [ + Message( + role=r["role"], content=r["content"], timestamp=r["timestamp"], source=r["source"] + ) + for r in reversed(rows) + ] def clear(self) -> None: - self._entries.clear() + with self._lock: + conn = self._ensure_conn() + conn.execute("DELETE FROM chat_messages") + conn.commit() + + def _prune(self, conn: sqlite3.Connection) -> None: + """Keep at most MAX_MESSAGES rows, deleting the oldest.""" + count = conn.execute("SELECT COUNT(*) FROM chat_messages").fetchone()[0] + if count > MAX_MESSAGES: + excess = count - MAX_MESSAGES + conn.execute( + "DELETE FROM chat_messages WHERE id IN " + "(SELECT id FROM chat_messages ORDER BY id LIMIT ?)", + (excess,), + ) + conn.commit() + + def close(self) -> None: + if self._conn is not None: + self._conn.close() + self._conn = None def __len__(self) -> int: - return len(self._entries) + with self._lock: + conn = self._ensure_conn() + return conn.execute("SELECT COUNT(*) FROM chat_messages").fetchone()[0] # Module-level singleton shared across the app diff --git a/tests/conftest.py b/tests/conftest.py index 1c8eea6..5c2be15 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -55,13 +55,27 @@ os.environ["TIMMY_SKIP_EMBEDDINGS"] = "1" @pytest.fixture(autouse=True) -def reset_message_log(): - """Clear the in-memory chat log before and after every test.""" - from dashboard.store import message_log +def reset_message_log(tmp_path): + """Redirect chat DB to temp dir and clear before/after every test.""" + import dashboard.store as _store_mod - message_log.clear() + original_db_path = _store_mod.DB_PATH + tmp_chat_db = tmp_path / "chat.db" + _store_mod.DB_PATH = tmp_chat_db + + # Close existing singleton connection and point it at tmp DB + _store_mod.message_log.close() + _store_mod.message_log._db_path = tmp_chat_db + _store_mod.message_log._conn = None + + _store_mod.message_log.clear() yield - message_log.clear() + _store_mod.message_log.clear() + _store_mod.message_log.close() + + _store_mod.DB_PATH = original_db_path + _store_mod.message_log._db_path = original_db_path + _store_mod.message_log._conn = None @pytest.fixture(autouse=True) diff --git a/tests/dashboard/test_chat_persistence.py b/tests/dashboard/test_chat_persistence.py new file mode 100644 index 0000000..cea4084 --- /dev/null +++ b/tests/dashboard/test_chat_persistence.py @@ -0,0 +1,125 @@ +"""Tests for SQLite-backed chat persistence (issue #46).""" + +from dashboard.store import Message, MessageLog + + +def test_persistence_across_instances(tmp_path): + """Messages survive creating a new MessageLog pointing at the same DB.""" + db = tmp_path / "chat.db" + log1 = MessageLog(db_path=db) + log1.append(role="user", content="hello", timestamp="10:00:00", source="browser") + log1.append(role="agent", content="hi back", timestamp="10:00:01", source="browser") + log1.close() + + # New instance — simulates server restart + log2 = MessageLog(db_path=db) + msgs = log2.all() + assert len(msgs) == 2 + assert msgs[0].role == "user" + assert msgs[0].content == "hello" + assert msgs[1].role == "agent" + assert msgs[1].content == "hi back" + log2.close() + + +def test_retention_policy(tmp_path): + """Oldest messages are pruned when count exceeds MAX_MESSAGES.""" + import dashboard.store as store_mod + + original_max = store_mod.MAX_MESSAGES + store_mod.MAX_MESSAGES = 5 # Small limit for testing + + try: + db = tmp_path / "chat.db" + log = MessageLog(db_path=db) + for i in range(8): + log.append(role="user", content=f"msg-{i}", timestamp=f"10:00:{i:02d}") + + assert len(log) == 5 + msgs = log.all() + # Oldest 3 should have been pruned + assert msgs[0].content == "msg-3" + assert msgs[-1].content == "msg-7" + log.close() + finally: + store_mod.MAX_MESSAGES = original_max + + +def test_clear_removes_all(tmp_path): + db = tmp_path / "chat.db" + log = MessageLog(db_path=db) + log.append(role="user", content="data", timestamp="12:00:00") + assert len(log) == 1 + log.clear() + assert len(log) == 0 + assert log.all() == [] + log.close() + + +def test_recent_returns_limited_newest(tmp_path): + db = tmp_path / "chat.db" + log = MessageLog(db_path=db) + for i in range(10): + log.append(role="user", content=f"msg-{i}", timestamp=f"10:00:{i:02d}") + + recent = log.recent(limit=3) + assert len(recent) == 3 + # Should be oldest-first within the window + assert recent[0].content == "msg-7" + assert recent[1].content == "msg-8" + assert recent[2].content == "msg-9" + log.close() + + +def test_source_field_persisted(tmp_path): + db = tmp_path / "chat.db" + log = MessageLog(db_path=db) + log.append(role="user", content="from api", timestamp="10:00:00", source="api") + log.append(role="user", content="from tg", timestamp="10:00:01", source="telegram") + log.close() + + log2 = MessageLog(db_path=db) + msgs = log2.all() + assert msgs[0].source == "api" + assert msgs[1].source == "telegram" + log2.close() + + +def test_message_dataclass_defaults(): + m = Message(role="user", content="hi", timestamp="12:00:00") + assert m.source == "browser" + + +def test_empty_db_returns_empty(tmp_path): + db = tmp_path / "chat.db" + log = MessageLog(db_path=db) + assert log.all() == [] + assert len(log) == 0 + assert log.recent() == [] + log.close() + + +def test_concurrent_appends(tmp_path): + """Multiple threads can append without corrupting data.""" + import threading + + db = tmp_path / "chat.db" + log = MessageLog(db_path=db) + errors = [] + + def writer(thread_id): + try: + for i in range(20): + log.append(role="user", content=f"t{thread_id}-{i}", timestamp="10:00:00") + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=writer, args=(t,)) for t in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors + assert len(log) == 80 + log.close()