feat(mcp): add HTTP transport, reconnection, security hardening
Upgrades the MCP client implementation from PR #291 with: - HTTP/Streamable HTTP transport: support 'url' key in config for remote MCP servers (Notion, Slack, Sentry, Supabase, etc.) - Automatic reconnection with exponential backoff (1s-60s, 5 retries) when a server connection drops unexpectedly - Environment variable filtering: only pass safe vars (PATH, HOME, etc.) plus user-specified env to stdio subprocesses (prevents secret leaks) - Credential stripping: sanitize error messages before returning to the LLM (strips GitHub PATs, OpenAI keys, Bearer tokens, etc.) - Configurable per-server timeouts: 'timeout' and 'connect_timeout' keys - Fix shutdown race condition in servers_snapshot variable scoping Test coverage: 50 tests (up from 30), including new tests for env filtering, credential sanitization, HTTP config detection, reconnection logic, and configurable timeouts. All 1162 tests pass (1162 passed, 3 skipped, 0 failed).
This commit is contained in:
@@ -5,6 +5,7 @@ All tests use mocks -- no real MCP servers or subprocesses are started.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@@ -189,7 +190,7 @@ class TestToolHandler:
|
||||
_servers["test_srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_tool_handler("test_srv", "greet")
|
||||
handler = _make_tool_handler("test_srv", "greet", 120)
|
||||
with self._patch_mcp_loop():
|
||||
result = json.loads(handler({"name": "world"}))
|
||||
assert result["result"] == "hello world"
|
||||
@@ -208,7 +209,7 @@ class TestToolHandler:
|
||||
_servers["test_srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_tool_handler("test_srv", "fail_tool")
|
||||
handler = _make_tool_handler("test_srv", "fail_tool", 120)
|
||||
with self._patch_mcp_loop():
|
||||
result = json.loads(handler({}))
|
||||
assert "error" in result
|
||||
@@ -220,7 +221,7 @@ class TestToolHandler:
|
||||
from tools.mcp_tool import _make_tool_handler, _servers
|
||||
|
||||
_servers.pop("ghost", None)
|
||||
handler = _make_tool_handler("ghost", "any_tool")
|
||||
handler = _make_tool_handler("ghost", "any_tool", 120)
|
||||
result = json.loads(handler({}))
|
||||
assert "error" in result
|
||||
assert "not connected" in result["error"]
|
||||
@@ -234,7 +235,7 @@ class TestToolHandler:
|
||||
_servers["test_srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_tool_handler("test_srv", "broken_tool")
|
||||
handler = _make_tool_handler("test_srv", "broken_tool", 120)
|
||||
with self._patch_mcp_loop():
|
||||
result = json.loads(handler({}))
|
||||
assert "error" in result
|
||||
@@ -400,8 +401,8 @@ class TestMCPServerTask:
|
||||
|
||||
asyncio.run(_test())
|
||||
|
||||
def test_empty_env_passed_as_none(self):
|
||||
"""Empty env dict is passed as None to StdioServerParameters."""
|
||||
def test_empty_env_gets_safe_defaults(self):
|
||||
"""Empty env dict gets safe default env vars (PATH, HOME, etc.)."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
mock_session = MagicMock()
|
||||
@@ -414,13 +415,18 @@ class TestMCPServerTask:
|
||||
|
||||
async def _test():
|
||||
with patch("tools.mcp_tool.StdioServerParameters") as mock_params, \
|
||||
p_stdio, p_cs:
|
||||
p_stdio, p_cs, \
|
||||
patch.dict("os.environ", {"PATH": "/usr/bin", "HOME": "/home/test"}, clear=False):
|
||||
server = MCPServerTask("srv")
|
||||
await server.start({"command": "node", "env": {}})
|
||||
|
||||
# Empty dict -> None
|
||||
# Empty dict -> safe env vars (not None)
|
||||
call_kwargs = mock_params.call_args
|
||||
assert call_kwargs.kwargs.get("env") is None
|
||||
env_arg = call_kwargs.kwargs.get("env")
|
||||
assert env_arg is not None
|
||||
assert isinstance(env_arg, dict)
|
||||
assert "PATH" in env_arg
|
||||
assert "HOME" in env_arg
|
||||
|
||||
await server.shutdown()
|
||||
|
||||
@@ -698,3 +704,353 @@ class TestShutdown:
|
||||
assert len(_servers) == 0
|
||||
# Parallel: ~1s, not ~3s. Allow some margin.
|
||||
assert elapsed < 2.5, f"Shutdown took {elapsed:.1f}s, expected ~1s (parallel)"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_safe_env
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBuildSafeEnv:
|
||||
"""Tests for _build_safe_env() environment filtering."""
|
||||
|
||||
def test_only_safe_vars_passed(self):
|
||||
"""Only safe baseline vars and XDG_* from os.environ are included."""
|
||||
from tools.mcp_tool import _build_safe_env
|
||||
|
||||
fake_env = {
|
||||
"PATH": "/usr/bin",
|
||||
"HOME": "/home/test",
|
||||
"USER": "test",
|
||||
"LANG": "en_US.UTF-8",
|
||||
"LC_ALL": "C",
|
||||
"TERM": "xterm",
|
||||
"SHELL": "/bin/bash",
|
||||
"TMPDIR": "/tmp",
|
||||
"XDG_DATA_HOME": "/home/test/.local/share",
|
||||
"SECRET_KEY": "should_not_appear",
|
||||
"AWS_ACCESS_KEY_ID": "AKIAIOSFODNN7EXAMPLE",
|
||||
}
|
||||
with patch.dict("os.environ", fake_env, clear=True):
|
||||
result = _build_safe_env(None)
|
||||
|
||||
# Safe vars present
|
||||
assert result["PATH"] == "/usr/bin"
|
||||
assert result["HOME"] == "/home/test"
|
||||
assert result["USER"] == "test"
|
||||
assert result["LANG"] == "en_US.UTF-8"
|
||||
assert result["XDG_DATA_HOME"] == "/home/test/.local/share"
|
||||
# Unsafe vars excluded
|
||||
assert "SECRET_KEY" not in result
|
||||
assert "AWS_ACCESS_KEY_ID" not in result
|
||||
|
||||
def test_user_env_merged(self):
|
||||
"""User-specified env vars are merged into the safe env."""
|
||||
from tools.mcp_tool import _build_safe_env
|
||||
|
||||
with patch.dict("os.environ", {"PATH": "/usr/bin"}, clear=True):
|
||||
result = _build_safe_env({"MY_CUSTOM_VAR": "hello"})
|
||||
|
||||
assert result["PATH"] == "/usr/bin"
|
||||
assert result["MY_CUSTOM_VAR"] == "hello"
|
||||
|
||||
def test_user_env_overrides_safe(self):
|
||||
"""User env can override safe defaults."""
|
||||
from tools.mcp_tool import _build_safe_env
|
||||
|
||||
with patch.dict("os.environ", {"PATH": "/usr/bin"}, clear=True):
|
||||
result = _build_safe_env({"PATH": "/custom/bin"})
|
||||
|
||||
assert result["PATH"] == "/custom/bin"
|
||||
|
||||
def test_none_user_env(self):
|
||||
"""None user_env still returns safe vars from os.environ."""
|
||||
from tools.mcp_tool import _build_safe_env
|
||||
|
||||
with patch.dict("os.environ", {"PATH": "/usr/bin", "HOME": "/root"}, clear=True):
|
||||
result = _build_safe_env(None)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["PATH"] == "/usr/bin"
|
||||
assert result["HOME"] == "/root"
|
||||
|
||||
def test_secret_vars_excluded(self):
|
||||
"""Sensitive env vars from os.environ are NOT passed through."""
|
||||
from tools.mcp_tool import _build_safe_env
|
||||
|
||||
fake_env = {
|
||||
"PATH": "/usr/bin",
|
||||
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
|
||||
"GITHUB_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"OPENAI_API_KEY": "sk-proj-abc123",
|
||||
"DATABASE_URL": "postgres://user:pass@localhost/db",
|
||||
"API_SECRET": "supersecret",
|
||||
}
|
||||
with patch.dict("os.environ", fake_env, clear=True):
|
||||
result = _build_safe_env(None)
|
||||
|
||||
assert "PATH" in result
|
||||
assert "AWS_SECRET_ACCESS_KEY" not in result
|
||||
assert "GITHUB_TOKEN" not in result
|
||||
assert "OPENAI_API_KEY" not in result
|
||||
assert "DATABASE_URL" not in result
|
||||
assert "API_SECRET" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _sanitize_error
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSanitizeError:
|
||||
"""Tests for _sanitize_error() credential stripping."""
|
||||
|
||||
def test_strips_github_pat(self):
|
||||
from tools.mcp_tool import _sanitize_error
|
||||
result = _sanitize_error("Error with ghp_abc123def456")
|
||||
assert result == "Error with [REDACTED]"
|
||||
|
||||
def test_strips_openai_key(self):
|
||||
from tools.mcp_tool import _sanitize_error
|
||||
result = _sanitize_error("key sk-projABC123xyz")
|
||||
assert result == "key [REDACTED]"
|
||||
|
||||
def test_strips_bearer_token(self):
|
||||
from tools.mcp_tool import _sanitize_error
|
||||
result = _sanitize_error("Authorization: Bearer eyJabc123def")
|
||||
assert result == "Authorization: [REDACTED]"
|
||||
|
||||
def test_strips_token_param(self):
|
||||
from tools.mcp_tool import _sanitize_error
|
||||
result = _sanitize_error("url?token=secret123")
|
||||
assert result == "url?[REDACTED]"
|
||||
|
||||
def test_no_credentials_unchanged(self):
|
||||
from tools.mcp_tool import _sanitize_error
|
||||
result = _sanitize_error("normal error message")
|
||||
assert result == "normal error message"
|
||||
|
||||
def test_multiple_credentials(self):
|
||||
from tools.mcp_tool import _sanitize_error
|
||||
result = _sanitize_error("ghp_abc123 and sk-projXyz789 and token=foo")
|
||||
assert "ghp_" not in result
|
||||
assert "sk-" not in result
|
||||
assert "token=" not in result
|
||||
assert result.count("[REDACTED]") == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHTTPConfig:
|
||||
"""Tests for HTTP transport detection and handling."""
|
||||
|
||||
def test_is_http_with_url(self):
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
server = MCPServerTask("remote")
|
||||
server._config = {"url": "https://example.com/mcp"}
|
||||
assert server._is_http() is True
|
||||
|
||||
def test_is_stdio_with_command(self):
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
server = MCPServerTask("local")
|
||||
server._config = {"command": "npx", "args": []}
|
||||
assert server._is_http() is False
|
||||
|
||||
def test_http_unavailable_raises(self):
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
server = MCPServerTask("remote")
|
||||
config = {"url": "https://example.com/mcp"}
|
||||
|
||||
async def _test():
|
||||
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", False):
|
||||
with pytest.raises(ImportError, match="HTTP transport"):
|
||||
await server._run_http(config)
|
||||
|
||||
asyncio.run(_test())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reconnection logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReconnection:
|
||||
"""Tests for automatic reconnection behavior in MCPServerTask.run()."""
|
||||
|
||||
def test_reconnect_on_disconnect(self):
|
||||
"""After initial success, a connection drop triggers reconnection."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
run_count = 0
|
||||
target_server = None
|
||||
|
||||
original_run_stdio = MCPServerTask._run_stdio
|
||||
|
||||
async def patched_run_stdio(self_srv, config):
|
||||
nonlocal run_count, target_server
|
||||
run_count += 1
|
||||
if target_server is not self_srv:
|
||||
return await original_run_stdio(self_srv, config)
|
||||
if run_count == 1:
|
||||
# First connection succeeds, then simulate disconnect
|
||||
self_srv.session = MagicMock()
|
||||
self_srv._tools = []
|
||||
self_srv._ready.set()
|
||||
raise ConnectionError("connection dropped")
|
||||
else:
|
||||
# Reconnection succeeds; signal shutdown so run() exits
|
||||
self_srv.session = MagicMock()
|
||||
self_srv._shutdown_event.set()
|
||||
await self_srv._shutdown_event.wait()
|
||||
|
||||
async def _test():
|
||||
nonlocal target_server
|
||||
server = MCPServerTask("test_srv")
|
||||
target_server = server
|
||||
|
||||
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
await server.run({"command": "test"})
|
||||
|
||||
assert run_count >= 2 # At least one reconnection attempt
|
||||
|
||||
asyncio.run(_test())
|
||||
|
||||
def test_no_reconnect_on_shutdown(self):
|
||||
"""If shutdown is requested, don't attempt reconnection."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
run_count = 0
|
||||
target_server = None
|
||||
|
||||
original_run_stdio = MCPServerTask._run_stdio
|
||||
|
||||
async def patched_run_stdio(self_srv, config):
|
||||
nonlocal run_count, target_server
|
||||
run_count += 1
|
||||
if target_server is not self_srv:
|
||||
return await original_run_stdio(self_srv, config)
|
||||
self_srv.session = MagicMock()
|
||||
self_srv._tools = []
|
||||
self_srv._ready.set()
|
||||
raise ConnectionError("connection dropped")
|
||||
|
||||
async def _test():
|
||||
nonlocal target_server
|
||||
server = MCPServerTask("test_srv")
|
||||
target_server = server
|
||||
server._shutdown_event.set() # Shutdown already requested
|
||||
|
||||
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
await server.run({"command": "test"})
|
||||
|
||||
# Should not retry because shutdown was set
|
||||
assert run_count == 1
|
||||
|
||||
asyncio.run(_test())
|
||||
|
||||
def test_no_reconnect_on_initial_failure(self):
|
||||
"""First connection failure reports error immediately, no retry."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
run_count = 0
|
||||
target_server = None
|
||||
|
||||
original_run_stdio = MCPServerTask._run_stdio
|
||||
|
||||
async def patched_run_stdio(self_srv, config):
|
||||
nonlocal run_count, target_server
|
||||
run_count += 1
|
||||
if target_server is not self_srv:
|
||||
return await original_run_stdio(self_srv, config)
|
||||
raise ConnectionError("cannot connect")
|
||||
|
||||
async def _test():
|
||||
nonlocal target_server
|
||||
server = MCPServerTask("test_srv")
|
||||
target_server = server
|
||||
|
||||
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
await server.run({"command": "test"})
|
||||
|
||||
# Only one attempt, no retry on initial failure
|
||||
assert run_count == 1
|
||||
assert server._error is not None
|
||||
assert "cannot connect" in str(server._error)
|
||||
|
||||
asyncio.run(_test())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configurable timeouts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestConfigurableTimeouts:
|
||||
"""Tests for configurable per-server timeouts."""
|
||||
|
||||
def test_default_timeout(self):
|
||||
"""Server with no timeout config gets _DEFAULT_TOOL_TIMEOUT."""
|
||||
from tools.mcp_tool import MCPServerTask, _DEFAULT_TOOL_TIMEOUT
|
||||
|
||||
server = MCPServerTask("test_srv")
|
||||
assert server.tool_timeout == _DEFAULT_TOOL_TIMEOUT
|
||||
assert server.tool_timeout == 120
|
||||
|
||||
def test_custom_timeout(self):
|
||||
"""Server with timeout=180 in config gets 180."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
target_server = None
|
||||
|
||||
original_run_stdio = MCPServerTask._run_stdio
|
||||
|
||||
async def patched_run_stdio(self_srv, config):
|
||||
if target_server is not self_srv:
|
||||
return await original_run_stdio(self_srv, config)
|
||||
self_srv.session = MagicMock()
|
||||
self_srv._tools = []
|
||||
self_srv._ready.set()
|
||||
await self_srv._shutdown_event.wait()
|
||||
|
||||
async def _test():
|
||||
nonlocal target_server
|
||||
server = MCPServerTask("test_srv")
|
||||
target_server = server
|
||||
|
||||
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio):
|
||||
task = asyncio.ensure_future(
|
||||
server.run({"command": "test", "timeout": 180})
|
||||
)
|
||||
await server._ready.wait()
|
||||
assert server.tool_timeout == 180
|
||||
server._shutdown_event.set()
|
||||
await task
|
||||
|
||||
asyncio.run(_test())
|
||||
|
||||
def test_timeout_passed_to_handler(self):
|
||||
"""The tool handler uses the server's configured timeout."""
|
||||
from tools.mcp_tool import _make_tool_handler, _servers, MCPServerTask
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.call_tool = AsyncMock(
|
||||
return_value=_make_call_result("ok", is_error=False)
|
||||
)
|
||||
server = _make_mock_server("test_srv", session=mock_session)
|
||||
server.tool_timeout = 180
|
||||
_servers["test_srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_tool_handler("test_srv", "my_tool", 180)
|
||||
with patch("tools.mcp_tool._run_on_mcp_loop") as mock_run:
|
||||
mock_run.return_value = json.dumps({"result": "ok"})
|
||||
handler({})
|
||||
# Verify timeout=180 was passed
|
||||
call_kwargs = mock_run.call_args
|
||||
assert call_kwargs.kwargs.get("timeout") == 180 or \
|
||||
(len(call_kwargs.args) > 1 and call_kwargs.args[1] == 180) or \
|
||||
call_kwargs[1].get("timeout") == 180
|
||||
finally:
|
||||
_servers.pop("test_srv", None)
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
"""
|
||||
MCP (Model Context Protocol) Client Support
|
||||
|
||||
Connects to external MCP servers via stdio transport, discovers their tools,
|
||||
and registers them into the hermes-agent tool registry so the agent can call
|
||||
them like any built-in tool.
|
||||
Connects to external MCP servers via stdio or HTTP/StreamableHTTP transport,
|
||||
discovers their tools, and registers them into the hermes-agent tool registry
|
||||
so the agent can call them like any built-in tool.
|
||||
|
||||
Configuration is read from ~/.hermes/config.yaml under the ``mcp_servers`` key.
|
||||
The ``mcp`` Python package is optional -- if not installed, this module is a
|
||||
@@ -17,17 +17,32 @@ Example config::
|
||||
command: "npx"
|
||||
args: ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
|
||||
env: {}
|
||||
timeout: 120 # per-tool-call timeout in seconds (default: 120)
|
||||
connect_timeout: 60 # initial connection timeout (default: 60)
|
||||
github:
|
||||
command: "npx"
|
||||
args: ["-y", "@modelcontextprotocol/server-github"]
|
||||
env:
|
||||
GITHUB_PERSONAL_ACCESS_TOKEN: "ghp_..."
|
||||
remote_api:
|
||||
url: "https://my-mcp-server.example.com/mcp"
|
||||
headers:
|
||||
Authorization: "Bearer sk-..."
|
||||
timeout: 180
|
||||
|
||||
Features:
|
||||
- Stdio transport (command + args) and HTTP/StreamableHTTP transport (url)
|
||||
- Automatic reconnection with exponential backoff (up to 5 retries)
|
||||
- Environment variable filtering for stdio subprocesses (security)
|
||||
- Credential stripping in error messages returned to the LLM
|
||||
- Configurable per-server timeouts for tool calls and connections
|
||||
- Thread-safe architecture with dedicated background event loop
|
||||
|
||||
Architecture:
|
||||
A dedicated background event loop (_mcp_loop) runs in a daemon thread.
|
||||
Each MCP server runs as a long-lived asyncio Task on this loop, keeping
|
||||
its ``async with stdio_client(...)`` context alive. Tool call coroutines
|
||||
are scheduled onto the loop via ``run_coroutine_threadsafe()``.
|
||||
its transport context alive. Tool call coroutines are scheduled onto the
|
||||
loop via ``run_coroutine_threadsafe()``.
|
||||
|
||||
On shutdown, each server Task is signalled to exit its ``async with``
|
||||
block, ensuring the anyio cancel-scope cleanup happens in the *same*
|
||||
@@ -43,6 +58,8 @@ Thread safety:
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -53,13 +70,81 @@ logger = logging.getLogger(__name__)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MCP_AVAILABLE = False
|
||||
_MCP_HTTP_AVAILABLE = False
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
_MCP_AVAILABLE = True
|
||||
try:
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
_MCP_HTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
_MCP_HTTP_AVAILABLE = False
|
||||
except ImportError:
|
||||
logger.debug("mcp package not installed -- MCP tool support disabled")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls
|
||||
_DEFAULT_DISCOVERY_TIMEOUT = 60 # seconds for server discovery
|
||||
_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection
|
||||
_MAX_RECONNECT_RETRIES = 5
|
||||
_MAX_BACKOFF_SECONDS = 60
|
||||
|
||||
# Environment variables that are safe to pass to stdio subprocesses
|
||||
_SAFE_ENV_KEYS = frozenset({
|
||||
"PATH", "HOME", "USER", "LANG", "LC_ALL", "TERM", "SHELL", "TMPDIR",
|
||||
})
|
||||
|
||||
# Regex for credential patterns to strip from error messages
|
||||
_CREDENTIAL_PATTERN = re.compile(
|
||||
r"(?:"
|
||||
r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT
|
||||
r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key
|
||||
r"|Bearer\s+\S+" # Bearer token
|
||||
r"|token=[^\s&,;\"']{1,255}" # token=...
|
||||
r"|key=[^\s&,;\"']{1,255}" # key=...
|
||||
r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=...
|
||||
r"|password=[^\s&,;\"']{1,255}" # password=...
|
||||
r"|secret=[^\s&,;\"']{1,255}" # secret=...
|
||||
r")",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Security helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _build_safe_env(user_env: Optional[dict]) -> dict:
|
||||
"""Build a filtered environment dict for stdio subprocesses.
|
||||
|
||||
Only passes through safe baseline variables (PATH, HOME, etc.) and XDG_*
|
||||
variables from the current process environment, plus any variables
|
||||
explicitly specified by the user in the server config.
|
||||
|
||||
This prevents accidentally leaking secrets like API keys, tokens, or
|
||||
credentials to MCP server subprocesses.
|
||||
"""
|
||||
env = {}
|
||||
for key, value in os.environ.items():
|
||||
if key in _SAFE_ENV_KEYS or key.startswith("XDG_"):
|
||||
env[key] = value
|
||||
if user_env:
|
||||
env.update(user_env)
|
||||
return env
|
||||
|
||||
|
||||
def _sanitize_error(text: str) -> str:
|
||||
"""Strip credential-like patterns from error text before returning to LLM.
|
||||
|
||||
Replaces tokens, keys, and other secrets with [REDACTED] to prevent
|
||||
accidental credential exposure in tool error responses.
|
||||
"""
|
||||
return _CREDENTIAL_PATTERN.sub("[REDACTED]", text)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Server task -- each MCP server lives in one long-lived asyncio Task
|
||||
@@ -70,66 +155,152 @@ class MCPServerTask:
|
||||
|
||||
The entire connection lifecycle (connect, discover, serve, disconnect)
|
||||
runs inside one asyncio Task so that anyio cancel-scopes created by
|
||||
``stdio_client`` are entered and exited in the same Task context.
|
||||
the transport client are entered and exited in the same Task context.
|
||||
|
||||
Supports both stdio and HTTP/StreamableHTTP transports.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"name", "session",
|
||||
"_task", "_ready", "_shutdown_event", "_tools", "_error",
|
||||
"name", "session", "tool_timeout",
|
||||
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
|
||||
)
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.session: Optional[Any] = None
|
||||
self.tool_timeout: float = _DEFAULT_TOOL_TIMEOUT
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._ready = asyncio.Event()
|
||||
self._shutdown_event = asyncio.Event()
|
||||
self._tools: list = []
|
||||
self._error: Optional[Exception] = None
|
||||
self._config: dict = {}
|
||||
|
||||
async def run(self, config: dict):
|
||||
"""Long-lived coroutine: connect, discover tools, wait, disconnect."""
|
||||
def _is_http(self) -> bool:
|
||||
"""Check if this server uses HTTP transport."""
|
||||
return "url" in self._config
|
||||
|
||||
async def _run_stdio(self, config: dict):
|
||||
"""Run the server using stdio transport."""
|
||||
command = config.get("command")
|
||||
args = config.get("args", [])
|
||||
env = config.get("env")
|
||||
user_env = config.get("env")
|
||||
|
||||
if not command:
|
||||
self._error = ValueError(
|
||||
raise ValueError(
|
||||
f"MCP server '{self.name}' has no 'command' in config"
|
||||
)
|
||||
self._ready.set()
|
||||
return
|
||||
|
||||
safe_env = _build_safe_env(user_env)
|
||||
server_params = StdioServerParameters(
|
||||
command=command,
|
||||
args=args,
|
||||
env=env if env else None,
|
||||
env=safe_env if safe_env else None,
|
||||
)
|
||||
|
||||
try:
|
||||
async with stdio_client(server_params) as (read_stream, write_stream):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
async with stdio_client(server_params) as (read_stream, write_stream):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
await self._discover_tools()
|
||||
self._ready.set()
|
||||
await self._shutdown_event.wait()
|
||||
|
||||
tools_result = await session.list_tools()
|
||||
self._tools = (
|
||||
tools_result.tools
|
||||
if hasattr(tools_result, "tools")
|
||||
else []
|
||||
)
|
||||
async def _run_http(self, config: dict):
|
||||
"""Run the server using HTTP/StreamableHTTP transport."""
|
||||
if not _MCP_HTTP_AVAILABLE:
|
||||
raise ImportError(
|
||||
f"MCP server '{self.name}' requires HTTP transport but "
|
||||
"mcp.client.streamable_http is not available. "
|
||||
"Upgrade the mcp package to get HTTP support."
|
||||
)
|
||||
|
||||
# Signal that connection is ready
|
||||
url = config["url"]
|
||||
headers = config.get("headers")
|
||||
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
||||
|
||||
async with streamablehttp_client(
|
||||
url,
|
||||
headers=headers,
|
||||
timeout=float(connect_timeout),
|
||||
) as (read_stream, write_stream, _get_session_id):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
await self._discover_tools()
|
||||
self._ready.set()
|
||||
await self._shutdown_event.wait()
|
||||
|
||||
async def _discover_tools(self):
|
||||
"""Discover tools from the connected session."""
|
||||
if self.session is None:
|
||||
return
|
||||
tools_result = await self.session.list_tools()
|
||||
self._tools = (
|
||||
tools_result.tools
|
||||
if hasattr(tools_result, "tools")
|
||||
else []
|
||||
)
|
||||
|
||||
async def run(self, config: dict):
|
||||
"""Long-lived coroutine: connect, discover tools, wait, disconnect.
|
||||
|
||||
Includes automatic reconnection with exponential backoff if the
|
||||
connection drops unexpectedly (unless shutdown was requested).
|
||||
"""
|
||||
self._config = config
|
||||
self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT)
|
||||
retries = 0
|
||||
backoff = 1.0
|
||||
|
||||
while True:
|
||||
try:
|
||||
if self._is_http():
|
||||
await self._run_http(config)
|
||||
else:
|
||||
await self._run_stdio(config)
|
||||
# Normal exit (shutdown requested) -- break out
|
||||
break
|
||||
except Exception as exc:
|
||||
self.session = None
|
||||
|
||||
# If this is the first connection attempt, report the error
|
||||
if not self._ready.is_set():
|
||||
self._error = exc
|
||||
self._ready.set()
|
||||
return
|
||||
|
||||
# Block until shutdown is requested -- this keeps the
|
||||
# async-with contexts alive on THIS Task.
|
||||
await self._shutdown_event.wait()
|
||||
except Exception as exc:
|
||||
self._error = exc
|
||||
self._ready.set()
|
||||
finally:
|
||||
self.session = None
|
||||
# If shutdown was requested, don't reconnect
|
||||
if self._shutdown_event.is_set():
|
||||
logger.debug(
|
||||
"MCP server '%s' disconnected during shutdown: %s",
|
||||
self.name, exc,
|
||||
)
|
||||
return
|
||||
|
||||
retries += 1
|
||||
if retries > _MAX_RECONNECT_RETRIES:
|
||||
logger.warning(
|
||||
"MCP server '%s' failed after %d reconnection attempts, "
|
||||
"giving up: %s",
|
||||
self.name, _MAX_RECONNECT_RETRIES, exc,
|
||||
)
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"MCP server '%s' connection lost (attempt %d/%d), "
|
||||
"reconnecting in %.0fs: %s",
|
||||
self.name, retries, _MAX_RECONNECT_RETRIES,
|
||||
backoff, exc,
|
||||
)
|
||||
await asyncio.sleep(backoff)
|
||||
backoff = min(backoff * 2, _MAX_BACKOFF_SECONDS)
|
||||
|
||||
# Check again after sleeping
|
||||
if self._shutdown_event.is_set():
|
||||
return
|
||||
finally:
|
||||
self.session = None
|
||||
|
||||
async def start(self, config: dict):
|
||||
"""Create the background Task and wait until ready (or failed)."""
|
||||
@@ -203,7 +374,10 @@ def _run_on_mcp_loop(coro, timeout: float = 30):
|
||||
def _load_mcp_config() -> Dict[str, dict]:
|
||||
"""Read ``mcp_servers`` from the Hermes config file.
|
||||
|
||||
Returns a dict of ``{server_name: {command, args, env}}`` or empty dict.
|
||||
Returns a dict of ``{server_name: server_config}`` or empty dict.
|
||||
Server config can contain either ``command``/``args``/``env`` for stdio
|
||||
transport or ``url``/``headers`` for HTTP transport, plus optional
|
||||
``timeout`` and ``connect_timeout`` overrides.
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
@@ -224,11 +398,12 @@ def _load_mcp_config() -> Dict[str, dict]:
|
||||
async def _connect_server(name: str, config: dict) -> MCPServerTask:
|
||||
"""Create an MCPServerTask, start it, and return when ready.
|
||||
|
||||
The server Task keeps the subprocess alive in the background.
|
||||
The server Task keeps the connection alive in the background.
|
||||
Call ``server.shutdown()`` (on the same event loop) to tear it down.
|
||||
|
||||
Raises:
|
||||
ValueError: if ``command`` is missing from *config*.
|
||||
ValueError: if required config keys are missing.
|
||||
ImportError: if HTTP transport is needed but not available.
|
||||
Exception: on connection or initialization failure.
|
||||
"""
|
||||
server = MCPServerTask(name)
|
||||
@@ -240,7 +415,7 @@ async def _connect_server(name: str, config: dict) -> MCPServerTask:
|
||||
# Handler / check-fn factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_tool_handler(server_name: str, tool_name: str):
|
||||
def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
||||
"""Return a sync handler that calls an MCP tool via the background loop.
|
||||
|
||||
The handler conforms to the registry's dispatch interface:
|
||||
@@ -263,7 +438,11 @@ def _make_tool_handler(server_name: str, tool_name: str):
|
||||
for block in (result.content or []):
|
||||
if hasattr(block, "text"):
|
||||
error_text += block.text
|
||||
return json.dumps({"error": error_text or "MCP tool returned an error"})
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
error_text or "MCP tool returned an error"
|
||||
)
|
||||
})
|
||||
|
||||
# Collect text from content blocks
|
||||
parts: List[str] = []
|
||||
@@ -273,10 +452,17 @@ def _make_tool_handler(server_name: str, tool_name: str):
|
||||
return json.dumps({"result": "\n".join(parts) if parts else ""})
|
||||
|
||||
try:
|
||||
return _run_on_mcp_loop(_call(), timeout=120)
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
except Exception as exc:
|
||||
logger.error("MCP tool %s/%s call failed: %s", server_name, tool_name, exc)
|
||||
return json.dumps({"error": f"MCP call failed: {type(exc).__name__}: {exc}"})
|
||||
logger.error(
|
||||
"MCP tool %s/%s call failed: %s",
|
||||
server_name, tool_name, exc,
|
||||
)
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
|
||||
return _handler
|
||||
|
||||
@@ -339,7 +525,11 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
from tools.registry import registry
|
||||
from toolsets import create_custom_toolset
|
||||
|
||||
server = await _connect_server(name, config)
|
||||
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
||||
server = await asyncio.wait_for(
|
||||
_connect_server(name, config),
|
||||
timeout=connect_timeout,
|
||||
)
|
||||
with _lock:
|
||||
_servers[name] = server
|
||||
|
||||
@@ -354,7 +544,7 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
name=tool_name_prefixed,
|
||||
toolset=toolset_name,
|
||||
schema=schema,
|
||||
handler=_make_tool_handler(name, mcp_tool.name),
|
||||
handler=_make_tool_handler(name, mcp_tool.name, server.tool_timeout),
|
||||
check_fn=_make_check_fn(name),
|
||||
is_async=False,
|
||||
description=schema["description"],
|
||||
@@ -369,9 +559,11 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
tools=registered_names,
|
||||
)
|
||||
|
||||
transport_type = "HTTP" if "url" in config else "stdio"
|
||||
logger.info(
|
||||
"MCP server '%s': registered %d tool(s): %s",
|
||||
name, len(registered_names), ", ".join(registered_names),
|
||||
"MCP server '%s' (%s): registered %d tool(s): %s",
|
||||
name, transport_type, len(registered_names),
|
||||
", ".join(registered_names),
|
||||
)
|
||||
return registered_names
|
||||
|
||||
@@ -419,9 +611,12 @@ def discover_mcp_tools() -> List[str]:
|
||||
registered = await _discover_and_register_server(name, cfg)
|
||||
all_tools.extend(registered)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to connect to MCP server '%s': %s", name, exc)
|
||||
logger.warning(
|
||||
"Failed to connect to MCP server '%s': %s",
|
||||
name, exc,
|
||||
)
|
||||
|
||||
_run_on_mcp_loop(_discover_all(), timeout=60)
|
||||
_run_on_mcp_loop(_discover_all(), timeout=_DEFAULT_DISCOVERY_TIMEOUT)
|
||||
|
||||
if all_tools:
|
||||
# Dynamically inject into all hermes-* platform toolsets
|
||||
@@ -444,15 +639,10 @@ def shutdown_mcp_servers():
|
||||
All servers are shut down in parallel via ``asyncio.gather``.
|
||||
"""
|
||||
with _lock:
|
||||
if not _servers:
|
||||
# No servers -- just stop the loop. _stop_mcp_loop() also
|
||||
# acquires _lock, so we must release it first.
|
||||
pass
|
||||
else:
|
||||
servers_snapshot = list(_servers.values())
|
||||
servers_snapshot = list(_servers.values())
|
||||
|
||||
# Fast path: nothing to shut down.
|
||||
if not _servers:
|
||||
if not servers_snapshot:
|
||||
_stop_mcp_loop()
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user