Add comprehensive test coverage for: - cron/jobs.py: schedule parsing, job CRUD, due-job detection (34 tests) - tools/memory_tool.py: security scanning, MemoryStore ops, dispatcher (32 tests) - toolsets.py: resolution, validation, composition, cycle detection (19 tests) - tools/file_operations.py: write deny list, result dataclasses, helpers (37 tests) - agent/prompt_builder.py: context scanning, truncation, skills index (24 tests) - agent/model_metadata.py: token estimation, context lengths (16 tests) - hermes_state.py: SessionDB SQLite CRUD, FTS5 search, export, prune (28 tests) Total: 210 new tests, all passing (380 total suite).
373 lines
14 KiB
Python
373 lines
14 KiB
Python
"""Tests for hermes_state.py — SessionDB SQLite CRUD, FTS5 search, export."""
|
|
|
|
import time
|
|
import pytest
|
|
from pathlib import Path
|
|
|
|
from hermes_state import SessionDB
|
|
|
|
|
|
@pytest.fixture()
|
|
def db(tmp_path):
|
|
"""Create a SessionDB with a temp database file."""
|
|
db_path = tmp_path / "test_state.db"
|
|
session_db = SessionDB(db_path=db_path)
|
|
yield session_db
|
|
session_db.close()
|
|
|
|
|
|
# =========================================================================
|
|
# Session lifecycle
|
|
# =========================================================================
|
|
|
|
class TestSessionLifecycle:
|
|
def test_create_and_get_session(self, db):
|
|
sid = db.create_session(
|
|
session_id="s1",
|
|
source="cli",
|
|
model="test-model",
|
|
)
|
|
assert sid == "s1"
|
|
|
|
session = db.get_session("s1")
|
|
assert session is not None
|
|
assert session["source"] == "cli"
|
|
assert session["model"] == "test-model"
|
|
assert session["ended_at"] is None
|
|
|
|
def test_get_nonexistent_session(self, db):
|
|
assert db.get_session("nonexistent") is None
|
|
|
|
def test_end_session(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.end_session("s1", end_reason="user_exit")
|
|
|
|
session = db.get_session("s1")
|
|
assert session["ended_at"] is not None
|
|
assert session["end_reason"] == "user_exit"
|
|
|
|
def test_update_system_prompt(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.update_system_prompt("s1", "You are a helpful assistant.")
|
|
|
|
session = db.get_session("s1")
|
|
assert session["system_prompt"] == "You are a helpful assistant."
|
|
|
|
def test_update_token_counts(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.update_token_counts("s1", input_tokens=100, output_tokens=50)
|
|
db.update_token_counts("s1", input_tokens=200, output_tokens=100)
|
|
|
|
session = db.get_session("s1")
|
|
assert session["input_tokens"] == 300
|
|
assert session["output_tokens"] == 150
|
|
|
|
def test_parent_session(self, db):
|
|
db.create_session(session_id="parent", source="cli")
|
|
db.create_session(session_id="child", source="cli", parent_session_id="parent")
|
|
|
|
child = db.get_session("child")
|
|
assert child["parent_session_id"] == "parent"
|
|
|
|
|
|
# =========================================================================
|
|
# Message storage
|
|
# =========================================================================
|
|
|
|
class TestMessageStorage:
|
|
def test_append_and_get_messages(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Hello")
|
|
db.append_message("s1", role="assistant", content="Hi there!")
|
|
|
|
messages = db.get_messages("s1")
|
|
assert len(messages) == 2
|
|
assert messages[0]["role"] == "user"
|
|
assert messages[0]["content"] == "Hello"
|
|
assert messages[1]["role"] == "assistant"
|
|
|
|
def test_message_increments_session_count(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Hello")
|
|
db.append_message("s1", role="assistant", content="Hi")
|
|
|
|
session = db.get_session("s1")
|
|
assert session["message_count"] == 2
|
|
|
|
def test_tool_message_increments_tool_count(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="tool", content="result", tool_name="web_search")
|
|
|
|
session = db.get_session("s1")
|
|
assert session["tool_call_count"] == 1
|
|
|
|
def test_tool_calls_serialization(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
tool_calls = [{"id": "call_1", "function": {"name": "web_search", "arguments": "{}"}}]
|
|
db.append_message("s1", role="assistant", tool_calls=tool_calls)
|
|
|
|
messages = db.get_messages("s1")
|
|
assert messages[0]["tool_calls"] == tool_calls
|
|
|
|
def test_get_messages_as_conversation(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Hello")
|
|
db.append_message("s1", role="assistant", content="Hi!")
|
|
|
|
conv = db.get_messages_as_conversation("s1")
|
|
assert len(conv) == 2
|
|
assert conv[0] == {"role": "user", "content": "Hello"}
|
|
assert conv[1] == {"role": "assistant", "content": "Hi!"}
|
|
|
|
def test_finish_reason_stored(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="assistant", content="Done", finish_reason="stop")
|
|
|
|
messages = db.get_messages("s1")
|
|
assert messages[0]["finish_reason"] == "stop"
|
|
|
|
|
|
# =========================================================================
|
|
# FTS5 search
|
|
# =========================================================================
|
|
|
|
class TestFTS5Search:
|
|
def test_search_finds_content(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="How do I deploy with Docker?")
|
|
db.append_message("s1", role="assistant", content="Use docker compose up.")
|
|
|
|
results = db.search_messages("docker")
|
|
assert len(results) >= 1
|
|
# At least one result should mention docker
|
|
snippets = [r.get("snippet", "") for r in results]
|
|
assert any("docker" in s.lower() or "Docker" in s for s in snippets)
|
|
|
|
def test_search_empty_query(self, db):
|
|
assert db.search_messages("") == []
|
|
assert db.search_messages(" ") == []
|
|
|
|
def test_search_with_source_filter(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="CLI question about Python")
|
|
|
|
db.create_session(session_id="s2", source="telegram")
|
|
db.append_message("s2", role="user", content="Telegram question about Python")
|
|
|
|
results = db.search_messages("Python", source_filter=["telegram"])
|
|
# Should only find the telegram message
|
|
sources = [r["source"] for r in results]
|
|
assert all(s == "telegram" for s in sources)
|
|
|
|
def test_search_with_role_filter(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="What is FastAPI?")
|
|
db.append_message("s1", role="assistant", content="FastAPI is a web framework.")
|
|
|
|
results = db.search_messages("FastAPI", role_filter=["assistant"])
|
|
roles = [r["role"] for r in results]
|
|
assert all(r == "assistant" for r in roles)
|
|
|
|
def test_search_returns_context(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Tell me about Kubernetes")
|
|
db.append_message("s1", role="assistant", content="Kubernetes is an orchestrator.")
|
|
|
|
results = db.search_messages("Kubernetes")
|
|
assert len(results) >= 1
|
|
assert "context" in results[0]
|
|
|
|
|
|
# =========================================================================
|
|
# Session search and listing
|
|
# =========================================================================
|
|
|
|
class TestSearchSessions:
|
|
def test_list_all_sessions(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="telegram")
|
|
|
|
sessions = db.search_sessions()
|
|
assert len(sessions) == 2
|
|
|
|
def test_filter_by_source(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="telegram")
|
|
|
|
sessions = db.search_sessions(source="cli")
|
|
assert len(sessions) == 1
|
|
assert sessions[0]["source"] == "cli"
|
|
|
|
def test_pagination(self, db):
|
|
for i in range(5):
|
|
db.create_session(session_id=f"s{i}", source="cli")
|
|
|
|
page1 = db.search_sessions(limit=2)
|
|
page2 = db.search_sessions(limit=2, offset=2)
|
|
assert len(page1) == 2
|
|
assert len(page2) == 2
|
|
assert page1[0]["id"] != page2[0]["id"]
|
|
|
|
|
|
# =========================================================================
|
|
# Counts
|
|
# =========================================================================
|
|
|
|
class TestCounts:
|
|
def test_session_count(self, db):
|
|
assert db.session_count() == 0
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="telegram")
|
|
assert db.session_count() == 2
|
|
|
|
def test_session_count_by_source(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="telegram")
|
|
db.create_session(session_id="s3", source="cli")
|
|
assert db.session_count(source="cli") == 2
|
|
assert db.session_count(source="telegram") == 1
|
|
|
|
def test_message_count_total(self, db):
|
|
assert db.message_count() == 0
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Hello")
|
|
db.append_message("s1", role="assistant", content="Hi")
|
|
assert db.message_count() == 2
|
|
|
|
def test_message_count_per_session(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="cli")
|
|
db.append_message("s1", role="user", content="A")
|
|
db.append_message("s2", role="user", content="B")
|
|
db.append_message("s2", role="user", content="C")
|
|
assert db.message_count(session_id="s1") == 1
|
|
assert db.message_count(session_id="s2") == 2
|
|
|
|
|
|
# =========================================================================
|
|
# Delete and export
|
|
# =========================================================================
|
|
|
|
class TestDeleteAndExport:
|
|
def test_delete_session(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Hello")
|
|
|
|
assert db.delete_session("s1") is True
|
|
assert db.get_session("s1") is None
|
|
assert db.message_count(session_id="s1") == 0
|
|
|
|
def test_delete_nonexistent(self, db):
|
|
assert db.delete_session("nope") is False
|
|
|
|
def test_export_session(self, db):
|
|
db.create_session(session_id="s1", source="cli", model="test")
|
|
db.append_message("s1", role="user", content="Hello")
|
|
db.append_message("s1", role="assistant", content="Hi")
|
|
|
|
export = db.export_session("s1")
|
|
assert export is not None
|
|
assert export["source"] == "cli"
|
|
assert len(export["messages"]) == 2
|
|
|
|
def test_export_nonexistent(self, db):
|
|
assert db.export_session("nope") is None
|
|
|
|
def test_export_all(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="telegram")
|
|
db.append_message("s1", role="user", content="A")
|
|
|
|
exports = db.export_all()
|
|
assert len(exports) == 2
|
|
|
|
def test_export_all_with_source(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="telegram")
|
|
|
|
exports = db.export_all(source="cli")
|
|
assert len(exports) == 1
|
|
assert exports[0]["source"] == "cli"
|
|
|
|
|
|
# =========================================================================
|
|
# Prune
|
|
# =========================================================================
|
|
|
|
class TestPruneSessions:
|
|
def test_prune_old_ended_sessions(self, db):
|
|
# Create and end an "old" session
|
|
db.create_session(session_id="old", source="cli")
|
|
db.end_session("old", end_reason="done")
|
|
# Manually backdate started_at
|
|
db._conn.execute(
|
|
"UPDATE sessions SET started_at = ? WHERE id = ?",
|
|
(time.time() - 100 * 86400, "old"),
|
|
)
|
|
db._conn.commit()
|
|
|
|
# Create a recent session
|
|
db.create_session(session_id="new", source="cli")
|
|
|
|
pruned = db.prune_sessions(older_than_days=90)
|
|
assert pruned == 1
|
|
assert db.get_session("old") is None
|
|
assert db.get_session("new") is not None
|
|
|
|
def test_prune_skips_active_sessions(self, db):
|
|
db.create_session(session_id="active", source="cli")
|
|
# Backdate but don't end
|
|
db._conn.execute(
|
|
"UPDATE sessions SET started_at = ? WHERE id = ?",
|
|
(time.time() - 200 * 86400, "active"),
|
|
)
|
|
db._conn.commit()
|
|
|
|
pruned = db.prune_sessions(older_than_days=90)
|
|
assert pruned == 0
|
|
assert db.get_session("active") is not None
|
|
|
|
def test_prune_with_source_filter(self, db):
|
|
for sid, src in [("old_cli", "cli"), ("old_tg", "telegram")]:
|
|
db.create_session(session_id=sid, source=src)
|
|
db.end_session(sid, end_reason="done")
|
|
db._conn.execute(
|
|
"UPDATE sessions SET started_at = ? WHERE id = ?",
|
|
(time.time() - 200 * 86400, sid),
|
|
)
|
|
db._conn.commit()
|
|
|
|
pruned = db.prune_sessions(older_than_days=90, source="cli")
|
|
assert pruned == 1
|
|
assert db.get_session("old_cli") is None
|
|
assert db.get_session("old_tg") is not None
|
|
|
|
|
|
# =========================================================================
|
|
# Schema and WAL mode
|
|
# =========================================================================
|
|
|
|
class TestSchemaInit:
|
|
def test_wal_mode(self, db):
|
|
cursor = db._conn.execute("PRAGMA journal_mode")
|
|
mode = cursor.fetchone()[0]
|
|
assert mode == "wal"
|
|
|
|
def test_foreign_keys_enabled(self, db):
|
|
cursor = db._conn.execute("PRAGMA foreign_keys")
|
|
assert cursor.fetchone()[0] == 1
|
|
|
|
def test_tables_exist(self, db):
|
|
cursor = db._conn.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
|
)
|
|
tables = {row[0] for row in cursor.fetchall()}
|
|
assert "sessions" in tables
|
|
assert "messages" in tables
|
|
assert "schema_version" in tables
|
|
|
|
def test_schema_version(self, db):
|
|
cursor = db._conn.execute("SELECT version FROM schema_version")
|
|
version = cursor.fetchone()[0]
|
|
assert version == 2
|