forked from Rockachopa/Timmy-time-dashboard
Compare commits
1 Commits
fix/test-c
...
test/chat-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
845070f70b |
247
tests/infrastructure/test_chat_store.py
Normal file
247
tests/infrastructure/test_chat_store.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
"""Unit tests for src/infrastructure/chat_store.py."""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.infrastructure.chat_store import MAX_MESSAGES, Message, MessageLog, _get_conn
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def tmp_db(tmp_path: Path) -> Path:
|
||||||
|
"""Return a temporary database path."""
|
||||||
|
return tmp_path / "test_chat.db"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def log(tmp_db: Path) -> MessageLog:
|
||||||
|
"""Return a MessageLog backed by a temp database."""
|
||||||
|
ml = MessageLog(db_path=tmp_db)
|
||||||
|
yield ml
|
||||||
|
ml.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Message dataclass ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessage:
|
||||||
|
def test_default_source(self):
|
||||||
|
m = Message(role="user", content="hi", timestamp="2026-01-01T00:00:00")
|
||||||
|
assert m.source == "browser"
|
||||||
|
|
||||||
|
def test_custom_source(self):
|
||||||
|
m = Message(role="agent", content="ok", timestamp="t1", source="telegram")
|
||||||
|
assert m.source == "telegram"
|
||||||
|
|
||||||
|
def test_fields(self):
|
||||||
|
m = Message(role="error", content="boom", timestamp="t2", source="api")
|
||||||
|
assert m.role == "error"
|
||||||
|
assert m.content == "boom"
|
||||||
|
assert m.timestamp == "t2"
|
||||||
|
|
||||||
|
|
||||||
|
# ── _get_conn context manager ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetConn:
|
||||||
|
def test_creates_db_and_table(self, tmp_db: Path):
|
||||||
|
with _get_conn(tmp_db) as conn:
|
||||||
|
tables = conn.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||||
|
).fetchall()
|
||||||
|
names = [t["name"] for t in tables]
|
||||||
|
assert "chat_messages" in names
|
||||||
|
|
||||||
|
def test_creates_parent_dirs(self, tmp_path: Path):
|
||||||
|
deep = tmp_path / "a" / "b" / "c" / "chat.db"
|
||||||
|
with _get_conn(deep) as conn:
|
||||||
|
assert deep.parent.exists()
|
||||||
|
|
||||||
|
def test_connection_closed_after_context(self, tmp_db: Path):
|
||||||
|
with _get_conn(tmp_db) as conn:
|
||||||
|
conn.execute("SELECT 1")
|
||||||
|
# Connection should be closed — operations should fail
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
conn.execute("SELECT 1")
|
||||||
|
|
||||||
|
|
||||||
|
# ── MessageLog core operations ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageLogAppendAndAll:
|
||||||
|
def test_append_and_all(self, log: MessageLog):
|
||||||
|
log.append("user", "hello", "t1")
|
||||||
|
log.append("agent", "hi back", "t2", source="api")
|
||||||
|
msgs = log.all()
|
||||||
|
assert len(msgs) == 2
|
||||||
|
assert msgs[0].role == "user"
|
||||||
|
assert msgs[0].content == "hello"
|
||||||
|
assert msgs[0].source == "browser"
|
||||||
|
assert msgs[1].role == "agent"
|
||||||
|
assert msgs[1].source == "api"
|
||||||
|
|
||||||
|
def test_all_returns_ordered_by_id(self, log: MessageLog):
|
||||||
|
for i in range(5):
|
||||||
|
log.append("user", f"msg{i}", f"t{i}")
|
||||||
|
msgs = log.all()
|
||||||
|
assert [m.content for m in msgs] == [f"msg{i}" for i in range(5)]
|
||||||
|
|
||||||
|
def test_all_empty_store(self, log: MessageLog):
|
||||||
|
assert log.all() == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageLogRecent:
|
||||||
|
def test_recent_returns_newest(self, log: MessageLog):
|
||||||
|
for i in range(10):
|
||||||
|
log.append("user", f"msg{i}", f"t{i}")
|
||||||
|
recent = log.recent(limit=3)
|
||||||
|
assert len(recent) == 3
|
||||||
|
assert recent[0].content == "msg7"
|
||||||
|
assert recent[2].content == "msg9"
|
||||||
|
|
||||||
|
def test_recent_oldest_first(self, log: MessageLog):
|
||||||
|
for i in range(5):
|
||||||
|
log.append("user", f"msg{i}", f"t{i}")
|
||||||
|
recent = log.recent(limit=3)
|
||||||
|
# Should be oldest-first within the window
|
||||||
|
assert recent[0].content == "msg2"
|
||||||
|
assert recent[1].content == "msg3"
|
||||||
|
assert recent[2].content == "msg4"
|
||||||
|
|
||||||
|
def test_recent_more_than_exists(self, log: MessageLog):
|
||||||
|
log.append("user", "only", "t0")
|
||||||
|
recent = log.recent(limit=100)
|
||||||
|
assert len(recent) == 1
|
||||||
|
|
||||||
|
def test_recent_empty_store(self, log: MessageLog):
|
||||||
|
assert log.recent() == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageLogClear:
|
||||||
|
def test_clear_removes_all(self, log: MessageLog):
|
||||||
|
for i in range(5):
|
||||||
|
log.append("user", f"msg{i}", f"t{i}")
|
||||||
|
assert len(log) == 5
|
||||||
|
log.clear()
|
||||||
|
assert len(log) == 0
|
||||||
|
assert log.all() == []
|
||||||
|
|
||||||
|
def test_clear_empty_store(self, log: MessageLog):
|
||||||
|
log.clear() # Should not raise
|
||||||
|
assert len(log) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageLogLen:
|
||||||
|
def test_len_empty(self, log: MessageLog):
|
||||||
|
assert len(log) == 0
|
||||||
|
|
||||||
|
def test_len_after_appends(self, log: MessageLog):
|
||||||
|
for i in range(7):
|
||||||
|
log.append("user", f"msg{i}", f"t{i}")
|
||||||
|
assert len(log) == 7
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageLogClose:
|
||||||
|
def test_close_sets_conn_none(self, tmp_db: Path):
|
||||||
|
ml = MessageLog(db_path=tmp_db)
|
||||||
|
ml.append("user", "x", "t0")
|
||||||
|
ml.close()
|
||||||
|
assert ml._conn is None
|
||||||
|
|
||||||
|
def test_close_idempotent(self, tmp_db: Path):
|
||||||
|
ml = MessageLog(db_path=tmp_db)
|
||||||
|
ml.close()
|
||||||
|
ml.close() # Should not raise
|
||||||
|
|
||||||
|
def test_reopen_after_close(self, tmp_db: Path):
|
||||||
|
ml = MessageLog(db_path=tmp_db)
|
||||||
|
ml.append("user", "before", "t0")
|
||||||
|
ml.close()
|
||||||
|
# Should reconnect on next use
|
||||||
|
ml.append("user", "after", "t1")
|
||||||
|
assert len(ml) == 2
|
||||||
|
ml.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Pruning ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestPrune:
|
||||||
|
def test_prune_keeps_max_messages(self, tmp_db: Path):
|
||||||
|
with patch("src.infrastructure.chat_store.MAX_MESSAGES", 5):
|
||||||
|
ml = MessageLog(db_path=tmp_db)
|
||||||
|
for i in range(10):
|
||||||
|
ml.append("user", f"msg{i}", f"t{i}")
|
||||||
|
# Should have pruned to 5
|
||||||
|
assert len(ml) == 5
|
||||||
|
msgs = ml.all()
|
||||||
|
# Oldest should be pruned, newest kept
|
||||||
|
assert msgs[0].content == "msg5"
|
||||||
|
assert msgs[-1].content == "msg9"
|
||||||
|
ml.close()
|
||||||
|
|
||||||
|
def test_no_prune_under_limit(self, tmp_db: Path):
|
||||||
|
with patch("src.infrastructure.chat_store.MAX_MESSAGES", 100):
|
||||||
|
ml = MessageLog(db_path=tmp_db)
|
||||||
|
for i in range(10):
|
||||||
|
ml.append("user", f"msg{i}", f"t{i}")
|
||||||
|
assert len(ml) == 10
|
||||||
|
ml.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Thread safety ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestThreadSafety:
|
||||||
|
def test_concurrent_appends(self, tmp_db: Path):
|
||||||
|
ml = MessageLog(db_path=tmp_db)
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
def writer(start: int):
|
||||||
|
try:
|
||||||
|
for i in range(20):
|
||||||
|
ml.append("user", f"msg{start + i}", f"t{start + i}")
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=writer, args=(i * 20,)) for i in range(5)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert not errors, f"Thread errors: {errors}"
|
||||||
|
assert len(ml) == 100
|
||||||
|
ml.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Edge cases ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
def test_empty_content(self, log: MessageLog):
|
||||||
|
log.append("user", "", "t0")
|
||||||
|
msgs = log.all()
|
||||||
|
assert len(msgs) == 1
|
||||||
|
assert msgs[0].content == ""
|
||||||
|
|
||||||
|
def test_unicode_content(self, log: MessageLog):
|
||||||
|
log.append("user", "こんにちは 🎉 مرحبا", "t0")
|
||||||
|
msgs = log.all()
|
||||||
|
assert msgs[0].content == "こんにちは 🎉 مرحبا"
|
||||||
|
|
||||||
|
def test_multiline_content(self, log: MessageLog):
|
||||||
|
content = "line1\nline2\nline3"
|
||||||
|
log.append("user", content, "t0")
|
||||||
|
assert log.all()[0].content == content
|
||||||
|
|
||||||
|
def test_special_sql_characters(self, log: MessageLog):
|
||||||
|
log.append("user", "Robert'; DROP TABLE chat_messages;--", "t0")
|
||||||
|
msgs = log.all()
|
||||||
|
assert len(msgs) == 1
|
||||||
|
assert "DROP TABLE" in msgs[0].content
|
||||||
Reference in New Issue
Block a user