510 lines
18 KiB
Python
510 lines
18 KiB
Python
"""Unit tests for infrastructure.chat_store module."""
|
||
|
||
import threading
|
||
|
||
from infrastructure.chat_store import Message, MessageLog, _get_conn
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Message dataclass
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestMessageDataclass:
|
||
"""Tests for the Message dataclass."""
|
||
|
||
def test_message_required_fields(self):
|
||
"""Message can be created with required fields only."""
|
||
msg = Message(role="user", content="hello", timestamp="2024-01-01T00:00:00")
|
||
assert msg.role == "user"
|
||
assert msg.content == "hello"
|
||
assert msg.timestamp == "2024-01-01T00:00:00"
|
||
|
||
def test_message_default_source(self):
|
||
"""Message source defaults to 'browser'."""
|
||
msg = Message(role="user", content="hi", timestamp="2024-01-01T00:00:00")
|
||
assert msg.source == "browser"
|
||
|
||
def test_message_custom_source(self):
|
||
"""Message source can be overridden."""
|
||
msg = Message(role="agent", content="reply", timestamp="2024-01-01T00:00:00", source="api")
|
||
assert msg.source == "api"
|
||
|
||
def test_message_equality(self):
|
||
"""Two Messages with the same fields are equal (dataclass default)."""
|
||
m1 = Message(role="user", content="x", timestamp="t")
|
||
m2 = Message(role="user", content="x", timestamp="t")
|
||
assert m1 == m2
|
||
|
||
def test_message_inequality(self):
|
||
"""Messages with different content are not equal."""
|
||
m1 = Message(role="user", content="x", timestamp="t")
|
||
m2 = Message(role="user", content="y", timestamp="t")
|
||
assert m1 != m2
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# _get_conn context manager
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestGetConnContextManager:
|
||
"""Tests for the _get_conn context manager."""
|
||
|
||
def test_creates_db_file(self, tmp_path):
|
||
"""_get_conn creates the database file on first use."""
|
||
db = tmp_path / "chat.db"
|
||
assert not db.exists()
|
||
with _get_conn(db) as conn:
|
||
assert conn is not None
|
||
assert db.exists()
|
||
|
||
def test_creates_parent_directories(self, tmp_path):
|
||
"""_get_conn creates any missing parent directories."""
|
||
db = tmp_path / "nested" / "deep" / "chat.db"
|
||
with _get_conn(db):
|
||
pass
|
||
assert db.exists()
|
||
|
||
def test_creates_schema(self, tmp_path):
|
||
"""_get_conn creates the chat_messages table."""
|
||
db = tmp_path / "chat.db"
|
||
with _get_conn(db) as conn:
|
||
tables = conn.execute(
|
||
"SELECT name FROM sqlite_master WHERE type='table' AND name='chat_messages'"
|
||
).fetchall()
|
||
assert len(tables) == 1
|
||
|
||
def test_schema_has_expected_columns(self, tmp_path):
|
||
"""chat_messages table has the expected columns."""
|
||
db = tmp_path / "chat.db"
|
||
with _get_conn(db) as conn:
|
||
info = conn.execute("PRAGMA table_info(chat_messages)").fetchall()
|
||
col_names = [row["name"] for row in info]
|
||
assert set(col_names) == {"id", "role", "content", "timestamp", "source"}
|
||
|
||
def test_idempotent_schema_creation(self, tmp_path):
|
||
"""Calling _get_conn twice does not fail (CREATE TABLE IF NOT EXISTS)."""
|
||
db = tmp_path / "chat.db"
|
||
with _get_conn(db):
|
||
pass
|
||
with _get_conn(db) as conn:
|
||
# Table still exists and is usable
|
||
conn.execute("SELECT COUNT(*) FROM chat_messages")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# MessageLog — basic operations
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestMessageLogAppend:
|
||
"""Tests for MessageLog.append()."""
|
||
|
||
def test_append_single_message(self, tmp_path):
|
||
"""append() stores a message that can be retrieved."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
log.append("user", "hello", "2024-01-01T00:00:00")
|
||
messages = log.all()
|
||
assert len(messages) == 1
|
||
assert messages[0].role == "user"
|
||
assert messages[0].content == "hello"
|
||
assert messages[0].timestamp == "2024-01-01T00:00:00"
|
||
assert messages[0].source == "browser"
|
||
log.close()
|
||
|
||
def test_append_custom_source(self, tmp_path):
|
||
"""append() stores the source field correctly."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
log.append("agent", "reply", "2024-01-01T00:00:01", source="api")
|
||
msg = log.all()[0]
|
||
assert msg.source == "api"
|
||
log.close()
|
||
|
||
def test_append_multiple_messages_preserves_order(self, tmp_path):
|
||
"""append() preserves insertion order."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
log.append("user", "first", "2024-01-01T00:00:00")
|
||
log.append("agent", "second", "2024-01-01T00:00:01")
|
||
log.append("user", "third", "2024-01-01T00:00:02")
|
||
messages = log.all()
|
||
assert [m.content for m in messages] == ["first", "second", "third"]
|
||
log.close()
|
||
|
||
def test_append_persists_across_instances(self, tmp_path):
|
||
"""Messages appended by one instance are readable by another."""
|
||
db = tmp_path / "chat.db"
|
||
log1 = MessageLog(db)
|
||
log1.append("user", "persisted", "2024-01-01T00:00:00")
|
||
log1.close()
|
||
|
||
log2 = MessageLog(db)
|
||
messages = log2.all()
|
||
assert len(messages) == 1
|
||
assert messages[0].content == "persisted"
|
||
log2.close()
|
||
|
||
|
||
class TestMessageLogAll:
|
||
"""Tests for MessageLog.all()."""
|
||
|
||
def test_all_on_empty_store_returns_empty_list(self, tmp_path):
|
||
"""all() returns [] when there are no messages."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
assert log.all() == []
|
||
log.close()
|
||
|
||
def test_all_returns_message_objects(self, tmp_path):
|
||
"""all() returns a list of Message dataclass instances."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
log.append("user", "hi", "2024-01-01T00:00:00")
|
||
messages = log.all()
|
||
assert all(isinstance(m, Message) for m in messages)
|
||
log.close()
|
||
|
||
def test_all_returns_all_messages(self, tmp_path):
|
||
"""all() returns every stored message."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
for i in range(5):
|
||
log.append("user", f"msg{i}", f"2024-01-01T00:00:0{i}")
|
||
assert len(log.all()) == 5
|
||
log.close()
|
||
|
||
|
||
class TestMessageLogRecent:
|
||
"""Tests for MessageLog.recent()."""
|
||
|
||
def test_recent_on_empty_store_returns_empty_list(self, tmp_path):
|
||
"""recent() returns [] when there are no messages."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
assert log.recent() == []
|
||
log.close()
|
||
|
||
def test_recent_default_limit(self, tmp_path):
|
||
"""recent() with default limit returns up to 50 messages."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
for i in range(60):
|
||
log.append("user", f"msg{i}", f"2024-01-01T00:00:{i:02d}")
|
||
msgs = log.recent()
|
||
assert len(msgs) == 50
|
||
log.close()
|
||
|
||
def test_recent_custom_limit(self, tmp_path):
|
||
"""recent() respects a custom limit."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
for i in range(10):
|
||
log.append("user", f"msg{i}", f"2024-01-01T00:00:0{i}")
|
||
msgs = log.recent(limit=3)
|
||
assert len(msgs) == 3
|
||
log.close()
|
||
|
||
def test_recent_returns_newest_messages(self, tmp_path):
|
||
"""recent() returns the most-recently-inserted messages."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
for i in range(10):
|
||
log.append("user", f"msg{i}", f"2024-01-01T00:00:0{i}")
|
||
msgs = log.recent(limit=3)
|
||
# Should be the last 3 inserted, in oldest-first order
|
||
assert [m.content for m in msgs] == ["msg7", "msg8", "msg9"]
|
||
log.close()
|
||
|
||
def test_recent_fewer_than_limit_returns_all(self, tmp_path):
|
||
"""recent() returns all messages when count < limit."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
log.append("user", "only", "2024-01-01T00:00:00")
|
||
msgs = log.recent(limit=10)
|
||
assert len(msgs) == 1
|
||
log.close()
|
||
|
||
def test_recent_returns_oldest_first(self, tmp_path):
|
||
"""recent() returns messages in oldest-first order."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
log.append("user", "a", "2024-01-01T00:00:00")
|
||
log.append("user", "b", "2024-01-01T00:00:01")
|
||
log.append("user", "c", "2024-01-01T00:00:02")
|
||
msgs = log.recent(limit=2)
|
||
assert [m.content for m in msgs] == ["b", "c"]
|
||
log.close()
|
||
|
||
|
||
class TestMessageLogClear:
|
||
"""Tests for MessageLog.clear()."""
|
||
|
||
def test_clear_empties_the_store(self, tmp_path):
|
||
"""clear() removes all messages."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
log.append("user", "hello", "2024-01-01T00:00:00")
|
||
log.clear()
|
||
assert log.all() == []
|
||
log.close()
|
||
|
||
def test_clear_on_empty_store_is_safe(self, tmp_path):
|
||
"""clear() on an empty store does not raise."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
log.clear() # should not raise
|
||
assert log.all() == []
|
||
log.close()
|
||
|
||
def test_clear_allows_new_appends(self, tmp_path):
|
||
"""After clear(), new messages can be appended."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
log.append("user", "old", "2024-01-01T00:00:00")
|
||
log.clear()
|
||
log.append("user", "new", "2024-01-01T00:00:01")
|
||
messages = log.all()
|
||
assert len(messages) == 1
|
||
assert messages[0].content == "new"
|
||
log.close()
|
||
|
||
def test_clear_resets_len_to_zero(self, tmp_path):
|
||
"""After clear(), __len__ returns 0."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
log.append("user", "a", "t")
|
||
log.append("user", "b", "t")
|
||
log.clear()
|
||
assert len(log) == 0
|
||
log.close()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# MessageLog — __len__
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestMessageLogLen:
|
||
"""Tests for MessageLog.__len__()."""
|
||
|
||
def test_len_empty_store(self, tmp_path):
|
||
"""__len__ returns 0 for an empty store."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
assert len(log) == 0
|
||
log.close()
|
||
|
||
def test_len_after_appends(self, tmp_path):
|
||
"""__len__ reflects the number of stored messages."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
for i in range(7):
|
||
log.append("user", f"msg{i}", "t")
|
||
assert len(log) == 7
|
||
log.close()
|
||
|
||
def test_len_after_clear(self, tmp_path):
|
||
"""__len__ is 0 after clear()."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
log.append("user", "x", "t")
|
||
log.clear()
|
||
assert len(log) == 0
|
||
log.close()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# MessageLog — pruning
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestMessageLogPrune:
|
||
"""Tests for automatic pruning via _prune()."""
|
||
|
||
def test_prune_keeps_at_most_max_messages(self, tmp_path):
|
||
"""After exceeding MAX_MESSAGES, oldest messages are pruned."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
# Temporarily lower the limit via monkeypatching is not straightforward
|
||
# because _prune reads the module-level MAX_MESSAGES constant.
|
||
# We therefore patch it directly.
|
||
import infrastructure.chat_store as cs
|
||
|
||
original = cs.MAX_MESSAGES
|
||
cs.MAX_MESSAGES = 5
|
||
try:
|
||
for i in range(8):
|
||
log.append("user", f"msg{i}", f"t{i}")
|
||
assert len(log) == 5
|
||
finally:
|
||
cs.MAX_MESSAGES = original
|
||
log.close()
|
||
|
||
def test_prune_keeps_newest_messages(self, tmp_path):
|
||
"""Pruning removes oldest messages and keeps the newest ones."""
|
||
import infrastructure.chat_store as cs
|
||
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
original = cs.MAX_MESSAGES
|
||
cs.MAX_MESSAGES = 3
|
||
try:
|
||
for i in range(5):
|
||
log.append("user", f"msg{i}", f"t{i}")
|
||
messages = log.all()
|
||
contents = [m.content for m in messages]
|
||
assert contents == ["msg2", "msg3", "msg4"]
|
||
finally:
|
||
cs.MAX_MESSAGES = original
|
||
log.close()
|
||
|
||
def test_no_prune_when_below_limit(self, tmp_path):
|
||
"""No messages are pruned while count is at or below MAX_MESSAGES."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
import infrastructure.chat_store as cs
|
||
|
||
original = cs.MAX_MESSAGES
|
||
cs.MAX_MESSAGES = 10
|
||
try:
|
||
for i in range(10):
|
||
log.append("user", f"msg{i}", f"t{i}")
|
||
assert len(log) == 10
|
||
finally:
|
||
cs.MAX_MESSAGES = original
|
||
log.close()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# MessageLog — close / lifecycle
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestMessageLogClose:
|
||
"""Tests for MessageLog.close()."""
|
||
|
||
def test_close_is_safe_before_first_use(self, tmp_path):
|
||
"""close() on a fresh (never-used) instance does not raise."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
log.close() # should not raise
|
||
|
||
def test_close_multiple_times_is_safe(self, tmp_path):
|
||
"""close() can be called multiple times without error."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
log.append("user", "hi", "t")
|
||
log.close()
|
||
log.close() # second close should not raise
|
||
|
||
def test_close_sets_conn_to_none(self, tmp_path):
|
||
"""close() sets the internal _conn attribute to None."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
log.append("user", "hi", "t")
|
||
assert log._conn is not None
|
||
log.close()
|
||
assert log._conn is None
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Thread safety
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestMessageLogThreadSafety:
|
||
"""Thread-safety tests for MessageLog."""
|
||
|
||
def test_concurrent_appends(self, tmp_path):
|
||
"""Multiple threads can append messages without data loss or errors."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
errors: list[Exception] = []
|
||
|
||
def worker(n: int) -> None:
|
||
try:
|
||
for i in range(5):
|
||
log.append("user", f"t{n}-{i}", f"ts-{n}-{i}")
|
||
except Exception as exc: # noqa: BLE001
|
||
errors.append(exc)
|
||
|
||
threads = [threading.Thread(target=worker, args=(n,)) for n in range(4)]
|
||
for t in threads:
|
||
t.start()
|
||
for t in threads:
|
||
t.join()
|
||
|
||
assert errors == [], f"Concurrent append raised: {errors}"
|
||
# All 20 messages should be present (4 threads × 5 messages)
|
||
assert len(log) == 20
|
||
log.close()
|
||
|
||
def test_concurrent_reads_and_writes(self, tmp_path):
|
||
"""Concurrent reads and writes do not corrupt state."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
errors: list[Exception] = []
|
||
|
||
def writer() -> None:
|
||
try:
|
||
for i in range(10):
|
||
log.append("user", f"msg{i}", f"t{i}")
|
||
except Exception as exc: # noqa: BLE001
|
||
errors.append(exc)
|
||
|
||
def reader() -> None:
|
||
try:
|
||
for _ in range(10):
|
||
log.all()
|
||
except Exception as exc: # noqa: BLE001
|
||
errors.append(exc)
|
||
|
||
threads = [threading.Thread(target=writer)] + [
|
||
threading.Thread(target=reader) for _ in range(3)
|
||
]
|
||
for t in threads:
|
||
t.start()
|
||
for t in threads:
|
||
t.join()
|
||
|
||
assert errors == [], f"Concurrent read/write raised: {errors}"
|
||
log.close()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Edge cases
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestMessageLogEdgeCases:
|
||
"""Edge-case tests for MessageLog."""
|
||
|
||
def test_empty_content_stored_and_retrieved(self, tmp_path):
|
||
"""Empty string content can be stored and retrieved."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
log.append("user", "", "2024-01-01T00:00:00")
|
||
assert log.all()[0].content == ""
|
||
log.close()
|
||
|
||
def test_unicode_content_stored_and_retrieved(self, tmp_path):
|
||
"""Unicode characters in content are stored and retrieved correctly."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
log.append("user", "こんにちは 🌍", "2024-01-01T00:00:00")
|
||
assert log.all()[0].content == "こんにちは 🌍"
|
||
log.close()
|
||
|
||
def test_newline_in_content(self, tmp_path):
|
||
"""Newlines in content are preserved."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
multiline = "line1\nline2\nline3"
|
||
log.append("agent", multiline, "2024-01-01T00:00:00")
|
||
assert log.all()[0].content == multiline
|
||
log.close()
|
||
|
||
def test_default_db_path_attribute(self):
|
||
"""MessageLog without explicit path uses the module-level DB_PATH."""
|
||
from infrastructure.chat_store import DB_PATH
|
||
|
||
log = MessageLog()
|
||
assert log._db_path == DB_PATH
|
||
# Do NOT call close() here — this is the global singleton's path
|
||
|
||
def test_custom_db_path_used(self, tmp_path):
|
||
"""MessageLog uses the provided db_path."""
|
||
db = tmp_path / "custom.db"
|
||
log = MessageLog(db)
|
||
log.append("user", "test", "t")
|
||
assert db.exists()
|
||
log.close()
|
||
|
||
def test_recent_limit_zero_returns_empty(self, tmp_path):
|
||
"""recent(limit=0) returns an empty list."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
log.append("user", "msg", "t")
|
||
assert log.recent(limit=0) == []
|
||
log.close()
|
||
|
||
def test_all_roles_stored_correctly(self, tmp_path):
|
||
"""Different role values are stored and retrieved correctly."""
|
||
log = MessageLog(tmp_path / "chat.db")
|
||
for role in ("user", "agent", "error", "system"):
|
||
log.append(role, f"{role} message", "t")
|
||
messages = log.all()
|
||
assert [m.role for m in messages] == ["user", "agent", "error", "system"]
|
||
log.close()
|