test: enhance session source tests and add validation for chat types

- Renamed test method for clarity and added comprehensive tests for `SessionSource` including handling of numeric `chat_id`, missing optional fields, and invalid platforms.
- Introduced tests for session source descriptions based on chat types and names, ensuring accurate representation in prompts.
- Improved file tools tests by validating schema structures, ensuring no duplicate model IDs, and enhancing error handling in file operations.
This commit is contained in:
teknium1
2026-02-26 00:53:57 -08:00
parent d372eb1f0e
commit 178658bf9f
3 changed files with 303 additions and 64 deletions

View File

@@ -1,5 +1,6 @@
"""Tests for gateway session management."""
import pytest
from gateway.config import Platform, HomeChannel, GatewayConfig, PlatformConfig
from gateway.session import (
SessionSource,
@@ -9,7 +10,7 @@ from gateway.session import (
class TestSessionSourceRoundtrip:
def test_to_dict_from_dict(self):
def test_full_roundtrip(self):
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="12345",
@@ -36,22 +37,97 @@ class TestSessionSourceRoundtrip:
restored = SessionSource.from_dict(d)
assert restored.platform == Platform.LOCAL
assert restored.chat_id == "cli"
assert restored.chat_type == "dm" # default value preserved
def test_chat_id_coerced_to_string(self):
"""from_dict should handle numeric chat_id (common from Telegram)."""
restored = SessionSource.from_dict({
"platform": "telegram",
"chat_id": 12345,
})
assert restored.chat_id == "12345"
assert isinstance(restored.chat_id, str)
def test_missing_optional_fields(self):
restored = SessionSource.from_dict({
"platform": "discord",
"chat_id": "abc",
})
assert restored.chat_name is None
assert restored.user_id is None
assert restored.user_name is None
assert restored.thread_id is None
assert restored.chat_type == "dm"
def test_invalid_platform_raises(self):
with pytest.raises((ValueError, KeyError)):
SessionSource.from_dict({"platform": "nonexistent", "chat_id": "1"})
class TestLocalCliSource:
class TestSessionSourceDescription:
def test_local_cli(self):
source = SessionSource.local_cli()
assert source.description == "CLI terminal"
def test_dm_with_username(self):
source = SessionSource(
platform=Platform.TELEGRAM, chat_id="123",
chat_type="dm", user_name="bob",
)
assert "DM" in source.description
assert "bob" in source.description
def test_dm_without_username_falls_back_to_user_id(self):
source = SessionSource(
platform=Platform.TELEGRAM, chat_id="123",
chat_type="dm", user_id="456",
)
assert "456" in source.description
def test_group_shows_chat_name(self):
source = SessionSource(
platform=Platform.DISCORD, chat_id="789",
chat_type="group", chat_name="Dev Chat",
)
assert "group" in source.description
assert "Dev Chat" in source.description
def test_channel_type(self):
source = SessionSource(
platform=Platform.TELEGRAM, chat_id="100",
chat_type="channel", chat_name="Announcements",
)
assert "channel" in source.description
assert "Announcements" in source.description
def test_thread_id_appended(self):
source = SessionSource(
platform=Platform.DISCORD, chat_id="789",
chat_type="group", chat_name="General",
thread_id="thread-42",
)
assert "thread" in source.description
assert "thread-42" in source.description
def test_unknown_chat_type_uses_name(self):
source = SessionSource(
platform=Platform.SLACK, chat_id="C01",
chat_type="forum", chat_name="Questions",
)
assert "Questions" in source.description
class TestLocalCliFactory:
def test_local_cli_defaults(self):
source = SessionSource.local_cli()
assert source.platform == Platform.LOCAL
assert source.chat_id == "cli"
assert source.chat_type == "dm"
def test_description_local(self):
source = SessionSource.local_cli()
assert source.description == "CLI terminal"
assert source.chat_name == "CLI terminal"
class TestBuildSessionContextPrompt:
def test_contains_platform_info(self):
def test_telegram_prompt_contains_platform_and_chat(self):
config = GatewayConfig(
platforms={
Platform.TELEGRAM: PlatformConfig(
@@ -76,9 +152,29 @@ class TestBuildSessionContextPrompt:
assert "Telegram" in prompt
assert "Home Chat" in prompt
assert "Session Context" in prompt
def test_local_source_prompt(self):
def test_discord_prompt(self):
config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(
enabled=True,
token="fake-discord-token",
),
},
)
source = SessionSource(
platform=Platform.DISCORD,
chat_id="guild-123",
chat_name="Server",
chat_type="group",
user_name="alice",
)
ctx = build_session_context(source, config)
prompt = build_session_context_prompt(ctx)
assert "Discord" in prompt
def test_local_prompt_mentions_machine(self):
config = GatewayConfig()
source = SessionSource.local_cli()
ctx = build_session_context(source, config)
@@ -86,3 +182,20 @@ class TestBuildSessionContextPrompt:
assert "Local" in prompt
assert "machine running this agent" in prompt
def test_whatsapp_prompt(self):
config = GatewayConfig(
platforms={
Platform.WHATSAPP: PlatformConfig(enabled=True, token=""),
},
)
source = SessionSource(
platform=Platform.WHATSAPP,
chat_id="15551234567@s.whatsapp.net",
chat_type="dm",
user_name="Phone User",
)
ctx = build_session_context(source, config)
prompt = build_session_context_prompt(ctx)
assert "WhatsApp" in prompt or "whatsapp" in prompt.lower()

View File

@@ -4,30 +4,53 @@ from hermes_cli.models import OPENROUTER_MODELS, menu_labels, model_ids
class TestModelIds:
def test_returns_strings(self):
def test_returns_non_empty_list(self):
ids = model_ids()
assert isinstance(ids, list)
assert len(ids) > 0
assert all(isinstance(mid, str) for mid in ids)
def test_ids_match_models_list(self):
ids = model_ids()
expected = [mid for mid, _ in OPENROUTER_MODELS]
assert ids == expected
def test_all_ids_contain_provider_slash(self):
"""Model IDs should follow the provider/model format."""
for mid in model_ids():
assert "/" in mid, f"Model ID '{mid}' missing provider/ prefix"
def test_no_duplicate_ids(self):
ids = model_ids()
assert len(ids) == len(set(ids)), "Duplicate model IDs found"
class TestMenuLabels:
def test_same_length_as_model_ids(self):
labels = menu_labels()
ids = model_ids()
assert len(labels) == len(ids)
assert len(menu_labels()) == len(model_ids())
def test_recommended_in_first(self):
def test_first_label_marked_recommended(self):
labels = menu_labels()
assert "recommended" in labels[0].lower()
def test_labels_contain_model_ids(self):
def test_each_label_contains_its_model_id(self):
for label, mid in zip(menu_labels(), model_ids()):
assert mid in label, f"Label '{label}' doesn't contain model ID '{mid}'"
def test_non_recommended_labels_have_no_tag(self):
"""Only the first model should have (recommended)."""
labels = menu_labels()
ids = model_ids()
for label, mid in zip(labels, ids):
assert mid in label
for label in labels[1:]:
assert "recommended" not in label.lower(), f"Unexpected 'recommended' in '{label}'"
class TestOpenRouterModels:
def test_structure_is_list_of_tuples(self):
for entry in OPENROUTER_MODELS:
assert isinstance(entry, tuple) and len(entry) == 2
mid, desc = entry
assert isinstance(mid, str) and len(mid) > 0
assert isinstance(desc, str)
def test_at_least_5_models(self):
"""Sanity check that the models list hasn't been accidentally truncated."""
assert len(OPENROUTER_MODELS) >= 5

View File

@@ -1,9 +1,7 @@
"""Tests for the file tools module (schema and handler wiring).
"""Tests for the file tools module (schema, handler wiring, error paths).
These tests verify the tool schemas and handler wiring without
requiring a running terminal environment. The actual file operations
(ShellFileOperations) depend on a terminal backend, so we mock
_get_file_ops to test the handler logic in isolation.
Tests verify tool schemas, handler dispatch, validation logic, and error
handling without requiring a running terminal environment.
"""
import json
@@ -18,82 +16,187 @@ from tools.file_tools import (
)
class TestSchemas:
def test_read_file_schema(self):
assert READ_FILE_SCHEMA["name"] == "read_file"
props = READ_FILE_SCHEMA["parameters"]["properties"]
assert "path" in props
assert "offset" in props
assert "limit" in props
def test_write_file_schema(self):
assert WRITE_FILE_SCHEMA["name"] == "write_file"
assert "path" in WRITE_FILE_SCHEMA["parameters"]["properties"]
assert "content" in WRITE_FILE_SCHEMA["parameters"]["properties"]
def test_patch_schema(self):
assert PATCH_SCHEMA["name"] == "patch"
props = PATCH_SCHEMA["parameters"]["properties"]
assert "mode" in props
assert "old_string" in props
assert "new_string" in props
def test_search_files_schema(self):
assert SEARCH_FILES_SCHEMA["name"] == "search_files"
props = SEARCH_FILES_SCHEMA["parameters"]["properties"]
assert "pattern" in props
assert "target" in props
class TestFileToolsList:
def test_file_tools_has_expected_entries(self):
def test_has_expected_entries(self):
names = {t["name"] for t in FILE_TOOLS}
assert names == {"read_file", "write_file", "patch", "search_files"}
def test_each_entry_has_callable_function(self):
for tool in FILE_TOOLS:
assert callable(tool["function"]), f"{tool['name']} missing callable"
def test_schemas_have_required_fields(self):
"""All schemas must have name, description, and parameters with properties."""
for schema in [READ_FILE_SCHEMA, WRITE_FILE_SCHEMA, PATCH_SCHEMA, SEARCH_FILES_SCHEMA]:
assert "name" in schema
assert "description" in schema
assert "properties" in schema["parameters"]
class TestReadFileHandler:
@patch("tools.file_tools._get_file_ops")
def test_read_file_returns_json(self, mock_get):
def test_returns_file_content(self, mock_get):
mock_ops = MagicMock()
result_obj = MagicMock()
result_obj.to_dict.return_value = {"content": "hello", "total_lines": 1}
result_obj.to_dict.return_value = {"content": "line1\nline2", "total_lines": 2}
mock_ops.read_file.return_value = result_obj
mock_get.return_value = mock_ops
from tools.file_tools import read_file_tool
result = json.loads(read_file_tool("/tmp/test.txt"))
assert result["content"] == "hello"
assert result["content"] == "line1\nline2"
assert result["total_lines"] == 2
mock_ops.read_file.assert_called_once_with("/tmp/test.txt", 1, 500)
@patch("tools.file_tools._get_file_ops")
def test_custom_offset_and_limit(self, mock_get):
mock_ops = MagicMock()
result_obj = MagicMock()
result_obj.to_dict.return_value = {"content": "line10", "total_lines": 50}
mock_ops.read_file.return_value = result_obj
mock_get.return_value = mock_ops
from tools.file_tools import read_file_tool
read_file_tool("/tmp/big.txt", offset=10, limit=20)
mock_ops.read_file.assert_called_once_with("/tmp/big.txt", 10, 20)
@patch("tools.file_tools._get_file_ops")
def test_exception_returns_error_json(self, mock_get):
mock_get.side_effect = RuntimeError("terminal not available")
from tools.file_tools import read_file_tool
result = json.loads(read_file_tool("/tmp/test.txt"))
assert "error" in result
assert "terminal not available" in result["error"]
class TestWriteFileHandler:
@patch("tools.file_tools._get_file_ops")
def test_write_file_returns_json(self, mock_get):
def test_writes_content(self, mock_get):
mock_ops = MagicMock()
result_obj = MagicMock()
result_obj.to_dict.return_value = {"status": "ok", "path": "/tmp/test.txt"}
result_obj.to_dict.return_value = {"status": "ok", "path": "/tmp/out.txt", "bytes": 13}
mock_ops.write_file.return_value = result_obj
mock_get.return_value = mock_ops
from tools.file_tools import write_file_tool
result = json.loads(write_file_tool("/tmp/test.txt", "content"))
result = json.loads(write_file_tool("/tmp/out.txt", "hello world!\n"))
assert result["status"] == "ok"
mock_ops.write_file.assert_called_once_with("/tmp/test.txt", "content")
mock_ops.write_file.assert_called_once_with("/tmp/out.txt", "hello world!\n")
@patch("tools.file_tools._get_file_ops")
def test_exception_returns_error_json(self, mock_get):
mock_get.side_effect = PermissionError("read-only filesystem")
from tools.file_tools import write_file_tool
result = json.loads(write_file_tool("/tmp/out.txt", "data"))
assert "error" in result
assert "read-only" in result["error"]
class TestPatchHandler:
@patch("tools.file_tools._get_file_ops")
def test_replace_mode_calls_patch_replace(self, mock_get):
mock_ops = MagicMock()
result_obj = MagicMock()
result_obj.to_dict.return_value = {"status": "ok", "replacements": 1}
mock_ops.patch_replace.return_value = result_obj
mock_get.return_value = mock_ops
from tools.file_tools import patch_tool
result = json.loads(patch_tool(
mode="replace", path="/tmp/f.py",
old_string="foo", new_string="bar"
))
assert result["status"] == "ok"
mock_ops.patch_replace.assert_called_once_with("/tmp/f.py", "foo", "bar", False)
@patch("tools.file_tools._get_file_ops")
def test_replace_mode_replace_all_flag(self, mock_get):
mock_ops = MagicMock()
result_obj = MagicMock()
result_obj.to_dict.return_value = {"status": "ok", "replacements": 5}
mock_ops.patch_replace.return_value = result_obj
mock_get.return_value = mock_ops
from tools.file_tools import patch_tool
patch_tool(mode="replace", path="/tmp/f.py",
old_string="x", new_string="y", replace_all=True)
mock_ops.patch_replace.assert_called_once_with("/tmp/f.py", "x", "y", True)
@patch("tools.file_tools._get_file_ops")
def test_replace_mode_missing_path_errors(self, mock_get):
from tools.file_tools import patch_tool
result = json.loads(patch_tool(mode="replace", path=None, old_string="a", new_string="b"))
assert "error" in result
@patch("tools.file_tools._get_file_ops")
def test_replace_mode_missing_strings_errors(self, mock_get):
from tools.file_tools import patch_tool
result = json.loads(patch_tool(mode="replace", path="/tmp/f.py", old_string=None, new_string="b"))
assert "error" in result
@patch("tools.file_tools._get_file_ops")
def test_patch_mode_calls_patch_v4a(self, mock_get):
mock_ops = MagicMock()
result_obj = MagicMock()
result_obj.to_dict.return_value = {"status": "ok", "operations": 1}
mock_ops.patch_v4a.return_value = result_obj
mock_get.return_value = mock_ops
from tools.file_tools import patch_tool
result = json.loads(patch_tool(mode="patch", patch="*** Begin Patch\n..."))
assert result["status"] == "ok"
mock_ops.patch_v4a.assert_called_once()
@patch("tools.file_tools._get_file_ops")
def test_patch_mode_missing_content_errors(self, mock_get):
from tools.file_tools import patch_tool
result = json.loads(patch_tool(mode="patch", patch=None))
assert "error" in result
@patch("tools.file_tools._get_file_ops")
def test_unknown_mode_errors(self, mock_get):
from tools.file_tools import patch_tool
result = json.loads(patch_tool(mode="unknown"))
result = json.loads(patch_tool(mode="invalid_mode"))
assert "error" in result
assert "Unknown mode" in result["error"]
class TestSearchHandler:
@patch("tools.file_tools._get_file_ops")
def test_search_calls_file_ops(self, mock_get):
mock_ops = MagicMock()
result_obj = MagicMock()
result_obj.to_dict.return_value = {"matches": ["file1.py:3:match"]}
mock_ops.search.return_value = result_obj
mock_get.return_value = mock_ops
from tools.file_tools import search_tool
result = json.loads(search_tool(pattern="TODO", target="content", path="."))
assert "matches" in result
mock_ops.search.assert_called_once()
@patch("tools.file_tools._get_file_ops")
def test_search_passes_all_params(self, mock_get):
mock_ops = MagicMock()
result_obj = MagicMock()
result_obj.to_dict.return_value = {"matches": []}
mock_ops.search.return_value = result_obj
mock_get.return_value = mock_ops
from tools.file_tools import search_tool
search_tool(pattern="class", target="files", path="/src",
file_glob="*.py", limit=10, offset=5, output_mode="count", context=2)
mock_ops.search.assert_called_once_with(
pattern="class", path="/src", target="files", file_glob="*.py",
limit=10, offset=5, output_mode="count", context=2,
)
@patch("tools.file_tools._get_file_ops")
def test_search_exception_returns_error(self, mock_get):
mock_get.side_effect = RuntimeError("no terminal")
from tools.file_tools import search_tool
result = json.loads(search_tool(pattern="x"))
assert "error" in result