test: enhance session source tests and add validation for chat types
- Renamed test method for clarity and added comprehensive tests for `SessionSource` including handling of numeric `chat_id`, missing optional fields, and invalid platforms. - Introduced tests for session source descriptions based on chat types and names, ensuring accurate representation in prompts. - Improved file tools tests by validating schema structures, ensuring no duplicate model IDs, and enhancing error handling in file operations.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Tests for gateway session management."""
|
||||
|
||||
import pytest
|
||||
from gateway.config import Platform, HomeChannel, GatewayConfig, PlatformConfig
|
||||
from gateway.session import (
|
||||
SessionSource,
|
||||
@@ -9,7 +10,7 @@ from gateway.session import (
|
||||
|
||||
|
||||
class TestSessionSourceRoundtrip:
|
||||
def test_to_dict_from_dict(self):
|
||||
def test_full_roundtrip(self):
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="12345",
|
||||
@@ -36,22 +37,97 @@ class TestSessionSourceRoundtrip:
|
||||
restored = SessionSource.from_dict(d)
|
||||
assert restored.platform == Platform.LOCAL
|
||||
assert restored.chat_id == "cli"
|
||||
assert restored.chat_type == "dm" # default value preserved
|
||||
|
||||
def test_chat_id_coerced_to_string(self):
|
||||
"""from_dict should handle numeric chat_id (common from Telegram)."""
|
||||
restored = SessionSource.from_dict({
|
||||
"platform": "telegram",
|
||||
"chat_id": 12345,
|
||||
})
|
||||
assert restored.chat_id == "12345"
|
||||
assert isinstance(restored.chat_id, str)
|
||||
|
||||
def test_missing_optional_fields(self):
|
||||
restored = SessionSource.from_dict({
|
||||
"platform": "discord",
|
||||
"chat_id": "abc",
|
||||
})
|
||||
assert restored.chat_name is None
|
||||
assert restored.user_id is None
|
||||
assert restored.user_name is None
|
||||
assert restored.thread_id is None
|
||||
assert restored.chat_type == "dm"
|
||||
|
||||
def test_invalid_platform_raises(self):
|
||||
with pytest.raises((ValueError, KeyError)):
|
||||
SessionSource.from_dict({"platform": "nonexistent", "chat_id": "1"})
|
||||
|
||||
|
||||
class TestLocalCliSource:
|
||||
class TestSessionSourceDescription:
|
||||
def test_local_cli(self):
|
||||
source = SessionSource.local_cli()
|
||||
assert source.description == "CLI terminal"
|
||||
|
||||
def test_dm_with_username(self):
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM, chat_id="123",
|
||||
chat_type="dm", user_name="bob",
|
||||
)
|
||||
assert "DM" in source.description
|
||||
assert "bob" in source.description
|
||||
|
||||
def test_dm_without_username_falls_back_to_user_id(self):
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM, chat_id="123",
|
||||
chat_type="dm", user_id="456",
|
||||
)
|
||||
assert "456" in source.description
|
||||
|
||||
def test_group_shows_chat_name(self):
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD, chat_id="789",
|
||||
chat_type="group", chat_name="Dev Chat",
|
||||
)
|
||||
assert "group" in source.description
|
||||
assert "Dev Chat" in source.description
|
||||
|
||||
def test_channel_type(self):
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM, chat_id="100",
|
||||
chat_type="channel", chat_name="Announcements",
|
||||
)
|
||||
assert "channel" in source.description
|
||||
assert "Announcements" in source.description
|
||||
|
||||
def test_thread_id_appended(self):
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD, chat_id="789",
|
||||
chat_type="group", chat_name="General",
|
||||
thread_id="thread-42",
|
||||
)
|
||||
assert "thread" in source.description
|
||||
assert "thread-42" in source.description
|
||||
|
||||
def test_unknown_chat_type_uses_name(self):
|
||||
source = SessionSource(
|
||||
platform=Platform.SLACK, chat_id="C01",
|
||||
chat_type="forum", chat_name="Questions",
|
||||
)
|
||||
assert "Questions" in source.description
|
||||
|
||||
|
||||
class TestLocalCliFactory:
|
||||
def test_local_cli_defaults(self):
|
||||
source = SessionSource.local_cli()
|
||||
assert source.platform == Platform.LOCAL
|
||||
assert source.chat_id == "cli"
|
||||
assert source.chat_type == "dm"
|
||||
|
||||
def test_description_local(self):
|
||||
source = SessionSource.local_cli()
|
||||
assert source.description == "CLI terminal"
|
||||
assert source.chat_name == "CLI terminal"
|
||||
|
||||
|
||||
class TestBuildSessionContextPrompt:
|
||||
def test_contains_platform_info(self):
|
||||
def test_telegram_prompt_contains_platform_and_chat(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(
|
||||
@@ -76,9 +152,29 @@ class TestBuildSessionContextPrompt:
|
||||
|
||||
assert "Telegram" in prompt
|
||||
assert "Home Chat" in prompt
|
||||
assert "Session Context" in prompt
|
||||
|
||||
def test_local_source_prompt(self):
|
||||
def test_discord_prompt(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(
|
||||
enabled=True,
|
||||
token="fake-discord-token",
|
||||
),
|
||||
},
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_name="Server",
|
||||
chat_type="group",
|
||||
user_name="alice",
|
||||
)
|
||||
ctx = build_session_context(source, config)
|
||||
prompt = build_session_context_prompt(ctx)
|
||||
|
||||
assert "Discord" in prompt
|
||||
|
||||
def test_local_prompt_mentions_machine(self):
|
||||
config = GatewayConfig()
|
||||
source = SessionSource.local_cli()
|
||||
ctx = build_session_context(source, config)
|
||||
@@ -86,3 +182,20 @@ class TestBuildSessionContextPrompt:
|
||||
|
||||
assert "Local" in prompt
|
||||
assert "machine running this agent" in prompt
|
||||
|
||||
def test_whatsapp_prompt(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.WHATSAPP: PlatformConfig(enabled=True, token=""),
|
||||
},
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.WHATSAPP,
|
||||
chat_id="15551234567@s.whatsapp.net",
|
||||
chat_type="dm",
|
||||
user_name="Phone User",
|
||||
)
|
||||
ctx = build_session_context(source, config)
|
||||
prompt = build_session_context_prompt(ctx)
|
||||
|
||||
assert "WhatsApp" in prompt or "whatsapp" in prompt.lower()
|
||||
|
||||
@@ -4,30 +4,53 @@ from hermes_cli.models import OPENROUTER_MODELS, menu_labels, model_ids
|
||||
|
||||
|
||||
class TestModelIds:
|
||||
def test_returns_strings(self):
|
||||
def test_returns_non_empty_list(self):
|
||||
ids = model_ids()
|
||||
assert isinstance(ids, list)
|
||||
assert len(ids) > 0
|
||||
assert all(isinstance(mid, str) for mid in ids)
|
||||
|
||||
def test_ids_match_models_list(self):
|
||||
ids = model_ids()
|
||||
expected = [mid for mid, _ in OPENROUTER_MODELS]
|
||||
assert ids == expected
|
||||
|
||||
def test_all_ids_contain_provider_slash(self):
|
||||
"""Model IDs should follow the provider/model format."""
|
||||
for mid in model_ids():
|
||||
assert "/" in mid, f"Model ID '{mid}' missing provider/ prefix"
|
||||
|
||||
def test_no_duplicate_ids(self):
|
||||
ids = model_ids()
|
||||
assert len(ids) == len(set(ids)), "Duplicate model IDs found"
|
||||
|
||||
|
||||
class TestMenuLabels:
|
||||
def test_same_length_as_model_ids(self):
|
||||
labels = menu_labels()
|
||||
ids = model_ids()
|
||||
assert len(labels) == len(ids)
|
||||
assert len(menu_labels()) == len(model_ids())
|
||||
|
||||
def test_recommended_in_first(self):
|
||||
def test_first_label_marked_recommended(self):
|
||||
labels = menu_labels()
|
||||
assert "recommended" in labels[0].lower()
|
||||
|
||||
def test_labels_contain_model_ids(self):
|
||||
def test_each_label_contains_its_model_id(self):
|
||||
for label, mid in zip(menu_labels(), model_ids()):
|
||||
assert mid in label, f"Label '{label}' doesn't contain model ID '{mid}'"
|
||||
|
||||
def test_non_recommended_labels_have_no_tag(self):
|
||||
"""Only the first model should have (recommended)."""
|
||||
labels = menu_labels()
|
||||
ids = model_ids()
|
||||
for label, mid in zip(labels, ids):
|
||||
assert mid in label
|
||||
for label in labels[1:]:
|
||||
assert "recommended" not in label.lower(), f"Unexpected 'recommended' in '{label}'"
|
||||
|
||||
|
||||
class TestOpenRouterModels:
|
||||
def test_structure_is_list_of_tuples(self):
|
||||
for entry in OPENROUTER_MODELS:
|
||||
assert isinstance(entry, tuple) and len(entry) == 2
|
||||
mid, desc = entry
|
||||
assert isinstance(mid, str) and len(mid) > 0
|
||||
assert isinstance(desc, str)
|
||||
|
||||
def test_at_least_5_models(self):
|
||||
"""Sanity check that the models list hasn't been accidentally truncated."""
|
||||
assert len(OPENROUTER_MODELS) >= 5
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
"""Tests for the file tools module (schema and handler wiring).
|
||||
"""Tests for the file tools module (schema, handler wiring, error paths).
|
||||
|
||||
These tests verify the tool schemas and handler wiring without
|
||||
requiring a running terminal environment. The actual file operations
|
||||
(ShellFileOperations) depend on a terminal backend, so we mock
|
||||
_get_file_ops to test the handler logic in isolation.
|
||||
Tests verify tool schemas, handler dispatch, validation logic, and error
|
||||
handling without requiring a running terminal environment.
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -18,82 +16,187 @@ from tools.file_tools import (
|
||||
)
|
||||
|
||||
|
||||
class TestSchemas:
|
||||
def test_read_file_schema(self):
|
||||
assert READ_FILE_SCHEMA["name"] == "read_file"
|
||||
props = READ_FILE_SCHEMA["parameters"]["properties"]
|
||||
assert "path" in props
|
||||
assert "offset" in props
|
||||
assert "limit" in props
|
||||
|
||||
def test_write_file_schema(self):
|
||||
assert WRITE_FILE_SCHEMA["name"] == "write_file"
|
||||
assert "path" in WRITE_FILE_SCHEMA["parameters"]["properties"]
|
||||
assert "content" in WRITE_FILE_SCHEMA["parameters"]["properties"]
|
||||
|
||||
def test_patch_schema(self):
|
||||
assert PATCH_SCHEMA["name"] == "patch"
|
||||
props = PATCH_SCHEMA["parameters"]["properties"]
|
||||
assert "mode" in props
|
||||
assert "old_string" in props
|
||||
assert "new_string" in props
|
||||
|
||||
def test_search_files_schema(self):
|
||||
assert SEARCH_FILES_SCHEMA["name"] == "search_files"
|
||||
props = SEARCH_FILES_SCHEMA["parameters"]["properties"]
|
||||
assert "pattern" in props
|
||||
assert "target" in props
|
||||
|
||||
|
||||
class TestFileToolsList:
|
||||
def test_file_tools_has_expected_entries(self):
|
||||
def test_has_expected_entries(self):
|
||||
names = {t["name"] for t in FILE_TOOLS}
|
||||
assert names == {"read_file", "write_file", "patch", "search_files"}
|
||||
|
||||
def test_each_entry_has_callable_function(self):
|
||||
for tool in FILE_TOOLS:
|
||||
assert callable(tool["function"]), f"{tool['name']} missing callable"
|
||||
|
||||
def test_schemas_have_required_fields(self):
|
||||
"""All schemas must have name, description, and parameters with properties."""
|
||||
for schema in [READ_FILE_SCHEMA, WRITE_FILE_SCHEMA, PATCH_SCHEMA, SEARCH_FILES_SCHEMA]:
|
||||
assert "name" in schema
|
||||
assert "description" in schema
|
||||
assert "properties" in schema["parameters"]
|
||||
|
||||
|
||||
class TestReadFileHandler:
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_read_file_returns_json(self, mock_get):
|
||||
def test_returns_file_content(self, mock_get):
|
||||
mock_ops = MagicMock()
|
||||
result_obj = MagicMock()
|
||||
result_obj.to_dict.return_value = {"content": "hello", "total_lines": 1}
|
||||
result_obj.to_dict.return_value = {"content": "line1\nline2", "total_lines": 2}
|
||||
mock_ops.read_file.return_value = result_obj
|
||||
mock_get.return_value = mock_ops
|
||||
|
||||
from tools.file_tools import read_file_tool
|
||||
|
||||
result = json.loads(read_file_tool("/tmp/test.txt"))
|
||||
assert result["content"] == "hello"
|
||||
assert result["content"] == "line1\nline2"
|
||||
assert result["total_lines"] == 2
|
||||
mock_ops.read_file.assert_called_once_with("/tmp/test.txt", 1, 500)
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_custom_offset_and_limit(self, mock_get):
|
||||
mock_ops = MagicMock()
|
||||
result_obj = MagicMock()
|
||||
result_obj.to_dict.return_value = {"content": "line10", "total_lines": 50}
|
||||
mock_ops.read_file.return_value = result_obj
|
||||
mock_get.return_value = mock_ops
|
||||
|
||||
from tools.file_tools import read_file_tool
|
||||
read_file_tool("/tmp/big.txt", offset=10, limit=20)
|
||||
mock_ops.read_file.assert_called_once_with("/tmp/big.txt", 10, 20)
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_exception_returns_error_json(self, mock_get):
|
||||
mock_get.side_effect = RuntimeError("terminal not available")
|
||||
|
||||
from tools.file_tools import read_file_tool
|
||||
result = json.loads(read_file_tool("/tmp/test.txt"))
|
||||
assert "error" in result
|
||||
assert "terminal not available" in result["error"]
|
||||
|
||||
|
||||
class TestWriteFileHandler:
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_write_file_returns_json(self, mock_get):
|
||||
def test_writes_content(self, mock_get):
|
||||
mock_ops = MagicMock()
|
||||
result_obj = MagicMock()
|
||||
result_obj.to_dict.return_value = {"status": "ok", "path": "/tmp/test.txt"}
|
||||
result_obj.to_dict.return_value = {"status": "ok", "path": "/tmp/out.txt", "bytes": 13}
|
||||
mock_ops.write_file.return_value = result_obj
|
||||
mock_get.return_value = mock_ops
|
||||
|
||||
from tools.file_tools import write_file_tool
|
||||
|
||||
result = json.loads(write_file_tool("/tmp/test.txt", "content"))
|
||||
result = json.loads(write_file_tool("/tmp/out.txt", "hello world!\n"))
|
||||
assert result["status"] == "ok"
|
||||
mock_ops.write_file.assert_called_once_with("/tmp/test.txt", "content")
|
||||
mock_ops.write_file.assert_called_once_with("/tmp/out.txt", "hello world!\n")
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_exception_returns_error_json(self, mock_get):
|
||||
mock_get.side_effect = PermissionError("read-only filesystem")
|
||||
|
||||
from tools.file_tools import write_file_tool
|
||||
result = json.loads(write_file_tool("/tmp/out.txt", "data"))
|
||||
assert "error" in result
|
||||
assert "read-only" in result["error"]
|
||||
|
||||
|
||||
class TestPatchHandler:
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_replace_mode_calls_patch_replace(self, mock_get):
|
||||
mock_ops = MagicMock()
|
||||
result_obj = MagicMock()
|
||||
result_obj.to_dict.return_value = {"status": "ok", "replacements": 1}
|
||||
mock_ops.patch_replace.return_value = result_obj
|
||||
mock_get.return_value = mock_ops
|
||||
|
||||
from tools.file_tools import patch_tool
|
||||
result = json.loads(patch_tool(
|
||||
mode="replace", path="/tmp/f.py",
|
||||
old_string="foo", new_string="bar"
|
||||
))
|
||||
assert result["status"] == "ok"
|
||||
mock_ops.patch_replace.assert_called_once_with("/tmp/f.py", "foo", "bar", False)
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_replace_mode_replace_all_flag(self, mock_get):
|
||||
mock_ops = MagicMock()
|
||||
result_obj = MagicMock()
|
||||
result_obj.to_dict.return_value = {"status": "ok", "replacements": 5}
|
||||
mock_ops.patch_replace.return_value = result_obj
|
||||
mock_get.return_value = mock_ops
|
||||
|
||||
from tools.file_tools import patch_tool
|
||||
patch_tool(mode="replace", path="/tmp/f.py",
|
||||
old_string="x", new_string="y", replace_all=True)
|
||||
mock_ops.patch_replace.assert_called_once_with("/tmp/f.py", "x", "y", True)
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_replace_mode_missing_path_errors(self, mock_get):
|
||||
from tools.file_tools import patch_tool
|
||||
|
||||
result = json.loads(patch_tool(mode="replace", path=None, old_string="a", new_string="b"))
|
||||
assert "error" in result
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_replace_mode_missing_strings_errors(self, mock_get):
|
||||
from tools.file_tools import patch_tool
|
||||
result = json.loads(patch_tool(mode="replace", path="/tmp/f.py", old_string=None, new_string="b"))
|
||||
assert "error" in result
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_patch_mode_calls_patch_v4a(self, mock_get):
|
||||
mock_ops = MagicMock()
|
||||
result_obj = MagicMock()
|
||||
result_obj.to_dict.return_value = {"status": "ok", "operations": 1}
|
||||
mock_ops.patch_v4a.return_value = result_obj
|
||||
mock_get.return_value = mock_ops
|
||||
|
||||
from tools.file_tools import patch_tool
|
||||
result = json.loads(patch_tool(mode="patch", patch="*** Begin Patch\n..."))
|
||||
assert result["status"] == "ok"
|
||||
mock_ops.patch_v4a.assert_called_once()
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_patch_mode_missing_content_errors(self, mock_get):
|
||||
from tools.file_tools import patch_tool
|
||||
result = json.loads(patch_tool(mode="patch", patch=None))
|
||||
assert "error" in result
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_unknown_mode_errors(self, mock_get):
|
||||
from tools.file_tools import patch_tool
|
||||
|
||||
result = json.loads(patch_tool(mode="unknown"))
|
||||
result = json.loads(patch_tool(mode="invalid_mode"))
|
||||
assert "error" in result
|
||||
assert "Unknown mode" in result["error"]
|
||||
|
||||
|
||||
class TestSearchHandler:
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_search_calls_file_ops(self, mock_get):
|
||||
mock_ops = MagicMock()
|
||||
result_obj = MagicMock()
|
||||
result_obj.to_dict.return_value = {"matches": ["file1.py:3:match"]}
|
||||
mock_ops.search.return_value = result_obj
|
||||
mock_get.return_value = mock_ops
|
||||
|
||||
from tools.file_tools import search_tool
|
||||
result = json.loads(search_tool(pattern="TODO", target="content", path="."))
|
||||
assert "matches" in result
|
||||
mock_ops.search.assert_called_once()
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_search_passes_all_params(self, mock_get):
|
||||
mock_ops = MagicMock()
|
||||
result_obj = MagicMock()
|
||||
result_obj.to_dict.return_value = {"matches": []}
|
||||
mock_ops.search.return_value = result_obj
|
||||
mock_get.return_value = mock_ops
|
||||
|
||||
from tools.file_tools import search_tool
|
||||
search_tool(pattern="class", target="files", path="/src",
|
||||
file_glob="*.py", limit=10, offset=5, output_mode="count", context=2)
|
||||
mock_ops.search.assert_called_once_with(
|
||||
pattern="class", path="/src", target="files", file_glob="*.py",
|
||||
limit=10, offset=5, output_mode="count", context=2,
|
||||
)
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_search_exception_returns_error(self, mock_get):
|
||||
mock_get.side_effect = RuntimeError("no terminal")
|
||||
|
||||
from tools.file_tools import search_tool
|
||||
result = json.loads(search_tool(pattern="x"))
|
||||
assert "error" in result
|
||||
|
||||
Reference in New Issue
Block a user