Files
hermes-agent/tests/gateway/test_session.py
teknium1 8602e61fca test: cover gateway session model backfill
Add regression coverage for backfilling NULL gateway session models in SQLite, preserving existing models, and forwarding the resolved agent model through SessionStore updates.
2026-03-14 06:44:14 -07:00

605 lines
21 KiB
Python

"""Tests for gateway session management."""
import json
import pytest
from pathlib import Path
from unittest.mock import patch, MagicMock
from gateway.config import Platform, HomeChannel, GatewayConfig, PlatformConfig
from gateway.session import (
SessionSource,
SessionStore,
build_session_context,
build_session_context_prompt,
build_session_key,
)
class TestSessionSourceRoundtrip:
def test_full_roundtrip(self):
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="12345",
chat_name="My Group",
chat_type="group",
user_id="99",
user_name="alice",
thread_id="t1",
)
d = source.to_dict()
restored = SessionSource.from_dict(d)
assert restored.platform == Platform.TELEGRAM
assert restored.chat_id == "12345"
assert restored.chat_name == "My Group"
assert restored.chat_type == "group"
assert restored.user_id == "99"
assert restored.user_name == "alice"
assert restored.thread_id == "t1"
def test_full_roundtrip_with_chat_topic(self):
"""chat_topic should survive to_dict/from_dict roundtrip."""
source = SessionSource(
platform=Platform.DISCORD,
chat_id="789",
chat_name="Server / #project-planning",
chat_type="group",
user_id="42",
user_name="bob",
chat_topic="Planning and coordination for Project X",
)
d = source.to_dict()
assert d["chat_topic"] == "Planning and coordination for Project X"
restored = SessionSource.from_dict(d)
assert restored.chat_topic == "Planning and coordination for Project X"
assert restored.chat_name == "Server / #project-planning"
def test_minimal_roundtrip(self):
source = SessionSource(platform=Platform.LOCAL, chat_id="cli")
d = source.to_dict()
restored = SessionSource.from_dict(d)
assert restored.platform == Platform.LOCAL
assert restored.chat_id == "cli"
assert restored.chat_type == "dm" # default value preserved
def test_chat_id_coerced_to_string(self):
"""from_dict should handle numeric chat_id (common from Telegram)."""
restored = SessionSource.from_dict({
"platform": "telegram",
"chat_id": 12345,
})
assert restored.chat_id == "12345"
assert isinstance(restored.chat_id, str)
def test_missing_optional_fields(self):
restored = SessionSource.from_dict({
"platform": "discord",
"chat_id": "abc",
})
assert restored.chat_name is None
assert restored.user_id is None
assert restored.user_name is None
assert restored.thread_id is None
assert restored.chat_topic is None
assert restored.chat_type == "dm"
def test_invalid_platform_raises(self):
with pytest.raises((ValueError, KeyError)):
SessionSource.from_dict({"platform": "nonexistent", "chat_id": "1"})
class TestSessionSourceDescription:
def test_local_cli(self):
source = SessionSource.local_cli()
assert source.description == "CLI terminal"
def test_dm_with_username(self):
source = SessionSource(
platform=Platform.TELEGRAM, chat_id="123",
chat_type="dm", user_name="bob",
)
assert "DM" in source.description
assert "bob" in source.description
def test_dm_without_username_falls_back_to_user_id(self):
source = SessionSource(
platform=Platform.TELEGRAM, chat_id="123",
chat_type="dm", user_id="456",
)
assert "456" in source.description
def test_group_shows_chat_name(self):
source = SessionSource(
platform=Platform.DISCORD, chat_id="789",
chat_type="group", chat_name="Dev Chat",
)
assert "group" in source.description
assert "Dev Chat" in source.description
def test_channel_type(self):
source = SessionSource(
platform=Platform.TELEGRAM, chat_id="100",
chat_type="channel", chat_name="Announcements",
)
assert "channel" in source.description
assert "Announcements" in source.description
def test_thread_id_appended(self):
source = SessionSource(
platform=Platform.DISCORD, chat_id="789",
chat_type="group", chat_name="General",
thread_id="thread-42",
)
assert "thread" in source.description
assert "thread-42" in source.description
def test_unknown_chat_type_uses_name(self):
source = SessionSource(
platform=Platform.SLACK, chat_id="C01",
chat_type="forum", chat_name="Questions",
)
assert "Questions" in source.description
class TestLocalCliFactory:
def test_local_cli_defaults(self):
source = SessionSource.local_cli()
assert source.platform == Platform.LOCAL
assert source.chat_id == "cli"
assert source.chat_type == "dm"
assert source.chat_name == "CLI terminal"
class TestBuildSessionContextPrompt:
def test_telegram_prompt_contains_platform_and_chat(self):
config = GatewayConfig(
platforms={
Platform.TELEGRAM: PlatformConfig(
enabled=True,
token="fake-token",
home_channel=HomeChannel(
platform=Platform.TELEGRAM,
chat_id="111",
name="Home Chat",
),
),
},
)
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="111",
chat_name="Home Chat",
chat_type="dm",
)
ctx = build_session_context(source, config)
prompt = build_session_context_prompt(ctx)
assert "Telegram" in prompt
assert "Home Chat" in prompt
def test_discord_prompt(self):
config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(
enabled=True,
token="fake-d...oken",
),
},
)
source = SessionSource(
platform=Platform.DISCORD,
chat_id="guild-123",
chat_name="Server",
chat_type="group",
user_name="alice",
)
ctx = build_session_context(source, config)
prompt = build_session_context_prompt(ctx)
assert "Discord" in prompt
assert "cannot search" in prompt.lower() or "do not have access" in prompt.lower()
def test_slack_prompt_includes_platform_notes(self):
config = GatewayConfig(
platforms={
Platform.SLACK: PlatformConfig(enabled=True, token="fake"),
},
)
source = SessionSource(
platform=Platform.SLACK,
chat_id="C123",
chat_name="general",
chat_type="group",
user_name="bob",
)
ctx = build_session_context(source, config)
prompt = build_session_context_prompt(ctx)
assert "Slack" in prompt
assert "cannot search" in prompt.lower()
assert "pin" in prompt.lower()
def test_discord_prompt_with_channel_topic(self):
"""Channel topic should appear in the session context prompt."""
config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(
enabled=True,
token="fake-discord-token",
),
},
)
source = SessionSource(
platform=Platform.DISCORD,
chat_id="guild-123",
chat_name="Server / #project-planning",
chat_type="group",
user_name="alice",
chat_topic="Planning and coordination for Project X",
)
ctx = build_session_context(source, config)
prompt = build_session_context_prompt(ctx)
assert "Discord" in prompt
assert "**Channel Topic:** Planning and coordination for Project X" in prompt
def test_prompt_omits_channel_topic_when_none(self):
"""Channel Topic line should NOT appear when chat_topic is None."""
config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(
enabled=True,
token="fake-discord-token",
),
},
)
source = SessionSource(
platform=Platform.DISCORD,
chat_id="guild-123",
chat_name="Server / #general",
chat_type="group",
user_name="alice",
)
ctx = build_session_context(source, config)
prompt = build_session_context_prompt(ctx)
assert "Channel Topic" not in prompt
def test_local_prompt_mentions_machine(self):
config = GatewayConfig()
source = SessionSource.local_cli()
ctx = build_session_context(source, config)
prompt = build_session_context_prompt(ctx)
assert "Local" in prompt
assert "machine running this agent" in prompt
def test_whatsapp_prompt(self):
config = GatewayConfig(
platforms={
Platform.WHATSAPP: PlatformConfig(enabled=True, token=""),
},
)
source = SessionSource(
platform=Platform.WHATSAPP,
chat_id="15551234567@s.whatsapp.net",
chat_type="dm",
user_name="Phone User",
)
ctx = build_session_context(source, config)
prompt = build_session_context_prompt(ctx)
assert "WhatsApp" in prompt or "whatsapp" in prompt.lower()
class TestSessionStoreRewriteTranscript:
"""Regression: /retry and /undo must persist truncated history to disk."""
@pytest.fixture()
def store(self, tmp_path):
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
s = SessionStore(sessions_dir=tmp_path, config=config)
s._db = None # no SQLite for these tests
s._loaded = True
return s
def test_rewrite_replaces_jsonl(self, store, tmp_path):
session_id = "test_session_1"
# Write initial transcript
for msg in [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
{"role": "user", "content": "undo this"},
{"role": "assistant", "content": "ok"},
]:
store.append_to_transcript(session_id, msg)
# Rewrite with truncated history
store.rewrite_transcript(session_id, [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
])
reloaded = store.load_transcript(session_id)
assert len(reloaded) == 2
assert reloaded[0]["content"] == "hello"
assert reloaded[1]["content"] == "hi"
def test_rewrite_with_empty_list(self, store):
session_id = "test_session_2"
store.append_to_transcript(session_id, {"role": "user", "content": "hi"})
store.rewrite_transcript(session_id, [])
reloaded = store.load_transcript(session_id)
assert reloaded == []
class TestWhatsAppDMSessionKeyConsistency:
"""Regression: all session-key construction must go through build_session_key
so WhatsApp DMs include chat_id while other DMs do not."""
@pytest.fixture()
def store(self, tmp_path):
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
s = SessionStore(sessions_dir=tmp_path, config=config)
s._db = None
s._loaded = True
return s
def test_whatsapp_dm_includes_chat_id(self):
source = SessionSource(
platform=Platform.WHATSAPP,
chat_id="15551234567@s.whatsapp.net",
chat_type="dm",
user_name="Phone User",
)
key = build_session_key(source)
assert key == "agent:main:whatsapp:dm:15551234567@s.whatsapp.net"
def test_store_delegates_to_build_session_key(self, store):
"""SessionStore._generate_session_key must produce the same result."""
source = SessionSource(
platform=Platform.WHATSAPP,
chat_id="15551234567@s.whatsapp.net",
chat_type="dm",
user_name="Phone User",
)
assert store._generate_session_key(source) == build_session_key(source)
def test_telegram_dm_omits_chat_id(self):
"""Non-WhatsApp DMs should still omit chat_id (single owner DM)."""
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="99",
chat_type="dm",
)
key = build_session_key(source)
assert key == "agent:main:telegram:dm"
def test_discord_group_includes_chat_id(self):
"""Group/channel keys include chat_type and chat_id."""
source = SessionSource(
platform=Platform.DISCORD,
chat_id="guild-123",
chat_type="group",
)
key = build_session_key(source)
assert key == "agent:main:discord:group:guild-123"
def test_group_thread_includes_thread_id(self):
"""Forum-style threads need a distinct session key within one group."""
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="-1002285219667",
chat_type="group",
thread_id="17585",
)
key = build_session_key(source)
assert key == "agent:main:telegram:group:-1002285219667:17585"
class TestSessionStoreEntriesAttribute:
"""Regression: /reset must access _entries, not _sessions."""
def test_entries_attribute_exists(self):
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=Path("/tmp"), config=config)
store._loaded = True
assert hasattr(store, "_entries")
assert not hasattr(store, "_sessions")
class TestHasAnySessions:
"""Tests for has_any_sessions() fix (issue #351)."""
@pytest.fixture
def store_with_mock_db(self, tmp_path):
"""SessionStore with a mocked database."""
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
s = SessionStore(sessions_dir=tmp_path, config=config)
s._loaded = True
s._entries = {}
s._db = MagicMock()
return s
def test_uses_database_count_when_available(self, store_with_mock_db):
"""has_any_sessions should use database session_count, not len(_entries)."""
store = store_with_mock_db
# Simulate single-platform user with only 1 entry in memory
store._entries = {"telegram:12345": MagicMock()}
# But database has 3 sessions (current + 2 previous resets)
store._db.session_count.return_value = 3
assert store.has_any_sessions() is True
store._db.session_count.assert_called_once()
def test_first_session_ever_returns_false(self, store_with_mock_db):
"""First session ever should return False (only current session in DB)."""
store = store_with_mock_db
store._entries = {"telegram:12345": MagicMock()}
# Database has exactly 1 session (the current one just created)
store._db.session_count.return_value = 1
assert store.has_any_sessions() is False
def test_fallback_without_database(self, tmp_path):
"""Should fall back to len(_entries) when DB is not available."""
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=tmp_path, config=config)
store._loaded = True
store._db = None
store._entries = {"key1": MagicMock(), "key2": MagicMock()}
# > 1 entries means has sessions
assert store.has_any_sessions() is True
store._entries = {"key1": MagicMock()}
assert store.has_any_sessions() is False
class TestLastPromptTokens:
"""Tests for the last_prompt_tokens field — actual API token tracking."""
def test_session_entry_default(self):
"""New sessions should have last_prompt_tokens=0."""
from gateway.session import SessionEntry
from datetime import datetime
entry = SessionEntry(
session_key="test",
session_id="s1",
created_at=datetime.now(),
updated_at=datetime.now(),
)
assert entry.last_prompt_tokens == 0
def test_session_entry_roundtrip(self):
"""last_prompt_tokens should survive serialization/deserialization."""
from gateway.session import SessionEntry
from datetime import datetime
entry = SessionEntry(
session_key="test",
session_id="s1",
created_at=datetime.now(),
updated_at=datetime.now(),
last_prompt_tokens=42000,
)
d = entry.to_dict()
assert d["last_prompt_tokens"] == 42000
restored = SessionEntry.from_dict(d)
assert restored.last_prompt_tokens == 42000
def test_session_entry_from_old_data(self):
"""Old session data without last_prompt_tokens should default to 0."""
from gateway.session import SessionEntry
data = {
"session_key": "test",
"session_id": "s1",
"created_at": "2025-01-01T00:00:00",
"updated_at": "2025-01-01T00:00:00",
"input_tokens": 100,
"output_tokens": 50,
"total_tokens": 150,
# No last_prompt_tokens — old format
}
entry = SessionEntry.from_dict(data)
assert entry.last_prompt_tokens == 0
def test_update_session_sets_last_prompt_tokens(self, tmp_path):
"""update_session should store the actual prompt token count."""
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=tmp_path, config=config)
store._loaded = True
store._db = None
store._save = MagicMock()
from gateway.session import SessionEntry
from datetime import datetime
entry = SessionEntry(
session_key="k1",
session_id="s1",
created_at=datetime.now(),
updated_at=datetime.now(),
)
store._entries = {"k1": entry}
store.update_session("k1", last_prompt_tokens=85000)
assert entry.last_prompt_tokens == 85000
def test_update_session_none_does_not_change(self, tmp_path):
"""update_session with default (None) should not change last_prompt_tokens."""
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=tmp_path, config=config)
store._loaded = True
store._db = None
store._save = MagicMock()
from gateway.session import SessionEntry
from datetime import datetime
entry = SessionEntry(
session_key="k1",
session_id="s1",
created_at=datetime.now(),
updated_at=datetime.now(),
last_prompt_tokens=50000,
)
store._entries = {"k1": entry}
store.update_session("k1") # No last_prompt_tokens arg
assert entry.last_prompt_tokens == 50000 # unchanged
def test_update_session_zero_resets(self, tmp_path):
"""update_session with last_prompt_tokens=0 should reset the field."""
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=tmp_path, config=config)
store._loaded = True
store._db = None
store._save = MagicMock()
from gateway.session import SessionEntry
from datetime import datetime
entry = SessionEntry(
session_key="k1",
session_id="s1",
created_at=datetime.now(),
updated_at=datetime.now(),
last_prompt_tokens=85000,
)
store._entries = {"k1": entry}
store.update_session("k1", last_prompt_tokens=0)
assert entry.last_prompt_tokens == 0
def test_update_session_passes_model_to_db(self, tmp_path):
"""Gateway session updates should forward the resolved model to SQLite."""
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=tmp_path, config=config)
store._loaded = True
store._save = MagicMock()
store._db = MagicMock()
from gateway.session import SessionEntry
from datetime import datetime
entry = SessionEntry(
session_key="k1",
session_id="s1",
created_at=datetime.now(),
updated_at=datetime.now(),
)
store._entries = {"k1": entry}
store.update_session("k1", model="openai/gpt-5.4")
store._db.update_token_counts.assert_called_once_with(
"s1", 0, 0, model="openai/gpt-5.4"
)