feat(mcp): add HTTP transport, reconnection, security hardening

Upgrades the MCP client implementation from PR #291 with:

- HTTP/Streamable HTTP transport: support 'url' key in config for remote
  MCP servers (Notion, Slack, Sentry, Supabase, etc.)
- Automatic reconnection with exponential backoff (1s-60s, 5 retries)
  when a server connection drops unexpectedly
- Environment variable filtering: only pass safe vars (PATH, HOME, etc.)
  plus user-specified env to stdio subprocesses (prevents secret leaks)
- Credential stripping: sanitize error messages before returning to the
  LLM (strips GitHub PATs, OpenAI keys, Bearer tokens, etc.)
- Configurable per-server timeouts: 'timeout' and 'connect_timeout' keys
- Fix shutdown race condition in servers_snapshot variable scoping

Test coverage: 50 tests (up from 30), including new tests for env
filtering, credential sanitization, HTTP config detection, reconnection
logic, and configurable timeouts.

All 1162 tests pass (1162 passed, 3 skipped, 0 failed).
This commit is contained in:
teknium1
2026-03-02 18:40:03 -08:00
parent 468b7fdbad
commit 64ff8f065b
2 changed files with 611 additions and 65 deletions

View File

@@ -5,6 +5,7 @@ All tests use mocks -- no real MCP servers or subprocesses are started.
import asyncio
import json
import os
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
@@ -189,7 +190,7 @@ class TestToolHandler:
_servers["test_srv"] = server
try:
handler = _make_tool_handler("test_srv", "greet")
handler = _make_tool_handler("test_srv", "greet", 120)
with self._patch_mcp_loop():
result = json.loads(handler({"name": "world"}))
assert result["result"] == "hello world"
@@ -208,7 +209,7 @@ class TestToolHandler:
_servers["test_srv"] = server
try:
handler = _make_tool_handler("test_srv", "fail_tool")
handler = _make_tool_handler("test_srv", "fail_tool", 120)
with self._patch_mcp_loop():
result = json.loads(handler({}))
assert "error" in result
@@ -220,7 +221,7 @@ class TestToolHandler:
from tools.mcp_tool import _make_tool_handler, _servers
_servers.pop("ghost", None)
handler = _make_tool_handler("ghost", "any_tool")
handler = _make_tool_handler("ghost", "any_tool", 120)
result = json.loads(handler({}))
assert "error" in result
assert "not connected" in result["error"]
@@ -234,7 +235,7 @@ class TestToolHandler:
_servers["test_srv"] = server
try:
handler = _make_tool_handler("test_srv", "broken_tool")
handler = _make_tool_handler("test_srv", "broken_tool", 120)
with self._patch_mcp_loop():
result = json.loads(handler({}))
assert "error" in result
@@ -400,8 +401,8 @@ class TestMCPServerTask:
asyncio.run(_test())
def test_empty_env_passed_as_none(self):
"""Empty env dict is passed as None to StdioServerParameters."""
def test_empty_env_gets_safe_defaults(self):
"""Empty env dict gets safe default env vars (PATH, HOME, etc.)."""
from tools.mcp_tool import MCPServerTask
mock_session = MagicMock()
@@ -414,13 +415,18 @@ class TestMCPServerTask:
async def _test():
with patch("tools.mcp_tool.StdioServerParameters") as mock_params, \
p_stdio, p_cs:
p_stdio, p_cs, \
patch.dict("os.environ", {"PATH": "/usr/bin", "HOME": "/home/test"}, clear=False):
server = MCPServerTask("srv")
await server.start({"command": "node", "env": {}})
# Empty dict -> None
# Empty dict -> safe env vars (not None)
call_kwargs = mock_params.call_args
assert call_kwargs.kwargs.get("env") is None
env_arg = call_kwargs.kwargs.get("env")
assert env_arg is not None
assert isinstance(env_arg, dict)
assert "PATH" in env_arg
assert "HOME" in env_arg
await server.shutdown()
@@ -698,3 +704,353 @@ class TestShutdown:
assert len(_servers) == 0
# Parallel: ~1s, not ~3s. Allow some margin.
assert elapsed < 2.5, f"Shutdown took {elapsed:.1f}s, expected ~1s (parallel)"
# ---------------------------------------------------------------------------
# _build_safe_env
# ---------------------------------------------------------------------------
class TestBuildSafeEnv:
"""Tests for _build_safe_env() environment filtering."""
def test_only_safe_vars_passed(self):
"""Only safe baseline vars and XDG_* from os.environ are included."""
from tools.mcp_tool import _build_safe_env
fake_env = {
"PATH": "/usr/bin",
"HOME": "/home/test",
"USER": "test",
"LANG": "en_US.UTF-8",
"LC_ALL": "C",
"TERM": "xterm",
"SHELL": "/bin/bash",
"TMPDIR": "/tmp",
"XDG_DATA_HOME": "/home/test/.local/share",
"SECRET_KEY": "should_not_appear",
"AWS_ACCESS_KEY_ID": "AKIAIOSFODNN7EXAMPLE",
}
with patch.dict("os.environ", fake_env, clear=True):
result = _build_safe_env(None)
# Safe vars present
assert result["PATH"] == "/usr/bin"
assert result["HOME"] == "/home/test"
assert result["USER"] == "test"
assert result["LANG"] == "en_US.UTF-8"
assert result["XDG_DATA_HOME"] == "/home/test/.local/share"
# Unsafe vars excluded
assert "SECRET_KEY" not in result
assert "AWS_ACCESS_KEY_ID" not in result
def test_user_env_merged(self):
"""User-specified env vars are merged into the safe env."""
from tools.mcp_tool import _build_safe_env
with patch.dict("os.environ", {"PATH": "/usr/bin"}, clear=True):
result = _build_safe_env({"MY_CUSTOM_VAR": "hello"})
assert result["PATH"] == "/usr/bin"
assert result["MY_CUSTOM_VAR"] == "hello"
def test_user_env_overrides_safe(self):
"""User env can override safe defaults."""
from tools.mcp_tool import _build_safe_env
with patch.dict("os.environ", {"PATH": "/usr/bin"}, clear=True):
result = _build_safe_env({"PATH": "/custom/bin"})
assert result["PATH"] == "/custom/bin"
def test_none_user_env(self):
"""None user_env still returns safe vars from os.environ."""
from tools.mcp_tool import _build_safe_env
with patch.dict("os.environ", {"PATH": "/usr/bin", "HOME": "/root"}, clear=True):
result = _build_safe_env(None)
assert isinstance(result, dict)
assert result["PATH"] == "/usr/bin"
assert result["HOME"] == "/root"
def test_secret_vars_excluded(self):
"""Sensitive env vars from os.environ are NOT passed through."""
from tools.mcp_tool import _build_safe_env
fake_env = {
"PATH": "/usr/bin",
"AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
"GITHUB_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"OPENAI_API_KEY": "sk-proj-abc123",
"DATABASE_URL": "postgres://user:pass@localhost/db",
"API_SECRET": "supersecret",
}
with patch.dict("os.environ", fake_env, clear=True):
result = _build_safe_env(None)
assert "PATH" in result
assert "AWS_SECRET_ACCESS_KEY" not in result
assert "GITHUB_TOKEN" not in result
assert "OPENAI_API_KEY" not in result
assert "DATABASE_URL" not in result
assert "API_SECRET" not in result
# ---------------------------------------------------------------------------
# _sanitize_error
# ---------------------------------------------------------------------------
class TestSanitizeError:
"""Tests for _sanitize_error() credential stripping."""
def test_strips_github_pat(self):
from tools.mcp_tool import _sanitize_error
result = _sanitize_error("Error with ghp_abc123def456")
assert result == "Error with [REDACTED]"
def test_strips_openai_key(self):
from tools.mcp_tool import _sanitize_error
result = _sanitize_error("key sk-projABC123xyz")
assert result == "key [REDACTED]"
def test_strips_bearer_token(self):
from tools.mcp_tool import _sanitize_error
result = _sanitize_error("Authorization: Bearer eyJabc123def")
assert result == "Authorization: [REDACTED]"
def test_strips_token_param(self):
from tools.mcp_tool import _sanitize_error
result = _sanitize_error("url?token=secret123")
assert result == "url?[REDACTED]"
def test_no_credentials_unchanged(self):
from tools.mcp_tool import _sanitize_error
result = _sanitize_error("normal error message")
assert result == "normal error message"
def test_multiple_credentials(self):
from tools.mcp_tool import _sanitize_error
result = _sanitize_error("ghp_abc123 and sk-projXyz789 and token=foo")
assert "ghp_" not in result
assert "sk-" not in result
assert "token=" not in result
assert result.count("[REDACTED]") == 3
# ---------------------------------------------------------------------------
# HTTP config
# ---------------------------------------------------------------------------
class TestHTTPConfig:
"""Tests for HTTP transport detection and handling."""
def test_is_http_with_url(self):
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("remote")
server._config = {"url": "https://example.com/mcp"}
assert server._is_http() is True
def test_is_stdio_with_command(self):
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("local")
server._config = {"command": "npx", "args": []}
assert server._is_http() is False
def test_http_unavailable_raises(self):
from tools.mcp_tool import MCPServerTask
server = MCPServerTask("remote")
config = {"url": "https://example.com/mcp"}
async def _test():
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", False):
with pytest.raises(ImportError, match="HTTP transport"):
await server._run_http(config)
asyncio.run(_test())
# ---------------------------------------------------------------------------
# Reconnection logic
# ---------------------------------------------------------------------------
class TestReconnection:
"""Tests for automatic reconnection behavior in MCPServerTask.run()."""
def test_reconnect_on_disconnect(self):
"""After initial success, a connection drop triggers reconnection."""
from tools.mcp_tool import MCPServerTask
run_count = 0
target_server = None
original_run_stdio = MCPServerTask._run_stdio
async def patched_run_stdio(self_srv, config):
nonlocal run_count, target_server
run_count += 1
if target_server is not self_srv:
return await original_run_stdio(self_srv, config)
if run_count == 1:
# First connection succeeds, then simulate disconnect
self_srv.session = MagicMock()
self_srv._tools = []
self_srv._ready.set()
raise ConnectionError("connection dropped")
else:
# Reconnection succeeds; signal shutdown so run() exits
self_srv.session = MagicMock()
self_srv._shutdown_event.set()
await self_srv._shutdown_event.wait()
async def _test():
nonlocal target_server
server = MCPServerTask("test_srv")
target_server = server
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \
patch("asyncio.sleep", new_callable=AsyncMock):
await server.run({"command": "test"})
assert run_count >= 2 # At least one reconnection attempt
asyncio.run(_test())
def test_no_reconnect_on_shutdown(self):
"""If shutdown is requested, don't attempt reconnection."""
from tools.mcp_tool import MCPServerTask
run_count = 0
target_server = None
original_run_stdio = MCPServerTask._run_stdio
async def patched_run_stdio(self_srv, config):
nonlocal run_count, target_server
run_count += 1
if target_server is not self_srv:
return await original_run_stdio(self_srv, config)
self_srv.session = MagicMock()
self_srv._tools = []
self_srv._ready.set()
raise ConnectionError("connection dropped")
async def _test():
nonlocal target_server
server = MCPServerTask("test_srv")
target_server = server
server._shutdown_event.set() # Shutdown already requested
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \
patch("asyncio.sleep", new_callable=AsyncMock):
await server.run({"command": "test"})
# Should not retry because shutdown was set
assert run_count == 1
asyncio.run(_test())
def test_no_reconnect_on_initial_failure(self):
"""First connection failure reports error immediately, no retry."""
from tools.mcp_tool import MCPServerTask
run_count = 0
target_server = None
original_run_stdio = MCPServerTask._run_stdio
async def patched_run_stdio(self_srv, config):
nonlocal run_count, target_server
run_count += 1
if target_server is not self_srv:
return await original_run_stdio(self_srv, config)
raise ConnectionError("cannot connect")
async def _test():
nonlocal target_server
server = MCPServerTask("test_srv")
target_server = server
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \
patch("asyncio.sleep", new_callable=AsyncMock):
await server.run({"command": "test"})
# Only one attempt, no retry on initial failure
assert run_count == 1
assert server._error is not None
assert "cannot connect" in str(server._error)
asyncio.run(_test())
# ---------------------------------------------------------------------------
# Configurable timeouts
# ---------------------------------------------------------------------------
class TestConfigurableTimeouts:
"""Tests for configurable per-server timeouts."""
def test_default_timeout(self):
"""Server with no timeout config gets _DEFAULT_TOOL_TIMEOUT."""
from tools.mcp_tool import MCPServerTask, _DEFAULT_TOOL_TIMEOUT
server = MCPServerTask("test_srv")
assert server.tool_timeout == _DEFAULT_TOOL_TIMEOUT
assert server.tool_timeout == 120
def test_custom_timeout(self):
"""Server with timeout=180 in config gets 180."""
from tools.mcp_tool import MCPServerTask
target_server = None
original_run_stdio = MCPServerTask._run_stdio
async def patched_run_stdio(self_srv, config):
if target_server is not self_srv:
return await original_run_stdio(self_srv, config)
self_srv.session = MagicMock()
self_srv._tools = []
self_srv._ready.set()
await self_srv._shutdown_event.wait()
async def _test():
nonlocal target_server
server = MCPServerTask("test_srv")
target_server = server
with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio):
task = asyncio.ensure_future(
server.run({"command": "test", "timeout": 180})
)
await server._ready.wait()
assert server.tool_timeout == 180
server._shutdown_event.set()
await task
asyncio.run(_test())
def test_timeout_passed_to_handler(self):
"""The tool handler uses the server's configured timeout."""
from tools.mcp_tool import _make_tool_handler, _servers, MCPServerTask
mock_session = MagicMock()
mock_session.call_tool = AsyncMock(
return_value=_make_call_result("ok", is_error=False)
)
server = _make_mock_server("test_srv", session=mock_session)
server.tool_timeout = 180
_servers["test_srv"] = server
try:
handler = _make_tool_handler("test_srv", "my_tool", 180)
with patch("tools.mcp_tool._run_on_mcp_loop") as mock_run:
mock_run.return_value = json.dumps({"result": "ok"})
handler({})
# Verify timeout=180 was passed
call_kwargs = mock_run.call_args
assert call_kwargs.kwargs.get("timeout") == 180 or \
(len(call_kwargs.args) > 1 and call_kwargs.args[1] == 180) or \
call_kwargs[1].get("timeout") == 180
finally:
_servers.pop("test_srv", None)