Files
hermes-agent/tests/tools/test_mcp_tool.py
teknium1 7df14227a9 feat(mcp): banner integration, /reload-mcp command, resources & prompts
Banner integration:
- MCP Servers section in CLI startup banner between Tools and Skills
- Shows each server with transport type, tool count, connection status
- Failed servers shown in red; section hidden when no MCP configured
- Summary line includes MCP server count
- Removed raw print() calls from discovery (banner handles display)

/reload-mcp command:
- New slash command in both CLI and gateway
- Disconnects all MCP servers, re-reads config.yaml, reconnects
- Reports what changed (added/removed/reconnected servers)
- Allows adding/removing MCP servers without restarting

Resources & Prompts support:
- 4 utility tools registered per server: list_resources, read_resource,
  list_prompts, get_prompt
- Exposes MCP Resources (data sources) and Prompts (templates) as tools
- Proper parameter schemas (uri for read_resource, name for get_prompt)
- Handles text and binary resource content
- 23 new tests covering schemas, handlers, and registration

Test coverage: 74 MCP tests total, 1186 tests pass overall.
2026-03-02 19:15:59 -08:00

1492 lines
55 KiB
Python

"""Tests for the MCP (Model Context Protocol) client support.
All tests use mocks -- no real MCP servers or subprocesses are started.
"""
import asyncio
import json
import os
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_mcp_tool(name="read_file", description="Read a file", input_schema=None):
"""Create a fake MCP Tool object matching the SDK interface."""
tool = SimpleNamespace()
tool.name = name
tool.description = description
tool.inputSchema = input_schema or {
"type": "object",
"properties": {
"path": {"type": "string", "description": "File path"},
},
"required": ["path"],
}
return tool
def _make_call_result(text="file contents here", is_error=False):
"""Create a fake MCP CallToolResult."""
block = SimpleNamespace(text=text)
return SimpleNamespace(content=[block], isError=is_error)
def _make_mock_server(name, session=None, tools=None):
"""Create an MCPServerTask with mock attributes for testing."""
from tools.mcp_tool import MCPServerTask
server = MCPServerTask(name)
server.session = session
server._tools = tools or []
return server
# ---------------------------------------------------------------------------
# Config loading
# ---------------------------------------------------------------------------
class TestLoadMCPConfig:
def test_no_config_returns_empty(self):
"""No mcp_servers key in config -> empty dict."""
with patch("hermes_cli.config.load_config", return_value={"model": "test"}):
from tools.mcp_tool import _load_mcp_config
result = _load_mcp_config()
assert result == {}
def test_valid_config_parsed(self):
"""Valid mcp_servers config is returned as-is."""
servers = {
"filesystem": {
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"],
"env": {},
}
}
with patch("hermes_cli.config.load_config", return_value={"mcp_servers": servers}):
from tools.mcp_tool import _load_mcp_config
result = _load_mcp_config()
assert "filesystem" in result
assert result["filesystem"]["command"] == "npx"
def test_mcp_servers_not_dict_returns_empty(self):
"""mcp_servers set to non-dict value -> empty dict."""
with patch("hermes_cli.config.load_config", return_value={"mcp_servers": "invalid"}):
from tools.mcp_tool import _load_mcp_config
result = _load_mcp_config()
assert result == {}
# ---------------------------------------------------------------------------
# Schema conversion
# ---------------------------------------------------------------------------
class TestSchemaConversion:
def test_converts_mcp_tool_to_hermes_schema(self):
from tools.mcp_tool import _convert_mcp_schema
mcp_tool = _make_mcp_tool(name="read_file", description="Read a file")
schema = _convert_mcp_schema("filesystem", mcp_tool)
assert schema["name"] == "mcp_filesystem_read_file"
assert schema["description"] == "Read a file"
assert "properties" in schema["parameters"]
def test_empty_input_schema_gets_default(self):
from tools.mcp_tool import _convert_mcp_schema
mcp_tool = _make_mcp_tool(name="ping", description="Ping", input_schema=None)
mcp_tool.inputSchema = None
schema = _convert_mcp_schema("test", mcp_tool)
assert schema["parameters"]["type"] == "object"
assert schema["parameters"]["properties"] == {}
def test_tool_name_prefix_format(self):
from tools.mcp_tool import _convert_mcp_schema
mcp_tool = _make_mcp_tool(name="list_dir")
schema = _convert_mcp_schema("my_server", mcp_tool)
assert schema["name"] == "mcp_my_server_list_dir"
def test_hyphens_sanitized_to_underscores(self):
"""Hyphens in tool/server names are replaced with underscores for LLM compat."""
from tools.mcp_tool import _convert_mcp_schema
mcp_tool = _make_mcp_tool(name="get-sum")
schema = _convert_mcp_schema("my-server", mcp_tool)
assert schema["name"] == "mcp_my_server_get_sum"
assert "-" not in schema["name"]
# ---------------------------------------------------------------------------
# Check function
# ---------------------------------------------------------------------------
class TestCheckFunction:
def test_disconnected_returns_false(self):
from tools.mcp_tool import _make_check_fn, _servers
_servers.pop("test_server", None)
check = _make_check_fn("test_server")
assert check() is False
def test_connected_returns_true(self):
from tools.mcp_tool import _make_check_fn, _servers
server = _make_mock_server("test_server", session=MagicMock())
_servers["test_server"] = server
try:
check = _make_check_fn("test_server")
assert check() is True
finally:
_servers.pop("test_server", None)
def test_session_none_returns_false(self):
from tools.mcp_tool import _make_check_fn, _servers
server = _make_mock_server("test_server", session=None)
_servers["test_server"] = server
try:
check = _make_check_fn("test_server")
assert check() is False
finally:
_servers.pop("test_server", None)
# ---------------------------------------------------------------------------
# Tool handler
# ---------------------------------------------------------------------------
class TestToolHandler:
"""Tool handlers are sync functions that schedule work on the MCP loop."""
def _patch_mcp_loop(self, coro_side_effect=None):
"""Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
def fake_run(coro, timeout=30):
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
if coro_side_effect:
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=coro_side_effect)
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run)
def test_successful_call(self):
from tools.mcp_tool import _make_tool_handler, _servers
mock_session = MagicMock()
mock_session.call_tool = AsyncMock(
return_value=_make_call_result("hello world", is_error=False)
)
server = _make_mock_server("test_srv", session=mock_session)
_servers["test_srv"] = server
try:
handler = _make_tool_handler("test_srv", "greet", 120)
with self._patch_mcp_loop():
result = json.loads(handler({"name": "world"}))
assert result["result"] == "hello world"
mock_session.call_tool.assert_called_once_with("greet", arguments={"name": "world"})
finally:
_servers.pop("test_srv", None)
def test_mcp_error_result(self):
from tools.mcp_tool import _make_tool_handler, _servers
mock_session = MagicMock()
mock_session.call_tool = AsyncMock(
return_value=_make_call_result("something went wrong", is_error=True)
)
server = _make_mock_server("test_srv", session=mock_session)
_servers["test_srv"] = server
try:
handler = _make_tool_handler("test_srv", "fail_tool", 120)
with self._patch_mcp_loop():
result = json.loads(handler({}))
assert "error" in result
assert "something went wrong" in result["error"]
finally:
_servers.pop("test_srv", None)
def test_disconnected_server(self):
from tools.mcp_tool import _make_tool_handler, _servers
_servers.pop("ghost", None)
handler = _make_tool_handler("ghost", "any_tool", 120)
result = json.loads(handler({}))
assert "error" in result
assert "not connected" in result["error"]
def test_exception_during_call(self):
from tools.mcp_tool import _make_tool_handler, _servers
mock_session = MagicMock()
mock_session.call_tool = AsyncMock(side_effect=RuntimeError("connection lost"))
server = _make_mock_server("test_srv", session=mock_session)
_servers["test_srv"] = server
try:
handler = _make_tool_handler("test_srv", "broken_tool", 120)
with self._patch_mcp_loop():
result = json.loads(handler({}))
assert "error" in result
assert "connection lost" in result["error"]
finally:
_servers.pop("test_srv", None)
# ---------------------------------------------------------------------------
# Tool registration (discovery + register)
# ---------------------------------------------------------------------------
class TestDiscoverAndRegister:
def test_tools_registered_in_registry(self):
"""_discover_and_register_server registers tools with correct names."""
from tools.registry import ToolRegistry
from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
mock_registry = ToolRegistry()
mock_tools = [
_make_mcp_tool("read_file", "Read a file"),
_make_mcp_tool("write_file", "Write a file"),
]
mock_session = MagicMock()
async def fake_connect(name, config):
server = MCPServerTask(name)
server.session = mock_session
server._tools = mock_tools
return server
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
patch("tools.registry.registry", mock_registry):
registered = asyncio.run(
_discover_and_register_server("fs", {"command": "npx", "args": []})
)
assert "mcp_fs_read_file" in registered
assert "mcp_fs_write_file" in registered
assert "mcp_fs_read_file" in mock_registry.get_all_tool_names()
assert "mcp_fs_write_file" in mock_registry.get_all_tool_names()
_servers.pop("fs", None)
def test_toolset_created(self):
"""A custom toolset is created for the MCP server."""
from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
mock_tools = [_make_mcp_tool("ping", "Ping")]
mock_session = MagicMock()
async def fake_connect(name, config):
server = MCPServerTask(name)
server.session = mock_session
server._tools = mock_tools
return server
mock_create = MagicMock()
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
patch("toolsets.create_custom_toolset", mock_create):
asyncio.run(
_discover_and_register_server("myserver", {"command": "test"})
)
mock_create.assert_called_once()
call_kwargs = mock_create.call_args
assert call_kwargs[1]["name"] == "mcp-myserver" or call_kwargs[0][0] == "mcp-myserver"
_servers.pop("myserver", None)
def test_schema_format_correct(self):
"""Registered schemas have the correct format."""
from tools.registry import ToolRegistry
from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
mock_registry = ToolRegistry()
mock_tools = [_make_mcp_tool("do_thing", "Do something")]
mock_session = MagicMock()
async def fake_connect(name, config):
server = MCPServerTask(name)
server.session = mock_session
server._tools = mock_tools
return server
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
patch("tools.registry.registry", mock_registry):
asyncio.run(
_discover_and_register_server("srv", {"command": "test"})
)
entry = mock_registry._tools.get("mcp_srv_do_thing")
assert entry is not None
assert entry.schema["name"] == "mcp_srv_do_thing"
assert "parameters" in entry.schema
assert entry.is_async is False
assert entry.toolset == "mcp-srv"
_servers.pop("srv", None)
# ---------------------------------------------------------------------------
# MCPServerTask (run / start / shutdown)
# ---------------------------------------------------------------------------
class TestMCPServerTask:
"""Test the MCPServerTask lifecycle with mocked MCP SDK."""
def _mock_stdio_and_session(self, session):
"""Return patches for stdio_client and ClientSession as async CMs."""
mock_read, mock_write = MagicMock(), MagicMock()
mock_stdio_cm = MagicMock()
mock_stdio_cm.__aenter__ = AsyncMock(return_value=(mock_read, mock_write))
mock_stdio_cm.__aexit__ = AsyncMock(return_value=False)
mock_cs_cm = MagicMock()
mock_cs_cm.__aenter__ = AsyncMock(return_value=session)
mock_cs_cm.__aexit__ = AsyncMock(return_value=False)
return (
patch("tools.mcp_tool.stdio_client", return_value=mock_stdio_cm),
patch("tools.mcp_tool.ClientSession", return_value=mock_cs_cm),
mock_read, mock_write,
)
def test_start_connects_and_discovers_tools(self):
"""start() creates a Task that connects, discovers tools, and waits."""
from tools.mcp_tool import MCPServerTask
mock_tools = [_make_mcp_tool("echo")]
mock_session = MagicMock()
mock_session.initialize = AsyncMock()
mock_session.list_tools = AsyncMock(
return_value=SimpleNamespace(tools=mock_tools)
)
p_stdio, p_cs, _, _ = self._mock_stdio_and_session(mock_session)
async def _test():
with patch("tools.mcp_tool.StdioServerParameters"), p_stdio, p_cs:
server = MCPServerTask("test_srv")
await server.start({"command": "npx", "args": ["-y", "test"]})
assert server.session is mock_session
assert len(server._tools) == 1
assert server._tools[0].name == "echo"
mock_session.initialize.assert_called_once()
await server.shutdown()
assert server.session is None
asyncio.run(_test())
def test_no_command_raises(self):
"""Missing 'command' in config raises ValueError."""
from tools.mcp_tool import MCPServerTask
async def _test():
server = MCPServerTask("bad")
with pytest.raises(ValueError, match="no 'command'"):
await server.start({"args": []})
asyncio.run(_test())
def test_empty_env_gets_safe_defaults(self):
"""Empty env dict gets safe default env vars (PATH, HOME, etc.)."""
from tools.mcp_tool import MCPServerTask
mock_session = MagicMock()
mock_session.initialize = AsyncMock()
mock_session.list_tools = AsyncMock(
return_value=SimpleNamespace(tools=[])
)
p_stdio, p_cs, _, _ = self._mock_stdio_and_session(mock_session)
async def _test():
with patch("tools.mcp_tool.StdioServerParameters") as mock_params, \
p_stdio, p_cs, \
patch.dict("os.environ", {"PATH": "/usr/bin", "HOME": "/home/test"}, clear=False):
server = MCPServerTask("srv")
await server.start({"command": "node", "env": {}})
# Empty dict -> safe env vars (not None)
call_kwargs = mock_params.call_args
env_arg = call_kwargs.kwargs.get("env")
assert env_arg is not None
assert isinstance(env_arg, dict)
assert "PATH" in env_arg
assert "HOME" in env_arg
await server.shutdown()
asyncio.run(_test())
def test_shutdown_signals_task_exit(self):
"""shutdown() signals the event and waits for task completion."""
from tools.mcp_tool import MCPServerTask
mock_session = MagicMock()
mock_session.initialize = AsyncMock()
mock_session.list_tools = AsyncMock(
return_value=SimpleNamespace(tools=[])
)
p_stdio, p_cs, _, _ = self._mock_stdio_and_session(mock_session)
async def _test():
with patch("tools.mcp_tool.StdioServerParameters"), p_stdio, p_cs:
server = MCPServerTask("srv")
await server.start({"command": "npx"})
assert server.session is not None
assert not server._task.done()
await server.shutdown()
assert server.session is None
assert server._task.done()
asyncio.run(_test())
# ---------------------------------------------------------------------------
# discover_mcp_tools toolset injection
# ---------------------------------------------------------------------------
class TestToolsetInjection:
def test_mcp_tools_added_to_all_hermes_toolsets(self):
"""Discovered MCP tools are dynamically injected into all hermes-* toolsets."""
from tools.mcp_tool import MCPServerTask
mock_tools = [_make_mcp_tool("list_files", "List files")]
mock_session = MagicMock()
fresh_servers = {}
async def fake_connect(name, config):
server = MCPServerTask(name)
server.session = mock_session
server._tools = mock_tools
return server
fake_toolsets = {
"hermes-cli": {"tools": ["terminal"], "description": "CLI", "includes": []},
"hermes-telegram": {"tools": ["terminal"], "description": "TG", "includes": []},
"hermes-gateway": {"tools": [], "description": "GW", "includes": []},
"non-hermes": {"tools": [], "description": "other", "includes": []},
}
fake_config = {"fs": {"command": "npx", "args": []}}
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
patch("tools.mcp_tool._servers", fresh_servers), \
patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
patch("toolsets.TOOLSETS", fake_toolsets):
from tools.mcp_tool import discover_mcp_tools
result = discover_mcp_tools()
assert "mcp_fs_list_files" in result
# All hermes-* toolsets get injection
assert "mcp_fs_list_files" in fake_toolsets["hermes-cli"]["tools"]
assert "mcp_fs_list_files" in fake_toolsets["hermes-telegram"]["tools"]
assert "mcp_fs_list_files" in fake_toolsets["hermes-gateway"]["tools"]
# Non-hermes toolset should NOT get injection
assert "mcp_fs_list_files" not in fake_toolsets["non-hermes"]["tools"]
# Original tools preserved
assert "terminal" in fake_toolsets["hermes-cli"]["tools"]
def test_server_connection_failure_skipped(self):
"""If one server fails to connect, others still proceed."""
from tools.mcp_tool import MCPServerTask
mock_tools = [_make_mcp_tool("ping", "Ping")]
mock_session = MagicMock()
fresh_servers = {}
call_count = 0
async def flaky_connect(name, config):
nonlocal call_count
call_count += 1
if name == "broken":
raise ConnectionError("cannot reach server")
server = MCPServerTask(name)
server.session = mock_session
server._tools = mock_tools
return server
fake_config = {
"broken": {"command": "bad"},
"good": {"command": "npx", "args": []},
}
fake_toolsets = {
"hermes-cli": {"tools": [], "description": "CLI", "includes": []},
}
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
patch("tools.mcp_tool._servers", fresh_servers), \
patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
patch("tools.mcp_tool._connect_server", side_effect=flaky_connect), \
patch("toolsets.TOOLSETS", fake_toolsets):
from tools.mcp_tool import discover_mcp_tools
result = discover_mcp_tools()
assert "mcp_good_ping" in result
assert "mcp_broken_ping" not in result
assert call_count == 2
def test_partial_failure_retry_on_second_call(self):
"""Failed servers are retried on subsequent discover_mcp_tools() calls."""
from tools.mcp_tool import MCPServerTask
mock_tools = [_make_mcp_tool("ping", "Ping")]
mock_session = MagicMock()
# Use a real dict so idempotency logic works correctly
fresh_servers = {}
call_count = 0
broken_fixed = False
async def flaky_connect(name, config):
nonlocal call_count
call_count += 1
if name == "broken" and not broken_fixed:
raise ConnectionError("cannot reach server")
server = MCPServerTask(name)
server.session = mock_session
server._tools = mock_tools
return server
fake_config = {
"broken": {"command": "bad"},
"good": {"command": "npx", "args": []},
}
fake_toolsets = {
"hermes-cli": {"tools": [], "description": "CLI", "includes": []},
}
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
patch("tools.mcp_tool._servers", fresh_servers), \
patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
patch("tools.mcp_tool._connect_server", side_effect=flaky_connect), \
patch("toolsets.TOOLSETS", fake_toolsets):
from tools.mcp_tool import discover_mcp_tools
# First call: good connects, broken fails
result1 = discover_mcp_tools()
assert "mcp_good_ping" in result1
assert "mcp_broken_ping" not in result1
first_attempts = call_count
# "Fix" the broken server
broken_fixed = True
call_count = 0
# Second call: should retry broken, skip good
result2 = discover_mcp_tools()
assert "mcp_good_ping" in result2
assert "mcp_broken_ping" in result2
assert call_count == 1 # Only broken retried
# ---------------------------------------------------------------------------
# Graceful fallback
# ---------------------------------------------------------------------------
class TestGracefulFallback:
def test_mcp_unavailable_returns_empty(self):
"""When _MCP_AVAILABLE is False, discover_mcp_tools is a no-op."""
with patch("tools.mcp_tool._MCP_AVAILABLE", False):
from tools.mcp_tool import discover_mcp_tools
result = discover_mcp_tools()
assert result == []
def test_no_servers_returns_empty(self):
"""No MCP servers configured -> empty list."""
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
patch("tools.mcp_tool._servers", {}), \
patch("tools.mcp_tool._load_mcp_config", return_value={}):
from tools.mcp_tool import discover_mcp_tools
result = discover_mcp_tools()
assert result == []
# ---------------------------------------------------------------------------
# Shutdown (public API)
# ---------------------------------------------------------------------------
class TestShutdown:
def test_no_servers_safe(self):
"""shutdown_mcp_servers with no servers does nothing."""
from tools.mcp_tool import shutdown_mcp_servers, _servers
_servers.clear()
shutdown_mcp_servers() # Should not raise
def test_shutdown_clears_servers(self):
"""shutdown_mcp_servers calls shutdown() on each server and clears dict."""
import tools.mcp_tool as mcp_mod
from tools.mcp_tool import shutdown_mcp_servers, _servers
_servers.clear()
mock_server = MagicMock()
mock_server.name = "test"
mock_server.shutdown = AsyncMock()
_servers["test"] = mock_server
mcp_mod._ensure_mcp_loop()
try:
shutdown_mcp_servers()
finally:
mcp_mod._mcp_loop = None
mcp_mod._mcp_thread = None
assert len(_servers) == 0
mock_server.shutdown.assert_called_once()
def test_shutdown_handles_errors(self):
"""shutdown_mcp_servers handles errors during close gracefully."""
import tools.mcp_tool as mcp_mod
from tools.mcp_tool import shutdown_mcp_servers, _servers
_servers.clear()
mock_server = MagicMock()
mock_server.name = "broken"
mock_server.shutdown = AsyncMock(side_effect=RuntimeError("close failed"))
_servers["broken"] = mock_server
mcp_mod._ensure_mcp_loop()
try:
shutdown_mcp_servers() # Should not raise
finally:
mcp_mod._mcp_loop = None
mcp_mod._mcp_thread = None
assert len(_servers) == 0
def test_shutdown_is_parallel(self):
"""Multiple servers are shut down in parallel via asyncio.gather."""
import tools.mcp_tool as mcp_mod
from tools.mcp_tool import shutdown_mcp_servers, _servers
import time
_servers.clear()
# 3 servers each taking 1s to shut down
for i in range(3):
mock_server = MagicMock()
mock_server.name = f"srv_{i}"
async def slow_shutdown():
await asyncio.sleep(1)
mock_server.shutdown = slow_shutdown
_servers[f"srv_{i}"] = mock_server
mcp_mod._ensure_mcp_loop()
try:
start = time.monotonic()
shutdown_mcp_servers()
elapsed = time.monotonic() - start
finally:
mcp_mod._mcp_loop = None
mcp_mod._mcp_thread = None
assert len(_servers) == 0
# Parallel: ~1s, not ~3s. Allow some margin.
assert elapsed < 2.5, f"Shutdown took {elapsed:.1f}s, expected ~1s (parallel)"
# ---------------------------------------------------------------------------
# _build_safe_env
# ---------------------------------------------------------------------------
class TestBuildSafeEnv:
"""Tests for _build_safe_env() environment filtering."""
def test_only_safe_vars_passed(self):
"""Only safe baseline vars and XDG_* from os.environ are included."""
from tools.mcp_tool import _build_safe_env
fake_env = {
"PATH": "/usr/bin",
"HOME": "/home/test",
"USER": "test",
"LANG": "en_US.UTF-8",
"LC_ALL": "C",
"TERM": "xterm",
"SHELL": "/bin/bash",
"TMPDIR": "/tmp",
"XDG_DATA_HOME": "/home/test/.local/share",
"SECRET_KEY": "should_not_appear",
"AWS_ACCESS_KEY_ID": "AKIAIOSFODNN7EXAMPLE",
}
with patch.dict("os.environ", fake_env, clear=True):
result = _build_safe_env(None)
# Safe vars present
assert result["PATH"] == "/usr/bin"
assert result["HOME"] == "/home/test"
assert result["USER"] == "test"
assert result["LANG"] == "en_US.UTF-8"
assert result["XDG_DATA_HOME"] == "/home/test/.local/share"
# Unsafe vars excluded
assert "SECRET_KEY" not in result
assert "AWS_ACCESS_KEY_ID" not in result
def test_user_env_merged(self):
"""User-specified env vars are merged into the safe env."""
from tools.mcp_tool import _build_safe_env
with patch.dict("os.environ", {"PATH": "/usr/bin"}, clear=True):
result = _build_safe_env({"MY_CUSTOM_VAR": "hello"})
assert result["PATH"] == "/usr/bin"
assert result["MY_CUSTOM_VAR"] == "hello"
def test_user_env_overrides_safe(self):
"""User env can override safe defaults."""
from tools.mcp_tool import _build_safe_env
with patch.dict("os.environ", {"PATH": "/usr/bin"}, clear=True):
result = _build_safe_env({"PATH": "/custom/bin"})
assert result["PATH"] == "/custom/bin"
def test_none_user_env(self):
"""None user_env still returns safe vars from os.environ."""
from tools.mcp_tool import _build_safe_env
with patch.dict("os.environ", {"PATH": "/usr/bin", "HOME": "/root"}, clear=True):
result = _build_safe_env(None)
assert isinstance(result, dict)
assert result["PATH"] == "/usr/bin"
assert result["HOME"] == "/root"
def test_secret_vars_excluded(self):
"""Sensitive env vars from os.environ are NOT passed through."""
from tools.mcp_tool import _build_safe_env
fake_env = {
"PATH": "/usr/bin",
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
"GITHUB_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"OPENAI_API_KEY": "sk-proj-abc123",
"DATABASE_URL": "postgres://user:pass@localhost/db",
"API_SECRET": "supersecret",
}
with patch.dict("os.environ", fake_env, clear=True):
result = _build_safe_env(None)
assert "PATH" in result
assert "AWS_SECRET_ACCESS_KEY" not in result
assert "GITHUB_TOKEN" not in result
assert "OPENAI_API_KEY" not in result
assert "DATABASE_URL" not in result
assert "API_SECRET" not in result
# ---------------------------------------------------------------------------
# _sanitize_error
# ---------------------------------------------------------------------------
class TestSanitizeError:
"""Tests for _sanitize_error() credential stripping."""
def test_strips_github_pat(self):
from tools.mcp_tool import _sanitize_error
result = _sanitize_error("Error with ghp_abc123def456")
assert result == "Error with [REDACTED]"
def test_strips_openai_key(self):
from tools.mcp_tool import _sanitize_error
result = _sanitize_error("key sk-projABC123xyz")
assert result == "key [REDACTED]"
def test_strips_bearer_token(self):
from tools.mcp_tool import _sanitize_error
result = _sanitize_error("Authorization: Bearer eyJabc123def")
assert result == "Authorization: [REDACTED]"
def test_strips_token_param(self):
from tools.mcp_tool import _sanitize_error
result = _sanitize_error("url?token=secret123")
assert result == "url?[REDACTED]"
def test_no_credentials_unchanged(self):
from tools.mcp_tool import _sanitize_error
result = _sanitize_error("normal error message")
assert result == "normal error message"
def test_multiple_credentials(self):
from tools.mcp_tool import _sanitize_error
result = _sanitize_error("ghp_abc123 and sk-projXyz789 and token=foo")
assert "ghp_" not in result
assert "sk-" not in result
assert "token=" not in result
assert result.count("[REDACTED]") == 3
# ---------------------------------------------------------------------------
# HTTP config
# ---------------------------------------------------------------------------
class TestHTTPConfig:
"""Tests for HTTP transport detection and handling."""
def test_is_http_with_url(self):
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("remote")
server._config = {"url": "https://example.com/mcp"}
assert server._is_http() is True
def test_is_stdio_with_command(self):
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("local")
server._config = {"command": "npx", "args": []}
assert server._is_http() is False
def test_conflicting_url_and_command_warns(self):
"""Config with both url and command logs a warning and uses HTTP."""
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("conflict")
config = {"url": "https://example.com/mcp", "command": "npx", "args": []}
# url takes precedence
server._config = config
assert server._is_http() is True
def test_http_unavailable_raises(self):
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("remote")
config = {"url": "https://example.com/mcp"}
async def _test():
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", False):
with pytest.raises(ImportError, match="HTTP transport"):
await server._run_http(config)
asyncio.run(_test())
# ---------------------------------------------------------------------------
# Reconnection logic
# ---------------------------------------------------------------------------
class TestReconnection:
"""Tests for automatic reconnection behavior in MCPServerTask.run()."""
def test_reconnect_on_disconnect(self):
"""After initial success, a connection drop triggers reconnection."""
from tools.mcp_tool import MCPServerTask
run_count = 0
target_server = None
original_run_stdio = MCPServerTask._run_stdio
async def patched_run_stdio(self_srv, config):
nonlocal run_count, target_server
run_count += 1
if target_server is not self_srv:
return await original_run_stdio(self_srv, config)
if run_count == 1:
# First connection succeeds, then simulate disconnect
self_srv.session = MagicMock()
self_srv._tools = []
self_srv._ready.set()
raise ConnectionError("connection dropped")
else:
# Reconnection succeeds; signal shutdown so run() exits
self_srv.session = MagicMock()
self_srv._shutdown_event.set()
await self_srv._shutdown_event.wait()
async def _test():
nonlocal target_server
server = MCPServerTask("test_srv")
target_server = server
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \
patch("asyncio.sleep", new_callable=AsyncMock):
await server.run({"command": "test"})
assert run_count >= 2 # At least one reconnection attempt
asyncio.run(_test())
def test_no_reconnect_on_shutdown(self):
"""If shutdown is requested, don't attempt reconnection."""
from tools.mcp_tool import MCPServerTask
run_count = 0
target_server = None
original_run_stdio = MCPServerTask._run_stdio
async def patched_run_stdio(self_srv, config):
nonlocal run_count, target_server
run_count += 1
if target_server is not self_srv:
return await original_run_stdio(self_srv, config)
self_srv.session = MagicMock()
self_srv._tools = []
self_srv._ready.set()
raise ConnectionError("connection dropped")
async def _test():
nonlocal target_server
server = MCPServerTask("test_srv")
target_server = server
server._shutdown_event.set() # Shutdown already requested
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \
patch("asyncio.sleep", new_callable=AsyncMock):
await server.run({"command": "test"})
# Should not retry because shutdown was set
assert run_count == 1
asyncio.run(_test())
def test_no_reconnect_on_initial_failure(self):
"""First connection failure reports error immediately, no retry."""
from tools.mcp_tool import MCPServerTask
run_count = 0
target_server = None
original_run_stdio = MCPServerTask._run_stdio
async def patched_run_stdio(self_srv, config):
nonlocal run_count, target_server
run_count += 1
if target_server is not self_srv:
return await original_run_stdio(self_srv, config)
raise ConnectionError("cannot connect")
async def _test():
nonlocal target_server
server = MCPServerTask("test_srv")
target_server = server
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \
patch("asyncio.sleep", new_callable=AsyncMock):
await server.run({"command": "test"})
# Only one attempt, no retry on initial failure
assert run_count == 1
assert server._error is not None
assert "cannot connect" in str(server._error)
asyncio.run(_test())
# ---------------------------------------------------------------------------
# Configurable timeouts
# ---------------------------------------------------------------------------
class TestConfigurableTimeouts:
"""Tests for configurable per-server timeouts."""
def test_default_timeout(self):
"""Server with no timeout config gets _DEFAULT_TOOL_TIMEOUT."""
from tools.mcp_tool import MCPServerTask, _DEFAULT_TOOL_TIMEOUT
server = MCPServerTask("test_srv")
assert server.tool_timeout == _DEFAULT_TOOL_TIMEOUT
assert server.tool_timeout == 120
def test_custom_timeout(self):
"""Server with timeout=180 in config gets 180."""
from tools.mcp_tool import MCPServerTask
target_server = None
original_run_stdio = MCPServerTask._run_stdio
async def patched_run_stdio(self_srv, config):
if target_server is not self_srv:
return await original_run_stdio(self_srv, config)
self_srv.session = MagicMock()
self_srv._tools = []
self_srv._ready.set()
await self_srv._shutdown_event.wait()
async def _test():
nonlocal target_server
server = MCPServerTask("test_srv")
target_server = server
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio):
task = asyncio.ensure_future(
server.run({"command": "test", "timeout": 180})
)
await server._ready.wait()
assert server.tool_timeout == 180
server._shutdown_event.set()
await task
asyncio.run(_test())
def test_timeout_passed_to_handler(self):
"""The tool handler uses the server's configured timeout."""
from tools.mcp_tool import _make_tool_handler, _servers, MCPServerTask
mock_session = MagicMock()
mock_session.call_tool = AsyncMock(
return_value=_make_call_result("ok", is_error=False)
)
server = _make_mock_server("test_srv", session=mock_session)
server.tool_timeout = 180
_servers["test_srv"] = server
try:
handler = _make_tool_handler("test_srv", "my_tool", 180)
with patch("tools.mcp_tool._run_on_mcp_loop") as mock_run:
mock_run.return_value = json.dumps({"result": "ok"})
handler({})
# Verify timeout=180 was passed
call_kwargs = mock_run.call_args
assert call_kwargs.kwargs.get("timeout") == 180 or \
(len(call_kwargs.args) > 1 and call_kwargs.args[1] == 180) or \
call_kwargs[1].get("timeout") == 180
finally:
_servers.pop("test_srv", None)
# ---------------------------------------------------------------------------
# Utility tool schemas (Resources & Prompts)
# ---------------------------------------------------------------------------
class TestUtilitySchemas:
"""Tests for _build_utility_schemas() and the schema format of utility tools."""
def test_builds_four_utility_schemas(self):
from tools.mcp_tool import _build_utility_schemas
schemas = _build_utility_schemas("myserver")
assert len(schemas) == 4
names = [s["schema"]["name"] for s in schemas]
assert "mcp_myserver_list_resources" in names
assert "mcp_myserver_read_resource" in names
assert "mcp_myserver_list_prompts" in names
assert "mcp_myserver_get_prompt" in names
def test_hyphens_sanitized_in_utility_names(self):
from tools.mcp_tool import _build_utility_schemas
schemas = _build_utility_schemas("my-server")
names = [s["schema"]["name"] for s in schemas]
for name in names:
assert "-" not in name
assert "mcp_my_server_list_resources" in names
def test_list_resources_schema_no_required_params(self):
from tools.mcp_tool import _build_utility_schemas
schemas = _build_utility_schemas("srv")
lr = next(s for s in schemas if s["handler_key"] == "list_resources")
params = lr["schema"]["parameters"]
assert params["type"] == "object"
assert params["properties"] == {}
assert "required" not in params
def test_read_resource_schema_requires_uri(self):
from tools.mcp_tool import _build_utility_schemas
schemas = _build_utility_schemas("srv")
rr = next(s for s in schemas if s["handler_key"] == "read_resource")
params = rr["schema"]["parameters"]
assert "uri" in params["properties"]
assert params["properties"]["uri"]["type"] == "string"
assert params["required"] == ["uri"]
def test_list_prompts_schema_no_required_params(self):
from tools.mcp_tool import _build_utility_schemas
schemas = _build_utility_schemas("srv")
lp = next(s for s in schemas if s["handler_key"] == "list_prompts")
params = lp["schema"]["parameters"]
assert params["type"] == "object"
assert params["properties"] == {}
assert "required" not in params
def test_get_prompt_schema_requires_name(self):
from tools.mcp_tool import _build_utility_schemas
schemas = _build_utility_schemas("srv")
gp = next(s for s in schemas if s["handler_key"] == "get_prompt")
params = gp["schema"]["parameters"]
assert "name" in params["properties"]
assert params["properties"]["name"]["type"] == "string"
assert "arguments" in params["properties"]
assert params["properties"]["arguments"]["type"] == "object"
assert params["required"] == ["name"]
def test_schemas_have_descriptions(self):
from tools.mcp_tool import _build_utility_schemas
schemas = _build_utility_schemas("test_srv")
for entry in schemas:
desc = entry["schema"]["description"]
assert desc and len(desc) > 0
assert "test_srv" in desc
# ---------------------------------------------------------------------------
# Utility tool handlers (Resources & Prompts)
# ---------------------------------------------------------------------------
class TestUtilityHandlers:
"""Tests for the MCP Resources & Prompts handler functions."""
def _patch_mcp_loop(self):
"""Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
def fake_run(coro, timeout=30):
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run)
# -- list_resources --
def test_list_resources_success(self):
from tools.mcp_tool import _make_list_resources_handler, _servers
mock_resource = SimpleNamespace(
uri="file:///tmp/test.txt", name="test.txt",
description="A test file", mimeType="text/plain",
)
mock_session = MagicMock()
mock_session.list_resources = AsyncMock(
return_value=SimpleNamespace(resources=[mock_resource])
)
server = _make_mock_server("srv", session=mock_session)
_servers["srv"] = server
try:
handler = _make_list_resources_handler("srv", 120)
with self._patch_mcp_loop():
result = json.loads(handler({}))
assert "resources" in result
assert len(result["resources"]) == 1
assert result["resources"][0]["uri"] == "file:///tmp/test.txt"
assert result["resources"][0]["name"] == "test.txt"
finally:
_servers.pop("srv", None)
def test_list_resources_empty(self):
from tools.mcp_tool import _make_list_resources_handler, _servers
mock_session = MagicMock()
mock_session.list_resources = AsyncMock(
return_value=SimpleNamespace(resources=[])
)
server = _make_mock_server("srv", session=mock_session)
_servers["srv"] = server
try:
handler = _make_list_resources_handler("srv", 120)
with self._patch_mcp_loop():
result = json.loads(handler({}))
assert result["resources"] == []
finally:
_servers.pop("srv", None)
def test_list_resources_disconnected(self):
from tools.mcp_tool import _make_list_resources_handler, _servers
_servers.pop("ghost", None)
handler = _make_list_resources_handler("ghost", 120)
result = json.loads(handler({}))
assert "error" in result
assert "not connected" in result["error"]
# -- read_resource --
def test_read_resource_success(self):
from tools.mcp_tool import _make_read_resource_handler, _servers
content_block = SimpleNamespace(text="Hello from resource")
mock_session = MagicMock()
mock_session.read_resource = AsyncMock(
return_value=SimpleNamespace(contents=[content_block])
)
server = _make_mock_server("srv", session=mock_session)
_servers["srv"] = server
try:
handler = _make_read_resource_handler("srv", 120)
with self._patch_mcp_loop():
result = json.loads(handler({"uri": "file:///tmp/test.txt"}))
assert result["result"] == "Hello from resource"
mock_session.read_resource.assert_called_once_with("file:///tmp/test.txt")
finally:
_servers.pop("srv", None)
def test_read_resource_missing_uri(self):
from tools.mcp_tool import _make_read_resource_handler, _servers
server = _make_mock_server("srv", session=MagicMock())
_servers["srv"] = server
try:
handler = _make_read_resource_handler("srv", 120)
result = json.loads(handler({}))
assert "error" in result
assert "uri" in result["error"].lower()
finally:
_servers.pop("srv", None)
def test_read_resource_disconnected(self):
from tools.mcp_tool import _make_read_resource_handler, _servers
_servers.pop("ghost", None)
handler = _make_read_resource_handler("ghost", 120)
result = json.loads(handler({"uri": "test://x"}))
assert "error" in result
assert "not connected" in result["error"]
# -- list_prompts --
def test_list_prompts_success(self):
from tools.mcp_tool import _make_list_prompts_handler, _servers
mock_prompt = SimpleNamespace(
name="summarize", description="Summarize text",
arguments=[
SimpleNamespace(name="text", description="Text to summarize", required=True),
],
)
mock_session = MagicMock()
mock_session.list_prompts = AsyncMock(
return_value=SimpleNamespace(prompts=[mock_prompt])
)
server = _make_mock_server("srv", session=mock_session)
_servers["srv"] = server
try:
handler = _make_list_prompts_handler("srv", 120)
with self._patch_mcp_loop():
result = json.loads(handler({}))
assert "prompts" in result
assert len(result["prompts"]) == 1
assert result["prompts"][0]["name"] == "summarize"
assert result["prompts"][0]["arguments"][0]["name"] == "text"
finally:
_servers.pop("srv", None)
def test_list_prompts_empty(self):
from tools.mcp_tool import _make_list_prompts_handler, _servers
mock_session = MagicMock()
mock_session.list_prompts = AsyncMock(
return_value=SimpleNamespace(prompts=[])
)
server = _make_mock_server("srv", session=mock_session)
_servers["srv"] = server
try:
handler = _make_list_prompts_handler("srv", 120)
with self._patch_mcp_loop():
result = json.loads(handler({}))
assert result["prompts"] == []
finally:
_servers.pop("srv", None)
def test_list_prompts_disconnected(self):
from tools.mcp_tool import _make_list_prompts_handler, _servers
_servers.pop("ghost", None)
handler = _make_list_prompts_handler("ghost", 120)
result = json.loads(handler({}))
assert "error" in result
assert "not connected" in result["error"]
# -- get_prompt --
def test_get_prompt_success(self):
from tools.mcp_tool import _make_get_prompt_handler, _servers
mock_msg = SimpleNamespace(
role="assistant",
content=SimpleNamespace(text="Here is a summary of your text."),
)
mock_session = MagicMock()
mock_session.get_prompt = AsyncMock(
return_value=SimpleNamespace(messages=[mock_msg], description=None)
)
server = _make_mock_server("srv", session=mock_session)
_servers["srv"] = server
try:
handler = _make_get_prompt_handler("srv", 120)
with self._patch_mcp_loop():
result = json.loads(handler({"name": "summarize", "arguments": {"text": "hello"}}))
assert "messages" in result
assert len(result["messages"]) == 1
assert result["messages"][0]["role"] == "assistant"
assert "summary" in result["messages"][0]["content"].lower()
mock_session.get_prompt.assert_called_once_with(
"summarize", arguments={"text": "hello"}
)
finally:
_servers.pop("srv", None)
def test_get_prompt_missing_name(self):
from tools.mcp_tool import _make_get_prompt_handler, _servers
server = _make_mock_server("srv", session=MagicMock())
_servers["srv"] = server
try:
handler = _make_get_prompt_handler("srv", 120)
result = json.loads(handler({}))
assert "error" in result
assert "name" in result["error"].lower()
finally:
_servers.pop("srv", None)
def test_get_prompt_disconnected(self):
from tools.mcp_tool import _make_get_prompt_handler, _servers
_servers.pop("ghost", None)
handler = _make_get_prompt_handler("ghost", 120)
result = json.loads(handler({"name": "test"}))
assert "error" in result
assert "not connected" in result["error"]
def test_get_prompt_default_arguments(self):
from tools.mcp_tool import _make_get_prompt_handler, _servers
mock_session = MagicMock()
mock_session.get_prompt = AsyncMock(
return_value=SimpleNamespace(messages=[], description=None)
)
server = _make_mock_server("srv", session=mock_session)
_servers["srv"] = server
try:
handler = _make_get_prompt_handler("srv", 120)
with self._patch_mcp_loop():
handler({"name": "test_prompt"})
# arguments defaults to {} when not provided
mock_session.get_prompt.assert_called_once_with(
"test_prompt", arguments={}
)
finally:
_servers.pop("srv", None)
# ---------------------------------------------------------------------------
# Utility tools registration in _discover_and_register_server
# ---------------------------------------------------------------------------
class TestUtilityToolRegistration:
"""Verify utility tools are registered alongside regular MCP tools."""
def test_utility_tools_registered(self):
"""_discover_and_register_server registers all 4 utility tools."""
from tools.registry import ToolRegistry
from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
mock_registry = ToolRegistry()
mock_tools = [_make_mcp_tool("read_file", "Read a file")]
mock_session = MagicMock()
async def fake_connect(name, config):
server = MCPServerTask(name)
server.session = mock_session
server._tools = mock_tools
return server
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
patch("tools.registry.registry", mock_registry):
registered = asyncio.run(
_discover_and_register_server("fs", {"command": "npx", "args": []})
)
# Regular tool + 4 utility tools
assert "mcp_fs_read_file" in registered
assert "mcp_fs_list_resources" in registered
assert "mcp_fs_read_resource" in registered
assert "mcp_fs_list_prompts" in registered
assert "mcp_fs_get_prompt" in registered
assert len(registered) == 5
# All in the registry
all_names = mock_registry.get_all_tool_names()
for name in registered:
assert name in all_names
_servers.pop("fs", None)
def test_utility_tools_in_same_toolset(self):
"""Utility tools belong to the same mcp-{server} toolset."""
from tools.registry import ToolRegistry
from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
mock_registry = ToolRegistry()
mock_session = MagicMock()
async def fake_connect(name, config):
server = MCPServerTask(name)
server.session = mock_session
server._tools = []
return server
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
patch("tools.registry.registry", mock_registry):
asyncio.run(
_discover_and_register_server("myserv", {"command": "test"})
)
# Check that utility tools are in the right toolset
for tool_name in ["mcp_myserv_list_resources", "mcp_myserv_read_resource",
"mcp_myserv_list_prompts", "mcp_myserv_get_prompt"]:
entry = mock_registry._tools.get(tool_name)
assert entry is not None, f"{tool_name} not found in registry"
assert entry.toolset == "mcp-myserv"
_servers.pop("myserv", None)
def test_utility_tools_have_check_fn(self):
"""Utility tools have a working check_fn."""
from tools.registry import ToolRegistry
from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
mock_registry = ToolRegistry()
mock_session = MagicMock()
async def fake_connect(name, config):
server = MCPServerTask(name)
server.session = mock_session
server._tools = []
return server
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
patch("tools.registry.registry", mock_registry):
asyncio.run(
_discover_and_register_server("chk", {"command": "test"})
)
entry = mock_registry._tools.get("mcp_chk_list_resources")
assert entry is not None
# Server is connected, check_fn should return True
assert entry.check_fn() is True
# Disconnect the server
_servers["chk"].session = None
assert entry.check_fn() is False
_servers.pop("chk", None)