feat(mcp): add HTTP transport, reconnection, security hardening

Upgrades the MCP client implementation from PR #291 with:

- HTTP/Streamable HTTP transport: support 'url' key in config for remote
  MCP servers (Notion, Slack, Sentry, Supabase, etc.)
- Automatic reconnection with exponential backoff (1s-60s, 5 retries)
  when a server connection drops unexpectedly
- Environment variable filtering: only pass safe vars (PATH, HOME, etc.)
  plus user-specified env to stdio subprocesses (prevents secret leaks)
- Credential stripping: sanitize error messages before returning to the
  LLM (strips GitHub PATs, OpenAI keys, Bearer tokens, etc.)
- Configurable per-server timeouts: 'timeout' and 'connect_timeout' keys
- Fix shutdown race condition in servers_snapshot variable scoping

Test coverage: 50 tests (up from 30), including new tests for env
filtering, credential sanitization, HTTP config detection, reconnection
logic, and configurable timeouts.

All 1162 tests pass (1162 passed, 3 skipped, 0 failed).
This commit is contained in:
teknium1
2026-03-02 18:40:03 -08:00
parent 468b7fdbad
commit 64ff8f065b
2 changed files with 611 additions and 65 deletions

View File

@@ -5,6 +5,7 @@ All tests use mocks -- no real MCP servers or subprocesses are started.
import asyncio
import json
import os
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
@@ -189,7 +190,7 @@ class TestToolHandler:
_servers["test_srv"] = server
try:
handler = _make_tool_handler("test_srv", "greet")
handler = _make_tool_handler("test_srv", "greet", 120)
with self._patch_mcp_loop():
result = json.loads(handler({"name": "world"}))
assert result["result"] == "hello world"
@@ -208,7 +209,7 @@ class TestToolHandler:
_servers["test_srv"] = server
try:
handler = _make_tool_handler("test_srv", "fail_tool")
handler = _make_tool_handler("test_srv", "fail_tool", 120)
with self._patch_mcp_loop():
result = json.loads(handler({}))
assert "error" in result
@@ -220,7 +221,7 @@ class TestToolHandler:
from tools.mcp_tool import _make_tool_handler, _servers
_servers.pop("ghost", None)
handler = _make_tool_handler("ghost", "any_tool")
handler = _make_tool_handler("ghost", "any_tool", 120)
result = json.loads(handler({}))
assert "error" in result
assert "not connected" in result["error"]
@@ -234,7 +235,7 @@ class TestToolHandler:
_servers["test_srv"] = server
try:
handler = _make_tool_handler("test_srv", "broken_tool")
handler = _make_tool_handler("test_srv", "broken_tool", 120)
with self._patch_mcp_loop():
result = json.loads(handler({}))
assert "error" in result
@@ -400,8 +401,8 @@ class TestMCPServerTask:
asyncio.run(_test())
def test_empty_env_passed_as_none(self):
"""Empty env dict is passed as None to StdioServerParameters."""
def test_empty_env_gets_safe_defaults(self):
"""Empty env dict gets safe default env vars (PATH, HOME, etc.)."""
from tools.mcp_tool import MCPServerTask
mock_session = MagicMock()
@@ -414,13 +415,18 @@ class TestMCPServerTask:
async def _test():
with patch("tools.mcp_tool.StdioServerParameters") as mock_params, \
p_stdio, p_cs:
p_stdio, p_cs, \
patch.dict("os.environ", {"PATH": "/usr/bin", "HOME": "/home/test"}, clear=False):
server = MCPServerTask("srv")
await server.start({"command": "node", "env": {}})
# Empty dict -> None
# Empty dict -> safe env vars (not None)
call_kwargs = mock_params.call_args
assert call_kwargs.kwargs.get("env") is None
env_arg = call_kwargs.kwargs.get("env")
assert env_arg is not None
assert isinstance(env_arg, dict)
assert "PATH" in env_arg
assert "HOME" in env_arg
await server.shutdown()
@@ -698,3 +704,353 @@ class TestShutdown:
assert len(_servers) == 0
# Parallel: ~1s, not ~3s. Allow some margin.
assert elapsed < 2.5, f"Shutdown took {elapsed:.1f}s, expected ~1s (parallel)"
# ---------------------------------------------------------------------------
# _build_safe_env
# ---------------------------------------------------------------------------
class TestBuildSafeEnv:
"""Tests for _build_safe_env() environment filtering."""
def test_only_safe_vars_passed(self):
"""Only safe baseline vars and XDG_* from os.environ are included."""
from tools.mcp_tool import _build_safe_env
fake_env = {
"PATH": "/usr/bin",
"HOME": "/home/test",
"USER": "test",
"LANG": "en_US.UTF-8",
"LC_ALL": "C",
"TERM": "xterm",
"SHELL": "/bin/bash",
"TMPDIR": "/tmp",
"XDG_DATA_HOME": "/home/test/.local/share",
"SECRET_KEY": "should_not_appear",
"AWS_ACCESS_KEY_ID": "AKIAIOSFODNN7EXAMPLE",
}
with patch.dict("os.environ", fake_env, clear=True):
result = _build_safe_env(None)
# Safe vars present
assert result["PATH"] == "/usr/bin"
assert result["HOME"] == "/home/test"
assert result["USER"] == "test"
assert result["LANG"] == "en_US.UTF-8"
assert result["XDG_DATA_HOME"] == "/home/test/.local/share"
# Unsafe vars excluded
assert "SECRET_KEY" not in result
assert "AWS_ACCESS_KEY_ID" not in result
def test_user_env_merged(self):
"""User-specified env vars are merged into the safe env."""
from tools.mcp_tool import _build_safe_env
with patch.dict("os.environ", {"PATH": "/usr/bin"}, clear=True):
result = _build_safe_env({"MY_CUSTOM_VAR": "hello"})
assert result["PATH"] == "/usr/bin"
assert result["MY_CUSTOM_VAR"] == "hello"
def test_user_env_overrides_safe(self):
"""User env can override safe defaults."""
from tools.mcp_tool import _build_safe_env
with patch.dict("os.environ", {"PATH": "/usr/bin"}, clear=True):
result = _build_safe_env({"PATH": "/custom/bin"})
assert result["PATH"] == "/custom/bin"
def test_none_user_env(self):
"""None user_env still returns safe vars from os.environ."""
from tools.mcp_tool import _build_safe_env
with patch.dict("os.environ", {"PATH": "/usr/bin", "HOME": "/root"}, clear=True):
result = _build_safe_env(None)
assert isinstance(result, dict)
assert result["PATH"] == "/usr/bin"
assert result["HOME"] == "/root"
def test_secret_vars_excluded(self):
"""Sensitive env vars from os.environ are NOT passed through."""
from tools.mcp_tool import _build_safe_env
fake_env = {
"PATH": "/usr/bin",
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
"GITHUB_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"OPENAI_API_KEY": "sk-proj-abc123",
"DATABASE_URL": "postgres://user:pass@localhost/db",
"API_SECRET": "supersecret",
}
with patch.dict("os.environ", fake_env, clear=True):
result = _build_safe_env(None)
assert "PATH" in result
assert "AWS_SECRET_ACCESS_KEY" not in result
assert "GITHUB_TOKEN" not in result
assert "OPENAI_API_KEY" not in result
assert "DATABASE_URL" not in result
assert "API_SECRET" not in result
# ---------------------------------------------------------------------------
# _sanitize_error
# ---------------------------------------------------------------------------
class TestSanitizeError:
"""Tests for _sanitize_error() credential stripping."""
def test_strips_github_pat(self):
from tools.mcp_tool import _sanitize_error
result = _sanitize_error("Error with ghp_abc123def456")
assert result == "Error with [REDACTED]"
def test_strips_openai_key(self):
from tools.mcp_tool import _sanitize_error
result = _sanitize_error("key sk-projABC123xyz")
assert result == "key [REDACTED]"
def test_strips_bearer_token(self):
from tools.mcp_tool import _sanitize_error
result = _sanitize_error("Authorization: Bearer eyJabc123def")
assert result == "Authorization: [REDACTED]"
def test_strips_token_param(self):
from tools.mcp_tool import _sanitize_error
result = _sanitize_error("url?token=secret123")
assert result == "url?[REDACTED]"
def test_no_credentials_unchanged(self):
from tools.mcp_tool import _sanitize_error
result = _sanitize_error("normal error message")
assert result == "normal error message"
def test_multiple_credentials(self):
from tools.mcp_tool import _sanitize_error
result = _sanitize_error("ghp_abc123 and sk-projXyz789 and token=foo")
assert "ghp_" not in result
assert "sk-" not in result
assert "token=" not in result
assert result.count("[REDACTED]") == 3
# ---------------------------------------------------------------------------
# HTTP config
# ---------------------------------------------------------------------------
class TestHTTPConfig:
"""Tests for HTTP transport detection and handling."""
def test_is_http_with_url(self):
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("remote")
server._config = {"url": "https://example.com/mcp"}
assert server._is_http() is True
def test_is_stdio_with_command(self):
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("local")
server._config = {"command": "npx", "args": []}
assert server._is_http() is False
def test_http_unavailable_raises(self):
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("remote")
config = {"url": "https://example.com/mcp"}
async def _test():
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", False):
with pytest.raises(ImportError, match="HTTP transport"):
await server._run_http(config)
asyncio.run(_test())
# ---------------------------------------------------------------------------
# Reconnection logic
# ---------------------------------------------------------------------------
class TestReconnection:
"""Tests for automatic reconnection behavior in MCPServerTask.run()."""
def test_reconnect_on_disconnect(self):
"""After initial success, a connection drop triggers reconnection."""
from tools.mcp_tool import MCPServerTask
run_count = 0
target_server = None
original_run_stdio = MCPServerTask._run_stdio
async def patched_run_stdio(self_srv, config):
nonlocal run_count, target_server
run_count += 1
if target_server is not self_srv:
return await original_run_stdio(self_srv, config)
if run_count == 1:
# First connection succeeds, then simulate disconnect
self_srv.session = MagicMock()
self_srv._tools = []
self_srv._ready.set()
raise ConnectionError("connection dropped")
else:
# Reconnection succeeds; signal shutdown so run() exits
self_srv.session = MagicMock()
self_srv._shutdown_event.set()
await self_srv._shutdown_event.wait()
async def _test():
nonlocal target_server
server = MCPServerTask("test_srv")
target_server = server
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \
patch("asyncio.sleep", new_callable=AsyncMock):
await server.run({"command": "test"})
assert run_count >= 2 # At least one reconnection attempt
asyncio.run(_test())
def test_no_reconnect_on_shutdown(self):
"""If shutdown is requested, don't attempt reconnection."""
from tools.mcp_tool import MCPServerTask
run_count = 0
target_server = None
original_run_stdio = MCPServerTask._run_stdio
async def patched_run_stdio(self_srv, config):
nonlocal run_count, target_server
run_count += 1
if target_server is not self_srv:
return await original_run_stdio(self_srv, config)
self_srv.session = MagicMock()
self_srv._tools = []
self_srv._ready.set()
raise ConnectionError("connection dropped")
async def _test():
nonlocal target_server
server = MCPServerTask("test_srv")
target_server = server
server._shutdown_event.set() # Shutdown already requested
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \
patch("asyncio.sleep", new_callable=AsyncMock):
await server.run({"command": "test"})
# Should not retry because shutdown was set
assert run_count == 1
asyncio.run(_test())
def test_no_reconnect_on_initial_failure(self):
"""First connection failure reports error immediately, no retry."""
from tools.mcp_tool import MCPServerTask
run_count = 0
target_server = None
original_run_stdio = MCPServerTask._run_stdio
async def patched_run_stdio(self_srv, config):
nonlocal run_count, target_server
run_count += 1
if target_server is not self_srv:
return await original_run_stdio(self_srv, config)
raise ConnectionError("cannot connect")
async def _test():
nonlocal target_server
server = MCPServerTask("test_srv")
target_server = server
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \
patch("asyncio.sleep", new_callable=AsyncMock):
await server.run({"command": "test"})
# Only one attempt, no retry on initial failure
assert run_count == 1
assert server._error is not None
assert "cannot connect" in str(server._error)
asyncio.run(_test())
# ---------------------------------------------------------------------------
# Configurable timeouts
# ---------------------------------------------------------------------------
class TestConfigurableTimeouts:
"""Tests for configurable per-server timeouts."""
def test_default_timeout(self):
"""Server with no timeout config gets _DEFAULT_TOOL_TIMEOUT."""
from tools.mcp_tool import MCPServerTask, _DEFAULT_TOOL_TIMEOUT
server = MCPServerTask("test_srv")
assert server.tool_timeout == _DEFAULT_TOOL_TIMEOUT
assert server.tool_timeout == 120
def test_custom_timeout(self):
"""Server with timeout=180 in config gets 180."""
from tools.mcp_tool import MCPServerTask
target_server = None
original_run_stdio = MCPServerTask._run_stdio
async def patched_run_stdio(self_srv, config):
if target_server is not self_srv:
return await original_run_stdio(self_srv, config)
self_srv.session = MagicMock()
self_srv._tools = []
self_srv._ready.set()
await self_srv._shutdown_event.wait()
async def _test():
nonlocal target_server
server = MCPServerTask("test_srv")
target_server = server
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio):
task = asyncio.ensure_future(
server.run({"command": "test", "timeout": 180})
)
await server._ready.wait()
assert server.tool_timeout == 180
server._shutdown_event.set()
await task
asyncio.run(_test())
def test_timeout_passed_to_handler(self):
"""The tool handler uses the server's configured timeout."""
from tools.mcp_tool import _make_tool_handler, _servers, MCPServerTask
mock_session = MagicMock()
mock_session.call_tool = AsyncMock(
return_value=_make_call_result("ok", is_error=False)
)
server = _make_mock_server("test_srv", session=mock_session)
server.tool_timeout = 180
_servers["test_srv"] = server
try:
handler = _make_tool_handler("test_srv", "my_tool", 180)
with patch("tools.mcp_tool._run_on_mcp_loop") as mock_run:
mock_run.return_value = json.dumps({"result": "ok"})
handler({})
# Verify timeout=180 was passed
call_kwargs = mock_run.call_args
assert call_kwargs.kwargs.get("timeout") == 180 or \
(len(call_kwargs.args) > 1 and call_kwargs.args[1] == 180) or \
call_kwargs[1].get("timeout") == 180
finally:
_servers.pop("test_srv", None)

