Files
hermes-agent/tools/mcp_tool.py
teknium1 64ff8f065b feat(mcp): add HTTP transport, reconnection, security hardening
Upgrades the MCP client implementation from PR #291 with:

- HTTP/Streamable HTTP transport: support 'url' key in config for remote
  MCP servers (Notion, Slack, Sentry, Supabase, etc.)
- Automatic reconnection with exponential backoff (1s-60s, 5 retries)
  when a server connection drops unexpectedly
- Environment variable filtering: only pass safe vars (PATH, HOME, etc.)
  plus user-specified env to stdio subprocesses (prevents secret leaks)
- Credential stripping: sanitize error messages before returning to the
  LLM (strips GitHub PATs, OpenAI keys, Bearer tokens, etc.)
- Configurable per-server timeouts: 'timeout' and 'connect_timeout' keys
- Fix shutdown race condition in servers_snapshot variable scoping

Test coverage: 50 tests (up from 30), including new tests for env
filtering, credential sanitization, HTTP config detection, reconnection
logic, and configurable timeouts.

All 1162 tests pass (1162 passed, 3 skipped, 0 failed).
2026-03-02 18:40:03 -08:00

687 lines
24 KiB
Python

#!/usr/bin/env python3
"""
MCP (Model Context Protocol) Client Support
Connects to external MCP servers via stdio or HTTP/StreamableHTTP transport,
discovers their tools, and registers them into the hermes-agent tool registry
so the agent can call them like any built-in tool.
Configuration is read from ~/.hermes/config.yaml under the ``mcp_servers`` key.
The ``mcp`` Python package is optional -- if not installed, this module is a
no-op and logs a debug message.
Example config::
mcp_servers:
filesystem:
command: "npx"
args: ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
env: {}
timeout: 120 # per-tool-call timeout in seconds (default: 120)
connect_timeout: 60 # initial connection timeout (default: 60)
github:
command: "npx"
args: ["-y", "@modelcontextprotocol/server-github"]
env:
GITHUB_PERSONAL_ACCESS_TOKEN: "ghp_..."
remote_api:
url: "https://my-mcp-server.example.com/mcp"
headers:
Authorization: "Bearer sk-..."
timeout: 180
Features:
- Stdio transport (command + args) and HTTP/StreamableHTTP transport (url)
- Automatic reconnection with exponential backoff (up to 5 retries)
- Environment variable filtering for stdio subprocesses (security)
- Credential stripping in error messages returned to the LLM
- Configurable per-server timeouts for tool calls and connections
- Thread-safe architecture with dedicated background event loop
Architecture:
A dedicated background event loop (_mcp_loop) runs in a daemon thread.
Each MCP server runs as a long-lived asyncio Task on this loop, keeping
its transport context alive. Tool call coroutines are scheduled onto the
loop via ``run_coroutine_threadsafe()``.
On shutdown, each server Task is signalled to exit its ``async with``
block, ensuring the anyio cancel-scope cleanup happens in the *same*
Task that opened the connection (required by anyio).
Thread safety:
_servers and _mcp_loop/_mcp_thread are accessed from both the MCP
background thread and caller threads. All mutations are protected by
_lock so the code is safe regardless of GIL presence (e.g. Python 3.13+
free-threading).
"""
import asyncio
import json
import logging
import os
import re
import threading
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Graceful import -- MCP SDK is an optional dependency
# ---------------------------------------------------------------------------
_MCP_AVAILABLE = False
_MCP_HTTP_AVAILABLE = False
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
_MCP_AVAILABLE = True
try:
from mcp.client.streamable_http import streamablehttp_client
_MCP_HTTP_AVAILABLE = True
except ImportError:
_MCP_HTTP_AVAILABLE = False
except ImportError:
logger.debug("mcp package not installed -- MCP tool support disabled")
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls
_DEFAULT_DISCOVERY_TIMEOUT = 60 # seconds for server discovery
_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection
_MAX_RECONNECT_RETRIES = 5
_MAX_BACKOFF_SECONDS = 60
# Environment variables that are safe to pass to stdio subprocesses
_SAFE_ENV_KEYS = frozenset({
"PATH", "HOME", "USER", "LANG", "LC_ALL", "TERM", "SHELL", "TMPDIR",
})
# Regex for credential patterns to strip from error messages
_CREDENTIAL_PATTERN = re.compile(
r"(?:"
r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT
r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key
r"|Bearer\s+\S+" # Bearer token
r"|token=[^\s&,;\"']{1,255}" # token=...
r"|key=[^\s&,;\"']{1,255}" # key=...
r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=...
r"|password=[^\s&,;\"']{1,255}" # password=...
r"|secret=[^\s&,;\"']{1,255}" # secret=...
r")",
re.IGNORECASE,
)
# ---------------------------------------------------------------------------
# Security helpers
# ---------------------------------------------------------------------------
def _build_safe_env(user_env: Optional[dict]) -> dict:
"""Build a filtered environment dict for stdio subprocesses.
Only passes through safe baseline variables (PATH, HOME, etc.) and XDG_*
variables from the current process environment, plus any variables
explicitly specified by the user in the server config.
This prevents accidentally leaking secrets like API keys, tokens, or
credentials to MCP server subprocesses.
"""
env = {}
for key, value in os.environ.items():
if key in _SAFE_ENV_KEYS or key.startswith("XDG_"):
env[key] = value
if user_env:
env.update(user_env)
return env
def _sanitize_error(text: str) -> str:
"""Strip credential-like patterns from error text before returning to LLM.
Replaces tokens, keys, and other secrets with [REDACTED] to prevent
accidental credential exposure in tool error responses.
"""
return _CREDENTIAL_PATTERN.sub("[REDACTED]", text)
# ---------------------------------------------------------------------------
# Server task -- each MCP server lives in one long-lived asyncio Task
# ---------------------------------------------------------------------------
class MCPServerTask:
"""Manages a single MCP server connection in a dedicated asyncio Task.
The entire connection lifecycle (connect, discover, serve, disconnect)
runs inside one asyncio Task so that anyio cancel-scopes created by
the transport client are entered and exited in the same Task context.
Supports both stdio and HTTP/StreamableHTTP transports.
"""
__slots__ = (
"name", "session", "tool_timeout",
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
)
def __init__(self, name: str):
self.name = name
self.session: Optional[Any] = None
self.tool_timeout: float = _DEFAULT_TOOL_TIMEOUT
self._task: Optional[asyncio.Task] = None
self._ready = asyncio.Event()
self._shutdown_event = asyncio.Event()
self._tools: list = []
self._error: Optional[Exception] = None
self._config: dict = {}
def _is_http(self) -> bool:
"""Check if this server uses HTTP transport."""
return "url" in self._config
async def _run_stdio(self, config: dict):
"""Run the server using stdio transport."""
command = config.get("command")
args = config.get("args", [])
user_env = config.get("env")
if not command:
raise ValueError(
f"MCP server '{self.name}' has no 'command' in config"
)
safe_env = _build_safe_env(user_env)
server_params = StdioServerParameters(
command=command,
args=args,
env=safe_env if safe_env else None,
)
async with stdio_client(server_params) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
self.session = session
await self._discover_tools()
self._ready.set()
await self._shutdown_event.wait()
async def _run_http(self, config: dict):
"""Run the server using HTTP/StreamableHTTP transport."""
if not _MCP_HTTP_AVAILABLE:
raise ImportError(
f"MCP server '{self.name}' requires HTTP transport but "
"mcp.client.streamable_http is not available. "
"Upgrade the mcp package to get HTTP support."
)
url = config["url"]
headers = config.get("headers")
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
async with streamablehttp_client(
url,
headers=headers,
timeout=float(connect_timeout),
) as (read_stream, write_stream, _get_session_id):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
self.session = session
await self._discover_tools()
self._ready.set()
await self._shutdown_event.wait()
async def _discover_tools(self):
"""Discover tools from the connected session."""
if self.session is None:
return
tools_result = await self.session.list_tools()
self._tools = (
tools_result.tools
if hasattr(tools_result, "tools")
else []
)
async def run(self, config: dict):
"""Long-lived coroutine: connect, discover tools, wait, disconnect.
Includes automatic reconnection with exponential backoff if the
connection drops unexpectedly (unless shutdown was requested).
"""
self._config = config
self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT)
retries = 0
backoff = 1.0
while True:
try:
if self._is_http():
await self._run_http(config)
else:
await self._run_stdio(config)
# Normal exit (shutdown requested) -- break out
break
except Exception as exc:
self.session = None
# If this is the first connection attempt, report the error
if not self._ready.is_set():
self._error = exc
self._ready.set()
return
# If shutdown was requested, don't reconnect
if self._shutdown_event.is_set():
logger.debug(
"MCP server '%s' disconnected during shutdown: %s",
self.name, exc,
)
return
retries += 1
if retries > _MAX_RECONNECT_RETRIES:
logger.warning(
"MCP server '%s' failed after %d reconnection attempts, "
"giving up: %s",
self.name, _MAX_RECONNECT_RETRIES, exc,
)
return
logger.warning(
"MCP server '%s' connection lost (attempt %d/%d), "
"reconnecting in %.0fs: %s",
self.name, retries, _MAX_RECONNECT_RETRIES,
backoff, exc,
)
await asyncio.sleep(backoff)
backoff = min(backoff * 2, _MAX_BACKOFF_SECONDS)
# Check again after sleeping
if self._shutdown_event.is_set():
return
finally:
self.session = None
async def start(self, config: dict):
"""Create the background Task and wait until ready (or failed)."""
self._task = asyncio.ensure_future(self.run(config))
await self._ready.wait()
if self._error:
raise self._error
async def shutdown(self):
"""Signal the Task to exit and wait for clean resource teardown."""
self._shutdown_event.set()
if self._task and not self._task.done():
try:
await asyncio.wait_for(self._task, timeout=10)
except asyncio.TimeoutError:
logger.warning(
"MCP server '%s' shutdown timed out, cancelling task",
self.name,
)
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
self.session = None
# ---------------------------------------------------------------------------
# Module-level state
# ---------------------------------------------------------------------------
_servers: Dict[str, MCPServerTask] = {}
# Dedicated event loop running in a background daemon thread.
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
_mcp_thread: Optional[threading.Thread] = None
# Protects _mcp_loop, _mcp_thread, and _servers from concurrent access.
_lock = threading.Lock()
def _ensure_mcp_loop():
"""Start the background event loop thread if not already running."""
global _mcp_loop, _mcp_thread
with _lock:
if _mcp_loop is not None and _mcp_loop.is_running():
return
_mcp_loop = asyncio.new_event_loop()
_mcp_thread = threading.Thread(
target=_mcp_loop.run_forever,
name="mcp-event-loop",
daemon=True,
)
_mcp_thread.start()
def _run_on_mcp_loop(coro, timeout: float = 30):
"""Schedule a coroutine on the MCP event loop and block until done."""
with _lock:
loop = _mcp_loop
if loop is None or not loop.is_running():
raise RuntimeError("MCP event loop is not running")
future = asyncio.run_coroutine_threadsafe(coro, loop)
return future.result(timeout=timeout)
# ---------------------------------------------------------------------------
# Config loading
# ---------------------------------------------------------------------------
def _load_mcp_config() -> Dict[str, dict]:
"""Read ``mcp_servers`` from the Hermes config file.
Returns a dict of ``{server_name: server_config}`` or empty dict.
Server config can contain either ``command``/``args``/``env`` for stdio
transport or ``url``/``headers`` for HTTP transport, plus optional
``timeout`` and ``connect_timeout`` overrides.
"""
try:
from hermes_cli.config import load_config
config = load_config()
servers = config.get("mcp_servers")
if not servers or not isinstance(servers, dict):
return {}
return servers
except Exception as exc:
logger.debug("Failed to load MCP config: %s", exc)
return {}
# ---------------------------------------------------------------------------
# Server connection helper
# ---------------------------------------------------------------------------
async def _connect_server(name: str, config: dict) -> MCPServerTask:
"""Create an MCPServerTask, start it, and return when ready.
The server Task keeps the connection alive in the background.
Call ``server.shutdown()`` (on the same event loop) to tear it down.
Raises:
ValueError: if required config keys are missing.
ImportError: if HTTP transport is needed but not available.
Exception: on connection or initialization failure.
"""
server = MCPServerTask(name)
await server.start(config)
return server
# ---------------------------------------------------------------------------
# Handler / check-fn factories
# ---------------------------------------------------------------------------
def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
"""Return a sync handler that calls an MCP tool via the background loop.
The handler conforms to the registry's dispatch interface:
``handler(args_dict, **kwargs) -> str``
"""
def _handler(args: dict, **kwargs) -> str:
with _lock:
server = _servers.get(server_name)
if not server or not server.session:
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
})
async def _call():
result = await server.session.call_tool(tool_name, arguments=args)
# MCP CallToolResult has .content (list of content blocks) and .isError
if result.isError:
error_text = ""
for block in (result.content or []):
if hasattr(block, "text"):
error_text += block.text
return json.dumps({
"error": _sanitize_error(
error_text or "MCP tool returned an error"
)
})
# Collect text from content blocks
parts: List[str] = []
for block in (result.content or []):
if hasattr(block, "text"):
parts.append(block.text)
return json.dumps({"result": "\n".join(parts) if parts else ""})
try:
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except Exception as exc:
logger.error(
"MCP tool %s/%s call failed: %s",
server_name, tool_name, exc,
)
return json.dumps({
"error": _sanitize_error(
f"MCP call failed: {type(exc).__name__}: {exc}"
)
})
return _handler
def _make_check_fn(server_name: str):
"""Return a check function that verifies the MCP connection is alive."""
def _check() -> bool:
with _lock:
server = _servers.get(server_name)
return server is not None and server.session is not None
return _check
# ---------------------------------------------------------------------------
# Discovery & registration
# ---------------------------------------------------------------------------
def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
"""Convert an MCP tool listing to the Hermes registry schema format.
Args:
server_name: The logical server name for prefixing.
mcp_tool: An MCP ``Tool`` object with ``.name``, ``.description``,
and ``.inputSchema``.
Returns:
A dict suitable for ``registry.register(schema=...)``.
"""
# Sanitize: replace hyphens and dots with underscores for LLM API compatibility
safe_tool_name = mcp_tool.name.replace("-", "_").replace(".", "_")
safe_server_name = server_name.replace("-", "_").replace(".", "_")
prefixed_name = f"mcp_{safe_server_name}_{safe_tool_name}"
return {
"name": prefixed_name,
"description": mcp_tool.description or f"MCP tool {mcp_tool.name} from {server_name}",
"parameters": mcp_tool.inputSchema if mcp_tool.inputSchema else {
"type": "object",
"properties": {},
},
}
def _existing_tool_names() -> List[str]:
"""Return tool names for all currently connected servers."""
names: List[str] = []
for sname, server in _servers.items():
for mcp_tool in server._tools:
schema = _convert_mcp_schema(sname, mcp_tool)
names.append(schema["name"])
return names
async def _discover_and_register_server(name: str, config: dict) -> List[str]:
"""Connect to a single MCP server, discover tools, and register them.
Returns list of registered tool names.
"""
from tools.registry import registry
from toolsets import create_custom_toolset
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
server = await asyncio.wait_for(
_connect_server(name, config),
timeout=connect_timeout,
)
with _lock:
_servers[name] = server
registered_names: List[str] = []
toolset_name = f"mcp-{name}"
for mcp_tool in server._tools:
schema = _convert_mcp_schema(name, mcp_tool)
tool_name_prefixed = schema["name"]
registry.register(
name=tool_name_prefixed,
toolset=toolset_name,
schema=schema,
handler=_make_tool_handler(name, mcp_tool.name, server.tool_timeout),
check_fn=_make_check_fn(name),
is_async=False,
description=schema["description"],
)
registered_names.append(tool_name_prefixed)
# Create a custom toolset so these tools are discoverable
if registered_names:
create_custom_toolset(
name=toolset_name,
description=f"MCP tools from {name} server",
tools=registered_names,
)
transport_type = "HTTP" if "url" in config else "stdio"
logger.info(
"MCP server '%s' (%s): registered %d tool(s): %s",
name, transport_type, len(registered_names),
", ".join(registered_names),
)
return registered_names
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def discover_mcp_tools() -> List[str]:
"""Entry point: load config, connect to MCP servers, register tools.
Called from ``model_tools._discover_tools()``. Safe to call even when
the ``mcp`` package is not installed (returns empty list).
Idempotent for already-connected servers. If some servers failed on a
previous call, only the missing ones are retried.
Returns:
List of all registered MCP tool names.
"""
if not _MCP_AVAILABLE:
logger.debug("MCP SDK not available -- skipping MCP tool discovery")
return []
servers = _load_mcp_config()
if not servers:
logger.debug("No MCP servers configured")
return []
# Only attempt servers that aren't already connected
with _lock:
new_servers = {k: v for k, v in servers.items() if k not in _servers}
if not new_servers:
return _existing_tool_names()
# Start the background event loop for MCP connections
_ensure_mcp_loop()
all_tools: List[str] = []
async def _discover_all():
for name, cfg in new_servers.items():
try:
registered = await _discover_and_register_server(name, cfg)
all_tools.extend(registered)
except Exception as exc:
logger.warning(
"Failed to connect to MCP server '%s': %s",
name, exc,
)
_run_on_mcp_loop(_discover_all(), timeout=_DEFAULT_DISCOVERY_TIMEOUT)
if all_tools:
# Dynamically inject into all hermes-* platform toolsets
from toolsets import TOOLSETS
for ts_name, ts in TOOLSETS.items():
if ts_name.startswith("hermes-"):
for tool_name in all_tools:
if tool_name not in ts["tools"]:
ts["tools"].append(tool_name)
# Return ALL registered tools (existing + newly discovered)
return _existing_tool_names()
def shutdown_mcp_servers():
"""Close all MCP server connections and stop the background loop.
Each server Task is signalled to exit its ``async with`` block so that
the anyio cancel-scope cleanup happens in the same Task that opened it.
All servers are shut down in parallel via ``asyncio.gather``.
"""
with _lock:
servers_snapshot = list(_servers.values())
# Fast path: nothing to shut down.
if not servers_snapshot:
_stop_mcp_loop()
return
async def _shutdown():
results = await asyncio.gather(
*(server.shutdown() for server in servers_snapshot),
return_exceptions=True,
)
for server, result in zip(servers_snapshot, results):
if isinstance(result, Exception):
logger.debug(
"Error closing MCP server '%s': %s", server.name, result,
)
with _lock:
_servers.clear()
with _lock:
loop = _mcp_loop
if loop is not None and loop.is_running():
try:
future = asyncio.run_coroutine_threadsafe(_shutdown(), loop)
future.result(timeout=15)
except Exception as exc:
logger.debug("Error during MCP shutdown: %s", exc)
_stop_mcp_loop()
def _stop_mcp_loop():
"""Stop the background event loop and join its thread."""
global _mcp_loop, _mcp_thread
with _lock:
loop = _mcp_loop
thread = _mcp_thread
_mcp_loop = None
_mcp_thread = None
if loop is not None:
loop.call_soon_threadsafe(loop.stop)
if thread is not None:
thread.join(timeout=5)
loop.close()