feat(mcp): add HTTP transport, reconnection, security hardening
Upgrades the MCP client implementation from PR #291 with: - HTTP/Streamable HTTP transport: support 'url' key in config for remote MCP servers (Notion, Slack, Sentry, Supabase, etc.) - Automatic reconnection with exponential backoff (1s-60s, 5 retries) when a server connection drops unexpectedly - Environment variable filtering: only pass safe vars (PATH, HOME, etc.) plus user-specified env to stdio subprocesses (prevents secret leaks) - Credential stripping: sanitize error messages before returning to the LLM (strips GitHub PATs, OpenAI keys, Bearer tokens, etc.) - Configurable per-server timeouts: 'timeout' and 'connect_timeout' keys - Fix shutdown race condition in servers_snapshot variable scoping Test coverage: 50 tests (up from 30), including new tests for env filtering, credential sanitization, HTTP config detection, reconnection logic, and configurable timeouts. All 1162 tests pass (1162 passed, 3 skipped, 0 failed).
This commit is contained in:
@@ -5,6 +5,7 @@ All tests use mocks -- no real MCP servers or subprocesses are started.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@@ -189,7 +190,7 @@ class TestToolHandler:
|
||||
_servers["test_srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_tool_handler("test_srv", "greet")
|
||||
handler = _make_tool_handler("test_srv", "greet", 120)
|
||||
with self._patch_mcp_loop():
|
||||
result = json.loads(handler({"name": "world"}))
|
||||
assert result["result"] == "hello world"
|
||||
@@ -208,7 +209,7 @@ class TestToolHandler:
|
||||
_servers["test_srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_tool_handler("test_srv", "fail_tool")
|
||||
handler = _make_tool_handler("test_srv", "fail_tool", 120)
|
||||
with self._patch_mcp_loop():
|
||||
result = json.loads(handler({}))
|
||||
assert "error" in result
|
||||
@@ -220,7 +221,7 @@ class TestToolHandler:
|
||||
from tools.mcp_tool import _make_tool_handler, _servers
|
||||
|
||||
_servers.pop("ghost", None)
|
||||
handler = _make_tool_handler("ghost", "any_tool")
|
||||
handler = _make_tool_handler("ghost", "any_tool", 120)
|
||||
result = json.loads(handler({}))
|
||||
assert "error" in result
|
||||
assert "not connected" in result["error"]
|
||||
@@ -234,7 +235,7 @@ class TestToolHandler:
|
||||
_servers["test_srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_tool_handler("test_srv", "broken_tool")
|
||||
handler = _make_tool_handler("test_srv", "broken_tool", 120)
|
||||
with self._patch_mcp_loop():
|
||||
result = json.loads(handler({}))
|
||||
assert "error" in result
|
||||
@@ -400,8 +401,8 @@ class TestMCPServerTask:
|
||||
|
||||
asyncio.run(_test())
|
||||
|
||||
def test_empty_env_passed_as_none(self):
|
||||
"""Empty env dict is passed as None to StdioServerParameters."""
|
||||
def test_empty_env_gets_safe_defaults(self):
|
||||
"""Empty env dict gets safe default env vars (PATH, HOME, etc.)."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
mock_session = MagicMock()
|
||||
@@ -414,13 +415,18 @@ class TestMCPServerTask:
|
||||
|
||||
async def _test():
|
||||
with patch("tools.mcp_tool.StdioServerParameters") as mock_params, \
|
||||
p_stdio, p_cs:
|
||||
p_stdio, p_cs, \
|
||||
patch.dict("os.environ", {"PATH": "/usr/bin", "HOME": "/home/test"}, clear=False):
|
||||
server = MCPServerTask("srv")
|
||||
await server.start({"command": "node", "env": {}})
|
||||
|
||||
# Empty dict -> None
|
||||
# Empty dict -> safe env vars (not None)
|
||||
call_kwargs = mock_params.call_args
|
||||
assert call_kwargs.kwargs.get("env") is None
|
||||
env_arg = call_kwargs.kwargs.get("env")
|
||||
assert env_arg is not None
|
||||
assert isinstance(env_arg, dict)
|
||||
assert "PATH" in env_arg
|
||||
assert "HOME" in env_arg
|
||||
|
||||
await server.shutdown()
|
||||
|
||||
@@ -698,3 +704,353 @@ class TestShutdown:
|
||||
assert len(_servers) == 0
|
||||
# Parallel: ~1s, not ~3s. Allow some margin.
|
||||
assert elapsed < 2.5, f"Shutdown took {elapsed:.1f}s, expected ~1s (parallel)"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_safe_env
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBuildSafeEnv:
|
||||
"""Tests for _build_safe_env() environment filtering."""
|
||||
|
||||
def test_only_safe_vars_passed(self):
|
||||
"""Only safe baseline vars and XDG_* from os.environ are included."""
|
||||
from tools.mcp_tool import _build_safe_env
|
||||
|
||||
fake_env = {
|
||||
"PATH": "/usr/bin",
|
||||
"HOME": "/home/test",
|
||||
"USER": "test",
|
||||
"LANG": "en_US.UTF-8",
|
||||
"LC_ALL": "C",
|
||||
"TERM": "xterm",
|
||||
"SHELL": "/bin/bash",
|
||||
"TMPDIR": "/tmp",
|
||||
"XDG_DATA_HOME": "/home/test/.local/share",
|
||||
"SECRET_KEY": "should_not_appear",
|
||||
"AWS_ACCESS_KEY_ID": "AKIAIOSFODNN7EXAMPLE",
|
||||
}
|
||||
with patch.dict("os.environ", fake_env, clear=True):
|
||||
result = _build_safe_env(None)
|
||||
|
||||
# Safe vars present
|
||||
assert result["PATH"] == "/usr/bin"
|
||||
assert result["HOME"] == "/home/test"
|
||||
assert result["USER"] == "test"
|
||||
assert result["LANG"] == "en_US.UTF-8"
|
||||
assert result["XDG_DATA_HOME"] == "/home/test/.local/share"
|
||||
# Unsafe vars excluded
|
||||
assert "SECRET_KEY" not in result
|
||||
assert "AWS_ACCESS_KEY_ID" not in result
|
||||
|
||||
def test_user_env_merged(self):
|
||||
"""User-specified env vars are merged into the safe env."""
|
||||
from tools.mcp_tool import _build_safe_env
|
||||
|
||||
with patch.dict("os.environ", {"PATH": "/usr/bin"}, clear=True):
|
||||
result = _build_safe_env({"MY_CUSTOM_VAR": "hello"})
|
||||
|
||||
assert result["PATH"] == "/usr/bin"
|
||||
assert result["MY_CUSTOM_VAR"] == "hello"
|
||||
|
||||
def test_user_env_overrides_safe(self):
|
||||
"""User env can override safe defaults."""
|
||||
from tools.mcp_tool import _build_safe_env
|
||||
|
||||
with patch.dict("os.environ", {"PATH": "/usr/bin"}, clear=True):
|
||||
result = _build_safe_env({"PATH": "/custom/bin"})
|
||||
|
||||
assert result["PATH"] == "/custom/bin"
|
||||
|
||||
def test_none_user_env(self):
|
||||
"""None user_env still returns safe vars from os.environ."""
|
||||
from tools.mcp_tool import _build_safe_env
|
||||
|
||||
with patch.dict("os.environ", {"PATH": "/usr/bin", "HOME": "/root"}, clear=True):
|
||||
result = _build_safe_env(None)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["PATH"] == "/usr/bin"
|
||||
assert result["HOME"] == "/root"
|
||||
|
||||
def test_secret_vars_excluded(self):
|
||||
"""Sensitive env vars from os.environ are NOT passed through."""
|
||||
from tools.mcp_tool import _build_safe_env
|
||||
|
||||
fake_env = {
|
||||
"PATH": "/usr/bin",
|
||||
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
|
||||
"GITHUB_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"OPENAI_API_KEY": "sk-proj-abc123",
|
||||
"DATABASE_URL": "postgres://user:pass@localhost/db",
|
||||
"API_SECRET": "supersecret",
|
||||
}
|
||||
with patch.dict("os.environ", fake_env, clear=True):
|
||||
result = _build_safe_env(None)
|
||||
|
||||
assert "PATH" in result
|
||||
assert "AWS_SECRET_ACCESS_KEY" not in result
|
||||
assert "GITHUB_TOKEN" not in result
|
||||
assert "OPENAI_API_KEY" not in result
|
||||
assert "DATABASE_URL" not in result
|
||||
assert "API_SECRET" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _sanitize_error
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSanitizeError:
|
||||
"""Tests for _sanitize_error() credential stripping."""
|
||||
|
||||
def test_strips_github_pat(self):
|
||||
from tools.mcp_tool import _sanitize_error
|
||||
result = _sanitize_error("Error with ghp_abc123def456")
|
||||
assert result == "Error with [REDACTED]"
|
||||
|
||||
def test_strips_openai_key(self):
|
||||
from tools.mcp_tool import _sanitize_error
|
||||
result = _sanitize_error("key sk-projABC123xyz")
|
||||
assert result == "key [REDACTED]"
|
||||
|
||||
def test_strips_bearer_token(self):
|
||||
from tools.mcp_tool import _sanitize_error
|
||||
result = _sanitize_error("Authorization: Bearer eyJabc123def")
|
||||
assert result == "Authorization: [REDACTED]"
|
||||
|
||||
def test_strips_token_param(self):
|
||||
from tools.mcp_tool import _sanitize_error
|
||||
result = _sanitize_error("url?token=secret123")
|
||||
assert result == "url?[REDACTED]"
|
||||
|
||||
def test_no_credentials_unchanged(self):
|
||||
from tools.mcp_tool import _sanitize_error
|
||||
result = _sanitize_error("normal error message")
|
||||
assert result == "normal error message"
|
||||
|
||||
def test_multiple_credentials(self):
|
||||
from tools.mcp_tool import _sanitize_error
|
||||
result = _sanitize_error("ghp_abc123 and sk-projXyz789 and token=foo")
|
||||
assert "ghp_" not in result
|
||||
assert "sk-" not in result
|
||||
assert "token=" not in result
|
||||
assert result.count("[REDACTED]") == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHTTPConfig:
|
||||
"""Tests for HTTP transport detection and handling."""
|
||||
|
||||
def test_is_http_with_url(self):
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
server = MCPServerTask("remote")
|
||||
server._config = {"url": "https://example.com/mcp"}
|
||||
assert server._is_http() is True
|
||||
|
||||
def test_is_stdio_with_command(self):
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
server = MCPServerTask("local")
|
||||
server._config = {"command": "npx", "args": []}
|
||||
assert server._is_http() is False
|
||||
|
||||
def test_http_unavailable_raises(self):
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
server = MCPServerTask("remote")
|
||||
config = {"url": "https://example.com/mcp"}
|
||||
|
||||
async def _test():
|
||||
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", False):
|
||||
with pytest.raises(ImportError, match="HTTP transport"):
|
||||
await server._run_http(config)
|
||||
|
||||
asyncio.run(_test())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reconnection logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReconnection:
|
||||
"""Tests for automatic reconnection behavior in MCPServerTask.run()."""
|
||||
|
||||
def test_reconnect_on_disconnect(self):
|
||||
"""After initial success, a connection drop triggers reconnection."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
run_count = 0
|
||||
target_server = None
|
||||
|
||||
original_run_stdio = MCPServerTask._run_stdio
|
||||
|
||||
async def patched_run_stdio(self_srv, config):
|
||||
nonlocal run_count, target_server
|
||||
run_count += 1
|
||||
if target_server is not self_srv:
|
||||
return await original_run_stdio(self_srv, config)
|
||||
if run_count == 1:
|
||||
# First connection succeeds, then simulate disconnect
|
||||
self_srv.session = MagicMock()
|
||||
self_srv._tools = []
|
||||
self_srv._ready.set()
|
||||
raise ConnectionError("connection dropped")
|
||||
else:
|
||||
# Reconnection succeeds; signal shutdown so run() exits
|
||||
self_srv.session = MagicMock()
|
||||
self_srv._shutdown_event.set()
|
||||
await self_srv._shutdown_event.wait()
|
||||
|
||||
async def _test():
|
||||
nonlocal target_server
|
||||
server = MCPServerTask("test_srv")
|
||||
target_server = server
|
||||
|
||||
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
await server.run({"command": "test"})
|
||||
|
||||
assert run_count >= 2 # At least one reconnection attempt
|
||||
|
||||
asyncio.run(_test())
|
||||
|
||||
def test_no_reconnect_on_shutdown(self):
|
||||
"""If shutdown is requested, don't attempt reconnection."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
run_count = 0
|
||||
target_server = None
|
||||
|
||||
original_run_stdio = MCPServerTask._run_stdio
|
||||
|
||||
async def patched_run_stdio(self_srv, config):
|
||||
nonlocal run_count, target_server
|
||||
run_count += 1
|
||||
if target_server is not self_srv:
|
||||
return await original_run_stdio(self_srv, config)
|
||||
self_srv.session = MagicMock()
|
||||
self_srv._tools = []
|
||||
self_srv._ready.set()
|
||||
raise ConnectionError("connection dropped")
|
||||
|
||||
async def _test():
|
||||
nonlocal target_server
|
||||
server = MCPServerTask("test_srv")
|
||||
target_server = server
|
||||
server._shutdown_event.set() # Shutdown already requested
|
||||
|
||||
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
await server.run({"command": "test"})
|
||||
|
||||
# Should not retry because shutdown was set
|
||||
assert run_count == 1
|
||||
|
||||
asyncio.run(_test())
|
||||
|
||||
def test_no_reconnect_on_initial_failure(self):
|
||||
"""First connection failure reports error immediately, no retry."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
run_count = 0
|
||||
target_server = None
|
||||
|
||||
original_run_stdio = MCPServerTask._run_stdio
|
||||
|
||||
async def patched_run_stdio(self_srv, config):
|
||||
nonlocal run_count, target_server
|
||||
run_count += 1
|
||||
if target_server is not self_srv:
|
||||
return await original_run_stdio(self_srv, config)
|
||||
raise ConnectionError("cannot connect")
|
||||
|
||||
async def _test():
|
||||
nonlocal target_server
|
||||
server = MCPServerTask("test_srv")
|
||||
target_server = server
|
||||
|
||||
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
await server.run({"command": "test"})
|
||||
|
||||
# Only one attempt, no retry on initial failure
|
||||
assert run_count == 1
|
||||
assert server._error is not None
|
||||
assert "cannot connect" in str(server._error)
|
||||
|
||||
asyncio.run(_test())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configurable timeouts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestConfigurableTimeouts:
|
||||
"""Tests for configurable per-server timeouts."""
|
||||
|
||||
def test_default_timeout(self):
|
||||
"""Server with no timeout config gets _DEFAULT_TOOL_TIMEOUT."""
|
||||
from tools.mcp_tool import MCPServerTask, _DEFAULT_TOOL_TIMEOUT
|
||||
|
||||
server = MCPServerTask("test_srv")
|
||||
assert server.tool_timeout == _DEFAULT_TOOL_TIMEOUT
|
||||
assert server.tool_timeout == 120
|
||||
|
||||
def test_custom_timeout(self):
|
||||
"""Server with timeout=180 in config gets 180."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
target_server = None
|
||||
|
||||
original_run_stdio = MCPServerTask._run_stdio
|
||||
|
||||
async def patched_run_stdio(self_srv, config):
|
||||
if target_server is not self_srv:
|
||||
return await original_run_stdio(self_srv, config)
|
||||
self_srv.session = MagicMock()
|
||||
self_srv._tools = []
|
||||
self_srv._ready.set()
|
||||
await self_srv._shutdown_event.wait()
|
||||
|
||||
async def _test():
|
||||
nonlocal target_server
|
||||
server = MCPServerTask("test_srv")
|
||||
target_server = server
|
||||
|
||||
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio):
|
||||
task = asyncio.ensure_future(
|
||||
server.run({"command": "test", "timeout": 180})
|
||||
)
|
||||
await server._ready.wait()
|
||||
assert server.tool_timeout == 180
|
||||
server._shutdown_event.set()
|
||||
await task
|
||||
|
||||
asyncio.run(_test())
|
||||
|
||||
def test_timeout_passed_to_handler(self):
|
||||
"""The tool handler uses the server's configured timeout."""
|
||||
from tools.mcp_tool import _make_tool_handler, _servers, MCPServerTask
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.call_tool = AsyncMock(
|
||||
return_value=_make_call_result("ok", is_error=False)
|
||||
)
|
||||
server = _make_mock_server("test_srv", session=mock_session)
|
||||
server.tool_timeout = 180
|
||||
_servers["test_srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_tool_handler("test_srv", "my_tool", 180)
|
||||
with patch("tools.mcp_tool._run_on_mcp_loop") as mock_run:
|
||||
mock_run.return_value = json.dumps({"result": "ok"})
|
||||
handler({})
|
||||
# Verify timeout=180 was passed
|
||||
call_kwargs = mock_run.call_args
|
||||
assert call_kwargs.kwargs.get("timeout") == 180 or \
|
||||
(len(call_kwargs.args) > 1 and call_kwargs.args[1] == 180) or \
|
||||
call_kwargs[1].get("timeout") == 180
|
||||
finally:
|
||||
_servers.pop("test_srv", None)
|
||||
|
||||
Reference in New Issue
Block a user