This commit was merged in pull request #1198.
This commit is contained in:
513
tests/infrastructure/test_chat_store.py
Normal file
513
tests/infrastructure/test_chat_store.py
Normal file
@@ -0,0 +1,513 @@
|
||||
"""Unit tests for infrastructure.chat_store module."""
|
||||
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.chat_store import MAX_MESSAGES, Message, MessageLog, _get_conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageDataclass:
|
||||
"""Tests for the Message dataclass."""
|
||||
|
||||
def test_message_required_fields(self):
|
||||
"""Message can be created with required fields only."""
|
||||
msg = Message(role="user", content="hello", timestamp="2024-01-01T00:00:00")
|
||||
assert msg.role == "user"
|
||||
assert msg.content == "hello"
|
||||
assert msg.timestamp == "2024-01-01T00:00:00"
|
||||
|
||||
def test_message_default_source(self):
|
||||
"""Message source defaults to 'browser'."""
|
||||
msg = Message(role="user", content="hi", timestamp="2024-01-01T00:00:00")
|
||||
assert msg.source == "browser"
|
||||
|
||||
def test_message_custom_source(self):
|
||||
"""Message source can be overridden."""
|
||||
msg = Message(role="agent", content="reply", timestamp="2024-01-01T00:00:00", source="api")
|
||||
assert msg.source == "api"
|
||||
|
||||
def test_message_equality(self):
|
||||
"""Two Messages with the same fields are equal (dataclass default)."""
|
||||
m1 = Message(role="user", content="x", timestamp="t")
|
||||
m2 = Message(role="user", content="x", timestamp="t")
|
||||
assert m1 == m2
|
||||
|
||||
def test_message_inequality(self):
|
||||
"""Messages with different content are not equal."""
|
||||
m1 = Message(role="user", content="x", timestamp="t")
|
||||
m2 = Message(role="user", content="y", timestamp="t")
|
||||
assert m1 != m2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_conn context manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetConnContextManager:
|
||||
"""Tests for the _get_conn context manager."""
|
||||
|
||||
def test_creates_db_file(self, tmp_path):
|
||||
"""_get_conn creates the database file on first use."""
|
||||
db = tmp_path / "chat.db"
|
||||
assert not db.exists()
|
||||
with _get_conn(db) as conn:
|
||||
assert conn is not None
|
||||
assert db.exists()
|
||||
|
||||
def test_creates_parent_directories(self, tmp_path):
|
||||
"""_get_conn creates any missing parent directories."""
|
||||
db = tmp_path / "nested" / "deep" / "chat.db"
|
||||
with _get_conn(db):
|
||||
pass
|
||||
assert db.exists()
|
||||
|
||||
def test_creates_schema(self, tmp_path):
|
||||
"""_get_conn creates the chat_messages table."""
|
||||
db = tmp_path / "chat.db"
|
||||
with _get_conn(db) as conn:
|
||||
tables = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='chat_messages'"
|
||||
).fetchall()
|
||||
assert len(tables) == 1
|
||||
|
||||
def test_schema_has_expected_columns(self, tmp_path):
|
||||
"""chat_messages table has the expected columns."""
|
||||
db = tmp_path / "chat.db"
|
||||
with _get_conn(db) as conn:
|
||||
info = conn.execute("PRAGMA table_info(chat_messages)").fetchall()
|
||||
col_names = [row["name"] for row in info]
|
||||
assert set(col_names) == {"id", "role", "content", "timestamp", "source"}
|
||||
|
||||
def test_idempotent_schema_creation(self, tmp_path):
|
||||
"""Calling _get_conn twice does not fail (CREATE TABLE IF NOT EXISTS)."""
|
||||
db = tmp_path / "chat.db"
|
||||
with _get_conn(db):
|
||||
pass
|
||||
with _get_conn(db) as conn:
|
||||
# Table still exists and is usable
|
||||
conn.execute("SELECT COUNT(*) FROM chat_messages")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MessageLog — basic operations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageLogAppend:
|
||||
"""Tests for MessageLog.append()."""
|
||||
|
||||
def test_append_single_message(self, tmp_path):
|
||||
"""append() stores a message that can be retrieved."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
log.append("user", "hello", "2024-01-01T00:00:00")
|
||||
messages = log.all()
|
||||
assert len(messages) == 1
|
||||
assert messages[0].role == "user"
|
||||
assert messages[0].content == "hello"
|
||||
assert messages[0].timestamp == "2024-01-01T00:00:00"
|
||||
assert messages[0].source == "browser"
|
||||
log.close()
|
||||
|
||||
def test_append_custom_source(self, tmp_path):
|
||||
"""append() stores the source field correctly."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
log.append("agent", "reply", "2024-01-01T00:00:01", source="api")
|
||||
msg = log.all()[0]
|
||||
assert msg.source == "api"
|
||||
log.close()
|
||||
|
||||
def test_append_multiple_messages_preserves_order(self, tmp_path):
|
||||
"""append() preserves insertion order."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
log.append("user", "first", "2024-01-01T00:00:00")
|
||||
log.append("agent", "second", "2024-01-01T00:00:01")
|
||||
log.append("user", "third", "2024-01-01T00:00:02")
|
||||
messages = log.all()
|
||||
assert [m.content for m in messages] == ["first", "second", "third"]
|
||||
log.close()
|
||||
|
||||
def test_append_persists_across_instances(self, tmp_path):
|
||||
"""Messages appended by one instance are readable by another."""
|
||||
db = tmp_path / "chat.db"
|
||||
log1 = MessageLog(db)
|
||||
log1.append("user", "persisted", "2024-01-01T00:00:00")
|
||||
log1.close()
|
||||
|
||||
log2 = MessageLog(db)
|
||||
messages = log2.all()
|
||||
assert len(messages) == 1
|
||||
assert messages[0].content == "persisted"
|
||||
log2.close()
|
||||
|
||||
|
||||
class TestMessageLogAll:
|
||||
"""Tests for MessageLog.all()."""
|
||||
|
||||
def test_all_on_empty_store_returns_empty_list(self, tmp_path):
|
||||
"""all() returns [] when there are no messages."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
assert log.all() == []
|
||||
log.close()
|
||||
|
||||
def test_all_returns_message_objects(self, tmp_path):
|
||||
"""all() returns a list of Message dataclass instances."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
log.append("user", "hi", "2024-01-01T00:00:00")
|
||||
messages = log.all()
|
||||
assert all(isinstance(m, Message) for m in messages)
|
||||
log.close()
|
||||
|
||||
def test_all_returns_all_messages(self, tmp_path):
|
||||
"""all() returns every stored message."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
for i in range(5):
|
||||
log.append("user", f"msg{i}", f"2024-01-01T00:00:0{i}")
|
||||
assert len(log.all()) == 5
|
||||
log.close()
|
||||
|
||||
|
||||
class TestMessageLogRecent:
|
||||
"""Tests for MessageLog.recent()."""
|
||||
|
||||
def test_recent_on_empty_store_returns_empty_list(self, tmp_path):
|
||||
"""recent() returns [] when there are no messages."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
assert log.recent() == []
|
||||
log.close()
|
||||
|
||||
def test_recent_default_limit(self, tmp_path):
|
||||
"""recent() with default limit returns up to 50 messages."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
for i in range(60):
|
||||
log.append("user", f"msg{i}", f"2024-01-01T00:00:{i:02d}")
|
||||
msgs = log.recent()
|
||||
assert len(msgs) == 50
|
||||
log.close()
|
||||
|
||||
def test_recent_custom_limit(self, tmp_path):
|
||||
"""recent() respects a custom limit."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
for i in range(10):
|
||||
log.append("user", f"msg{i}", f"2024-01-01T00:00:0{i}")
|
||||
msgs = log.recent(limit=3)
|
||||
assert len(msgs) == 3
|
||||
log.close()
|
||||
|
||||
def test_recent_returns_newest_messages(self, tmp_path):
|
||||
"""recent() returns the most-recently-inserted messages."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
for i in range(10):
|
||||
log.append("user", f"msg{i}", f"2024-01-01T00:00:0{i}")
|
||||
msgs = log.recent(limit=3)
|
||||
# Should be the last 3 inserted, in oldest-first order
|
||||
assert [m.content for m in msgs] == ["msg7", "msg8", "msg9"]
|
||||
log.close()
|
||||
|
||||
def test_recent_fewer_than_limit_returns_all(self, tmp_path):
|
||||
"""recent() returns all messages when count < limit."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
log.append("user", "only", "2024-01-01T00:00:00")
|
||||
msgs = log.recent(limit=10)
|
||||
assert len(msgs) == 1
|
||||
log.close()
|
||||
|
||||
def test_recent_returns_oldest_first(self, tmp_path):
|
||||
"""recent() returns messages in oldest-first order."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
log.append("user", "a", "2024-01-01T00:00:00")
|
||||
log.append("user", "b", "2024-01-01T00:00:01")
|
||||
log.append("user", "c", "2024-01-01T00:00:02")
|
||||
msgs = log.recent(limit=2)
|
||||
assert [m.content for m in msgs] == ["b", "c"]
|
||||
log.close()
|
||||
|
||||
|
||||
class TestMessageLogClear:
|
||||
"""Tests for MessageLog.clear()."""
|
||||
|
||||
def test_clear_empties_the_store(self, tmp_path):
|
||||
"""clear() removes all messages."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
log.append("user", "hello", "2024-01-01T00:00:00")
|
||||
log.clear()
|
||||
assert log.all() == []
|
||||
log.close()
|
||||
|
||||
def test_clear_on_empty_store_is_safe(self, tmp_path):
|
||||
"""clear() on an empty store does not raise."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
log.clear() # should not raise
|
||||
assert log.all() == []
|
||||
log.close()
|
||||
|
||||
def test_clear_allows_new_appends(self, tmp_path):
|
||||
"""After clear(), new messages can be appended."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
log.append("user", "old", "2024-01-01T00:00:00")
|
||||
log.clear()
|
||||
log.append("user", "new", "2024-01-01T00:00:01")
|
||||
messages = log.all()
|
||||
assert len(messages) == 1
|
||||
assert messages[0].content == "new"
|
||||
log.close()
|
||||
|
||||
def test_clear_resets_len_to_zero(self, tmp_path):
|
||||
"""After clear(), __len__ returns 0."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
log.append("user", "a", "t")
|
||||
log.append("user", "b", "t")
|
||||
log.clear()
|
||||
assert len(log) == 0
|
||||
log.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MessageLog — __len__
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageLogLen:
|
||||
"""Tests for MessageLog.__len__()."""
|
||||
|
||||
def test_len_empty_store(self, tmp_path):
|
||||
"""__len__ returns 0 for an empty store."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
assert len(log) == 0
|
||||
log.close()
|
||||
|
||||
def test_len_after_appends(self, tmp_path):
|
||||
"""__len__ reflects the number of stored messages."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
for i in range(7):
|
||||
log.append("user", f"msg{i}", "t")
|
||||
assert len(log) == 7
|
||||
log.close()
|
||||
|
||||
def test_len_after_clear(self, tmp_path):
|
||||
"""__len__ is 0 after clear()."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
log.append("user", "x", "t")
|
||||
log.clear()
|
||||
assert len(log) == 0
|
||||
log.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MessageLog — pruning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageLogPrune:
|
||||
"""Tests for automatic pruning via _prune()."""
|
||||
|
||||
def test_prune_keeps_at_most_max_messages(self, tmp_path):
|
||||
"""After exceeding MAX_MESSAGES, oldest messages are pruned."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
# Temporarily lower the limit via monkeypatching is not straightforward
|
||||
# because _prune reads the module-level MAX_MESSAGES constant.
|
||||
# We therefore patch it directly.
|
||||
import infrastructure.chat_store as cs
|
||||
|
||||
original = cs.MAX_MESSAGES
|
||||
cs.MAX_MESSAGES = 5
|
||||
try:
|
||||
for i in range(8):
|
||||
log.append("user", f"msg{i}", f"t{i}")
|
||||
assert len(log) == 5
|
||||
finally:
|
||||
cs.MAX_MESSAGES = original
|
||||
log.close()
|
||||
|
||||
def test_prune_keeps_newest_messages(self, tmp_path):
|
||||
"""Pruning removes oldest messages and keeps the newest ones."""
|
||||
import infrastructure.chat_store as cs
|
||||
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
original = cs.MAX_MESSAGES
|
||||
cs.MAX_MESSAGES = 3
|
||||
try:
|
||||
for i in range(5):
|
||||
log.append("user", f"msg{i}", f"t{i}")
|
||||
messages = log.all()
|
||||
contents = [m.content for m in messages]
|
||||
assert contents == ["msg2", "msg3", "msg4"]
|
||||
finally:
|
||||
cs.MAX_MESSAGES = original
|
||||
log.close()
|
||||
|
||||
def test_no_prune_when_below_limit(self, tmp_path):
|
||||
"""No messages are pruned while count is at or below MAX_MESSAGES."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
import infrastructure.chat_store as cs
|
||||
|
||||
original = cs.MAX_MESSAGES
|
||||
cs.MAX_MESSAGES = 10
|
||||
try:
|
||||
for i in range(10):
|
||||
log.append("user", f"msg{i}", f"t{i}")
|
||||
assert len(log) == 10
|
||||
finally:
|
||||
cs.MAX_MESSAGES = original
|
||||
log.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MessageLog — close / lifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageLogClose:
|
||||
"""Tests for MessageLog.close()."""
|
||||
|
||||
def test_close_is_safe_before_first_use(self, tmp_path):
|
||||
"""close() on a fresh (never-used) instance does not raise."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
log.close() # should not raise
|
||||
|
||||
def test_close_multiple_times_is_safe(self, tmp_path):
|
||||
"""close() can be called multiple times without error."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
log.append("user", "hi", "t")
|
||||
log.close()
|
||||
log.close() # second close should not raise
|
||||
|
||||
def test_close_sets_conn_to_none(self, tmp_path):
|
||||
"""close() sets the internal _conn attribute to None."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
log.append("user", "hi", "t")
|
||||
assert log._conn is not None
|
||||
log.close()
|
||||
assert log._conn is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread safety
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageLogThreadSafety:
|
||||
"""Thread-safety tests for MessageLog."""
|
||||
|
||||
def test_concurrent_appends(self, tmp_path):
|
||||
"""Multiple threads can append messages without data loss or errors."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
errors: list[Exception] = []
|
||||
|
||||
def worker(n: int) -> None:
|
||||
try:
|
||||
for i in range(5):
|
||||
log.append("user", f"t{n}-{i}", f"ts-{n}-{i}")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
errors.append(exc)
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(n,)) for n in range(4)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert errors == [], f"Concurrent append raised: {errors}"
|
||||
# All 20 messages should be present (4 threads × 5 messages)
|
||||
assert len(log) == 20
|
||||
log.close()
|
||||
|
||||
def test_concurrent_reads_and_writes(self, tmp_path):
|
||||
"""Concurrent reads and writes do not corrupt state."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
errors: list[Exception] = []
|
||||
|
||||
def writer() -> None:
|
||||
try:
|
||||
for i in range(10):
|
||||
log.append("user", f"msg{i}", f"t{i}")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
errors.append(exc)
|
||||
|
||||
def reader() -> None:
|
||||
try:
|
||||
for _ in range(10):
|
||||
log.all()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
errors.append(exc)
|
||||
|
||||
threads = [threading.Thread(target=writer)] + [
|
||||
threading.Thread(target=reader) for _ in range(3)
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert errors == [], f"Concurrent read/write raised: {errors}"
|
||||
log.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageLogEdgeCases:
|
||||
"""Edge-case tests for MessageLog."""
|
||||
|
||||
def test_empty_content_stored_and_retrieved(self, tmp_path):
|
||||
"""Empty string content can be stored and retrieved."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
log.append("user", "", "2024-01-01T00:00:00")
|
||||
assert log.all()[0].content == ""
|
||||
log.close()
|
||||
|
||||
def test_unicode_content_stored_and_retrieved(self, tmp_path):
|
||||
"""Unicode characters in content are stored and retrieved correctly."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
log.append("user", "こんにちは 🌍", "2024-01-01T00:00:00")
|
||||
assert log.all()[0].content == "こんにちは 🌍"
|
||||
log.close()
|
||||
|
||||
def test_newline_in_content(self, tmp_path):
|
||||
"""Newlines in content are preserved."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
multiline = "line1\nline2\nline3"
|
||||
log.append("agent", multiline, "2024-01-01T00:00:00")
|
||||
assert log.all()[0].content == multiline
|
||||
log.close()
|
||||
|
||||
def test_default_db_path_attribute(self):
|
||||
"""MessageLog without explicit path uses the module-level DB_PATH."""
|
||||
from infrastructure.chat_store import DB_PATH
|
||||
|
||||
log = MessageLog()
|
||||
assert log._db_path == DB_PATH
|
||||
# Do NOT call close() here — this is the global singleton's path
|
||||
|
||||
def test_custom_db_path_used(self, tmp_path):
|
||||
"""MessageLog uses the provided db_path."""
|
||||
db = tmp_path / "custom.db"
|
||||
log = MessageLog(db)
|
||||
log.append("user", "test", "t")
|
||||
assert db.exists()
|
||||
log.close()
|
||||
|
||||
def test_recent_limit_zero_returns_empty(self, tmp_path):
|
||||
"""recent(limit=0) returns an empty list."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
log.append("user", "msg", "t")
|
||||
assert log.recent(limit=0) == []
|
||||
log.close()
|
||||
|
||||
def test_all_roles_stored_correctly(self, tmp_path):
|
||||
"""Different role values are stored and retrieved correctly."""
|
||||
log = MessageLog(tmp_path / "chat.db")
|
||||
for role in ("user", "agent", "error", "system"):
|
||||
log.append(role, f"{role} message", "t")
|
||||
messages = log.all()
|
||||
assert [m.role for m in messages] == ["user", "agent", "error", "system"]
|
||||
log.close()
|
||||
Reference in New Issue
Block a user