- Renamed test method for clarity and added comprehensive tests for `SessionSource` including handling of numeric `chat_id`, missing optional fields, and invalid platforms. - Introduced tests for session source descriptions based on chat types and names, ensuring accurate representation in prompts. - Improved file tools tests by validating schema structures, ensuring no duplicate model IDs, and enhancing error handling in file operations.
202 lines
6.5 KiB
Python
202 lines
6.5 KiB
Python
"""Tests for gateway session management."""
|
|
|
|
import pytest
|
|
from gateway.config import Platform, HomeChannel, GatewayConfig, PlatformConfig
|
|
from gateway.session import (
|
|
SessionSource,
|
|
build_session_context,
|
|
build_session_context_prompt,
|
|
)
|
|
|
|
|
|
class TestSessionSourceRoundtrip:
|
|
def test_full_roundtrip(self):
|
|
source = SessionSource(
|
|
platform=Platform.TELEGRAM,
|
|
chat_id="12345",
|
|
chat_name="My Group",
|
|
chat_type="group",
|
|
user_id="99",
|
|
user_name="alice",
|
|
thread_id="t1",
|
|
)
|
|
d = source.to_dict()
|
|
restored = SessionSource.from_dict(d)
|
|
|
|
assert restored.platform == Platform.TELEGRAM
|
|
assert restored.chat_id == "12345"
|
|
assert restored.chat_name == "My Group"
|
|
assert restored.chat_type == "group"
|
|
assert restored.user_id == "99"
|
|
assert restored.user_name == "alice"
|
|
assert restored.thread_id == "t1"
|
|
|
|
def test_minimal_roundtrip(self):
|
|
source = SessionSource(platform=Platform.LOCAL, chat_id="cli")
|
|
d = source.to_dict()
|
|
restored = SessionSource.from_dict(d)
|
|
assert restored.platform == Platform.LOCAL
|
|
assert restored.chat_id == "cli"
|
|
assert restored.chat_type == "dm" # default value preserved
|
|
|
|
def test_chat_id_coerced_to_string(self):
|
|
"""from_dict should handle numeric chat_id (common from Telegram)."""
|
|
restored = SessionSource.from_dict({
|
|
"platform": "telegram",
|
|
"chat_id": 12345,
|
|
})
|
|
assert restored.chat_id == "12345"
|
|
assert isinstance(restored.chat_id, str)
|
|
|
|
def test_missing_optional_fields(self):
|
|
restored = SessionSource.from_dict({
|
|
"platform": "discord",
|
|
"chat_id": "abc",
|
|
})
|
|
assert restored.chat_name is None
|
|
assert restored.user_id is None
|
|
assert restored.user_name is None
|
|
assert restored.thread_id is None
|
|
assert restored.chat_type == "dm"
|
|
|
|
def test_invalid_platform_raises(self):
|
|
with pytest.raises((ValueError, KeyError)):
|
|
SessionSource.from_dict({"platform": "nonexistent", "chat_id": "1"})
|
|
|
|
|
|
class TestSessionSourceDescription:
|
|
def test_local_cli(self):
|
|
source = SessionSource.local_cli()
|
|
assert source.description == "CLI terminal"
|
|
|
|
def test_dm_with_username(self):
|
|
source = SessionSource(
|
|
platform=Platform.TELEGRAM, chat_id="123",
|
|
chat_type="dm", user_name="bob",
|
|
)
|
|
assert "DM" in source.description
|
|
assert "bob" in source.description
|
|
|
|
def test_dm_without_username_falls_back_to_user_id(self):
|
|
source = SessionSource(
|
|
platform=Platform.TELEGRAM, chat_id="123",
|
|
chat_type="dm", user_id="456",
|
|
)
|
|
assert "456" in source.description
|
|
|
|
def test_group_shows_chat_name(self):
|
|
source = SessionSource(
|
|
platform=Platform.DISCORD, chat_id="789",
|
|
chat_type="group", chat_name="Dev Chat",
|
|
)
|
|
assert "group" in source.description
|
|
assert "Dev Chat" in source.description
|
|
|
|
def test_channel_type(self):
|
|
source = SessionSource(
|
|
platform=Platform.TELEGRAM, chat_id="100",
|
|
chat_type="channel", chat_name="Announcements",
|
|
)
|
|
assert "channel" in source.description
|
|
assert "Announcements" in source.description
|
|
|
|
def test_thread_id_appended(self):
|
|
source = SessionSource(
|
|
platform=Platform.DISCORD, chat_id="789",
|
|
chat_type="group", chat_name="General",
|
|
thread_id="thread-42",
|
|
)
|
|
assert "thread" in source.description
|
|
assert "thread-42" in source.description
|
|
|
|
def test_unknown_chat_type_uses_name(self):
|
|
source = SessionSource(
|
|
platform=Platform.SLACK, chat_id="C01",
|
|
chat_type="forum", chat_name="Questions",
|
|
)
|
|
assert "Questions" in source.description
|
|
|
|
|
|
class TestLocalCliFactory:
|
|
def test_local_cli_defaults(self):
|
|
source = SessionSource.local_cli()
|
|
assert source.platform == Platform.LOCAL
|
|
assert source.chat_id == "cli"
|
|
assert source.chat_type == "dm"
|
|
assert source.chat_name == "CLI terminal"
|
|
|
|
|
|
class TestBuildSessionContextPrompt:
|
|
def test_telegram_prompt_contains_platform_and_chat(self):
|
|
config = GatewayConfig(
|
|
platforms={
|
|
Platform.TELEGRAM: PlatformConfig(
|
|
enabled=True,
|
|
token="fake-token",
|
|
home_channel=HomeChannel(
|
|
platform=Platform.TELEGRAM,
|
|
chat_id="111",
|
|
name="Home Chat",
|
|
),
|
|
),
|
|
},
|
|
)
|
|
source = SessionSource(
|
|
platform=Platform.TELEGRAM,
|
|
chat_id="111",
|
|
chat_name="Home Chat",
|
|
chat_type="dm",
|
|
)
|
|
ctx = build_session_context(source, config)
|
|
prompt = build_session_context_prompt(ctx)
|
|
|
|
assert "Telegram" in prompt
|
|
assert "Home Chat" in prompt
|
|
|
|
def test_discord_prompt(self):
|
|
config = GatewayConfig(
|
|
platforms={
|
|
Platform.DISCORD: PlatformConfig(
|
|
enabled=True,
|
|
token="fake-discord-token",
|
|
),
|
|
},
|
|
)
|
|
source = SessionSource(
|
|
platform=Platform.DISCORD,
|
|
chat_id="guild-123",
|
|
chat_name="Server",
|
|
chat_type="group",
|
|
user_name="alice",
|
|
)
|
|
ctx = build_session_context(source, config)
|
|
prompt = build_session_context_prompt(ctx)
|
|
|
|
assert "Discord" in prompt
|
|
|
|
def test_local_prompt_mentions_machine(self):
|
|
config = GatewayConfig()
|
|
source = SessionSource.local_cli()
|
|
ctx = build_session_context(source, config)
|
|
prompt = build_session_context_prompt(ctx)
|
|
|
|
assert "Local" in prompt
|
|
assert "machine running this agent" in prompt
|
|
|
|
def test_whatsapp_prompt(self):
|
|
config = GatewayConfig(
|
|
platforms={
|
|
Platform.WHATSAPP: PlatformConfig(enabled=True, token=""),
|
|
},
|
|
)
|
|
source = SessionSource(
|
|
platform=Platform.WHATSAPP,
|
|
chat_id="15551234567@s.whatsapp.net",
|
|
chat_type="dm",
|
|
user_name="Phone User",
|
|
)
|
|
ctx = build_session_context(source, config)
|
|
prompt = build_session_context_prompt(ctx)
|
|
|
|
assert "WhatsApp" in prompt or "whatsapp" in prompt.lower()
|