"""Tests for hermes_state.py — SessionDB SQLite CRUD, FTS5 search, export.""" import time import pytest from pathlib import Path from hermes_state import SessionDB @pytest.fixture() def db(tmp_path): """Create a SessionDB with a temp database file.""" db_path = tmp_path / "test_state.db" session_db = SessionDB(db_path=db_path) yield session_db session_db.close() # ========================================================================= # Session lifecycle # ========================================================================= class TestSessionLifecycle: def test_create_and_get_session(self, db): sid = db.create_session( session_id="s1", source="cli", model="test-model", ) assert sid == "s1" session = db.get_session("s1") assert session is not None assert session["source"] == "cli" assert session["model"] == "test-model" assert session["ended_at"] is None def test_get_nonexistent_session(self, db): assert db.get_session("nonexistent") is None def test_end_session(self, db): db.create_session(session_id="s1", source="cli") db.end_session("s1", end_reason="user_exit") session = db.get_session("s1") assert isinstance(session["ended_at"], float) assert session["end_reason"] == "user_exit" def test_update_system_prompt(self, db): db.create_session(session_id="s1", source="cli") db.update_system_prompt("s1", "You are a helpful assistant.") session = db.get_session("s1") assert session["system_prompt"] == "You are a helpful assistant." def test_update_token_counts(self, db): db.create_session(session_id="s1", source="cli") db.update_token_counts("s1", input_tokens=100, output_tokens=50) db.update_token_counts("s1", input_tokens=200, output_tokens=100) session = db.get_session("s1") assert session["input_tokens"] == 300 assert session["output_tokens"] == 150 def test_parent_session(self, db): db.create_session(session_id="parent", source="cli") db.create_session(session_id="child", source="cli", parent_session_id="parent") child = db.get_session("child") assert child["parent_session_id"] == "parent" # ========================================================================= # Message storage # ========================================================================= class TestMessageStorage: def test_append_and_get_messages(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="Hello") db.append_message("s1", role="assistant", content="Hi there!") messages = db.get_messages("s1") assert len(messages) == 2 assert messages[0]["role"] == "user" assert messages[0]["content"] == "Hello" assert messages[1]["role"] == "assistant" def test_message_increments_session_count(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="Hello") db.append_message("s1", role="assistant", content="Hi") session = db.get_session("s1") assert session["message_count"] == 2 def test_tool_message_increments_tool_count(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="tool", content="result", tool_name="web_search") session = db.get_session("s1") assert session["tool_call_count"] == 1 def test_tool_calls_serialization(self, db): db.create_session(session_id="s1", source="cli") tool_calls = [{"id": "call_1", "function": {"name": "web_search", "arguments": "{}"}}] db.append_message("s1", role="assistant", tool_calls=tool_calls) messages = db.get_messages("s1") assert messages[0]["tool_calls"] == tool_calls def test_get_messages_as_conversation(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="Hello") db.append_message("s1", role="assistant", content="Hi!") conv = db.get_messages_as_conversation("s1") assert len(conv) == 2 assert conv[0] == {"role": "user", "content": "Hello"} assert conv[1] == {"role": "assistant", "content": "Hi!"} def test_finish_reason_stored(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="assistant", content="Done", finish_reason="stop") messages = db.get_messages("s1") assert messages[0]["finish_reason"] == "stop" # ========================================================================= # FTS5 search # ========================================================================= class TestFTS5Search: def test_search_finds_content(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="How do I deploy with Docker?") db.append_message("s1", role="assistant", content="Use docker compose up.") results = db.search_messages("docker") assert len(results) == 2 # At least one result should mention docker snippets = [r.get("snippet", "") for r in results] assert any("docker" in s.lower() or "Docker" in s for s in snippets) def test_search_empty_query(self, db): assert db.search_messages("") == [] assert db.search_messages(" ") == [] def test_search_with_source_filter(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="CLI question about Python") db.create_session(session_id="s2", source="telegram") db.append_message("s2", role="user", content="Telegram question about Python") results = db.search_messages("Python", source_filter=["telegram"]) # Should only find the telegram message sources = [r["source"] for r in results] assert all(s == "telegram" for s in sources) def test_search_with_role_filter(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="What is FastAPI?") db.append_message("s1", role="assistant", content="FastAPI is a web framework.") results = db.search_messages("FastAPI", role_filter=["assistant"]) roles = [r["role"] for r in results] assert all(r == "assistant" for r in roles) def test_search_returns_context(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="Tell me about Kubernetes") db.append_message("s1", role="assistant", content="Kubernetes is an orchestrator.") results = db.search_messages("Kubernetes") assert len(results) == 2 assert "context" in results[0] assert isinstance(results[0]["context"], list) assert len(results[0]["context"]) > 0 # ========================================================================= # Session search and listing # ========================================================================= class TestSearchSessions: def test_list_all_sessions(self, db): db.create_session(session_id="s1", source="cli") db.create_session(session_id="s2", source="telegram") sessions = db.search_sessions() assert len(sessions) == 2 def test_filter_by_source(self, db): db.create_session(session_id="s1", source="cli") db.create_session(session_id="s2", source="telegram") sessions = db.search_sessions(source="cli") assert len(sessions) == 1 assert sessions[0]["source"] == "cli" def test_pagination(self, db): for i in range(5): db.create_session(session_id=f"s{i}", source="cli") page1 = db.search_sessions(limit=2) page2 = db.search_sessions(limit=2, offset=2) assert len(page1) == 2 assert len(page2) == 2 assert page1[0]["id"] != page2[0]["id"] # ========================================================================= # Counts # ========================================================================= class TestCounts: def test_session_count(self, db): assert db.session_count() == 0 db.create_session(session_id="s1", source="cli") db.create_session(session_id="s2", source="telegram") assert db.session_count() == 2 def test_session_count_by_source(self, db): db.create_session(session_id="s1", source="cli") db.create_session(session_id="s2", source="telegram") db.create_session(session_id="s3", source="cli") assert db.session_count(source="cli") == 2 assert db.session_count(source="telegram") == 1 def test_message_count_total(self, db): assert db.message_count() == 0 db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="Hello") db.append_message("s1", role="assistant", content="Hi") assert db.message_count() == 2 def test_message_count_per_session(self, db): db.create_session(session_id="s1", source="cli") db.create_session(session_id="s2", source="cli") db.append_message("s1", role="user", content="A") db.append_message("s2", role="user", content="B") db.append_message("s2", role="user", content="C") assert db.message_count(session_id="s1") == 1 assert db.message_count(session_id="s2") == 2 # ========================================================================= # Delete and export # ========================================================================= class TestDeleteAndExport: def test_delete_session(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="Hello") assert db.delete_session("s1") is True assert db.get_session("s1") is None assert db.message_count(session_id="s1") == 0 def test_delete_nonexistent(self, db): assert db.delete_session("nope") is False def test_export_session(self, db): db.create_session(session_id="s1", source="cli", model="test") db.append_message("s1", role="user", content="Hello") db.append_message("s1", role="assistant", content="Hi") export = db.export_session("s1") assert isinstance(export, dict) assert export["source"] == "cli" assert len(export["messages"]) == 2 def test_export_nonexistent(self, db): assert db.export_session("nope") is None def test_export_all(self, db): db.create_session(session_id="s1", source="cli") db.create_session(session_id="s2", source="telegram") db.append_message("s1", role="user", content="A") exports = db.export_all() assert len(exports) == 2 def test_export_all_with_source(self, db): db.create_session(session_id="s1", source="cli") db.create_session(session_id="s2", source="telegram") exports = db.export_all(source="cli") assert len(exports) == 1 assert exports[0]["source"] == "cli" # ========================================================================= # Prune # ========================================================================= class TestPruneSessions: def test_prune_old_ended_sessions(self, db): # Create and end an "old" session db.create_session(session_id="old", source="cli") db.end_session("old", end_reason="done") # Manually backdate started_at db._conn.execute( "UPDATE sessions SET started_at = ? WHERE id = ?", (time.time() - 100 * 86400, "old"), ) db._conn.commit() # Create a recent session db.create_session(session_id="new", source="cli") pruned = db.prune_sessions(older_than_days=90) assert pruned == 1 assert db.get_session("old") is None session = db.get_session("new") assert session is not None assert session["id"] == "new" def test_prune_skips_active_sessions(self, db): db.create_session(session_id="active", source="cli") # Backdate but don't end db._conn.execute( "UPDATE sessions SET started_at = ? WHERE id = ?", (time.time() - 200 * 86400, "active"), ) db._conn.commit() pruned = db.prune_sessions(older_than_days=90) assert pruned == 0 assert db.get_session("active") is not None def test_prune_with_source_filter(self, db): for sid, src in [("old_cli", "cli"), ("old_tg", "telegram")]: db.create_session(session_id=sid, source=src) db.end_session(sid, end_reason="done") db._conn.execute( "UPDATE sessions SET started_at = ? WHERE id = ?", (time.time() - 200 * 86400, sid), ) db._conn.commit() pruned = db.prune_sessions(older_than_days=90, source="cli") assert pruned == 1 assert db.get_session("old_cli") is None assert db.get_session("old_tg") is not None # ========================================================================= # Schema and WAL mode # ========================================================================= class TestSchemaInit: def test_wal_mode(self, db): cursor = db._conn.execute("PRAGMA journal_mode") mode = cursor.fetchone()[0] assert mode == "wal" def test_foreign_keys_enabled(self, db): cursor = db._conn.execute("PRAGMA foreign_keys") assert cursor.fetchone()[0] == 1 def test_tables_exist(self, db): cursor = db._conn.execute( "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" ) tables = {row[0] for row in cursor.fetchall()} assert "sessions" in tables assert "messages" in tables assert "schema_version" in tables def test_schema_version(self, db): cursor = db._conn.execute("SELECT version FROM schema_version") version = cursor.fetchone()[0] assert version == 2