diff --git a/tests/infrastructure/test_chat_store.py b/tests/infrastructure/test_chat_store.py new file mode 100644 index 00000000..0846af9c --- /dev/null +++ b/tests/infrastructure/test_chat_store.py @@ -0,0 +1,247 @@ +"""Unit tests for src/infrastructure/chat_store.py.""" + +import sqlite3 +import threading +from pathlib import Path +from unittest.mock import patch + +import pytest + +from src.infrastructure.chat_store import MAX_MESSAGES, Message, MessageLog, _get_conn + +pytestmark = pytest.mark.unit + + +@pytest.fixture() +def tmp_db(tmp_path: Path) -> Path: + """Return a temporary database path.""" + return tmp_path / "test_chat.db" + + +@pytest.fixture() +def log(tmp_db: Path) -> MessageLog: + """Return a MessageLog backed by a temp database.""" + ml = MessageLog(db_path=tmp_db) + yield ml + ml.close() + + +# ── Message dataclass ────────────────────────────────────────────────── + + +class TestMessage: + def test_default_source(self): + m = Message(role="user", content="hi", timestamp="2026-01-01T00:00:00") + assert m.source == "browser" + + def test_custom_source(self): + m = Message(role="agent", content="ok", timestamp="t1", source="telegram") + assert m.source == "telegram" + + def test_fields(self): + m = Message(role="error", content="boom", timestamp="t2", source="api") + assert m.role == "error" + assert m.content == "boom" + assert m.timestamp == "t2" + + +# ── _get_conn context manager ────────────────────────────────────────── + + +class TestGetConn: + def test_creates_db_and_table(self, tmp_db: Path): + with _get_conn(tmp_db) as conn: + tables = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table'" + ).fetchall() + names = [t["name"] for t in tables] + assert "chat_messages" in names + + def test_creates_parent_dirs(self, tmp_path: Path): + deep = tmp_path / "a" / "b" / "c" / "chat.db" + with _get_conn(deep) as conn: + assert deep.parent.exists() + + def test_connection_closed_after_context(self, tmp_db: Path): + with _get_conn(tmp_db) as conn: + conn.execute("SELECT 1") + # Connection should be closed — operations should fail + with pytest.raises(Exception): + conn.execute("SELECT 1") + + +# ── MessageLog core operations ───────────────────────────────────────── + + +class TestMessageLogAppendAndAll: + def test_append_and_all(self, log: MessageLog): + log.append("user", "hello", "t1") + log.append("agent", "hi back", "t2", source="api") + msgs = log.all() + assert len(msgs) == 2 + assert msgs[0].role == "user" + assert msgs[0].content == "hello" + assert msgs[0].source == "browser" + assert msgs[1].role == "agent" + assert msgs[1].source == "api" + + def test_all_returns_ordered_by_id(self, log: MessageLog): + for i in range(5): + log.append("user", f"msg{i}", f"t{i}") + msgs = log.all() + assert [m.content for m in msgs] == [f"msg{i}" for i in range(5)] + + def test_all_empty_store(self, log: MessageLog): + assert log.all() == [] + + +class TestMessageLogRecent: + def test_recent_returns_newest(self, log: MessageLog): + for i in range(10): + log.append("user", f"msg{i}", f"t{i}") + recent = log.recent(limit=3) + assert len(recent) == 3 + assert recent[0].content == "msg7" + assert recent[2].content == "msg9" + + def test_recent_oldest_first(self, log: MessageLog): + for i in range(5): + log.append("user", f"msg{i}", f"t{i}") + recent = log.recent(limit=3) + # Should be oldest-first within the window + assert recent[0].content == "msg2" + assert recent[1].content == "msg3" + assert recent[2].content == "msg4" + + def test_recent_more_than_exists(self, log: MessageLog): + log.append("user", "only", "t0") + recent = log.recent(limit=100) + assert len(recent) == 1 + + def test_recent_empty_store(self, log: MessageLog): + assert log.recent() == [] + + +class TestMessageLogClear: + def test_clear_removes_all(self, log: MessageLog): + for i in range(5): + log.append("user", f"msg{i}", f"t{i}") + assert len(log) == 5 + log.clear() + assert len(log) == 0 + assert log.all() == [] + + def test_clear_empty_store(self, log: MessageLog): + log.clear() # Should not raise + assert len(log) == 0 + + +class TestMessageLogLen: + def test_len_empty(self, log: MessageLog): + assert len(log) == 0 + + def test_len_after_appends(self, log: MessageLog): + for i in range(7): + log.append("user", f"msg{i}", f"t{i}") + assert len(log) == 7 + + +class TestMessageLogClose: + def test_close_sets_conn_none(self, tmp_db: Path): + ml = MessageLog(db_path=tmp_db) + ml.append("user", "x", "t0") + ml.close() + assert ml._conn is None + + def test_close_idempotent(self, tmp_db: Path): + ml = MessageLog(db_path=tmp_db) + ml.close() + ml.close() # Should not raise + + def test_reopen_after_close(self, tmp_db: Path): + ml = MessageLog(db_path=tmp_db) + ml.append("user", "before", "t0") + ml.close() + # Should reconnect on next use + ml.append("user", "after", "t1") + assert len(ml) == 2 + ml.close() + + +# ── Pruning ──────────────────────────────────────────────────────────── + + +class TestPrune: + def test_prune_keeps_max_messages(self, tmp_db: Path): + with patch("src.infrastructure.chat_store.MAX_MESSAGES", 5): + ml = MessageLog(db_path=tmp_db) + for i in range(10): + ml.append("user", f"msg{i}", f"t{i}") + # Should have pruned to 5 + assert len(ml) == 5 + msgs = ml.all() + # Oldest should be pruned, newest kept + assert msgs[0].content == "msg5" + assert msgs[-1].content == "msg9" + ml.close() + + def test_no_prune_under_limit(self, tmp_db: Path): + with patch("src.infrastructure.chat_store.MAX_MESSAGES", 100): + ml = MessageLog(db_path=tmp_db) + for i in range(10): + ml.append("user", f"msg{i}", f"t{i}") + assert len(ml) == 10 + ml.close() + + +# ── Thread safety ────────────────────────────────────────────────────── + + +class TestThreadSafety: + def test_concurrent_appends(self, tmp_db: Path): + ml = MessageLog(db_path=tmp_db) + errors = [] + + def writer(start: int): + try: + for i in range(20): + ml.append("user", f"msg{start + i}", f"t{start + i}") + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=writer, args=(i * 20,)) for i in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors, f"Thread errors: {errors}" + assert len(ml) == 100 + ml.close() + + +# ── Edge cases ───────────────────────────────────────────────────────── + + +class TestEdgeCases: + def test_empty_content(self, log: MessageLog): + log.append("user", "", "t0") + msgs = log.all() + assert len(msgs) == 1 + assert msgs[0].content == "" + + def test_unicode_content(self, log: MessageLog): + log.append("user", "こんにちは 🎉 مرحبا", "t0") + msgs = log.all() + assert msgs[0].content == "こんにちは 🎉 مرحبا" + + def test_multiline_content(self, log: MessageLog): + content = "line1\nline2\nline3" + log.append("user", content, "t0") + assert log.all()[0].content == content + + def test_special_sql_characters(self, log: MessageLog): + log.append("user", "Robert'; DROP TABLE chat_messages;--", "t0") + msgs = log.all() + assert len(msgs) == 1 + assert "DROP TABLE" in msgs[0].content