View File

@@ -2,9 +2,9 @@
"""
MCP (Model Context Protocol) Client Support
Connects to external MCP servers via stdio transport, discovers their tools,
and registers them into the hermes-agent tool registry so the agent can call
them like any built-in tool.
Connects to external MCP servers via stdio or HTTP/StreamableHTTP transport,
discovers their tools, and registers them into the hermes-agent tool registry
so the agent can call them like any built-in tool.
Configuration is read from ~/.hermes/config.yaml under the ``mcp_servers`` key.
The ``mcp`` Python package is optional -- if not installed, this module is a
@@ -17,17 +17,32 @@ Example config::
command: "npx"
args: ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
env: {}
timeout: 120 # per-tool-call timeout in seconds (default: 120)
connect_timeout: 60 # initial connection timeout (default: 60)
github:
command: "npx"
args: ["-y", "@modelcontextprotocol/server-github"]
env:
GITHUB_PERSONAL_ACCESS_TOKEN: "ghp_..."
remote_api:
url: "https://my-mcp-server.example.com/mcp"
headers:
Authorization: "Bearer sk-..."
timeout: 180
Features:
- Stdio transport (command + args) and HTTP/StreamableHTTP transport (url)
- Automatic reconnection with exponential backoff (up to 5 retries)
- Environment variable filtering for stdio subprocesses (security)
- Credential stripping in error messages returned to the LLM
- Configurable per-server timeouts for tool calls and connections
- Thread-safe architecture with dedicated background event loop
Architecture:
A dedicated background event loop (_mcp_loop) runs in a daemon thread.
Each MCP server runs as a long-lived asyncio Task on this loop, keeping
its ``async with stdio_client(...)`` context alive. Tool call coroutines
are scheduled onto the loop via ``run_coroutine_threadsafe()``.
its transport context alive. Tool call coroutines are scheduled onto the
loop via ``run_coroutine_threadsafe()``.
On shutdown, each server Task is signalled to exit its ``async with``
block, ensuring the anyio cancel-scope cleanup happens in the *same*
@@ -43,6 +58,8 @@ Thread safety:
import asyncio
import json
import logging
import os
import re
import threading
from typing import Any, Dict, List, Optional
@@ -53,13 +70,81 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
_MCP_AVAILABLE = False
_MCP_HTTP_AVAILABLE = False
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
_MCP_AVAILABLE = True
try:
from mcp.client.streamable_http import streamablehttp_client
_MCP_HTTP_AVAILABLE = True
except ImportError:
_MCP_HTTP_AVAILABLE = False
except ImportError:
logger.debug("mcp package not installed -- MCP tool support disabled")
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls
_DEFAULT_DISCOVERY_TIMEOUT = 60 # seconds for server discovery
_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection
_MAX_RECONNECT_RETRIES = 5
_MAX_BACKOFF_SECONDS = 60
# Environment variables that are safe to pass to stdio subprocesses
_SAFE_ENV_KEYS = frozenset({
"PATH", "HOME", "USER", "LANG", "LC_ALL", "TERM", "SHELL", "TMPDIR",
})
# Regex for credential patterns to strip from error messages
_CREDENTIAL_PATTERN = re.compile(
r"(?:"
r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT
r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key
r"|Bearer\s+\S+" # Bearer token
r"|token=[^\s&,;\"']{1,255}" # token=...
r"|key=[^\s&,;\"']{1,255}" # key=...
r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=...
r"|password=[^\s&,;\"']{1,255}" # password=...
r"|secret=[^\s&,;\"']{1,255}" # secret=...
r")",
re.IGNORECASE,
)
# ---------------------------------------------------------------------------
# Security helpers
# ---------------------------------------------------------------------------
def _build_safe_env(user_env: Optional[dict]) -> dict:
"""Build a filtered environment dict for stdio subprocesses.
Only passes through safe baseline variables (PATH, HOME, etc.) and XDG_*
variables from the current process environment, plus any variables
explicitly specified by the user in the server config.
This prevents accidentally leaking secrets like API keys, tokens, or
credentials to MCP server subprocesses.
"""
env = {}
for key, value in os.environ.items():
if key in _SAFE_ENV_KEYS or key.startswith("XDG_"):
env[key] = value
if user_env:
env.update(user_env)
return env
def _sanitize_error(text: str) -> str:
"""Strip credential-like patterns from error text before returning to LLM.
Replaces tokens, keys, and other secrets with [REDACTED] to prevent
accidental credential exposure in tool error responses.
"""
return _CREDENTIAL_PATTERN.sub("[REDACTED]", text)
# ---------------------------------------------------------------------------
# Server task -- each MCP server lives in one long-lived asyncio Task
@@ -70,66 +155,152 @@ class MCPServerTask:
The entire connection lifecycle (connect, discover, serve, disconnect)
runs inside one asyncio Task so that anyio cancel-scopes created by
``stdio_client`` are entered and exited in the same Task context.
the transport client are entered and exited in the same Task context.
Supports both stdio and HTTP/StreamableHTTP transports.
"""
__slots__ = (
"name", "session",
"_task", "_ready", "_shutdown_event", "_tools", "_error",
"name", "session", "tool_timeout",
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
)
def __init__(self, name: str):
self.name = name
self.session: Optional[Any] = None
self.tool_timeout: float = _DEFAULT_TOOL_TIMEOUT
self._task: Optional[asyncio.Task] = None
self._ready = asyncio.Event()
self._shutdown_event = asyncio.Event()
self._tools: list = []
self._error: Optional[Exception] = None
self._config: dict = {}
async def run(self, config: dict):
"""Long-lived coroutine: connect, discover tools, wait, disconnect."""
def _is_http(self) -> bool:
"""Check if this server uses HTTP transport."""
return "url" in self._config
async def _run_stdio(self, config: dict):
"""Run the server using stdio transport."""
command = config.get("command")
args = config.get("args", [])
env = config.get("env")
user_env = config.get("env")
if not command:
self._error = ValueError(
raise ValueError(
f"MCP server '{self.name}' has no 'command' in config"
)
self._ready.set()
return
safe_env = _build_safe_env(user_env)
server_params = StdioServerParameters(
command=command,
args=args,
env=env if env else None,
env=safe_env if safe_env else None,
)
try:
async with stdio_client(server_params) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
self.session = session
async with stdio_client(server_params) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
self.session = session
await self._discover_tools()
self._ready.set()
await self._shutdown_event.wait()
tools_result = await session.list_tools()
self._tools = (
tools_result.tools
if hasattr(tools_result, "tools")
else []
)
async def _run_http(self, config: dict):
"""Run the server using HTTP/StreamableHTTP transport."""
if not _MCP_HTTP_AVAILABLE:
raise ImportError(
f"MCP server '{self.name}' requires HTTP transport but "
"mcp.client.streamable_http is not available. "
"Upgrade the mcp package to get HTTP support."
)
# Signal that connection is ready
url = config["url"]
headers = config.get("headers")
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
async with streamablehttp_client(
url,
headers=headers,
timeout=float(connect_timeout),
) as (read_stream, write_stream, _get_session_id):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
self.session = session
await self._discover_tools()
self._ready.set()
await self._shutdown_event.wait()
async def _discover_tools(self):
"""Discover tools from the connected session."""
if self.session is None:
return
tools_result = await self.session.list_tools()
self._tools = (
tools_result.tools
if hasattr(tools_result, "tools")
else []
)
async def run(self, config: dict):
"""Long-lived coroutine: connect, discover tools, wait, disconnect.
Includes automatic reconnection with exponential backoff if the
connection drops unexpectedly (unless shutdown was requested).
"""
self._config = config
self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT)
retries = 0
backoff = 1.0
while True:
try:
if self._is_http():
await self._run_http(config)
else:
await self._run_stdio(config)
# Normal exit (shutdown requested) -- break out
break
except Exception as exc:
self.session = None
# If this is the first connection attempt, report the error
if not self._ready.is_set():
self._error = exc
self._ready.set()
return
# Block until shutdown is requested -- this keeps the
# async-with contexts alive on THIS Task.
await self._shutdown_event.wait()
except Exception as exc:
self._error = exc
self._ready.set()
finally:
self.session = None
# If shutdown was requested, don't reconnect
if self._shutdown_event.is_set():
logger.debug(
"MCP server '%s' disconnected during shutdown: %s",
self.name, exc,
)
return
retries += 1
if retries > _MAX_RECONNECT_RETRIES:
logger.warning(
"MCP server '%s' failed after %d reconnection attempts, "
"giving up: %s",
self.name, _MAX_RECONNECT_RETRIES, exc,
)
return
logger.warning(
"MCP server '%s' connection lost (attempt %d/%d), "
"reconnecting in %.0fs: %s",
self.name, retries, _MAX_RECONNECT_RETRIES,
backoff, exc,
)
await asyncio.sleep(backoff)
backoff = min(backoff * 2, _MAX_BACKOFF_SECONDS)
# Check again after sleeping
if self._shutdown_event.is_set():
return
finally:
self.session = None
async def start(self, config: dict):
"""Create the background Task and wait until ready (or failed)."""
@@ -203,7 +374,10 @@ def _run_on_mcp_loop(coro, timeout: float = 30):
def _load_mcp_config() -> Dict[str, dict]:
"""Read ``mcp_servers`` from the Hermes config file.
Returns a dict of ``{server_name: {command, args, env}}`` or empty dict.
Returns a dict of ``{server_name: server_config}`` or empty dict.
Server config can contain either ``command``/``args``/``env`` for stdio
transport or ``url``/``headers`` for HTTP transport, plus optional
``timeout`` and ``connect_timeout`` overrides.
"""
try:
from hermes_cli.config import load_config
@@ -224,11 +398,12 @@ def _load_mcp_config() -> Dict[str, dict]:
async def _connect_server(name: str, config: dict) -> MCPServerTask:
"""Create an MCPServerTask, start it, and return when ready.
The server Task keeps the subprocess alive in the background.
The server Task keeps the connection alive in the background.
Call ``server.shutdown()`` (on the same event loop) to tear it down.
Raises:
ValueError: if ``command`` is missing from *config*.
ValueError: if required config keys are missing.
ImportError: if HTTP transport is needed but not available.
Exception: on connection or initialization failure.
"""
server = MCPServerTask(name)
@@ -240,7 +415,7 @@ async def _connect_server(name: str, config: dict) -> MCPServerTask:
# Handler / check-fn factories
# ---------------------------------------------------------------------------
def _make_tool_handler(server_name: str, tool_name: str):
def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
"""Return a sync handler that calls an MCP tool via the background loop.
The handler conforms to the registry's dispatch interface:
@@ -263,7 +438,11 @@ def _make_tool_handler(server_name: str, tool_name: str):
for block in (result.content or []):
if hasattr(block, "text"):
error_text += block.text
return json.dumps({"error": error_text or "MCP tool returned an error"})
return json.dumps({
"error": _sanitize_error(
error_text or "MCP tool returned an error"
)
})
# Collect text from content blocks
parts: List[str] = []
@@ -273,10 +452,17 @@ def _make_tool_handler(server_name: str, tool_name: str):
return json.dumps({"result": "\n".join(parts) if parts else ""})
try:
return _run_on_mcp_loop(_call(), timeout=120)
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except Exception as exc:
logger.error("MCP tool %s/%s call failed: %s", server_name, tool_name, exc)
return json.dumps({"error": f"MCP call failed: {type(exc).__name__}: {exc}"})
logger.error(
"MCP tool %s/%s call failed: %s",
server_name, tool_name, exc,
)
return json.dumps({
"error": _sanitize_error(
f"MCP call failed: {type(exc).__name__}: {exc}"
)
})
return _handler
@@ -339,7 +525,11 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
from tools.registry import registry
from toolsets import create_custom_toolset
server = await _connect_server(name, config)
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
server = await asyncio.wait_for(
_connect_server(name, config),
timeout=connect_timeout,
)
with _lock:
_servers[name] = server
@@ -354,7 +544,7 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
name=tool_name_prefixed,
toolset=toolset_name,
schema=schema,
handler=_make_tool_handler(name, mcp_tool.name),
handler=_make_tool_handler(name, mcp_tool.name, server.tool_timeout),
check_fn=_make_check_fn(name),
is_async=False,
description=schema["description"],
@@ -369,9 +559,11 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
tools=registered_names,
)
transport_type = "HTTP" if "url" in config else "stdio"
logger.info(
"MCP server '%s': registered %d tool(s): %s",
name, len(registered_names), ", ".join(registered_names),
"MCP server '%s' (%s): registered %d tool(s): %s",
name, transport_type, len(registered_names),
", ".join(registered_names),
)
return registered_names
@@ -419,9 +611,12 @@ def discover_mcp_tools() -> List[str]:
registered = await _discover_and_register_server(name, cfg)
all_tools.extend(registered)
except Exception as exc:
logger.warning("Failed to connect to MCP server '%s': %s", name, exc)
logger.warning(
"Failed to connect to MCP server '%s': %s",
name, exc,
)
_run_on_mcp_loop(_discover_all(), timeout=60)
_run_on_mcp_loop(_discover_all(), timeout=_DEFAULT_DISCOVERY_TIMEOUT)
if all_tools:
# Dynamically inject into all hermes-* platform toolsets
@@ -444,15 +639,10 @@ def shutdown_mcp_servers():
All servers are shut down in parallel via ``asyncio.gather``.
"""
with _lock:
if not _servers:
# No servers -- just stop the loop. _stop_mcp_loop() also
# acquires _lock, so we must release it first.
pass
else:
servers_snapshot = list(_servers.values())
servers_snapshot = list(_servers.values())
# Fast path: nothing to shut down.
if not _servers:
if not servers_snapshot:
_stop_mcp_loop()
return