Add regression coverage for backfilling NULL gateway session models in SQLite, preserving existing models, and forwarding the resolved agent model through SessionStore updates.
936 lines
37 KiB
Python
936 lines
37 KiB
Python
"""Tests for hermes_state.py — SessionDB SQLite CRUD, FTS5 search, export."""
|
|
|
|
import time
|
|
import pytest
|
|
from pathlib import Path
|
|
|
|
from hermes_state import SessionDB
|
|
|
|
|
|
@pytest.fixture()
|
|
def db(tmp_path):
|
|
"""Create a SessionDB with a temp database file."""
|
|
db_path = tmp_path / "test_state.db"
|
|
session_db = SessionDB(db_path=db_path)
|
|
yield session_db
|
|
session_db.close()
|
|
|
|
|
|
# =========================================================================
|
|
# Session lifecycle
|
|
# =========================================================================
|
|
|
|
class TestSessionLifecycle:
|
|
def test_create_and_get_session(self, db):
|
|
sid = db.create_session(
|
|
session_id="s1",
|
|
source="cli",
|
|
model="test-model",
|
|
)
|
|
assert sid == "s1"
|
|
|
|
session = db.get_session("s1")
|
|
assert session is not None
|
|
assert session["source"] == "cli"
|
|
assert session["model"] == "test-model"
|
|
assert session["ended_at"] is None
|
|
|
|
def test_get_nonexistent_session(self, db):
|
|
assert db.get_session("nonexistent") is None
|
|
|
|
def test_end_session(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.end_session("s1", end_reason="user_exit")
|
|
|
|
session = db.get_session("s1")
|
|
assert isinstance(session["ended_at"], float)
|
|
assert session["end_reason"] == "user_exit"
|
|
|
|
def test_update_system_prompt(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.update_system_prompt("s1", "You are a helpful assistant.")
|
|
|
|
session = db.get_session("s1")
|
|
assert session["system_prompt"] == "You are a helpful assistant."
|
|
|
|
def test_update_token_counts(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.update_token_counts("s1", input_tokens=200, output_tokens=100)
|
|
db.update_token_counts("s1", input_tokens=100, output_tokens=50)
|
|
|
|
session = db.get_session("s1")
|
|
assert session["input_tokens"] == 300
|
|
assert session["output_tokens"] == 150
|
|
|
|
def test_update_token_counts_backfills_model_when_null(self, db):
|
|
db.create_session(session_id="s1", source="telegram")
|
|
db.update_token_counts("s1", input_tokens=10, output_tokens=5, model="openai/gpt-5.4")
|
|
|
|
session = db.get_session("s1")
|
|
assert session["model"] == "openai/gpt-5.4"
|
|
|
|
def test_update_token_counts_preserves_existing_model(self, db):
|
|
db.create_session(session_id="s1", source="cli", model="anthropic/claude-opus-4.6")
|
|
db.update_token_counts("s1", input_tokens=10, output_tokens=5, model="openai/gpt-5.4")
|
|
|
|
session = db.get_session("s1")
|
|
assert session["model"] == "anthropic/claude-opus-4.6"
|
|
|
|
def test_parent_session(self, db):
|
|
db.create_session(session_id="parent", source="cli")
|
|
db.create_session(session_id="child", source="cli", parent_session_id="parent")
|
|
|
|
child = db.get_session("child")
|
|
assert child["parent_session_id"] == "parent"
|
|
|
|
|
|
# =========================================================================
|
|
# Message storage
|
|
# =========================================================================
|
|
|
|
class TestMessageStorage:
|
|
def test_append_and_get_messages(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Hello")
|
|
db.append_message("s1", role="assistant", content="Hi there!")
|
|
|
|
messages = db.get_messages("s1")
|
|
assert len(messages) == 2
|
|
assert messages[0]["role"] == "user"
|
|
assert messages[0]["content"] == "Hello"
|
|
assert messages[1]["role"] == "assistant"
|
|
|
|
def test_message_increments_session_count(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Hello")
|
|
db.append_message("s1", role="assistant", content="Hi")
|
|
|
|
session = db.get_session("s1")
|
|
assert session["message_count"] == 2
|
|
|
|
def test_tool_response_does_not_increment_tool_count(self, db):
|
|
"""Tool responses (role=tool) should not increment tool_call_count.
|
|
|
|
Only assistant messages with tool_calls should count.
|
|
"""
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="tool", content="result", tool_name="web_search")
|
|
|
|
session = db.get_session("s1")
|
|
assert session["tool_call_count"] == 0
|
|
|
|
def test_assistant_tool_calls_increment_by_count(self, db):
|
|
"""An assistant message with N tool_calls should increment by N."""
|
|
db.create_session(session_id="s1", source="cli")
|
|
tool_calls = [
|
|
{"id": "call_1", "function": {"name": "web_search", "arguments": "{}"}},
|
|
]
|
|
db.append_message("s1", role="assistant", content="", tool_calls=tool_calls)
|
|
|
|
session = db.get_session("s1")
|
|
assert session["tool_call_count"] == 1
|
|
|
|
def test_tool_call_count_matches_actual_calls(self, db):
|
|
"""tool_call_count should equal the number of tool calls made, not messages."""
|
|
db.create_session(session_id="s1", source="cli")
|
|
|
|
# Assistant makes 2 parallel tool calls in one message
|
|
tool_calls = [
|
|
{"id": "call_1", "function": {"name": "ha_call_service", "arguments": "{}"}},
|
|
{"id": "call_2", "function": {"name": "ha_call_service", "arguments": "{}"}},
|
|
]
|
|
db.append_message("s1", role="assistant", content="", tool_calls=tool_calls)
|
|
|
|
# Two tool responses come back
|
|
db.append_message("s1", role="tool", content="ok", tool_name="ha_call_service")
|
|
db.append_message("s1", role="tool", content="ok", tool_name="ha_call_service")
|
|
|
|
session = db.get_session("s1")
|
|
# Should be 2 (the actual number of tool calls), not 3
|
|
assert session["tool_call_count"] == 2, (
|
|
f"Expected 2 tool calls but got {session['tool_call_count']}. "
|
|
"tool responses are double-counted and multi-call messages are under-counted"
|
|
)
|
|
|
|
def test_tool_calls_serialization(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
tool_calls = [{"id": "call_1", "function": {"name": "web_search", "arguments": "{}"}}]
|
|
db.append_message("s1", role="assistant", tool_calls=tool_calls)
|
|
|
|
messages = db.get_messages("s1")
|
|
assert messages[0]["tool_calls"] == tool_calls
|
|
|
|
def test_get_messages_as_conversation(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Hello")
|
|
db.append_message("s1", role="assistant", content="Hi!")
|
|
|
|
conv = db.get_messages_as_conversation("s1")
|
|
assert len(conv) == 2
|
|
assert conv[0] == {"role": "user", "content": "Hello"}
|
|
assert conv[1] == {"role": "assistant", "content": "Hi!"}
|
|
|
|
def test_finish_reason_stored(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="assistant", content="Done", finish_reason="stop")
|
|
|
|
messages = db.get_messages("s1")
|
|
assert messages[0]["finish_reason"] == "stop"
|
|
|
|
|
|
# =========================================================================
|
|
# FTS5 search
|
|
# =========================================================================
|
|
|
|
class TestFTS5Search:
|
|
def test_search_finds_content(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="How do I deploy with Docker?")
|
|
db.append_message("s1", role="assistant", content="Use docker compose up.")
|
|
|
|
results = db.search_messages("docker")
|
|
assert len(results) == 2
|
|
# At least one result should mention docker
|
|
snippets = [r.get("snippet", "") for r in results]
|
|
assert any("docker" in s.lower() or "Docker" in s for s in snippets)
|
|
|
|
def test_search_empty_query(self, db):
|
|
assert db.search_messages("") == []
|
|
assert db.search_messages(" ") == []
|
|
|
|
def test_search_with_source_filter(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="CLI question about Python")
|
|
|
|
db.create_session(session_id="s2", source="telegram")
|
|
db.append_message("s2", role="user", content="Telegram question about Python")
|
|
|
|
results = db.search_messages("Python", source_filter=["telegram"])
|
|
# Should only find the telegram message
|
|
sources = [r["source"] for r in results]
|
|
assert all(s == "telegram" for s in sources)
|
|
|
|
def test_search_with_role_filter(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="What is FastAPI?")
|
|
db.append_message("s1", role="assistant", content="FastAPI is a web framework.")
|
|
|
|
results = db.search_messages("FastAPI", role_filter=["assistant"])
|
|
roles = [r["role"] for r in results]
|
|
assert all(r == "assistant" for r in roles)
|
|
|
|
def test_search_returns_context(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Tell me about Kubernetes")
|
|
db.append_message("s1", role="assistant", content="Kubernetes is an orchestrator.")
|
|
|
|
results = db.search_messages("Kubernetes")
|
|
assert len(results) == 2
|
|
assert "context" in results[0]
|
|
assert isinstance(results[0]["context"], list)
|
|
assert len(results[0]["context"]) > 0
|
|
|
|
def test_search_special_chars_do_not_crash(self, db):
|
|
"""FTS5 special characters in queries must not raise OperationalError."""
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="How do I use C++ templates?")
|
|
|
|
# Each of these previously caused sqlite3.OperationalError
|
|
dangerous_queries = [
|
|
'C++', # + is FTS5 column filter
|
|
'"unterminated', # unbalanced double-quote
|
|
'(problem', # unbalanced parenthesis
|
|
'hello AND', # dangling boolean operator
|
|
'***', # repeated wildcard
|
|
'{test}', # curly braces (column reference)
|
|
'OR hello', # leading boolean operator
|
|
'a AND OR b', # adjacent operators
|
|
]
|
|
for query in dangerous_queries:
|
|
# Must not raise — should return list (possibly empty)
|
|
results = db.search_messages(query)
|
|
assert isinstance(results, list), f"Query {query!r} did not return a list"
|
|
|
|
def test_search_sanitized_query_still_finds_content(self, db):
|
|
"""Sanitization must not break normal keyword search."""
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Learning C++ templates today")
|
|
|
|
# "C++" sanitized to "C" should still match "C++"
|
|
results = db.search_messages("C++")
|
|
# The word "C" appears in the content, so FTS5 should find it
|
|
assert isinstance(results, list)
|
|
|
|
def test_sanitize_fts5_query_strips_dangerous_chars(self):
|
|
"""Unit test for _sanitize_fts5_query static method."""
|
|
from hermes_state import SessionDB
|
|
s = SessionDB._sanitize_fts5_query
|
|
assert s('hello world') == 'hello world'
|
|
assert '+' not in s('C++')
|
|
assert '"' not in s('"unterminated')
|
|
assert '(' not in s('(problem')
|
|
assert '{' not in s('{test}')
|
|
# Dangling operators removed
|
|
assert s('hello AND') == 'hello'
|
|
assert s('OR world') == 'world'
|
|
# Leading bare * removed
|
|
assert s('***') == ''
|
|
# Valid prefix kept
|
|
assert s('deploy*') == 'deploy*'
|
|
|
|
|
|
# =========================================================================
|
|
# Session search and listing
|
|
# =========================================================================
|
|
|
|
class TestSearchSessions:
|
|
def test_list_all_sessions(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="telegram")
|
|
|
|
sessions = db.search_sessions()
|
|
assert len(sessions) == 2
|
|
|
|
def test_filter_by_source(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="telegram")
|
|
|
|
sessions = db.search_sessions(source="cli")
|
|
assert len(sessions) == 1
|
|
assert sessions[0]["source"] == "cli"
|
|
|
|
def test_pagination(self, db):
|
|
for i in range(5):
|
|
db.create_session(session_id=f"s{i}", source="cli")
|
|
|
|
page1 = db.search_sessions(limit=2)
|
|
page2 = db.search_sessions(limit=2, offset=2)
|
|
assert len(page1) == 2
|
|
assert len(page2) == 2
|
|
assert page1[0]["id"] != page2[0]["id"]
|
|
|
|
|
|
# =========================================================================
|
|
# Counts
|
|
# =========================================================================
|
|
|
|
class TestCounts:
|
|
def test_session_count(self, db):
|
|
assert db.session_count() == 0
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="telegram")
|
|
assert db.session_count() == 2
|
|
|
|
def test_session_count_by_source(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="telegram")
|
|
db.create_session(session_id="s3", source="cli")
|
|
assert db.session_count(source="cli") == 2
|
|
assert db.session_count(source="telegram") == 1
|
|
|
|
def test_message_count_total(self, db):
|
|
assert db.message_count() == 0
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Hello")
|
|
db.append_message("s1", role="assistant", content="Hi")
|
|
assert db.message_count() == 2
|
|
|
|
def test_message_count_per_session(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="cli")
|
|
db.append_message("s1", role="user", content="A")
|
|
db.append_message("s2", role="user", content="B")
|
|
db.append_message("s2", role="user", content="C")
|
|
assert db.message_count(session_id="s1") == 1
|
|
assert db.message_count(session_id="s2") == 2
|
|
|
|
|
|
# =========================================================================
|
|
# Delete and export
|
|
# =========================================================================
|
|
|
|
class TestDeleteAndExport:
|
|
def test_delete_session(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Hello")
|
|
|
|
assert db.delete_session("s1") is True
|
|
assert db.get_session("s1") is None
|
|
assert db.message_count(session_id="s1") == 0
|
|
|
|
def test_delete_nonexistent(self, db):
|
|
assert db.delete_session("nope") is False
|
|
|
|
def test_export_session(self, db):
|
|
db.create_session(session_id="s1", source="cli", model="test")
|
|
db.append_message("s1", role="user", content="Hello")
|
|
db.append_message("s1", role="assistant", content="Hi")
|
|
|
|
export = db.export_session("s1")
|
|
assert isinstance(export, dict)
|
|
assert export["source"] == "cli"
|
|
assert len(export["messages"]) == 2
|
|
|
|
def test_export_nonexistent(self, db):
|
|
assert db.export_session("nope") is None
|
|
|
|
def test_export_all(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="telegram")
|
|
db.append_message("s1", role="user", content="A")
|
|
|
|
exports = db.export_all()
|
|
assert len(exports) == 2
|
|
|
|
def test_export_all_with_source(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="telegram")
|
|
|
|
exports = db.export_all(source="cli")
|
|
assert len(exports) == 1
|
|
assert exports[0]["source"] == "cli"
|
|
|
|
|
|
# =========================================================================
|
|
# Prune
|
|
# =========================================================================
|
|
|
|
class TestPruneSessions:
|
|
def test_prune_old_ended_sessions(self, db):
|
|
# Create and end an "old" session
|
|
db.create_session(session_id="old", source="cli")
|
|
db.end_session("old", end_reason="done")
|
|
# Manually backdate started_at
|
|
db._conn.execute(
|
|
"UPDATE sessions SET started_at = ? WHERE id = ?",
|
|
(time.time() - 100 * 86400, "old"),
|
|
)
|
|
db._conn.commit()
|
|
|
|
# Create a recent session
|
|
db.create_session(session_id="new", source="cli")
|
|
|
|
pruned = db.prune_sessions(older_than_days=90)
|
|
assert pruned == 1
|
|
assert db.get_session("old") is None
|
|
session = db.get_session("new")
|
|
assert session is not None
|
|
assert session["id"] == "new"
|
|
|
|
def test_prune_skips_active_sessions(self, db):
|
|
db.create_session(session_id="active", source="cli")
|
|
# Backdate but don't end
|
|
db._conn.execute(
|
|
"UPDATE sessions SET started_at = ? WHERE id = ?",
|
|
(time.time() - 200 * 86400, "active"),
|
|
)
|
|
db._conn.commit()
|
|
|
|
pruned = db.prune_sessions(older_than_days=90)
|
|
assert pruned == 0
|
|
assert db.get_session("active") is not None
|
|
|
|
def test_prune_with_source_filter(self, db):
|
|
for sid, src in [("old_cli", "cli"), ("old_tg", "telegram")]:
|
|
db.create_session(session_id=sid, source=src)
|
|
db.end_session(sid, end_reason="done")
|
|
db._conn.execute(
|
|
"UPDATE sessions SET started_at = ? WHERE id = ?",
|
|
(time.time() - 200 * 86400, sid),
|
|
)
|
|
db._conn.commit()
|
|
|
|
pruned = db.prune_sessions(older_than_days=90, source="cli")
|
|
assert pruned == 1
|
|
assert db.get_session("old_cli") is None
|
|
assert db.get_session("old_tg") is not None
|
|
|
|
|
|
# =========================================================================
|
|
# Schema and WAL mode
|
|
# =========================================================================
|
|
|
|
# =========================================================================
|
|
# Session title
|
|
# =========================================================================
|
|
|
|
class TestSessionTitle:
|
|
def test_set_and_get_title(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
assert db.set_session_title("s1", "My Session") is True
|
|
|
|
session = db.get_session("s1")
|
|
assert session["title"] == "My Session"
|
|
|
|
def test_set_title_nonexistent_session(self, db):
|
|
assert db.set_session_title("nonexistent", "Title") is False
|
|
|
|
def test_title_initially_none(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
session = db.get_session("s1")
|
|
assert session["title"] is None
|
|
|
|
def test_update_title(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.set_session_title("s1", "First Title")
|
|
db.set_session_title("s1", "Updated Title")
|
|
|
|
session = db.get_session("s1")
|
|
assert session["title"] == "Updated Title"
|
|
|
|
def test_title_in_search_sessions(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.set_session_title("s1", "Debugging Auth")
|
|
db.create_session(session_id="s2", source="cli")
|
|
|
|
sessions = db.search_sessions()
|
|
titled = [s for s in sessions if s.get("title") == "Debugging Auth"]
|
|
assert len(titled) == 1
|
|
assert titled[0]["id"] == "s1"
|
|
|
|
def test_title_in_export(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.set_session_title("s1", "Export Test")
|
|
db.append_message("s1", role="user", content="Hello")
|
|
|
|
export = db.export_session("s1")
|
|
assert export["title"] == "Export Test"
|
|
|
|
def test_title_with_special_characters(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
title = "PR #438 — fixing the 'auth' middleware"
|
|
db.set_session_title("s1", title)
|
|
|
|
session = db.get_session("s1")
|
|
assert session["title"] == title
|
|
|
|
def test_title_empty_string_normalized_to_none(self, db):
|
|
"""Empty strings are normalized to None (clearing the title)."""
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.set_session_title("s1", "My Title")
|
|
# Setting to empty string should clear the title (normalize to None)
|
|
db.set_session_title("s1", "")
|
|
|
|
session = db.get_session("s1")
|
|
assert session["title"] is None
|
|
|
|
def test_multiple_empty_titles_no_conflict(self, db):
|
|
"""Multiple sessions can have empty-string (normalized to NULL) titles."""
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="cli")
|
|
db.set_session_title("s1", "")
|
|
db.set_session_title("s2", "")
|
|
# Both should be None, no uniqueness conflict
|
|
assert db.get_session("s1")["title"] is None
|
|
assert db.get_session("s2")["title"] is None
|
|
|
|
def test_title_survives_end_session(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.set_session_title("s1", "Before End")
|
|
db.end_session("s1", end_reason="user_exit")
|
|
|
|
session = db.get_session("s1")
|
|
assert session["title"] == "Before End"
|
|
assert session["ended_at"] is not None
|
|
|
|
|
|
class TestSanitizeTitle:
|
|
"""Tests for SessionDB.sanitize_title() validation and cleaning."""
|
|
|
|
def test_normal_title_unchanged(self):
|
|
assert SessionDB.sanitize_title("My Project") == "My Project"
|
|
|
|
def test_strips_whitespace(self):
|
|
assert SessionDB.sanitize_title(" hello world ") == "hello world"
|
|
|
|
def test_collapses_internal_whitespace(self):
|
|
assert SessionDB.sanitize_title("hello world") == "hello world"
|
|
|
|
def test_tabs_and_newlines_collapsed(self):
|
|
assert SessionDB.sanitize_title("hello\t\nworld") == "hello world"
|
|
|
|
def test_none_returns_none(self):
|
|
assert SessionDB.sanitize_title(None) is None
|
|
|
|
def test_empty_string_returns_none(self):
|
|
assert SessionDB.sanitize_title("") is None
|
|
|
|
def test_whitespace_only_returns_none(self):
|
|
assert SessionDB.sanitize_title(" \t\n ") is None
|
|
|
|
def test_control_chars_stripped(self):
|
|
# Null byte, bell, backspace, etc.
|
|
assert SessionDB.sanitize_title("hello\x00world") == "helloworld"
|
|
assert SessionDB.sanitize_title("\x07\x08test\x1b") == "test"
|
|
|
|
def test_del_char_stripped(self):
|
|
assert SessionDB.sanitize_title("hello\x7fworld") == "helloworld"
|
|
|
|
def test_zero_width_chars_stripped(self):
|
|
# Zero-width space (U+200B), zero-width joiner (U+200D)
|
|
assert SessionDB.sanitize_title("hello\u200bworld") == "helloworld"
|
|
assert SessionDB.sanitize_title("hello\u200dworld") == "helloworld"
|
|
|
|
def test_rtl_override_stripped(self):
|
|
# Right-to-left override (U+202E) — used in filename spoofing attacks
|
|
assert SessionDB.sanitize_title("hello\u202eworld") == "helloworld"
|
|
|
|
def test_bom_stripped(self):
|
|
# Byte order mark (U+FEFF)
|
|
assert SessionDB.sanitize_title("\ufeffhello") == "hello"
|
|
|
|
def test_only_control_chars_returns_none(self):
|
|
assert SessionDB.sanitize_title("\x00\x01\x02\u200b\ufeff") is None
|
|
|
|
def test_max_length_allowed(self):
|
|
title = "A" * 100
|
|
assert SessionDB.sanitize_title(title) == title
|
|
|
|
def test_exceeds_max_length_raises(self):
|
|
title = "A" * 101
|
|
with pytest.raises(ValueError, match="too long"):
|
|
SessionDB.sanitize_title(title)
|
|
|
|
def test_unicode_emoji_allowed(self):
|
|
assert SessionDB.sanitize_title("🚀 My Project 🎉") == "🚀 My Project 🎉"
|
|
|
|
def test_cjk_characters_allowed(self):
|
|
assert SessionDB.sanitize_title("我的项目") == "我的项目"
|
|
|
|
def test_accented_characters_allowed(self):
|
|
assert SessionDB.sanitize_title("Résumé éditing") == "Résumé éditing"
|
|
|
|
def test_special_punctuation_allowed(self):
|
|
title = "PR #438 — fixing the 'auth' middleware"
|
|
assert SessionDB.sanitize_title(title) == title
|
|
|
|
def test_sanitize_applied_in_set_session_title(self, db):
|
|
"""set_session_title applies sanitize_title internally."""
|
|
db.create_session("s1", "cli")
|
|
db.set_session_title("s1", " hello\x00 world ")
|
|
assert db.get_session("s1")["title"] == "hello world"
|
|
|
|
def test_too_long_title_rejected_by_set(self, db):
|
|
"""set_session_title raises ValueError for overly long titles."""
|
|
db.create_session("s1", "cli")
|
|
with pytest.raises(ValueError, match="too long"):
|
|
db.set_session_title("s1", "X" * 150)
|
|
|
|
|
|
class TestSchemaInit:
|
|
def test_wal_mode(self, db):
|
|
cursor = db._conn.execute("PRAGMA journal_mode")
|
|
mode = cursor.fetchone()[0]
|
|
assert mode == "wal"
|
|
|
|
def test_foreign_keys_enabled(self, db):
|
|
cursor = db._conn.execute("PRAGMA foreign_keys")
|
|
assert cursor.fetchone()[0] == 1
|
|
|
|
def test_tables_exist(self, db):
|
|
cursor = db._conn.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
|
)
|
|
tables = {row[0] for row in cursor.fetchall()}
|
|
assert "sessions" in tables
|
|
assert "messages" in tables
|
|
assert "schema_version" in tables
|
|
|
|
def test_schema_version(self, db):
|
|
cursor = db._conn.execute("SELECT version FROM schema_version")
|
|
version = cursor.fetchone()[0]
|
|
assert version == 4
|
|
|
|
def test_title_column_exists(self, db):
|
|
"""Verify the title column was created in the sessions table."""
|
|
cursor = db._conn.execute("PRAGMA table_info(sessions)")
|
|
columns = {row[1] for row in cursor.fetchall()}
|
|
assert "title" in columns
|
|
|
|
def test_migration_from_v2(self, tmp_path):
|
|
"""Simulate a v2 database and verify migration adds title column."""
|
|
import sqlite3
|
|
|
|
db_path = tmp_path / "migrate_test.db"
|
|
conn = sqlite3.connect(str(db_path))
|
|
# Create v2 schema (without title column)
|
|
conn.executescript("""
|
|
CREATE TABLE schema_version (version INTEGER NOT NULL);
|
|
INSERT INTO schema_version (version) VALUES (2);
|
|
|
|
CREATE TABLE sessions (
|
|
id TEXT PRIMARY KEY,
|
|
source TEXT NOT NULL,
|
|
user_id TEXT,
|
|
model TEXT,
|
|
model_config TEXT,
|
|
system_prompt TEXT,
|
|
parent_session_id TEXT,
|
|
started_at REAL NOT NULL,
|
|
ended_at REAL,
|
|
end_reason TEXT,
|
|
message_count INTEGER DEFAULT 0,
|
|
tool_call_count INTEGER DEFAULT 0,
|
|
input_tokens INTEGER DEFAULT 0,
|
|
output_tokens INTEGER DEFAULT 0
|
|
);
|
|
|
|
CREATE TABLE messages (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
session_id TEXT NOT NULL,
|
|
role TEXT NOT NULL,
|
|
content TEXT,
|
|
tool_call_id TEXT,
|
|
tool_calls TEXT,
|
|
tool_name TEXT,
|
|
timestamp REAL NOT NULL,
|
|
token_count INTEGER,
|
|
finish_reason TEXT
|
|
);
|
|
""")
|
|
conn.execute(
|
|
"INSERT INTO sessions (id, source, started_at) VALUES (?, ?, ?)",
|
|
("existing", "cli", 1000.0),
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
# Open with SessionDB — should migrate to v4
|
|
migrated_db = SessionDB(db_path=db_path)
|
|
|
|
# Verify migration
|
|
cursor = migrated_db._conn.execute("SELECT version FROM schema_version")
|
|
assert cursor.fetchone()[0] == 4
|
|
|
|
# Verify title column exists and is NULL for existing sessions
|
|
session = migrated_db.get_session("existing")
|
|
assert session is not None
|
|
assert session["title"] is None
|
|
|
|
# Verify we can set title on migrated session
|
|
assert migrated_db.set_session_title("existing", "Migrated Title") is True
|
|
session = migrated_db.get_session("existing")
|
|
assert session["title"] == "Migrated Title"
|
|
|
|
migrated_db.close()
|
|
|
|
|
|
class TestTitleUniqueness:
|
|
"""Tests for unique title enforcement and title-based lookups."""
|
|
|
|
def test_duplicate_title_raises(self, db):
|
|
"""Setting a title already used by another session raises ValueError."""
|
|
db.create_session("s1", "cli")
|
|
db.create_session("s2", "cli")
|
|
db.set_session_title("s1", "my project")
|
|
with pytest.raises(ValueError, match="already in use"):
|
|
db.set_session_title("s2", "my project")
|
|
|
|
def test_same_session_can_keep_title(self, db):
|
|
"""A session can re-set its own title without error."""
|
|
db.create_session("s1", "cli")
|
|
db.set_session_title("s1", "my project")
|
|
# Should not raise — it's the same session
|
|
assert db.set_session_title("s1", "my project") is True
|
|
|
|
def test_null_titles_not_unique(self, db):
|
|
"""Multiple sessions can have NULL titles (no constraint violation)."""
|
|
db.create_session("s1", "cli")
|
|
db.create_session("s2", "cli")
|
|
# Both have NULL titles — no error
|
|
assert db.get_session("s1")["title"] is None
|
|
assert db.get_session("s2")["title"] is None
|
|
|
|
def test_get_session_by_title(self, db):
|
|
db.create_session("s1", "cli")
|
|
db.set_session_title("s1", "refactoring auth")
|
|
result = db.get_session_by_title("refactoring auth")
|
|
assert result is not None
|
|
assert result["id"] == "s1"
|
|
|
|
def test_get_session_by_title_not_found(self, db):
|
|
assert db.get_session_by_title("nonexistent") is None
|
|
|
|
def test_get_session_title(self, db):
|
|
db.create_session("s1", "cli")
|
|
assert db.get_session_title("s1") is None
|
|
db.set_session_title("s1", "my title")
|
|
assert db.get_session_title("s1") == "my title"
|
|
|
|
def test_get_session_title_nonexistent(self, db):
|
|
assert db.get_session_title("nonexistent") is None
|
|
|
|
|
|
class TestTitleLineage:
|
|
"""Tests for title lineage resolution and auto-numbering."""
|
|
|
|
def test_resolve_exact_title(self, db):
|
|
db.create_session("s1", "cli")
|
|
db.set_session_title("s1", "my project")
|
|
assert db.resolve_session_by_title("my project") == "s1"
|
|
|
|
def test_resolve_returns_latest_numbered(self, db):
|
|
"""When numbered variants exist, return the most recent one."""
|
|
import time
|
|
db.create_session("s1", "cli")
|
|
db.set_session_title("s1", "my project")
|
|
time.sleep(0.01)
|
|
db.create_session("s2", "cli")
|
|
db.set_session_title("s2", "my project #2")
|
|
time.sleep(0.01)
|
|
db.create_session("s3", "cli")
|
|
db.set_session_title("s3", "my project #3")
|
|
# Resolving "my project" should return s3 (latest numbered variant)
|
|
assert db.resolve_session_by_title("my project") == "s3"
|
|
|
|
def test_resolve_exact_numbered(self, db):
|
|
"""Resolving an exact numbered title returns that specific session."""
|
|
db.create_session("s1", "cli")
|
|
db.set_session_title("s1", "my project")
|
|
db.create_session("s2", "cli")
|
|
db.set_session_title("s2", "my project #2")
|
|
# Resolving "my project #2" exactly should return s2
|
|
assert db.resolve_session_by_title("my project #2") == "s2"
|
|
|
|
def test_resolve_nonexistent_title(self, db):
|
|
assert db.resolve_session_by_title("nonexistent") is None
|
|
|
|
def test_next_title_no_existing(self, db):
|
|
"""With no existing sessions, base title is returned as-is."""
|
|
assert db.get_next_title_in_lineage("my project") == "my project"
|
|
|
|
def test_next_title_first_continuation(self, db):
|
|
"""First continuation after the original gets #2."""
|
|
db.create_session("s1", "cli")
|
|
db.set_session_title("s1", "my project")
|
|
assert db.get_next_title_in_lineage("my project") == "my project #2"
|
|
|
|
def test_next_title_increments(self, db):
|
|
"""Each continuation increments the number."""
|
|
db.create_session("s1", "cli")
|
|
db.set_session_title("s1", "my project")
|
|
db.create_session("s2", "cli")
|
|
db.set_session_title("s2", "my project #2")
|
|
db.create_session("s3", "cli")
|
|
db.set_session_title("s3", "my project #3")
|
|
assert db.get_next_title_in_lineage("my project") == "my project #4"
|
|
|
|
def test_next_title_strips_existing_number(self, db):
|
|
"""Passing a numbered title strips the number and finds the base."""
|
|
db.create_session("s1", "cli")
|
|
db.set_session_title("s1", "my project")
|
|
db.create_session("s2", "cli")
|
|
db.set_session_title("s2", "my project #2")
|
|
# Even when called with "my project #2", it should return #3
|
|
assert db.get_next_title_in_lineage("my project #2") == "my project #3"
|
|
|
|
|
|
class TestTitleSqlWildcards:
|
|
"""Titles containing SQL LIKE wildcards (%, _) must not cause false matches."""
|
|
|
|
def test_resolve_title_with_underscore(self, db):
|
|
"""A title like 'test_project' should not match 'testXproject #2'."""
|
|
db.create_session("s1", "cli")
|
|
db.set_session_title("s1", "test_project")
|
|
db.create_session("s2", "cli")
|
|
db.set_session_title("s2", "testXproject #2")
|
|
# Resolving "test_project" should return s1 (exact), not s2
|
|
assert db.resolve_session_by_title("test_project") == "s1"
|
|
|
|
def test_resolve_title_with_percent(self, db):
|
|
"""A title with '%' should not wildcard-match unrelated sessions."""
|
|
db.create_session("s1", "cli")
|
|
db.set_session_title("s1", "100% done")
|
|
db.create_session("s2", "cli")
|
|
db.set_session_title("s2", "100X done #2")
|
|
# Should resolve to s1 (exact), not s2
|
|
assert db.resolve_session_by_title("100% done") == "s1"
|
|
|
|
def test_next_lineage_with_underscore(self, db):
|
|
"""get_next_title_in_lineage with underscores doesn't match wrong sessions."""
|
|
db.create_session("s1", "cli")
|
|
db.set_session_title("s1", "test_project")
|
|
db.create_session("s2", "cli")
|
|
db.set_session_title("s2", "testXproject #2")
|
|
# Only "test_project" exists, so next should be "test_project #2"
|
|
assert db.get_next_title_in_lineage("test_project") == "test_project #2"
|
|
|
|
|
|
class TestListSessionsRich:
|
|
"""Tests for enhanced session listing with preview and last_active."""
|
|
|
|
def test_preview_from_first_user_message(self, db):
|
|
db.create_session("s1", "cli")
|
|
db.append_message("s1", "system", "You are a helpful assistant.")
|
|
db.append_message("s1", "user", "Help me refactor the auth module please")
|
|
db.append_message("s1", "assistant", "Sure, let me look at it.")
|
|
sessions = db.list_sessions_rich()
|
|
assert len(sessions) == 1
|
|
assert "Help me refactor the auth module" in sessions[0]["preview"]
|
|
|
|
def test_preview_truncated_at_60(self, db):
|
|
db.create_session("s1", "cli")
|
|
long_msg = "A" * 100
|
|
db.append_message("s1", "user", long_msg)
|
|
sessions = db.list_sessions_rich()
|
|
assert len(sessions[0]["preview"]) == 63 # 60 chars + "..."
|
|
assert sessions[0]["preview"].endswith("...")
|
|
|
|
def test_preview_empty_when_no_user_messages(self, db):
|
|
db.create_session("s1", "cli")
|
|
db.append_message("s1", "system", "System prompt")
|
|
sessions = db.list_sessions_rich()
|
|
assert sessions[0]["preview"] == ""
|
|
|
|
def test_last_active_from_latest_message(self, db):
|
|
import time
|
|
db.create_session("s1", "cli")
|
|
db.append_message("s1", "user", "Hello")
|
|
time.sleep(0.01)
|
|
db.append_message("s1", "assistant", "Hi there!")
|
|
sessions = db.list_sessions_rich()
|
|
# last_active should be close to now (the assistant message)
|
|
assert sessions[0]["last_active"] > sessions[0]["started_at"]
|
|
|
|
def test_last_active_fallback_to_started_at(self, db):
|
|
db.create_session("s1", "cli")
|
|
sessions = db.list_sessions_rich()
|
|
# No messages, so last_active falls back to started_at
|
|
assert sessions[0]["last_active"] == sessions[0]["started_at"]
|
|
|
|
def test_rich_list_includes_title(self, db):
|
|
db.create_session("s1", "cli")
|
|
db.set_session_title("s1", "refactoring auth")
|
|
sessions = db.list_sessions_rich()
|
|
assert sessions[0]["title"] == "refactoring auth"
|
|
|
|
def test_rich_list_source_filter(self, db):
|
|
db.create_session("s1", "cli")
|
|
db.create_session("s2", "telegram")
|
|
sessions = db.list_sessions_rich(source="cli")
|
|
assert len(sessions) == 1
|
|
assert sessions[0]["id"] == "s1"
|
|
|
|
def test_preview_newlines_collapsed(self, db):
|
|
db.create_session("s1", "cli")
|
|
db.append_message("s1", "user", "Line one\nLine two\nLine three")
|
|
sessions = db.list_sessions_rich()
|
|
assert "\n" not in sessions[0]["preview"]
|
|
assert "Line one Line two" in sessions[0]["preview"]
|
|
|
|
|
|
class TestResolveSessionByNameOrId:
|
|
"""Tests for the main.py helper that resolves names or IDs."""
|
|
|
|
def test_resolve_by_id(self, db):
|
|
db.create_session("test-id-123", "cli")
|
|
session = db.get_session("test-id-123")
|
|
assert session is not None
|
|
assert session["id"] == "test-id-123"
|
|
|
|
def test_resolve_by_title_falls_back(self, db):
|
|
db.create_session("s1", "cli")
|
|
db.set_session_title("s1", "my project")
|
|
result = db.resolve_session_by_title("my project")
|
|
assert result == "s1"
|