diff --git a/tests/gateway/test_session.py b/tests/gateway/test_session.py index 75026c77b..2f5f4e4a5 100644 --- a/tests/gateway/test_session.py +++ b/tests/gateway/test_session.py @@ -1,5 +1,6 @@ """Tests for gateway session management.""" +import pytest from gateway.config import Platform, HomeChannel, GatewayConfig, PlatformConfig from gateway.session import ( SessionSource, @@ -9,7 +10,7 @@ from gateway.session import ( class TestSessionSourceRoundtrip: - def test_to_dict_from_dict(self): + def test_full_roundtrip(self): source = SessionSource( platform=Platform.TELEGRAM, chat_id="12345", @@ -36,22 +37,97 @@ class TestSessionSourceRoundtrip: restored = SessionSource.from_dict(d) assert restored.platform == Platform.LOCAL assert restored.chat_id == "cli" + assert restored.chat_type == "dm" # default value preserved + + def test_chat_id_coerced_to_string(self): + """from_dict should handle numeric chat_id (common from Telegram).""" + restored = SessionSource.from_dict({ + "platform": "telegram", + "chat_id": 12345, + }) + assert restored.chat_id == "12345" + assert isinstance(restored.chat_id, str) + + def test_missing_optional_fields(self): + restored = SessionSource.from_dict({ + "platform": "discord", + "chat_id": "abc", + }) + assert restored.chat_name is None + assert restored.user_id is None + assert restored.user_name is None + assert restored.thread_id is None + assert restored.chat_type == "dm" + + def test_invalid_platform_raises(self): + with pytest.raises((ValueError, KeyError)): + SessionSource.from_dict({"platform": "nonexistent", "chat_id": "1"}) -class TestLocalCliSource: +class TestSessionSourceDescription: def test_local_cli(self): + source = SessionSource.local_cli() + assert source.description == "CLI terminal" + + def test_dm_with_username(self): + source = SessionSource( + platform=Platform.TELEGRAM, chat_id="123", + chat_type="dm", user_name="bob", + ) + assert "DM" in source.description + assert "bob" in source.description + + def test_dm_without_username_falls_back_to_user_id(self): + source = SessionSource( + platform=Platform.TELEGRAM, chat_id="123", + chat_type="dm", user_id="456", + ) + assert "456" in source.description + + def test_group_shows_chat_name(self): + source = SessionSource( + platform=Platform.DISCORD, chat_id="789", + chat_type="group", chat_name="Dev Chat", + ) + assert "group" in source.description + assert "Dev Chat" in source.description + + def test_channel_type(self): + source = SessionSource( + platform=Platform.TELEGRAM, chat_id="100", + chat_type="channel", chat_name="Announcements", + ) + assert "channel" in source.description + assert "Announcements" in source.description + + def test_thread_id_appended(self): + source = SessionSource( + platform=Platform.DISCORD, chat_id="789", + chat_type="group", chat_name="General", + thread_id="thread-42", + ) + assert "thread" in source.description + assert "thread-42" in source.description + + def test_unknown_chat_type_uses_name(self): + source = SessionSource( + platform=Platform.SLACK, chat_id="C01", + chat_type="forum", chat_name="Questions", + ) + assert "Questions" in source.description + + +class TestLocalCliFactory: + def test_local_cli_defaults(self): source = SessionSource.local_cli() assert source.platform == Platform.LOCAL assert source.chat_id == "cli" assert source.chat_type == "dm" - - def test_description_local(self): - source = SessionSource.local_cli() - assert source.description == "CLI terminal" + assert source.chat_name == "CLI terminal" class TestBuildSessionContextPrompt: - def test_contains_platform_info(self): + def test_telegram_prompt_contains_platform_and_chat(self): config = GatewayConfig( platforms={ Platform.TELEGRAM: PlatformConfig( @@ -76,9 +152,29 @@ class TestBuildSessionContextPrompt: assert "Telegram" in prompt assert "Home Chat" in prompt - assert "Session Context" in prompt - def test_local_source_prompt(self): + def test_discord_prompt(self): + config = GatewayConfig( + platforms={ + Platform.DISCORD: PlatformConfig( + enabled=True, + token="fake-discord-token", + ), + }, + ) + source = SessionSource( + platform=Platform.DISCORD, + chat_id="guild-123", + chat_name="Server", + chat_type="group", + user_name="alice", + ) + ctx = build_session_context(source, config) + prompt = build_session_context_prompt(ctx) + + assert "Discord" in prompt + + def test_local_prompt_mentions_machine(self): config = GatewayConfig() source = SessionSource.local_cli() ctx = build_session_context(source, config) @@ -86,3 +182,20 @@ class TestBuildSessionContextPrompt: assert "Local" in prompt assert "machine running this agent" in prompt + + def test_whatsapp_prompt(self): + config = GatewayConfig( + platforms={ + Platform.WHATSAPP: PlatformConfig(enabled=True, token=""), + }, + ) + source = SessionSource( + platform=Platform.WHATSAPP, + chat_id="15551234567@s.whatsapp.net", + chat_type="dm", + user_name="Phone User", + ) + ctx = build_session_context(source, config) + prompt = build_session_context_prompt(ctx) + + assert "WhatsApp" in prompt or "whatsapp" in prompt.lower() diff --git a/tests/hermes_cli/test_models.py b/tests/hermes_cli/test_models.py index 0a6cc21d4..3eff1faa7 100644 --- a/tests/hermes_cli/test_models.py +++ b/tests/hermes_cli/test_models.py @@ -4,30 +4,53 @@ from hermes_cli.models import OPENROUTER_MODELS, menu_labels, model_ids class TestModelIds: - def test_returns_strings(self): + def test_returns_non_empty_list(self): ids = model_ids() assert isinstance(ids, list) assert len(ids) > 0 - assert all(isinstance(mid, str) for mid in ids) def test_ids_match_models_list(self): ids = model_ids() expected = [mid for mid, _ in OPENROUTER_MODELS] assert ids == expected + def test_all_ids_contain_provider_slash(self): + """Model IDs should follow the provider/model format.""" + for mid in model_ids(): + assert "/" in mid, f"Model ID '{mid}' missing provider/ prefix" + + def test_no_duplicate_ids(self): + ids = model_ids() + assert len(ids) == len(set(ids)), "Duplicate model IDs found" + class TestMenuLabels: def test_same_length_as_model_ids(self): - labels = menu_labels() - ids = model_ids() - assert len(labels) == len(ids) + assert len(menu_labels()) == len(model_ids()) - def test_recommended_in_first(self): + def test_first_label_marked_recommended(self): labels = menu_labels() assert "recommended" in labels[0].lower() - def test_labels_contain_model_ids(self): + def test_each_label_contains_its_model_id(self): + for label, mid in zip(menu_labels(), model_ids()): + assert mid in label, f"Label '{label}' doesn't contain model ID '{mid}'" + + def test_non_recommended_labels_have_no_tag(self): + """Only the first model should have (recommended).""" labels = menu_labels() - ids = model_ids() - for label, mid in zip(labels, ids): - assert mid in label + for label in labels[1:]: + assert "recommended" not in label.lower(), f"Unexpected 'recommended' in '{label}'" + + +class TestOpenRouterModels: + def test_structure_is_list_of_tuples(self): + for entry in OPENROUTER_MODELS: + assert isinstance(entry, tuple) and len(entry) == 2 + mid, desc = entry + assert isinstance(mid, str) and len(mid) > 0 + assert isinstance(desc, str) + + def test_at_least_5_models(self): + """Sanity check that the models list hasn't been accidentally truncated.""" + assert len(OPENROUTER_MODELS) >= 5 diff --git a/tests/tools/test_file_tools.py b/tests/tools/test_file_tools.py index 997a7bf70..8b1bf3f7d 100644 --- a/tests/tools/test_file_tools.py +++ b/tests/tools/test_file_tools.py @@ -1,9 +1,7 @@ -"""Tests for the file tools module (schema and handler wiring). +"""Tests for the file tools module (schema, handler wiring, error paths). -These tests verify the tool schemas and handler wiring without -requiring a running terminal environment. The actual file operations -(ShellFileOperations) depend on a terminal backend, so we mock -_get_file_ops to test the handler logic in isolation. +Tests verify tool schemas, handler dispatch, validation logic, and error +handling without requiring a running terminal environment. """ import json @@ -18,82 +16,187 @@ from tools.file_tools import ( ) -class TestSchemas: - def test_read_file_schema(self): - assert READ_FILE_SCHEMA["name"] == "read_file" - props = READ_FILE_SCHEMA["parameters"]["properties"] - assert "path" in props - assert "offset" in props - assert "limit" in props - - def test_write_file_schema(self): - assert WRITE_FILE_SCHEMA["name"] == "write_file" - assert "path" in WRITE_FILE_SCHEMA["parameters"]["properties"] - assert "content" in WRITE_FILE_SCHEMA["parameters"]["properties"] - - def test_patch_schema(self): - assert PATCH_SCHEMA["name"] == "patch" - props = PATCH_SCHEMA["parameters"]["properties"] - assert "mode" in props - assert "old_string" in props - assert "new_string" in props - - def test_search_files_schema(self): - assert SEARCH_FILES_SCHEMA["name"] == "search_files" - props = SEARCH_FILES_SCHEMA["parameters"]["properties"] - assert "pattern" in props - assert "target" in props - - class TestFileToolsList: - def test_file_tools_has_expected_entries(self): + def test_has_expected_entries(self): names = {t["name"] for t in FILE_TOOLS} assert names == {"read_file", "write_file", "patch", "search_files"} + def test_each_entry_has_callable_function(self): + for tool in FILE_TOOLS: + assert callable(tool["function"]), f"{tool['name']} missing callable" + + def test_schemas_have_required_fields(self): + """All schemas must have name, description, and parameters with properties.""" + for schema in [READ_FILE_SCHEMA, WRITE_FILE_SCHEMA, PATCH_SCHEMA, SEARCH_FILES_SCHEMA]: + assert "name" in schema + assert "description" in schema + assert "properties" in schema["parameters"] + class TestReadFileHandler: @patch("tools.file_tools._get_file_ops") - def test_read_file_returns_json(self, mock_get): + def test_returns_file_content(self, mock_get): mock_ops = MagicMock() result_obj = MagicMock() - result_obj.to_dict.return_value = {"content": "hello", "total_lines": 1} + result_obj.to_dict.return_value = {"content": "line1\nline2", "total_lines": 2} mock_ops.read_file.return_value = result_obj mock_get.return_value = mock_ops from tools.file_tools import read_file_tool - result = json.loads(read_file_tool("/tmp/test.txt")) - assert result["content"] == "hello" + assert result["content"] == "line1\nline2" + assert result["total_lines"] == 2 mock_ops.read_file.assert_called_once_with("/tmp/test.txt", 1, 500) + @patch("tools.file_tools._get_file_ops") + def test_custom_offset_and_limit(self, mock_get): + mock_ops = MagicMock() + result_obj = MagicMock() + result_obj.to_dict.return_value = {"content": "line10", "total_lines": 50} + mock_ops.read_file.return_value = result_obj + mock_get.return_value = mock_ops + + from tools.file_tools import read_file_tool + read_file_tool("/tmp/big.txt", offset=10, limit=20) + mock_ops.read_file.assert_called_once_with("/tmp/big.txt", 10, 20) + + @patch("tools.file_tools._get_file_ops") + def test_exception_returns_error_json(self, mock_get): + mock_get.side_effect = RuntimeError("terminal not available") + + from tools.file_tools import read_file_tool + result = json.loads(read_file_tool("/tmp/test.txt")) + assert "error" in result + assert "terminal not available" in result["error"] + class TestWriteFileHandler: @patch("tools.file_tools._get_file_ops") - def test_write_file_returns_json(self, mock_get): + def test_writes_content(self, mock_get): mock_ops = MagicMock() result_obj = MagicMock() - result_obj.to_dict.return_value = {"status": "ok", "path": "/tmp/test.txt"} + result_obj.to_dict.return_value = {"status": "ok", "path": "/tmp/out.txt", "bytes": 13} mock_ops.write_file.return_value = result_obj mock_get.return_value = mock_ops from tools.file_tools import write_file_tool - - result = json.loads(write_file_tool("/tmp/test.txt", "content")) + result = json.loads(write_file_tool("/tmp/out.txt", "hello world!\n")) assert result["status"] == "ok" - mock_ops.write_file.assert_called_once_with("/tmp/test.txt", "content") + mock_ops.write_file.assert_called_once_with("/tmp/out.txt", "hello world!\n") + + @patch("tools.file_tools._get_file_ops") + def test_exception_returns_error_json(self, mock_get): + mock_get.side_effect = PermissionError("read-only filesystem") + + from tools.file_tools import write_file_tool + result = json.loads(write_file_tool("/tmp/out.txt", "data")) + assert "error" in result + assert "read-only" in result["error"] class TestPatchHandler: + @patch("tools.file_tools._get_file_ops") + def test_replace_mode_calls_patch_replace(self, mock_get): + mock_ops = MagicMock() + result_obj = MagicMock() + result_obj.to_dict.return_value = {"status": "ok", "replacements": 1} + mock_ops.patch_replace.return_value = result_obj + mock_get.return_value = mock_ops + + from tools.file_tools import patch_tool + result = json.loads(patch_tool( + mode="replace", path="/tmp/f.py", + old_string="foo", new_string="bar" + )) + assert result["status"] == "ok" + mock_ops.patch_replace.assert_called_once_with("/tmp/f.py", "foo", "bar", False) + + @patch("tools.file_tools._get_file_ops") + def test_replace_mode_replace_all_flag(self, mock_get): + mock_ops = MagicMock() + result_obj = MagicMock() + result_obj.to_dict.return_value = {"status": "ok", "replacements": 5} + mock_ops.patch_replace.return_value = result_obj + mock_get.return_value = mock_ops + + from tools.file_tools import patch_tool + patch_tool(mode="replace", path="/tmp/f.py", + old_string="x", new_string="y", replace_all=True) + mock_ops.patch_replace.assert_called_once_with("/tmp/f.py", "x", "y", True) + @patch("tools.file_tools._get_file_ops") def test_replace_mode_missing_path_errors(self, mock_get): from tools.file_tools import patch_tool - result = json.loads(patch_tool(mode="replace", path=None, old_string="a", new_string="b")) assert "error" in result + @patch("tools.file_tools._get_file_ops") + def test_replace_mode_missing_strings_errors(self, mock_get): + from tools.file_tools import patch_tool + result = json.loads(patch_tool(mode="replace", path="/tmp/f.py", old_string=None, new_string="b")) + assert "error" in result + + @patch("tools.file_tools._get_file_ops") + def test_patch_mode_calls_patch_v4a(self, mock_get): + mock_ops = MagicMock() + result_obj = MagicMock() + result_obj.to_dict.return_value = {"status": "ok", "operations": 1} + mock_ops.patch_v4a.return_value = result_obj + mock_get.return_value = mock_ops + + from tools.file_tools import patch_tool + result = json.loads(patch_tool(mode="patch", patch="*** Begin Patch\n...")) + assert result["status"] == "ok" + mock_ops.patch_v4a.assert_called_once() + + @patch("tools.file_tools._get_file_ops") + def test_patch_mode_missing_content_errors(self, mock_get): + from tools.file_tools import patch_tool + result = json.loads(patch_tool(mode="patch", patch=None)) + assert "error" in result + @patch("tools.file_tools._get_file_ops") def test_unknown_mode_errors(self, mock_get): from tools.file_tools import patch_tool - - result = json.loads(patch_tool(mode="unknown")) + result = json.loads(patch_tool(mode="invalid_mode")) + assert "error" in result + assert "Unknown mode" in result["error"] + + +class TestSearchHandler: + @patch("tools.file_tools._get_file_ops") + def test_search_calls_file_ops(self, mock_get): + mock_ops = MagicMock() + result_obj = MagicMock() + result_obj.to_dict.return_value = {"matches": ["file1.py:3:match"]} + mock_ops.search.return_value = result_obj + mock_get.return_value = mock_ops + + from tools.file_tools import search_tool + result = json.loads(search_tool(pattern="TODO", target="content", path=".")) + assert "matches" in result + mock_ops.search.assert_called_once() + + @patch("tools.file_tools._get_file_ops") + def test_search_passes_all_params(self, mock_get): + mock_ops = MagicMock() + result_obj = MagicMock() + result_obj.to_dict.return_value = {"matches": []} + mock_ops.search.return_value = result_obj + mock_get.return_value = mock_ops + + from tools.file_tools import search_tool + search_tool(pattern="class", target="files", path="/src", + file_glob="*.py", limit=10, offset=5, output_mode="count", context=2) + mock_ops.search.assert_called_once_with( + pattern="class", path="/src", target="files", file_glob="*.py", + limit=10, offset=5, output_mode="count", context=2, + ) + + @patch("tools.file_tools._get_file_ops") + def test_search_exception_returns_error(self, mock_get): + mock_get.side_effect = RuntimeError("no terminal") + + from tools.file_tools import search_tool + result = json.loads(search_tool(pattern="x")) assert "error" in result