diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 065baf4a1..4b7e2c722 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -5,6 +5,7 @@ All tests use mocks -- no real MCP servers or subprocesses are started. import asyncio import json +import os from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch @@ -189,7 +190,7 @@ class TestToolHandler: _servers["test_srv"] = server try: - handler = _make_tool_handler("test_srv", "greet") + handler = _make_tool_handler("test_srv", "greet", 120) with self._patch_mcp_loop(): result = json.loads(handler({"name": "world"})) assert result["result"] == "hello world" @@ -208,7 +209,7 @@ class TestToolHandler: _servers["test_srv"] = server try: - handler = _make_tool_handler("test_srv", "fail_tool") + handler = _make_tool_handler("test_srv", "fail_tool", 120) with self._patch_mcp_loop(): result = json.loads(handler({})) assert "error" in result @@ -220,7 +221,7 @@ class TestToolHandler: from tools.mcp_tool import _make_tool_handler, _servers _servers.pop("ghost", None) - handler = _make_tool_handler("ghost", "any_tool") + handler = _make_tool_handler("ghost", "any_tool", 120) result = json.loads(handler({})) assert "error" in result assert "not connected" in result["error"] @@ -234,7 +235,7 @@ class TestToolHandler: _servers["test_srv"] = server try: - handler = _make_tool_handler("test_srv", "broken_tool") + handler = _make_tool_handler("test_srv", "broken_tool", 120) with self._patch_mcp_loop(): result = json.loads(handler({})) assert "error" in result @@ -400,8 +401,8 @@ class TestMCPServerTask: asyncio.run(_test()) - def test_empty_env_passed_as_none(self): - """Empty env dict is passed as None to StdioServerParameters.""" + def test_empty_env_gets_safe_defaults(self): + """Empty env dict gets safe default env vars (PATH, HOME, etc.).""" from tools.mcp_tool import MCPServerTask mock_session = MagicMock() @@ -414,13 +415,18 @@ class TestMCPServerTask: async def _test(): with patch("tools.mcp_tool.StdioServerParameters") as mock_params, \ - p_stdio, p_cs: + p_stdio, p_cs, \ + patch.dict("os.environ", {"PATH": "/usr/bin", "HOME": "/home/test"}, clear=False): server = MCPServerTask("srv") await server.start({"command": "node", "env": {}}) - # Empty dict -> None + # Empty dict -> safe env vars (not None) call_kwargs = mock_params.call_args - assert call_kwargs.kwargs.get("env") is None + env_arg = call_kwargs.kwargs.get("env") + assert env_arg is not None + assert isinstance(env_arg, dict) + assert "PATH" in env_arg + assert "HOME" in env_arg await server.shutdown() @@ -698,3 +704,353 @@ class TestShutdown: assert len(_servers) == 0 # Parallel: ~1s, not ~3s. Allow some margin. assert elapsed < 2.5, f"Shutdown took {elapsed:.1f}s, expected ~1s (parallel)" + + +# --------------------------------------------------------------------------- +# _build_safe_env +# --------------------------------------------------------------------------- + +class TestBuildSafeEnv: + """Tests for _build_safe_env() environment filtering.""" + + def test_only_safe_vars_passed(self): + """Only safe baseline vars and XDG_* from os.environ are included.""" + from tools.mcp_tool import _build_safe_env + + fake_env = { + "PATH": "/usr/bin", + "HOME": "/home/test", + "USER": "test", + "LANG": "en_US.UTF-8", + "LC_ALL": "C", + "TERM": "xterm", + "SHELL": "/bin/bash", + "TMPDIR": "/tmp", + "XDG_DATA_HOME": "/home/test/.local/share", + "SECRET_KEY": "should_not_appear", + "AWS_ACCESS_KEY_ID": "AKIAIOSFODNN7EXAMPLE", + } + with patch.dict("os.environ", fake_env, clear=True): + result = _build_safe_env(None) + + # Safe vars present + assert result["PATH"] == "/usr/bin" + assert result["HOME"] == "/home/test" + assert result["USER"] == "test" + assert result["LANG"] == "en_US.UTF-8" + assert result["XDG_DATA_HOME"] == "/home/test/.local/share" + # Unsafe vars excluded + assert "SECRET_KEY" not in result + assert "AWS_ACCESS_KEY_ID" not in result + + def test_user_env_merged(self): + """User-specified env vars are merged into the safe env.""" + from tools.mcp_tool import _build_safe_env + + with patch.dict("os.environ", {"PATH": "/usr/bin"}, clear=True): + result = _build_safe_env({"MY_CUSTOM_VAR": "hello"}) + + assert result["PATH"] == "/usr/bin" + assert result["MY_CUSTOM_VAR"] == "hello" + + def test_user_env_overrides_safe(self): + """User env can override safe defaults.""" + from tools.mcp_tool import _build_safe_env + + with patch.dict("os.environ", {"PATH": "/usr/bin"}, clear=True): + result = _build_safe_env({"PATH": "/custom/bin"}) + + assert result["PATH"] == "/custom/bin" + + def test_none_user_env(self): + """None user_env still returns safe vars from os.environ.""" + from tools.mcp_tool import _build_safe_env + + with patch.dict("os.environ", {"PATH": "/usr/bin", "HOME": "/root"}, clear=True): + result = _build_safe_env(None) + + assert isinstance(result, dict) + assert result["PATH"] == "/usr/bin" + assert result["HOME"] == "/root" + + def test_secret_vars_excluded(self): + """Sensitive env vars from os.environ are NOT passed through.""" + from tools.mcp_tool import _build_safe_env + + fake_env = { + "PATH": "/usr/bin", + "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "GITHUB_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + "OPENAI_API_KEY": "sk-proj-abc123", + "DATABASE_URL": "postgres://user:pass@localhost/db", + "API_SECRET": "supersecret", + } + with patch.dict("os.environ", fake_env, clear=True): + result = _build_safe_env(None) + + assert "PATH" in result + assert "AWS_SECRET_ACCESS_KEY" not in result + assert "GITHUB_TOKEN" not in result + assert "OPENAI_API_KEY" not in result + assert "DATABASE_URL" not in result + assert "API_SECRET" not in result + + +# --------------------------------------------------------------------------- +# _sanitize_error +# --------------------------------------------------------------------------- + +class TestSanitizeError: + """Tests for _sanitize_error() credential stripping.""" + + def test_strips_github_pat(self): + from tools.mcp_tool import _sanitize_error + result = _sanitize_error("Error with ghp_abc123def456") + assert result == "Error with [REDACTED]" + + def test_strips_openai_key(self): + from tools.mcp_tool import _sanitize_error + result = _sanitize_error("key sk-projABC123xyz") + assert result == "key [REDACTED]" + + def test_strips_bearer_token(self): + from tools.mcp_tool import _sanitize_error + result = _sanitize_error("Authorization: Bearer eyJabc123def") + assert result == "Authorization: [REDACTED]" + + def test_strips_token_param(self): + from tools.mcp_tool import _sanitize_error + result = _sanitize_error("url?token=secret123") + assert result == "url?[REDACTED]" + + def test_no_credentials_unchanged(self): + from tools.mcp_tool import _sanitize_error + result = _sanitize_error("normal error message") + assert result == "normal error message" + + def test_multiple_credentials(self): + from tools.mcp_tool import _sanitize_error + result = _sanitize_error("ghp_abc123 and sk-projXyz789 and token=foo") + assert "ghp_" not in result + assert "sk-" not in result + assert "token=" not in result + assert result.count("[REDACTED]") == 3 + + +# --------------------------------------------------------------------------- +# HTTP config +# --------------------------------------------------------------------------- + +class TestHTTPConfig: + """Tests for HTTP transport detection and handling.""" + + def test_is_http_with_url(self): + from tools.mcp_tool import MCPServerTask + server = MCPServerTask("remote") + server._config = {"url": "https://example.com/mcp"} + assert server._is_http() is True + + def test_is_stdio_with_command(self): + from tools.mcp_tool import MCPServerTask + server = MCPServerTask("local") + server._config = {"command": "npx", "args": []} + assert server._is_http() is False + + def test_http_unavailable_raises(self): + from tools.mcp_tool import MCPServerTask + + server = MCPServerTask("remote") + config = {"url": "https://example.com/mcp"} + + async def _test(): + with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", False): + with pytest.raises(ImportError, match="HTTP transport"): + await server._run_http(config) + + asyncio.run(_test()) + + +# --------------------------------------------------------------------------- +# Reconnection logic +# --------------------------------------------------------------------------- + +class TestReconnection: + """Tests for automatic reconnection behavior in MCPServerTask.run().""" + + def test_reconnect_on_disconnect(self): + """After initial success, a connection drop triggers reconnection.""" + from tools.mcp_tool import MCPServerTask + + run_count = 0 + target_server = None + + original_run_stdio = MCPServerTask._run_stdio + + async def patched_run_stdio(self_srv, config): + nonlocal run_count, target_server + run_count += 1 + if target_server is not self_srv: + return await original_run_stdio(self_srv, config) + if run_count == 1: + # First connection succeeds, then simulate disconnect + self_srv.session = MagicMock() + self_srv._tools = [] + self_srv._ready.set() + raise ConnectionError("connection dropped") + else: + # Reconnection succeeds; signal shutdown so run() exits + self_srv.session = MagicMock() + self_srv._shutdown_event.set() + await self_srv._shutdown_event.wait() + + async def _test(): + nonlocal target_server + server = MCPServerTask("test_srv") + target_server = server + + with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \ + patch("asyncio.sleep", new_callable=AsyncMock): + await server.run({"command": "test"}) + + assert run_count >= 2 # At least one reconnection attempt + + asyncio.run(_test()) + + def test_no_reconnect_on_shutdown(self): + """If shutdown is requested, don't attempt reconnection.""" + from tools.mcp_tool import MCPServerTask + + run_count = 0 + target_server = None + + original_run_stdio = MCPServerTask._run_stdio + + async def patched_run_stdio(self_srv, config): + nonlocal run_count, target_server + run_count += 1 + if target_server is not self_srv: + return await original_run_stdio(self_srv, config) + self_srv.session = MagicMock() + self_srv._tools = [] + self_srv._ready.set() + raise ConnectionError("connection dropped") + + async def _test(): + nonlocal target_server + server = MCPServerTask("test_srv") + target_server = server + server._shutdown_event.set() # Shutdown already requested + + with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \ + patch("asyncio.sleep", new_callable=AsyncMock): + await server.run({"command": "test"}) + + # Should not retry because shutdown was set + assert run_count == 1 + + asyncio.run(_test()) + + def test_no_reconnect_on_initial_failure(self): + """First connection failure reports error immediately, no retry.""" + from tools.mcp_tool import MCPServerTask + + run_count = 0 + target_server = None + + original_run_stdio = MCPServerTask._run_stdio + + async def patched_run_stdio(self_srv, config): + nonlocal run_count, target_server + run_count += 1 + if target_server is not self_srv: + return await original_run_stdio(self_srv, config) + raise ConnectionError("cannot connect") + + async def _test(): + nonlocal target_server + server = MCPServerTask("test_srv") + target_server = server + + with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \ + patch("asyncio.sleep", new_callable=AsyncMock): + await server.run({"command": "test"}) + + # Only one attempt, no retry on initial failure + assert run_count == 1 + assert server._error is not None + assert "cannot connect" in str(server._error) + + asyncio.run(_test()) + + +# --------------------------------------------------------------------------- +# Configurable timeouts +# --------------------------------------------------------------------------- + +class TestConfigurableTimeouts: + """Tests for configurable per-server timeouts.""" + + def test_default_timeout(self): + """Server with no timeout config gets _DEFAULT_TOOL_TIMEOUT.""" + from tools.mcp_tool import MCPServerTask, _DEFAULT_TOOL_TIMEOUT + + server = MCPServerTask("test_srv") + assert server.tool_timeout == _DEFAULT_TOOL_TIMEOUT + assert server.tool_timeout == 120 + + def test_custom_timeout(self): + """Server with timeout=180 in config gets 180.""" + from tools.mcp_tool import MCPServerTask + + target_server = None + + original_run_stdio = MCPServerTask._run_stdio + + async def patched_run_stdio(self_srv, config): + if target_server is not self_srv: + return await original_run_stdio(self_srv, config) + self_srv.session = MagicMock() + self_srv._tools = [] + self_srv._ready.set() + await self_srv._shutdown_event.wait() + + async def _test(): + nonlocal target_server + server = MCPServerTask("test_srv") + target_server = server + + with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio): + task = asyncio.ensure_future( + server.run({"command": "test", "timeout": 180}) + ) + await server._ready.wait() + assert server.tool_timeout == 180 + server._shutdown_event.set() + await task + + asyncio.run(_test()) + + def test_timeout_passed_to_handler(self): + """The tool handler uses the server's configured timeout.""" + from tools.mcp_tool import _make_tool_handler, _servers, MCPServerTask + + mock_session = MagicMock() + mock_session.call_tool = AsyncMock( + return_value=_make_call_result("ok", is_error=False) + ) + server = _make_mock_server("test_srv", session=mock_session) + server.tool_timeout = 180 + _servers["test_srv"] = server + + try: + handler = _make_tool_handler("test_srv", "my_tool", 180) + with patch("tools.mcp_tool._run_on_mcp_loop") as mock_run: + mock_run.return_value = json.dumps({"result": "ok"}) + handler({}) + # Verify timeout=180 was passed + call_kwargs = mock_run.call_args + assert call_kwargs.kwargs.get("timeout") == 180 or \ + (len(call_kwargs.args) > 1 and call_kwargs.args[1] == 180) or \ + call_kwargs[1].get("timeout") == 180 + finally: + _servers.pop("test_srv", None) diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 4ab55215b..1419327c8 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -2,9 +2,9 @@ """ MCP (Model Context Protocol) Client Support -Connects to external MCP servers via stdio transport, discovers their tools, -and registers them into the hermes-agent tool registry so the agent can call -them like any built-in tool. +Connects to external MCP servers via stdio or HTTP/StreamableHTTP transport, +discovers their tools, and registers them into the hermes-agent tool registry +so the agent can call them like any built-in tool. Configuration is read from ~/.hermes/config.yaml under the ``mcp_servers`` key. The ``mcp`` Python package is optional -- if not installed, this module is a @@ -17,17 +17,32 @@ Example config:: command: "npx" args: ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"] env: {} + timeout: 120 # per-tool-call timeout in seconds (default: 120) + connect_timeout: 60 # initial connection timeout (default: 60) github: command: "npx" args: ["-y", "@modelcontextprotocol/server-github"] env: GITHUB_PERSONAL_ACCESS_TOKEN: "ghp_..." + remote_api: + url: "https://my-mcp-server.example.com/mcp" + headers: + Authorization: "Bearer sk-..." + timeout: 180 + +Features: + - Stdio transport (command + args) and HTTP/StreamableHTTP transport (url) + - Automatic reconnection with exponential backoff (up to 5 retries) + - Environment variable filtering for stdio subprocesses (security) + - Credential stripping in error messages returned to the LLM + - Configurable per-server timeouts for tool calls and connections + - Thread-safe architecture with dedicated background event loop Architecture: A dedicated background event loop (_mcp_loop) runs in a daemon thread. Each MCP server runs as a long-lived asyncio Task on this loop, keeping - its ``async with stdio_client(...)`` context alive. Tool call coroutines - are scheduled onto the loop via ``run_coroutine_threadsafe()``. + its transport context alive. Tool call coroutines are scheduled onto the + loop via ``run_coroutine_threadsafe()``. On shutdown, each server Task is signalled to exit its ``async with`` block, ensuring the anyio cancel-scope cleanup happens in the *same* @@ -43,6 +58,8 @@ Thread safety: import asyncio import json import logging +import os +import re import threading from typing import Any, Dict, List, Optional @@ -53,13 +70,81 @@ logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- _MCP_AVAILABLE = False +_MCP_HTTP_AVAILABLE = False try: from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client _MCP_AVAILABLE = True + try: + from mcp.client.streamable_http import streamablehttp_client + _MCP_HTTP_AVAILABLE = True + except ImportError: + _MCP_HTTP_AVAILABLE = False except ImportError: logger.debug("mcp package not installed -- MCP tool support disabled") +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls +_DEFAULT_DISCOVERY_TIMEOUT = 60 # seconds for server discovery +_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection +_MAX_RECONNECT_RETRIES = 5 +_MAX_BACKOFF_SECONDS = 60 + +# Environment variables that are safe to pass to stdio subprocesses +_SAFE_ENV_KEYS = frozenset({ + "PATH", "HOME", "USER", "LANG", "LC_ALL", "TERM", "SHELL", "TMPDIR", +}) + +# Regex for credential patterns to strip from error messages +_CREDENTIAL_PATTERN = re.compile( + r"(?:" + r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT + r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key + r"|Bearer\s+\S+" # Bearer token + r"|token=[^\s&,;\"']{1,255}" # token=... + r"|key=[^\s&,;\"']{1,255}" # key=... + r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=... + r"|password=[^\s&,;\"']{1,255}" # password=... + r"|secret=[^\s&,;\"']{1,255}" # secret=... + r")", + re.IGNORECASE, +) + + +# --------------------------------------------------------------------------- +# Security helpers +# --------------------------------------------------------------------------- + +def _build_safe_env(user_env: Optional[dict]) -> dict: + """Build a filtered environment dict for stdio subprocesses. + + Only passes through safe baseline variables (PATH, HOME, etc.) and XDG_* + variables from the current process environment, plus any variables + explicitly specified by the user in the server config. + + This prevents accidentally leaking secrets like API keys, tokens, or + credentials to MCP server subprocesses. + """ + env = {} + for key, value in os.environ.items(): + if key in _SAFE_ENV_KEYS or key.startswith("XDG_"): + env[key] = value + if user_env: + env.update(user_env) + return env + + +def _sanitize_error(text: str) -> str: + """Strip credential-like patterns from error text before returning to LLM. + + Replaces tokens, keys, and other secrets with [REDACTED] to prevent + accidental credential exposure in tool error responses. + """ + return _CREDENTIAL_PATTERN.sub("[REDACTED]", text) + # --------------------------------------------------------------------------- # Server task -- each MCP server lives in one long-lived asyncio Task @@ -70,66 +155,152 @@ class MCPServerTask: The entire connection lifecycle (connect, discover, serve, disconnect) runs inside one asyncio Task so that anyio cancel-scopes created by - ``stdio_client`` are entered and exited in the same Task context. + the transport client are entered and exited in the same Task context. + + Supports both stdio and HTTP/StreamableHTTP transports. """ __slots__ = ( - "name", "session", - "_task", "_ready", "_shutdown_event", "_tools", "_error", + "name", "session", "tool_timeout", + "_task", "_ready", "_shutdown_event", "_tools", "_error", "_config", ) def __init__(self, name: str): self.name = name self.session: Optional[Any] = None + self.tool_timeout: float = _DEFAULT_TOOL_TIMEOUT self._task: Optional[asyncio.Task] = None self._ready = asyncio.Event() self._shutdown_event = asyncio.Event() self._tools: list = [] self._error: Optional[Exception] = None + self._config: dict = {} - async def run(self, config: dict): - """Long-lived coroutine: connect, discover tools, wait, disconnect.""" + def _is_http(self) -> bool: + """Check if this server uses HTTP transport.""" + return "url" in self._config + + async def _run_stdio(self, config: dict): + """Run the server using stdio transport.""" command = config.get("command") args = config.get("args", []) - env = config.get("env") + user_env = config.get("env") if not command: - self._error = ValueError( + raise ValueError( f"MCP server '{self.name}' has no 'command' in config" ) - self._ready.set() - return + safe_env = _build_safe_env(user_env) server_params = StdioServerParameters( command=command, args=args, - env=env if env else None, + env=safe_env if safe_env else None, ) - try: - async with stdio_client(server_params) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - self.session = session + async with stdio_client(server_params) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + self.session = session + await self._discover_tools() + self._ready.set() + await self._shutdown_event.wait() - tools_result = await session.list_tools() - self._tools = ( - tools_result.tools - if hasattr(tools_result, "tools") - else [] - ) + async def _run_http(self, config: dict): + """Run the server using HTTP/StreamableHTTP transport.""" + if not _MCP_HTTP_AVAILABLE: + raise ImportError( + f"MCP server '{self.name}' requires HTTP transport but " + "mcp.client.streamable_http is not available. " + "Upgrade the mcp package to get HTTP support." + ) - # Signal that connection is ready + url = config["url"] + headers = config.get("headers") + connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT) + + async with streamablehttp_client( + url, + headers=headers, + timeout=float(connect_timeout), + ) as (read_stream, write_stream, _get_session_id): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + self.session = session + await self._discover_tools() + self._ready.set() + await self._shutdown_event.wait() + + async def _discover_tools(self): + """Discover tools from the connected session.""" + if self.session is None: + return + tools_result = await self.session.list_tools() + self._tools = ( + tools_result.tools + if hasattr(tools_result, "tools") + else [] + ) + + async def run(self, config: dict): + """Long-lived coroutine: connect, discover tools, wait, disconnect. + + Includes automatic reconnection with exponential backoff if the + connection drops unexpectedly (unless shutdown was requested). + """ + self._config = config + self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT) + retries = 0 + backoff = 1.0 + + while True: + try: + if self._is_http(): + await self._run_http(config) + else: + await self._run_stdio(config) + # Normal exit (shutdown requested) -- break out + break + except Exception as exc: + self.session = None + + # If this is the first connection attempt, report the error + if not self._ready.is_set(): + self._error = exc self._ready.set() + return - # Block until shutdown is requested -- this keeps the - # async-with contexts alive on THIS Task. - await self._shutdown_event.wait() - except Exception as exc: - self._error = exc - self._ready.set() - finally: - self.session = None + # If shutdown was requested, don't reconnect + if self._shutdown_event.is_set(): + logger.debug( + "MCP server '%s' disconnected during shutdown: %s", + self.name, exc, + ) + return + + retries += 1 + if retries > _MAX_RECONNECT_RETRIES: + logger.warning( + "MCP server '%s' failed after %d reconnection attempts, " + "giving up: %s", + self.name, _MAX_RECONNECT_RETRIES, exc, + ) + return + + logger.warning( + "MCP server '%s' connection lost (attempt %d/%d), " + "reconnecting in %.0fs: %s", + self.name, retries, _MAX_RECONNECT_RETRIES, + backoff, exc, + ) + await asyncio.sleep(backoff) + backoff = min(backoff * 2, _MAX_BACKOFF_SECONDS) + + # Check again after sleeping + if self._shutdown_event.is_set(): + return + finally: + self.session = None async def start(self, config: dict): """Create the background Task and wait until ready (or failed).""" @@ -203,7 +374,10 @@ def _run_on_mcp_loop(coro, timeout: float = 30): def _load_mcp_config() -> Dict[str, dict]: """Read ``mcp_servers`` from the Hermes config file. - Returns a dict of ``{server_name: {command, args, env}}`` or empty dict. + Returns a dict of ``{server_name: server_config}`` or empty dict. + Server config can contain either ``command``/``args``/``env`` for stdio + transport or ``url``/``headers`` for HTTP transport, plus optional + ``timeout`` and ``connect_timeout`` overrides. """ try: from hermes_cli.config import load_config @@ -224,11 +398,12 @@ def _load_mcp_config() -> Dict[str, dict]: async def _connect_server(name: str, config: dict) -> MCPServerTask: """Create an MCPServerTask, start it, and return when ready. - The server Task keeps the subprocess alive in the background. + The server Task keeps the connection alive in the background. Call ``server.shutdown()`` (on the same event loop) to tear it down. Raises: - ValueError: if ``command`` is missing from *config*. + ValueError: if required config keys are missing. + ImportError: if HTTP transport is needed but not available. Exception: on connection or initialization failure. """ server = MCPServerTask(name) @@ -240,7 +415,7 @@ async def _connect_server(name: str, config: dict) -> MCPServerTask: # Handler / check-fn factories # --------------------------------------------------------------------------- -def _make_tool_handler(server_name: str, tool_name: str): +def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float): """Return a sync handler that calls an MCP tool via the background loop. The handler conforms to the registry's dispatch interface: @@ -263,7 +438,11 @@ def _make_tool_handler(server_name: str, tool_name: str): for block in (result.content or []): if hasattr(block, "text"): error_text += block.text - return json.dumps({"error": error_text or "MCP tool returned an error"}) + return json.dumps({ + "error": _sanitize_error( + error_text or "MCP tool returned an error" + ) + }) # Collect text from content blocks parts: List[str] = [] @@ -273,10 +452,17 @@ def _make_tool_handler(server_name: str, tool_name: str): return json.dumps({"result": "\n".join(parts) if parts else ""}) try: - return _run_on_mcp_loop(_call(), timeout=120) + return _run_on_mcp_loop(_call(), timeout=tool_timeout) except Exception as exc: - logger.error("MCP tool %s/%s call failed: %s", server_name, tool_name, exc) - return json.dumps({"error": f"MCP call failed: {type(exc).__name__}: {exc}"}) + logger.error( + "MCP tool %s/%s call failed: %s", + server_name, tool_name, exc, + ) + return json.dumps({ + "error": _sanitize_error( + f"MCP call failed: {type(exc).__name__}: {exc}" + ) + }) return _handler @@ -339,7 +525,11 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]: from tools.registry import registry from toolsets import create_custom_toolset - server = await _connect_server(name, config) + connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT) + server = await asyncio.wait_for( + _connect_server(name, config), + timeout=connect_timeout, + ) with _lock: _servers[name] = server @@ -354,7 +544,7 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]: name=tool_name_prefixed, toolset=toolset_name, schema=schema, - handler=_make_tool_handler(name, mcp_tool.name), + handler=_make_tool_handler(name, mcp_tool.name, server.tool_timeout), check_fn=_make_check_fn(name), is_async=False, description=schema["description"], @@ -369,9 +559,11 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]: tools=registered_names, ) + transport_type = "HTTP" if "url" in config else "stdio" logger.info( - "MCP server '%s': registered %d tool(s): %s", - name, len(registered_names), ", ".join(registered_names), + "MCP server '%s' (%s): registered %d tool(s): %s", + name, transport_type, len(registered_names), + ", ".join(registered_names), ) return registered_names @@ -419,9 +611,12 @@ def discover_mcp_tools() -> List[str]: registered = await _discover_and_register_server(name, cfg) all_tools.extend(registered) except Exception as exc: - logger.warning("Failed to connect to MCP server '%s': %s", name, exc) + logger.warning( + "Failed to connect to MCP server '%s': %s", + name, exc, + ) - _run_on_mcp_loop(_discover_all(), timeout=60) + _run_on_mcp_loop(_discover_all(), timeout=_DEFAULT_DISCOVERY_TIMEOUT) if all_tools: # Dynamically inject into all hermes-* platform toolsets @@ -444,15 +639,10 @@ def shutdown_mcp_servers(): All servers are shut down in parallel via ``asyncio.gather``. """ with _lock: - if not _servers: - # No servers -- just stop the loop. _stop_mcp_loop() also - # acquires _lock, so we must release it first. - pass - else: - servers_snapshot = list(_servers.values()) + servers_snapshot = list(_servers.values()) # Fast path: nothing to shut down. - if not _servers: + if not servers_snapshot: _stop_mcp_loop() return