From 7afb72209a344b130d938b4d8a8845467912bd2f Mon Sep 17 00:00:00 2001 From: "Claude (Opus 4.6)" Date: Mon, 23 Mar 2026 21:58:38 +0000 Subject: [PATCH] [claude] Add unit tests for chat_store.py (#1192) (#1198) --- tests/infrastructure/test_chat_store.py | 513 ++++++++++++++++++++++++ 1 file changed, 513 insertions(+) create mode 100644 tests/infrastructure/test_chat_store.py diff --git a/tests/infrastructure/test_chat_store.py b/tests/infrastructure/test_chat_store.py new file mode 100644 index 00000000..0b50f628 --- /dev/null +++ b/tests/infrastructure/test_chat_store.py @@ -0,0 +1,513 @@ +"""Unit tests for infrastructure.chat_store module.""" + +import threading +from pathlib import Path + +import pytest + +from infrastructure.chat_store import MAX_MESSAGES, Message, MessageLog, _get_conn + + +# --------------------------------------------------------------------------- +# Message dataclass +# --------------------------------------------------------------------------- + + +class TestMessageDataclass: + """Tests for the Message dataclass.""" + + def test_message_required_fields(self): + """Message can be created with required fields only.""" + msg = Message(role="user", content="hello", timestamp="2024-01-01T00:00:00") + assert msg.role == "user" + assert msg.content == "hello" + assert msg.timestamp == "2024-01-01T00:00:00" + + def test_message_default_source(self): + """Message source defaults to 'browser'.""" + msg = Message(role="user", content="hi", timestamp="2024-01-01T00:00:00") + assert msg.source == "browser" + + def test_message_custom_source(self): + """Message source can be overridden.""" + msg = Message(role="agent", content="reply", timestamp="2024-01-01T00:00:00", source="api") + assert msg.source == "api" + + def test_message_equality(self): + """Two Messages with the same fields are equal (dataclass default).""" + m1 = Message(role="user", content="x", timestamp="t") + m2 = Message(role="user", content="x", timestamp="t") + assert m1 == m2 + + def test_message_inequality(self): + """Messages with different content are not equal.""" + m1 = Message(role="user", content="x", timestamp="t") + m2 = Message(role="user", content="y", timestamp="t") + assert m1 != m2 + + +# --------------------------------------------------------------------------- +# _get_conn context manager +# --------------------------------------------------------------------------- + + +class TestGetConnContextManager: + """Tests for the _get_conn context manager.""" + + def test_creates_db_file(self, tmp_path): + """_get_conn creates the database file on first use.""" + db = tmp_path / "chat.db" + assert not db.exists() + with _get_conn(db) as conn: + assert conn is not None + assert db.exists() + + def test_creates_parent_directories(self, tmp_path): + """_get_conn creates any missing parent directories.""" + db = tmp_path / "nested" / "deep" / "chat.db" + with _get_conn(db): + pass + assert db.exists() + + def test_creates_schema(self, tmp_path): + """_get_conn creates the chat_messages table.""" + db = tmp_path / "chat.db" + with _get_conn(db) as conn: + tables = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='chat_messages'" + ).fetchall() + assert len(tables) == 1 + + def test_schema_has_expected_columns(self, tmp_path): + """chat_messages table has the expected columns.""" + db = tmp_path / "chat.db" + with _get_conn(db) as conn: + info = conn.execute("PRAGMA table_info(chat_messages)").fetchall() + col_names = [row["name"] for row in info] + assert set(col_names) == {"id", "role", "content", "timestamp", "source"} + + def test_idempotent_schema_creation(self, tmp_path): + """Calling _get_conn twice does not fail (CREATE TABLE IF NOT EXISTS).""" + db = tmp_path / "chat.db" + with _get_conn(db): + pass + with _get_conn(db) as conn: + # Table still exists and is usable + conn.execute("SELECT COUNT(*) FROM chat_messages") + + +# --------------------------------------------------------------------------- +# MessageLog — basic operations +# --------------------------------------------------------------------------- + + +class TestMessageLogAppend: + """Tests for MessageLog.append().""" + + def test_append_single_message(self, tmp_path): + """append() stores a message that can be retrieved.""" + log = MessageLog(tmp_path / "chat.db") + log.append("user", "hello", "2024-01-01T00:00:00") + messages = log.all() + assert len(messages) == 1 + assert messages[0].role == "user" + assert messages[0].content == "hello" + assert messages[0].timestamp == "2024-01-01T00:00:00" + assert messages[0].source == "browser" + log.close() + + def test_append_custom_source(self, tmp_path): + """append() stores the source field correctly.""" + log = MessageLog(tmp_path / "chat.db") + log.append("agent", "reply", "2024-01-01T00:00:01", source="api") + msg = log.all()[0] + assert msg.source == "api" + log.close() + + def test_append_multiple_messages_preserves_order(self, tmp_path): + """append() preserves insertion order.""" + log = MessageLog(tmp_path / "chat.db") + log.append("user", "first", "2024-01-01T00:00:00") + log.append("agent", "second", "2024-01-01T00:00:01") + log.append("user", "third", "2024-01-01T00:00:02") + messages = log.all() + assert [m.content for m in messages] == ["first", "second", "third"] + log.close() + + def test_append_persists_across_instances(self, tmp_path): + """Messages appended by one instance are readable by another.""" + db = tmp_path / "chat.db" + log1 = MessageLog(db) + log1.append("user", "persisted", "2024-01-01T00:00:00") + log1.close() + + log2 = MessageLog(db) + messages = log2.all() + assert len(messages) == 1 + assert messages[0].content == "persisted" + log2.close() + + +class TestMessageLogAll: + """Tests for MessageLog.all().""" + + def test_all_on_empty_store_returns_empty_list(self, tmp_path): + """all() returns [] when there are no messages.""" + log = MessageLog(tmp_path / "chat.db") + assert log.all() == [] + log.close() + + def test_all_returns_message_objects(self, tmp_path): + """all() returns a list of Message dataclass instances.""" + log = MessageLog(tmp_path / "chat.db") + log.append("user", "hi", "2024-01-01T00:00:00") + messages = log.all() + assert all(isinstance(m, Message) for m in messages) + log.close() + + def test_all_returns_all_messages(self, tmp_path): + """all() returns every stored message.""" + log = MessageLog(tmp_path / "chat.db") + for i in range(5): + log.append("user", f"msg{i}", f"2024-01-01T00:00:0{i}") + assert len(log.all()) == 5 + log.close() + + +class TestMessageLogRecent: + """Tests for MessageLog.recent().""" + + def test_recent_on_empty_store_returns_empty_list(self, tmp_path): + """recent() returns [] when there are no messages.""" + log = MessageLog(tmp_path / "chat.db") + assert log.recent() == [] + log.close() + + def test_recent_default_limit(self, tmp_path): + """recent() with default limit returns up to 50 messages.""" + log = MessageLog(tmp_path / "chat.db") + for i in range(60): + log.append("user", f"msg{i}", f"2024-01-01T00:00:{i:02d}") + msgs = log.recent() + assert len(msgs) == 50 + log.close() + + def test_recent_custom_limit(self, tmp_path): + """recent() respects a custom limit.""" + log = MessageLog(tmp_path / "chat.db") + for i in range(10): + log.append("user", f"msg{i}", f"2024-01-01T00:00:0{i}") + msgs = log.recent(limit=3) + assert len(msgs) == 3 + log.close() + + def test_recent_returns_newest_messages(self, tmp_path): + """recent() returns the most-recently-inserted messages.""" + log = MessageLog(tmp_path / "chat.db") + for i in range(10): + log.append("user", f"msg{i}", f"2024-01-01T00:00:0{i}") + msgs = log.recent(limit=3) + # Should be the last 3 inserted, in oldest-first order + assert [m.content for m in msgs] == ["msg7", "msg8", "msg9"] + log.close() + + def test_recent_fewer_than_limit_returns_all(self, tmp_path): + """recent() returns all messages when count < limit.""" + log = MessageLog(tmp_path / "chat.db") + log.append("user", "only", "2024-01-01T00:00:00") + msgs = log.recent(limit=10) + assert len(msgs) == 1 + log.close() + + def test_recent_returns_oldest_first(self, tmp_path): + """recent() returns messages in oldest-first order.""" + log = MessageLog(tmp_path / "chat.db") + log.append("user", "a", "2024-01-01T00:00:00") + log.append("user", "b", "2024-01-01T00:00:01") + log.append("user", "c", "2024-01-01T00:00:02") + msgs = log.recent(limit=2) + assert [m.content for m in msgs] == ["b", "c"] + log.close() + + +class TestMessageLogClear: + """Tests for MessageLog.clear().""" + + def test_clear_empties_the_store(self, tmp_path): + """clear() removes all messages.""" + log = MessageLog(tmp_path / "chat.db") + log.append("user", "hello", "2024-01-01T00:00:00") + log.clear() + assert log.all() == [] + log.close() + + def test_clear_on_empty_store_is_safe(self, tmp_path): + """clear() on an empty store does not raise.""" + log = MessageLog(tmp_path / "chat.db") + log.clear() # should not raise + assert log.all() == [] + log.close() + + def test_clear_allows_new_appends(self, tmp_path): + """After clear(), new messages can be appended.""" + log = MessageLog(tmp_path / "chat.db") + log.append("user", "old", "2024-01-01T00:00:00") + log.clear() + log.append("user", "new", "2024-01-01T00:00:01") + messages = log.all() + assert len(messages) == 1 + assert messages[0].content == "new" + log.close() + + def test_clear_resets_len_to_zero(self, tmp_path): + """After clear(), __len__ returns 0.""" + log = MessageLog(tmp_path / "chat.db") + log.append("user", "a", "t") + log.append("user", "b", "t") + log.clear() + assert len(log) == 0 + log.close() + + +# --------------------------------------------------------------------------- +# MessageLog — __len__ +# --------------------------------------------------------------------------- + + +class TestMessageLogLen: + """Tests for MessageLog.__len__().""" + + def test_len_empty_store(self, tmp_path): + """__len__ returns 0 for an empty store.""" + log = MessageLog(tmp_path / "chat.db") + assert len(log) == 0 + log.close() + + def test_len_after_appends(self, tmp_path): + """__len__ reflects the number of stored messages.""" + log = MessageLog(tmp_path / "chat.db") + for i in range(7): + log.append("user", f"msg{i}", "t") + assert len(log) == 7 + log.close() + + def test_len_after_clear(self, tmp_path): + """__len__ is 0 after clear().""" + log = MessageLog(tmp_path / "chat.db") + log.append("user", "x", "t") + log.clear() + assert len(log) == 0 + log.close() + + +# --------------------------------------------------------------------------- +# MessageLog — pruning +# --------------------------------------------------------------------------- + + +class TestMessageLogPrune: + """Tests for automatic pruning via _prune().""" + + def test_prune_keeps_at_most_max_messages(self, tmp_path): + """After exceeding MAX_MESSAGES, oldest messages are pruned.""" + log = MessageLog(tmp_path / "chat.db") + # Temporarily lower the limit via monkeypatching is not straightforward + # because _prune reads the module-level MAX_MESSAGES constant. + # We therefore patch it directly. + import infrastructure.chat_store as cs + + original = cs.MAX_MESSAGES + cs.MAX_MESSAGES = 5 + try: + for i in range(8): + log.append("user", f"msg{i}", f"t{i}") + assert len(log) == 5 + finally: + cs.MAX_MESSAGES = original + log.close() + + def test_prune_keeps_newest_messages(self, tmp_path): + """Pruning removes oldest messages and keeps the newest ones.""" + import infrastructure.chat_store as cs + + log = MessageLog(tmp_path / "chat.db") + original = cs.MAX_MESSAGES + cs.MAX_MESSAGES = 3 + try: + for i in range(5): + log.append("user", f"msg{i}", f"t{i}") + messages = log.all() + contents = [m.content for m in messages] + assert contents == ["msg2", "msg3", "msg4"] + finally: + cs.MAX_MESSAGES = original + log.close() + + def test_no_prune_when_below_limit(self, tmp_path): + """No messages are pruned while count is at or below MAX_MESSAGES.""" + log = MessageLog(tmp_path / "chat.db") + import infrastructure.chat_store as cs + + original = cs.MAX_MESSAGES + cs.MAX_MESSAGES = 10 + try: + for i in range(10): + log.append("user", f"msg{i}", f"t{i}") + assert len(log) == 10 + finally: + cs.MAX_MESSAGES = original + log.close() + + +# --------------------------------------------------------------------------- +# MessageLog — close / lifecycle +# --------------------------------------------------------------------------- + + +class TestMessageLogClose: + """Tests for MessageLog.close().""" + + def test_close_is_safe_before_first_use(self, tmp_path): + """close() on a fresh (never-used) instance does not raise.""" + log = MessageLog(tmp_path / "chat.db") + log.close() # should not raise + + def test_close_multiple_times_is_safe(self, tmp_path): + """close() can be called multiple times without error.""" + log = MessageLog(tmp_path / "chat.db") + log.append("user", "hi", "t") + log.close() + log.close() # second close should not raise + + def test_close_sets_conn_to_none(self, tmp_path): + """close() sets the internal _conn attribute to None.""" + log = MessageLog(tmp_path / "chat.db") + log.append("user", "hi", "t") + assert log._conn is not None + log.close() + assert log._conn is None + + +# --------------------------------------------------------------------------- +# Thread safety +# --------------------------------------------------------------------------- + + +class TestMessageLogThreadSafety: + """Thread-safety tests for MessageLog.""" + + def test_concurrent_appends(self, tmp_path): + """Multiple threads can append messages without data loss or errors.""" + log = MessageLog(tmp_path / "chat.db") + errors: list[Exception] = [] + + def worker(n: int) -> None: + try: + for i in range(5): + log.append("user", f"t{n}-{i}", f"ts-{n}-{i}") + except Exception as exc: # noqa: BLE001 + errors.append(exc) + + threads = [threading.Thread(target=worker, args=(n,)) for n in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert errors == [], f"Concurrent append raised: {errors}" + # All 20 messages should be present (4 threads × 5 messages) + assert len(log) == 20 + log.close() + + def test_concurrent_reads_and_writes(self, tmp_path): + """Concurrent reads and writes do not corrupt state.""" + log = MessageLog(tmp_path / "chat.db") + errors: list[Exception] = [] + + def writer() -> None: + try: + for i in range(10): + log.append("user", f"msg{i}", f"t{i}") + except Exception as exc: # noqa: BLE001 + errors.append(exc) + + def reader() -> None: + try: + for _ in range(10): + log.all() + except Exception as exc: # noqa: BLE001 + errors.append(exc) + + threads = [threading.Thread(target=writer)] + [ + threading.Thread(target=reader) for _ in range(3) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + assert errors == [], f"Concurrent read/write raised: {errors}" + log.close() + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestMessageLogEdgeCases: + """Edge-case tests for MessageLog.""" + + def test_empty_content_stored_and_retrieved(self, tmp_path): + """Empty string content can be stored and retrieved.""" + log = MessageLog(tmp_path / "chat.db") + log.append("user", "", "2024-01-01T00:00:00") + assert log.all()[0].content == "" + log.close() + + def test_unicode_content_stored_and_retrieved(self, tmp_path): + """Unicode characters in content are stored and retrieved correctly.""" + log = MessageLog(tmp_path / "chat.db") + log.append("user", "こんにちは 🌍", "2024-01-01T00:00:00") + assert log.all()[0].content == "こんにちは 🌍" + log.close() + + def test_newline_in_content(self, tmp_path): + """Newlines in content are preserved.""" + log = MessageLog(tmp_path / "chat.db") + multiline = "line1\nline2\nline3" + log.append("agent", multiline, "2024-01-01T00:00:00") + assert log.all()[0].content == multiline + log.close() + + def test_default_db_path_attribute(self): + """MessageLog without explicit path uses the module-level DB_PATH.""" + from infrastructure.chat_store import DB_PATH + + log = MessageLog() + assert log._db_path == DB_PATH + # Do NOT call close() here — this is the global singleton's path + + def test_custom_db_path_used(self, tmp_path): + """MessageLog uses the provided db_path.""" + db = tmp_path / "custom.db" + log = MessageLog(db) + log.append("user", "test", "t") + assert db.exists() + log.close() + + def test_recent_limit_zero_returns_empty(self, tmp_path): + """recent(limit=0) returns an empty list.""" + log = MessageLog(tmp_path / "chat.db") + log.append("user", "msg", "t") + assert log.recent(limit=0) == [] + log.close() + + def test_all_roles_stored_correctly(self, tmp_path): + """Different role values are stored and retrieved correctly.""" + log = MessageLog(tmp_path / "chat.db") + for role in ("user", "agent", "error", "system"): + log.append(role, f"{role} message", "t") + messages = log.all() + assert [m.role for m in messages] == ["user", "agent", "error", "system"] + log.close()