forked from Rockachopa/Timmy-time-dashboard
Merge pull request '[loop-cycle-5] Persist chat history in SQLite (#46)' (#63) from fix/issue-46-chat-persistence into main
This commit is contained in:
@@ -1,4 +1,23 @@
|
||||
"""Persistent chat message store backed by SQLite.
|
||||
|
||||
Provides the same API as the original in-memory MessageLog so all callers
|
||||
(dashboard routes, chat_api, thinking, briefing) work without changes.
|
||||
|
||||
Data lives in ``data/chat.db`` — survives server restarts.
|
||||
A configurable retention policy (default 500 messages) keeps the DB lean.
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
# ── Data dir — resolved relative to repo root (two levels up from this file) ──
|
||||
_REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
DB_PATH: Path = _REPO_ROOT / "data" / "chat.db"
|
||||
|
||||
# Maximum messages to retain (oldest pruned on append)
|
||||
MAX_MESSAGES: int = 500
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -9,25 +28,106 @@ class Message:
|
||||
source: str = "browser" # "browser" | "api" | "telegram" | "discord" | "system"
|
||||
|
||||
|
||||
class MessageLog:
|
||||
"""In-memory chat history for the lifetime of the server process."""
|
||||
def _get_conn(db_path: Path | None = None) -> sqlite3.Connection:
|
||||
"""Open (or create) the chat database and ensure schema exists."""
|
||||
path = db_path or DB_PATH
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(path), check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS chat_messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
source TEXT NOT NULL DEFAULT 'browser'
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._entries: list[Message] = []
|
||||
|
||||
class MessageLog:
|
||||
"""SQLite-backed chat history — drop-in replacement for the old in-memory list."""
|
||||
|
||||
def __init__(self, db_path: Path | None = None) -> None:
|
||||
self._db_path = db_path or DB_PATH
|
||||
self._lock = threading.Lock()
|
||||
self._conn: sqlite3.Connection | None = None
|
||||
|
||||
# Lazy connection — opened on first use, not at import time.
|
||||
def _ensure_conn(self) -> sqlite3.Connection:
|
||||
if self._conn is None:
|
||||
self._conn = _get_conn(self._db_path)
|
||||
return self._conn
|
||||
|
||||
def append(self, role: str, content: str, timestamp: str, source: str = "browser") -> None:
|
||||
self._entries.append(
|
||||
Message(role=role, content=content, timestamp=timestamp, source=source)
|
||||
)
|
||||
with self._lock:
|
||||
conn = self._ensure_conn()
|
||||
conn.execute(
|
||||
"INSERT INTO chat_messages (role, content, timestamp, source) VALUES (?, ?, ?, ?)",
|
||||
(role, content, timestamp, source),
|
||||
)
|
||||
conn.commit()
|
||||
self._prune(conn)
|
||||
|
||||
def all(self) -> list[Message]:
|
||||
return list(self._entries)
|
||||
with self._lock:
|
||||
conn = self._ensure_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT role, content, timestamp, source FROM chat_messages ORDER BY id"
|
||||
).fetchall()
|
||||
return [
|
||||
Message(
|
||||
role=r["role"], content=r["content"], timestamp=r["timestamp"], source=r["source"]
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
def recent(self, limit: int = 50) -> list[Message]:
|
||||
"""Return the *limit* most recent messages (oldest-first)."""
|
||||
with self._lock:
|
||||
conn = self._ensure_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT role, content, timestamp, source FROM chat_messages "
|
||||
"ORDER BY id DESC LIMIT ?",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
return [
|
||||
Message(
|
||||
role=r["role"], content=r["content"], timestamp=r["timestamp"], source=r["source"]
|
||||
)
|
||||
for r in reversed(rows)
|
||||
]
|
||||
|
||||
def clear(self) -> None:
|
||||
self._entries.clear()
|
||||
with self._lock:
|
||||
conn = self._ensure_conn()
|
||||
conn.execute("DELETE FROM chat_messages")
|
||||
conn.commit()
|
||||
|
||||
def _prune(self, conn: sqlite3.Connection) -> None:
|
||||
"""Keep at most MAX_MESSAGES rows, deleting the oldest."""
|
||||
count = conn.execute("SELECT COUNT(*) FROM chat_messages").fetchone()[0]
|
||||
if count > MAX_MESSAGES:
|
||||
excess = count - MAX_MESSAGES
|
||||
conn.execute(
|
||||
"DELETE FROM chat_messages WHERE id IN "
|
||||
"(SELECT id FROM chat_messages ORDER BY id LIMIT ?)",
|
||||
(excess,),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def close(self) -> None:
|
||||
if self._conn is not None:
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._entries)
|
||||
with self._lock:
|
||||
conn = self._ensure_conn()
|
||||
return conn.execute("SELECT COUNT(*) FROM chat_messages").fetchone()[0]
|
||||
|
||||
|
||||
# Module-level singleton shared across the app
|
||||
|
||||
@@ -55,13 +55,27 @@ os.environ["TIMMY_SKIP_EMBEDDINGS"] = "1"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_message_log():
|
||||
"""Clear the in-memory chat log before and after every test."""
|
||||
from dashboard.store import message_log
|
||||
def reset_message_log(tmp_path):
|
||||
"""Redirect chat DB to temp dir and clear before/after every test."""
|
||||
import dashboard.store as _store_mod
|
||||
|
||||
message_log.clear()
|
||||
original_db_path = _store_mod.DB_PATH
|
||||
tmp_chat_db = tmp_path / "chat.db"
|
||||
_store_mod.DB_PATH = tmp_chat_db
|
||||
|
||||
# Close existing singleton connection and point it at tmp DB
|
||||
_store_mod.message_log.close()
|
||||
_store_mod.message_log._db_path = tmp_chat_db
|
||||
_store_mod.message_log._conn = None
|
||||
|
||||
_store_mod.message_log.clear()
|
||||
yield
|
||||
message_log.clear()
|
||||
_store_mod.message_log.clear()
|
||||
_store_mod.message_log.close()
|
||||
|
||||
_store_mod.DB_PATH = original_db_path
|
||||
_store_mod.message_log._db_path = original_db_path
|
||||
_store_mod.message_log._conn = None
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
||||
125
tests/dashboard/test_chat_persistence.py
Normal file
125
tests/dashboard/test_chat_persistence.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Tests for SQLite-backed chat persistence (issue #46)."""
|
||||
|
||||
from dashboard.store import Message, MessageLog
|
||||
|
||||
|
||||
def test_persistence_across_instances(tmp_path):
|
||||
"""Messages survive creating a new MessageLog pointing at the same DB."""
|
||||
db = tmp_path / "chat.db"
|
||||
log1 = MessageLog(db_path=db)
|
||||
log1.append(role="user", content="hello", timestamp="10:00:00", source="browser")
|
||||
log1.append(role="agent", content="hi back", timestamp="10:00:01", source="browser")
|
||||
log1.close()
|
||||
|
||||
# New instance — simulates server restart
|
||||
log2 = MessageLog(db_path=db)
|
||||
msgs = log2.all()
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0].role == "user"
|
||||
assert msgs[0].content == "hello"
|
||||
assert msgs[1].role == "agent"
|
||||
assert msgs[1].content == "hi back"
|
||||
log2.close()
|
||||
|
||||
|
||||
def test_retention_policy(tmp_path):
|
||||
"""Oldest messages are pruned when count exceeds MAX_MESSAGES."""
|
||||
import dashboard.store as store_mod
|
||||
|
||||
original_max = store_mod.MAX_MESSAGES
|
||||
store_mod.MAX_MESSAGES = 5 # Small limit for testing
|
||||
|
||||
try:
|
||||
db = tmp_path / "chat.db"
|
||||
log = MessageLog(db_path=db)
|
||||
for i in range(8):
|
||||
log.append(role="user", content=f"msg-{i}", timestamp=f"10:00:{i:02d}")
|
||||
|
||||
assert len(log) == 5
|
||||
msgs = log.all()
|
||||
# Oldest 3 should have been pruned
|
||||
assert msgs[0].content == "msg-3"
|
||||
assert msgs[-1].content == "msg-7"
|
||||
log.close()
|
||||
finally:
|
||||
store_mod.MAX_MESSAGES = original_max
|
||||
|
||||
|
||||
def test_clear_removes_all(tmp_path):
|
||||
db = tmp_path / "chat.db"
|
||||
log = MessageLog(db_path=db)
|
||||
log.append(role="user", content="data", timestamp="12:00:00")
|
||||
assert len(log) == 1
|
||||
log.clear()
|
||||
assert len(log) == 0
|
||||
assert log.all() == []
|
||||
log.close()
|
||||
|
||||
|
||||
def test_recent_returns_limited_newest(tmp_path):
|
||||
db = tmp_path / "chat.db"
|
||||
log = MessageLog(db_path=db)
|
||||
for i in range(10):
|
||||
log.append(role="user", content=f"msg-{i}", timestamp=f"10:00:{i:02d}")
|
||||
|
||||
recent = log.recent(limit=3)
|
||||
assert len(recent) == 3
|
||||
# Should be oldest-first within the window
|
||||
assert recent[0].content == "msg-7"
|
||||
assert recent[1].content == "msg-8"
|
||||
assert recent[2].content == "msg-9"
|
||||
log.close()
|
||||
|
||||
|
||||
def test_source_field_persisted(tmp_path):
|
||||
db = tmp_path / "chat.db"
|
||||
log = MessageLog(db_path=db)
|
||||
log.append(role="user", content="from api", timestamp="10:00:00", source="api")
|
||||
log.append(role="user", content="from tg", timestamp="10:00:01", source="telegram")
|
||||
log.close()
|
||||
|
||||
log2 = MessageLog(db_path=db)
|
||||
msgs = log2.all()
|
||||
assert msgs[0].source == "api"
|
||||
assert msgs[1].source == "telegram"
|
||||
log2.close()
|
||||
|
||||
|
||||
def test_message_dataclass_defaults():
|
||||
m = Message(role="user", content="hi", timestamp="12:00:00")
|
||||
assert m.source == "browser"
|
||||
|
||||
|
||||
def test_empty_db_returns_empty(tmp_path):
|
||||
db = tmp_path / "chat.db"
|
||||
log = MessageLog(db_path=db)
|
||||
assert log.all() == []
|
||||
assert len(log) == 0
|
||||
assert log.recent() == []
|
||||
log.close()
|
||||
|
||||
|
||||
def test_concurrent_appends(tmp_path):
|
||||
"""Multiple threads can append without corrupting data."""
|
||||
import threading
|
||||
|
||||
db = tmp_path / "chat.db"
|
||||
log = MessageLog(db_path=db)
|
||||
errors = []
|
||||
|
||||
def writer(thread_id):
|
||||
try:
|
||||
for i in range(20):
|
||||
log.append(role="user", content=f"t{thread_id}-{i}", timestamp="10:00:00")
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=writer, args=(t,)) for t in range(4)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert not errors
|
||||
assert len(log) == 80
|
||||
log.close()
|
||||
Reference in New Issue
Block a user