Files
hermes-agent/tools/mcp_tool.py
0xbyt4 aa2ecaef29 fix: resolve orphan subprocess leak on MCP server shutdown
Refactor MCP connections from AsyncExitStack to task-per-server
architecture. Each server now runs as a long-lived asyncio Task
with `async with stdio_client(...)`, ensuring anyio cancel-scope
cleanup happens in the same Task that opened the connection.
2026-03-02 21:22:00 +03:00

445 lines
15 KiB
Python

#!/usr/bin/env python3
"""
MCP (Model Context Protocol) Client Support
Connects to external MCP servers via stdio transport, discovers their tools,
and registers them into the hermes-agent tool registry so the agent can call
them like any built-in tool.
Configuration is read from ~/.hermes/config.yaml under the ``mcp_servers`` key.
The ``mcp`` Python package is optional -- if not installed, this module is a
no-op and logs a debug message.
Example config::
mcp_servers:
filesystem:
command: "npx"
args: ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
env: {}
github:
command: "npx"
args: ["-y", "@modelcontextprotocol/server-github"]
env:
GITHUB_PERSONAL_ACCESS_TOKEN: "ghp_..."
Architecture:
A dedicated background event loop (_mcp_loop) runs in a daemon thread.
Each MCP server runs as a long-lived asyncio Task on this loop, keeping
its ``async with stdio_client(...)`` context alive. Tool call coroutines
are scheduled onto the loop via ``run_coroutine_threadsafe()``.
On shutdown, each server Task is signalled to exit its ``async with``
block, ensuring the anyio cancel-scope cleanup happens in the *same*
Task that opened the connection (required by anyio).
"""
import asyncio
import json
import logging
import threading
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Graceful import -- MCP SDK is an optional dependency
# ---------------------------------------------------------------------------
_MCP_AVAILABLE = False
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
_MCP_AVAILABLE = True
except ImportError:
logger.debug("mcp package not installed -- MCP tool support disabled")
# ---------------------------------------------------------------------------
# Server task -- each MCP server lives in one long-lived asyncio Task
# ---------------------------------------------------------------------------
class MCPServerTask:
"""Manages a single MCP server connection in a dedicated asyncio Task.
The entire connection lifecycle (connect, discover, serve, disconnect)
runs inside one asyncio Task so that anyio cancel-scopes created by
``stdio_client`` are entered and exited in the same Task context.
"""
__slots__ = (
"name", "session",
"_task", "_ready", "_shutdown_event", "_tools", "_error",
)
def __init__(self, name: str):
self.name = name
self.session: Optional[Any] = None
self._task: Optional[asyncio.Task] = None
self._ready = asyncio.Event()
self._shutdown_event = asyncio.Event()
self._tools: list = []
self._error: Optional[Exception] = None
async def run(self, config: dict):
"""Long-lived coroutine: connect, discover tools, wait, disconnect."""
command = config.get("command")
args = config.get("args", [])
env = config.get("env")
if not command:
self._error = ValueError(
f"MCP server '{self.name}' has no 'command' in config"
)
self._ready.set()
return
server_params = StdioServerParameters(
command=command,
args=args,
env=env if env else None,
)
try:
async with stdio_client(server_params) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
self.session = session
tools_result = await session.list_tools()
self._tools = (
tools_result.tools
if hasattr(tools_result, "tools")
else []
)
# Signal that connection is ready
self._ready.set()
# Block until shutdown is requested -- this keeps the
# async-with contexts alive on THIS Task.
await self._shutdown_event.wait()
except Exception as exc:
self._error = exc
self._ready.set()
finally:
self.session = None
async def start(self, config: dict):
"""Create the background Task and wait until ready (or failed)."""
self._task = asyncio.ensure_future(self.run(config))
await self._ready.wait()
if self._error:
raise self._error
async def shutdown(self):
"""Signal the Task to exit and wait for clean resource teardown."""
self._shutdown_event.set()
if self._task and not self._task.done():
try:
await asyncio.wait_for(self._task, timeout=10)
except asyncio.TimeoutError:
logger.warning(
"MCP server '%s' shutdown timed out, cancelling task",
self.name,
)
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
self.session = None
# ---------------------------------------------------------------------------
# Module-level state
# ---------------------------------------------------------------------------
_servers: Dict[str, MCPServerTask] = {}
# Dedicated event loop running in a background daemon thread.
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
_mcp_thread: Optional[threading.Thread] = None
def _ensure_mcp_loop():
"""Start the background event loop thread if not already running."""
global _mcp_loop, _mcp_thread
if _mcp_loop is not None and _mcp_loop.is_running():
return
_mcp_loop = asyncio.new_event_loop()
_mcp_thread = threading.Thread(
target=_mcp_loop.run_forever,
name="mcp-event-loop",
daemon=True,
)
_mcp_thread.start()
def _run_on_mcp_loop(coro, timeout: float = 30):
"""Schedule a coroutine on the MCP event loop and block until done."""
if _mcp_loop is None or not _mcp_loop.is_running():
raise RuntimeError("MCP event loop is not running")
future = asyncio.run_coroutine_threadsafe(coro, _mcp_loop)
return future.result(timeout=timeout)
# ---------------------------------------------------------------------------
# Config loading
# ---------------------------------------------------------------------------
def _load_mcp_config() -> Dict[str, dict]:
"""Read ``mcp_servers`` from the Hermes config file.
Returns a dict of ``{server_name: {command, args, env}}`` or empty dict.
"""
try:
from hermes_cli.config import load_config
config = load_config()
servers = config.get("mcp_servers")
if not servers or not isinstance(servers, dict):
return {}
return servers
except Exception as exc:
logger.debug("Failed to load MCP config: %s", exc)
return {}
# ---------------------------------------------------------------------------
# Server connection helper
# ---------------------------------------------------------------------------
async def _connect_server(name: str, config: dict) -> MCPServerTask:
"""Create an MCPServerTask, start it, and return when ready.
The server Task keeps the subprocess alive in the background.
Call ``server.shutdown()`` (on the same event loop) to tear it down.
Raises:
ValueError: if ``command`` is missing from *config*.
Exception: on connection or initialization failure.
"""
server = MCPServerTask(name)
await server.start(config)
return server
# ---------------------------------------------------------------------------
# Handler / check-fn factories
# ---------------------------------------------------------------------------
def _make_tool_handler(server_name: str, tool_name: str):
"""Return a sync handler that calls an MCP tool via the background loop.
The handler conforms to the registry's dispatch interface:
``handler(args_dict, **kwargs) -> str``
"""
def _handler(args: dict, **kwargs) -> str:
server = _servers.get(server_name)
if not server or not server.session:
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
})
async def _call():
result = await server.session.call_tool(tool_name, arguments=args)
# MCP CallToolResult has .content (list of content blocks) and .isError
if result.isError:
error_text = ""
for block in (result.content or []):
if hasattr(block, "text"):
error_text += block.text
return json.dumps({"error": error_text or "MCP tool returned an error"})
# Collect text from content blocks
parts: List[str] = []
for block in (result.content or []):
if hasattr(block, "text"):
parts.append(block.text)
return json.dumps({"result": "\n".join(parts) if parts else ""})
try:
return _run_on_mcp_loop(_call(), timeout=120)
except Exception as exc:
logger.error("MCP tool %s/%s call failed: %s", server_name, tool_name, exc)
return json.dumps({"error": f"MCP call failed: {type(exc).__name__}: {exc}"})
return _handler
def _make_check_fn(server_name: str):
"""Return a check function that verifies the MCP connection is alive."""
def _check() -> bool:
server = _servers.get(server_name)
return server is not None and server.session is not None
return _check
# ---------------------------------------------------------------------------
# Discovery & registration
# ---------------------------------------------------------------------------
def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
"""Convert an MCP tool listing to the Hermes registry schema format.
Args:
server_name: The logical server name for prefixing.
mcp_tool: An MCP ``Tool`` object with ``.name``, ``.description``,
and ``.inputSchema``.
Returns:
A dict suitable for ``registry.register(schema=...)``.
"""
# Sanitize: replace hyphens and dots with underscores for LLM API compatibility
safe_tool_name = mcp_tool.name.replace("-", "_").replace(".", "_")
safe_server_name = server_name.replace("-", "_").replace(".", "_")
prefixed_name = f"mcp_{safe_server_name}_{safe_tool_name}"
return {
"name": prefixed_name,
"description": mcp_tool.description or f"MCP tool {mcp_tool.name} from {server_name}",
"parameters": mcp_tool.inputSchema if mcp_tool.inputSchema else {
"type": "object",
"properties": {},
},
}
async def _discover_and_register_server(name: str, config: dict) -> List[str]:
"""Connect to a single MCP server, discover tools, and register them.
Returns list of registered tool names.
"""
from tools.registry import registry
from toolsets import create_custom_toolset
server = await _connect_server(name, config)
_servers[name] = server
registered_names: List[str] = []
toolset_name = f"mcp-{name}"
for mcp_tool in server._tools:
schema = _convert_mcp_schema(name, mcp_tool)
tool_name_prefixed = schema["name"]
registry.register(
name=tool_name_prefixed,
toolset=toolset_name,
schema=schema,
handler=_make_tool_handler(name, mcp_tool.name),
check_fn=_make_check_fn(name),
is_async=False,
description=schema["description"],
)
registered_names.append(tool_name_prefixed)
# Create a custom toolset so these tools are discoverable
if registered_names:
create_custom_toolset(
name=toolset_name,
description=f"MCP tools from {name} server",
tools=registered_names,
)
logger.info(
"MCP server '%s': registered %d tool(s): %s",
name, len(registered_names), ", ".join(registered_names),
)
return registered_names
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def discover_mcp_tools() -> List[str]:
"""Entry point: load config, connect to MCP servers, register tools.
Called from ``model_tools._discover_tools()``. Safe to call even when
the ``mcp`` package is not installed (returns empty list).
Returns:
List of all registered MCP tool names.
"""
if not _MCP_AVAILABLE:
logger.debug("MCP SDK not available -- skipping MCP tool discovery")
return []
servers = _load_mcp_config()
if not servers:
logger.debug("No MCP servers configured")
return []
# Start the background event loop for MCP connections
_ensure_mcp_loop()
all_tools: List[str] = []
async def _discover_all():
for name, cfg in servers.items():
try:
registered = await _discover_and_register_server(name, cfg)
all_tools.extend(registered)
except Exception as exc:
logger.warning("Failed to connect to MCP server '%s': %s", name, exc)
_run_on_mcp_loop(_discover_all(), timeout=60)
if all_tools:
# Add MCP tools to hermes-cli and other platform toolsets
from toolsets import TOOLSETS
for ts_name in ("hermes-cli", "hermes-telegram", "hermes-discord",
"hermes-whatsapp", "hermes-slack"):
ts = TOOLSETS.get(ts_name)
if ts:
for tool_name in all_tools:
if tool_name not in ts["tools"]:
ts["tools"].append(tool_name)
return all_tools
def shutdown_mcp_servers():
"""Close all MCP server connections and stop the background loop.
Each server Task is signalled to exit its ``async with`` block so that
the anyio cancel-scope cleanup happens in the same Task that opened it.
"""
global _mcp_loop, _mcp_thread
if not _servers:
_stop_mcp_loop()
return
async def _shutdown():
for name, server in list(_servers.items()):
try:
await server.shutdown()
except Exception as exc:
logger.debug("Error closing MCP server '%s': %s", name, exc)
_servers.clear()
if _mcp_loop is not None and _mcp_loop.is_running():
try:
future = asyncio.run_coroutine_threadsafe(_shutdown(), _mcp_loop)
future.result(timeout=15)
except Exception as exc:
logger.debug("Error during MCP shutdown: %s", exc)
_stop_mcp_loop()
def _stop_mcp_loop():
"""Stop the background event loop and join its thread."""
global _mcp_loop, _mcp_thread
if _mcp_loop is not None:
_mcp_loop.call_soon_threadsafe(_mcp_loop.stop)
if _mcp_thread is not None:
_mcp_thread.join(timeout=5)
_mcp_thread = None
_mcp_loop.close()
_mcp_loop = None