"""Unit tests for infrastructure.chat_store module.""" import threading from infrastructure.chat_store import Message, MessageLog, _get_conn # --------------------------------------------------------------------------- # Message dataclass # --------------------------------------------------------------------------- class TestMessageDataclass: """Tests for the Message dataclass.""" def test_message_required_fields(self): """Message can be created with required fields only.""" msg = Message(role="user", content="hello", timestamp="2024-01-01T00:00:00") assert msg.role == "user" assert msg.content == "hello" assert msg.timestamp == "2024-01-01T00:00:00" def test_message_default_source(self): """Message source defaults to 'browser'.""" msg = Message(role="user", content="hi", timestamp="2024-01-01T00:00:00") assert msg.source == "browser" def test_message_custom_source(self): """Message source can be overridden.""" msg = Message(role="agent", content="reply", timestamp="2024-01-01T00:00:00", source="api") assert msg.source == "api" def test_message_equality(self): """Two Messages with the same fields are equal (dataclass default).""" m1 = Message(role="user", content="x", timestamp="t") m2 = Message(role="user", content="x", timestamp="t") assert m1 == m2 def test_message_inequality(self): """Messages with different content are not equal.""" m1 = Message(role="user", content="x", timestamp="t") m2 = Message(role="user", content="y", timestamp="t") assert m1 != m2 # --------------------------------------------------------------------------- # _get_conn context manager # --------------------------------------------------------------------------- class TestGetConnContextManager: """Tests for the _get_conn context manager.""" def test_creates_db_file(self, tmp_path): """_get_conn creates the database file on first use.""" db = tmp_path / "chat.db" assert not db.exists() with _get_conn(db) as conn: assert conn is not None assert db.exists() def test_creates_parent_directories(self, tmp_path): """_get_conn creates any missing parent directories.""" db = tmp_path / "nested" / "deep" / "chat.db" with _get_conn(db): pass assert db.exists() def test_creates_schema(self, tmp_path): """_get_conn creates the chat_messages table.""" db = tmp_path / "chat.db" with _get_conn(db) as conn: tables = conn.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name='chat_messages'" ).fetchall() assert len(tables) == 1 def test_schema_has_expected_columns(self, tmp_path): """chat_messages table has the expected columns.""" db = tmp_path / "chat.db" with _get_conn(db) as conn: info = conn.execute("PRAGMA table_info(chat_messages)").fetchall() col_names = [row["name"] for row in info] assert set(col_names) == {"id", "role", "content", "timestamp", "source"} def test_idempotent_schema_creation(self, tmp_path): """Calling _get_conn twice does not fail (CREATE TABLE IF NOT EXISTS).""" db = tmp_path / "chat.db" with _get_conn(db): pass with _get_conn(db) as conn: # Table still exists and is usable conn.execute("SELECT COUNT(*) FROM chat_messages") # --------------------------------------------------------------------------- # MessageLog — basic operations # --------------------------------------------------------------------------- class TestMessageLogAppend: """Tests for MessageLog.append().""" def test_append_single_message(self, tmp_path): """append() stores a message that can be retrieved.""" log = MessageLog(tmp_path / "chat.db") log.append("user", "hello", "2024-01-01T00:00:00") messages = log.all() assert len(messages) == 1 assert messages[0].role == "user" assert messages[0].content == "hello" assert messages[0].timestamp == "2024-01-01T00:00:00" assert messages[0].source == "browser" log.close() def test_append_custom_source(self, tmp_path): """append() stores the source field correctly.""" log = MessageLog(tmp_path / "chat.db") log.append("agent", "reply", "2024-01-01T00:00:01", source="api") msg = log.all()[0] assert msg.source == "api" log.close() def test_append_multiple_messages_preserves_order(self, tmp_path): """append() preserves insertion order.""" log = MessageLog(tmp_path / "chat.db") log.append("user", "first", "2024-01-01T00:00:00") log.append("agent", "second", "2024-01-01T00:00:01") log.append("user", "third", "2024-01-01T00:00:02") messages = log.all() assert [m.content for m in messages] == ["first", "second", "third"] log.close() def test_append_persists_across_instances(self, tmp_path): """Messages appended by one instance are readable by another.""" db = tmp_path / "chat.db" log1 = MessageLog(db) log1.append("user", "persisted", "2024-01-01T00:00:00") log1.close() log2 = MessageLog(db) messages = log2.all() assert len(messages) == 1 assert messages[0].content == "persisted" log2.close() class TestMessageLogAll: """Tests for MessageLog.all().""" def test_all_on_empty_store_returns_empty_list(self, tmp_path): """all() returns [] when there are no messages.""" log = MessageLog(tmp_path / "chat.db") assert log.all() == [] log.close() def test_all_returns_message_objects(self, tmp_path): """all() returns a list of Message dataclass instances.""" log = MessageLog(tmp_path / "chat.db") log.append("user", "hi", "2024-01-01T00:00:00") messages = log.all() assert all(isinstance(m, Message) for m in messages) log.close() def test_all_returns_all_messages(self, tmp_path): """all() returns every stored message.""" log = MessageLog(tmp_path / "chat.db") for i in range(5): log.append("user", f"msg{i}", f"2024-01-01T00:00:0{i}") assert len(log.all()) == 5 log.close() class TestMessageLogRecent: """Tests for MessageLog.recent().""" def test_recent_on_empty_store_returns_empty_list(self, tmp_path): """recent() returns [] when there are no messages.""" log = MessageLog(tmp_path / "chat.db") assert log.recent() == [] log.close() def test_recent_default_limit(self, tmp_path): """recent() with default limit returns up to 50 messages.""" log = MessageLog(tmp_path / "chat.db") for i in range(60): log.append("user", f"msg{i}", f"2024-01-01T00:00:{i:02d}") msgs = log.recent() assert len(msgs) == 50 log.close() def test_recent_custom_limit(self, tmp_path): """recent() respects a custom limit.""" log = MessageLog(tmp_path / "chat.db") for i in range(10): log.append("user", f"msg{i}", f"2024-01-01T00:00:0{i}") msgs = log.recent(limit=3) assert len(msgs) == 3 log.close() def test_recent_returns_newest_messages(self, tmp_path): """recent() returns the most-recently-inserted messages.""" log = MessageLog(tmp_path / "chat.db") for i in range(10): log.append("user", f"msg{i}", f"2024-01-01T00:00:0{i}") msgs = log.recent(limit=3) # Should be the last 3 inserted, in oldest-first order assert [m.content for m in msgs] == ["msg7", "msg8", "msg9"] log.close() def test_recent_fewer_than_limit_returns_all(self, tmp_path): """recent() returns all messages when count < limit.""" log = MessageLog(tmp_path / "chat.db") log.append("user", "only", "2024-01-01T00:00:00") msgs = log.recent(limit=10) assert len(msgs) == 1 log.close() def test_recent_returns_oldest_first(self, tmp_path): """recent() returns messages in oldest-first order.""" log = MessageLog(tmp_path / "chat.db") log.append("user", "a", "2024-01-01T00:00:00") log.append("user", "b", "2024-01-01T00:00:01") log.append("user", "c", "2024-01-01T00:00:02") msgs = log.recent(limit=2) assert [m.content for m in msgs] == ["b", "c"] log.close() class TestMessageLogClear: """Tests for MessageLog.clear().""" def test_clear_empties_the_store(self, tmp_path): """clear() removes all messages.""" log = MessageLog(tmp_path / "chat.db") log.append("user", "hello", "2024-01-01T00:00:00") log.clear() assert log.all() == [] log.close() def test_clear_on_empty_store_is_safe(self, tmp_path): """clear() on an empty store does not raise.""" log = MessageLog(tmp_path / "chat.db") log.clear() # should not raise assert log.all() == [] log.close() def test_clear_allows_new_appends(self, tmp_path): """After clear(), new messages can be appended.""" log = MessageLog(tmp_path / "chat.db") log.append("user", "old", "2024-01-01T00:00:00") log.clear() log.append("user", "new", "2024-01-01T00:00:01") messages = log.all() assert len(messages) == 1 assert messages[0].content == "new" log.close() def test_clear_resets_len_to_zero(self, tmp_path): """After clear(), __len__ returns 0.""" log = MessageLog(tmp_path / "chat.db") log.append("user", "a", "t") log.append("user", "b", "t") log.clear() assert len(log) == 0 log.close() # --------------------------------------------------------------------------- # MessageLog — __len__ # --------------------------------------------------------------------------- class TestMessageLogLen: """Tests for MessageLog.__len__().""" def test_len_empty_store(self, tmp_path): """__len__ returns 0 for an empty store.""" log = MessageLog(tmp_path / "chat.db") assert len(log) == 0 log.close() def test_len_after_appends(self, tmp_path): """__len__ reflects the number of stored messages.""" log = MessageLog(tmp_path / "chat.db") for i in range(7): log.append("user", f"msg{i}", "t") assert len(log) == 7 log.close() def test_len_after_clear(self, tmp_path): """__len__ is 0 after clear().""" log = MessageLog(tmp_path / "chat.db") log.append("user", "x", "t") log.clear() assert len(log) == 0 log.close() # --------------------------------------------------------------------------- # MessageLog — pruning # --------------------------------------------------------------------------- class TestMessageLogPrune: """Tests for automatic pruning via _prune().""" def test_prune_keeps_at_most_max_messages(self, tmp_path): """After exceeding MAX_MESSAGES, oldest messages are pruned.""" log = MessageLog(tmp_path / "chat.db") # Temporarily lower the limit via monkeypatching is not straightforward # because _prune reads the module-level MAX_MESSAGES constant. # We therefore patch it directly. import infrastructure.chat_store as cs original = cs.MAX_MESSAGES cs.MAX_MESSAGES = 5 try: for i in range(8): log.append("user", f"msg{i}", f"t{i}") assert len(log) == 5 finally: cs.MAX_MESSAGES = original log.close() def test_prune_keeps_newest_messages(self, tmp_path): """Pruning removes oldest messages and keeps the newest ones.""" import infrastructure.chat_store as cs log = MessageLog(tmp_path / "chat.db") original = cs.MAX_MESSAGES cs.MAX_MESSAGES = 3 try: for i in range(5): log.append("user", f"msg{i}", f"t{i}") messages = log.all() contents = [m.content for m in messages] assert contents == ["msg2", "msg3", "msg4"] finally: cs.MAX_MESSAGES = original log.close() def test_no_prune_when_below_limit(self, tmp_path): """No messages are pruned while count is at or below MAX_MESSAGES.""" log = MessageLog(tmp_path / "chat.db") import infrastructure.chat_store as cs original = cs.MAX_MESSAGES cs.MAX_MESSAGES = 10 try: for i in range(10): log.append("user", f"msg{i}", f"t{i}") assert len(log) == 10 finally: cs.MAX_MESSAGES = original log.close() # --------------------------------------------------------------------------- # MessageLog — close / lifecycle # --------------------------------------------------------------------------- class TestMessageLogClose: """Tests for MessageLog.close().""" def test_close_is_safe_before_first_use(self, tmp_path): """close() on a fresh (never-used) instance does not raise.""" log = MessageLog(tmp_path / "chat.db") log.close() # should not raise def test_close_multiple_times_is_safe(self, tmp_path): """close() can be called multiple times without error.""" log = MessageLog(tmp_path / "chat.db") log.append("user", "hi", "t") log.close() log.close() # second close should not raise def test_close_sets_conn_to_none(self, tmp_path): """close() sets the internal _conn attribute to None.""" log = MessageLog(tmp_path / "chat.db") log.append("user", "hi", "t") assert log._conn is not None log.close() assert log._conn is None # --------------------------------------------------------------------------- # Thread safety # --------------------------------------------------------------------------- class TestMessageLogThreadSafety: """Thread-safety tests for MessageLog.""" def test_concurrent_appends(self, tmp_path): """Multiple threads can append messages without data loss or errors.""" log = MessageLog(tmp_path / "chat.db") errors: list[Exception] = [] def worker(n: int) -> None: try: for i in range(5): log.append("user", f"t{n}-{i}", f"ts-{n}-{i}") except Exception as exc: # noqa: BLE001 errors.append(exc) threads = [threading.Thread(target=worker, args=(n,)) for n in range(4)] for t in threads: t.start() for t in threads: t.join() assert errors == [], f"Concurrent append raised: {errors}" # All 20 messages should be present (4 threads × 5 messages) assert len(log) == 20 log.close() def test_concurrent_reads_and_writes(self, tmp_path): """Concurrent reads and writes do not corrupt state.""" log = MessageLog(tmp_path / "chat.db") errors: list[Exception] = [] def writer() -> None: try: for i in range(10): log.append("user", f"msg{i}", f"t{i}") except Exception as exc: # noqa: BLE001 errors.append(exc) def reader() -> None: try: for _ in range(10): log.all() except Exception as exc: # noqa: BLE001 errors.append(exc) threads = [threading.Thread(target=writer)] + [ threading.Thread(target=reader) for _ in range(3) ] for t in threads: t.start() for t in threads: t.join() assert errors == [], f"Concurrent read/write raised: {errors}" log.close() # --------------------------------------------------------------------------- # Edge cases # --------------------------------------------------------------------------- class TestMessageLogEdgeCases: """Edge-case tests for MessageLog.""" def test_empty_content_stored_and_retrieved(self, tmp_path): """Empty string content can be stored and retrieved.""" log = MessageLog(tmp_path / "chat.db") log.append("user", "", "2024-01-01T00:00:00") assert log.all()[0].content == "" log.close() def test_unicode_content_stored_and_retrieved(self, tmp_path): """Unicode characters in content are stored and retrieved correctly.""" log = MessageLog(tmp_path / "chat.db") log.append("user", "こんにちは 🌍", "2024-01-01T00:00:00") assert log.all()[0].content == "こんにちは 🌍" log.close() def test_newline_in_content(self, tmp_path): """Newlines in content are preserved.""" log = MessageLog(tmp_path / "chat.db") multiline = "line1\nline2\nline3" log.append("agent", multiline, "2024-01-01T00:00:00") assert log.all()[0].content == multiline log.close() def test_default_db_path_attribute(self): """MessageLog without explicit path uses the module-level DB_PATH.""" from infrastructure.chat_store import DB_PATH log = MessageLog() assert log._db_path == DB_PATH # Do NOT call close() here — this is the global singleton's path def test_custom_db_path_used(self, tmp_path): """MessageLog uses the provided db_path.""" db = tmp_path / "custom.db" log = MessageLog(db) log.append("user", "test", "t") assert db.exists() log.close() def test_recent_limit_zero_returns_empty(self, tmp_path): """recent(limit=0) returns an empty list.""" log = MessageLog(tmp_path / "chat.db") log.append("user", "msg", "t") assert log.recent(limit=0) == [] log.close() def test_all_roles_stored_correctly(self, tmp_path): """Different role values are stored and retrieved correctly.""" log = MessageLog(tmp_path / "chat.db") for role in ("user", "agent", "error", "system"): log.append(role, f"{role} message", "t") messages = log.all() assert [m.role for m in messages] == ["user", "agent", "error", "system"] log.close()