fix: resolve orphan subprocess leak on MCP server shutdown
Refactor MCP connections from AsyncExitStack to task-per-server architecture. Each server now runs as a long-lived asyncio Task with `async with stdio_client(...)`, ensuring anyio cancel-scope cleanup happens in the same Task that opened the connection.
This commit is contained in:
@@ -36,6 +36,15 @@ def _make_call_result(text="file contents here", is_error=False):
|
||||
return SimpleNamespace(content=[block], isError=is_error)
|
||||
|
||||
|
||||
def _make_mock_server(name, session=None, tools=None):
|
||||
"""Create an MCPServerTask with mock attributes for testing."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
server = MCPServerTask(name)
|
||||
server.session = session
|
||||
server._tools = tools or []
|
||||
return server
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config loading
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -43,12 +52,10 @@ def _make_call_result(text="file contents here", is_error=False):
|
||||
class TestLoadMCPConfig:
|
||||
def test_no_config_returns_empty(self):
|
||||
"""No mcp_servers key in config -> empty dict."""
|
||||
with patch("tools.mcp_tool.load_config", create=True) as mock_lc:
|
||||
# Patch the actual import inside the function
|
||||
with patch("hermes_cli.config.load_config", return_value={"model": "test"}):
|
||||
from tools.mcp_tool import _load_mcp_config
|
||||
result = _load_mcp_config()
|
||||
assert result == {}
|
||||
with patch("hermes_cli.config.load_config", return_value={"model": "test"}):
|
||||
from tools.mcp_tool import _load_mcp_config
|
||||
result = _load_mcp_config()
|
||||
assert result == {}
|
||||
|
||||
def test_valid_config_parsed(self):
|
||||
"""Valid mcp_servers config is returned as-is."""
|
||||
@@ -123,46 +130,37 @@ class TestSchemaConversion:
|
||||
|
||||
class TestCheckFunction:
|
||||
def test_disconnected_returns_false(self):
|
||||
from tools.mcp_tool import _make_check_fn, _connections
|
||||
from tools.mcp_tool import _make_check_fn, _servers
|
||||
|
||||
# Ensure no connection exists
|
||||
_connections.pop("test_server", None)
|
||||
_servers.pop("test_server", None)
|
||||
check = _make_check_fn("test_server")
|
||||
assert check() is False
|
||||
|
||||
def test_connected_returns_true(self):
|
||||
from tools.mcp_tool import _make_check_fn, _connections, MCPConnection
|
||||
from tools.mcp_tool import _make_check_fn, _servers
|
||||
|
||||
conn = MCPConnection(
|
||||
server_name="test_server",
|
||||
session=MagicMock(),
|
||||
stack=MagicMock(),
|
||||
)
|
||||
_connections["test_server"] = conn
|
||||
server = _make_mock_server("test_server", session=MagicMock())
|
||||
_servers["test_server"] = server
|
||||
try:
|
||||
check = _make_check_fn("test_server")
|
||||
assert check() is True
|
||||
finally:
|
||||
_connections.pop("test_server", None)
|
||||
_servers.pop("test_server", None)
|
||||
|
||||
def test_session_none_returns_false(self):
|
||||
from tools.mcp_tool import _make_check_fn, _connections, MCPConnection
|
||||
from tools.mcp_tool import _make_check_fn, _servers
|
||||
|
||||
conn = MCPConnection(
|
||||
server_name="test_server",
|
||||
session=None,
|
||||
stack=MagicMock(),
|
||||
)
|
||||
_connections["test_server"] = conn
|
||||
server = _make_mock_server("test_server", session=None)
|
||||
_servers["test_server"] = server
|
||||
try:
|
||||
check = _make_check_fn("test_server")
|
||||
assert check() is False
|
||||
finally:
|
||||
_connections.pop("test_server", None)
|
||||
_servers.pop("test_server", None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool handler (async)
|
||||
# Tool handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestToolHandler:
|
||||
@@ -171,20 +169,24 @@ class TestToolHandler:
|
||||
def _patch_mcp_loop(self, coro_side_effect=None):
|
||||
"""Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
|
||||
def fake_run(coro, timeout=30):
|
||||
return asyncio.get_event_loop().run_until_complete(coro)
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
if coro_side_effect:
|
||||
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=coro_side_effect)
|
||||
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run)
|
||||
|
||||
def test_successful_call(self):
|
||||
from tools.mcp_tool import _make_tool_handler, _connections, MCPConnection
|
||||
from tools.mcp_tool import _make_tool_handler, _servers
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.call_tool = AsyncMock(
|
||||
return_value=_make_call_result("hello world", is_error=False)
|
||||
)
|
||||
conn = MCPConnection("test_srv", session=mock_session, stack=MagicMock())
|
||||
_connections["test_srv"] = conn
|
||||
server = _make_mock_server("test_srv", session=mock_session)
|
||||
_servers["test_srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_tool_handler("test_srv", "greet")
|
||||
@@ -193,17 +195,17 @@ class TestToolHandler:
|
||||
assert result["result"] == "hello world"
|
||||
mock_session.call_tool.assert_called_once_with("greet", arguments={"name": "world"})
|
||||
finally:
|
||||
_connections.pop("test_srv", None)
|
||||
_servers.pop("test_srv", None)
|
||||
|
||||
def test_mcp_error_result(self):
|
||||
from tools.mcp_tool import _make_tool_handler, _connections, MCPConnection
|
||||
from tools.mcp_tool import _make_tool_handler, _servers
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.call_tool = AsyncMock(
|
||||
return_value=_make_call_result("something went wrong", is_error=True)
|
||||
)
|
||||
conn = MCPConnection("test_srv", session=mock_session, stack=MagicMock())
|
||||
_connections["test_srv"] = conn
|
||||
server = _make_mock_server("test_srv", session=mock_session)
|
||||
_servers["test_srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_tool_handler("test_srv", "fail_tool")
|
||||
@@ -212,25 +214,24 @@ class TestToolHandler:
|
||||
assert "error" in result
|
||||
assert "something went wrong" in result["error"]
|
||||
finally:
|
||||
_connections.pop("test_srv", None)
|
||||
_servers.pop("test_srv", None)
|
||||
|
||||
def test_disconnected_server(self):
|
||||
from tools.mcp_tool import _make_tool_handler, _connections
|
||||
from tools.mcp_tool import _make_tool_handler, _servers
|
||||
|
||||
_connections.pop("ghost", None)
|
||||
_servers.pop("ghost", None)
|
||||
handler = _make_tool_handler("ghost", "any_tool")
|
||||
# Disconnected check happens before _run_on_mcp_loop, no patch needed
|
||||
result = json.loads(handler({}))
|
||||
assert "error" in result
|
||||
assert "not connected" in result["error"]
|
||||
|
||||
def test_exception_during_call(self):
|
||||
from tools.mcp_tool import _make_tool_handler, _connections, MCPConnection
|
||||
from tools.mcp_tool import _make_tool_handler, _servers
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.call_tool = AsyncMock(side_effect=RuntimeError("connection lost"))
|
||||
conn = MCPConnection("test_srv", session=mock_session, stack=MagicMock())
|
||||
_connections["test_srv"] = conn
|
||||
server = _make_mock_server("test_srv", session=mock_session)
|
||||
_servers["test_srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_tool_handler("test_srv", "broken_tool")
|
||||
@@ -239,7 +240,7 @@ class TestToolHandler:
|
||||
assert "error" in result
|
||||
assert "connection lost" in result["error"]
|
||||
finally:
|
||||
_connections.pop("test_srv", None)
|
||||
_servers.pop("test_srv", None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -249,23 +250,21 @@ class TestToolHandler:
|
||||
class TestDiscoverAndRegister:
|
||||
def test_tools_registered_in_registry(self):
|
||||
"""_discover_and_register_server registers tools with correct names."""
|
||||
from tools.registry import ToolRegistry, registry as real_registry
|
||||
from tools.mcp_tool import _discover_and_register_server, _connections, MCPConnection
|
||||
from tools.registry import ToolRegistry
|
||||
from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
|
||||
|
||||
mock_registry = ToolRegistry()
|
||||
mock_tools = [
|
||||
_make_mcp_tool("read_file", "Read a file"),
|
||||
_make_mcp_tool("write_file", "Write a file"),
|
||||
]
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.list_tools = AsyncMock(
|
||||
return_value=SimpleNamespace(tools=mock_tools)
|
||||
)
|
||||
|
||||
async def fake_connect(name, config):
|
||||
return MCPConnection(name, session=mock_session, stack=MagicMock())
|
||||
server = MCPServerTask(name)
|
||||
server.session = mock_session
|
||||
server._tools = mock_tools
|
||||
return server
|
||||
|
||||
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
|
||||
patch("tools.registry.registry", mock_registry):
|
||||
@@ -278,22 +277,20 @@ class TestDiscoverAndRegister:
|
||||
assert "mcp_fs_read_file" in mock_registry.get_all_tool_names()
|
||||
assert "mcp_fs_write_file" in mock_registry.get_all_tool_names()
|
||||
|
||||
_connections.pop("fs", None)
|
||||
_servers.pop("fs", None)
|
||||
|
||||
def test_toolset_created(self):
|
||||
"""A custom toolset is created for the MCP server."""
|
||||
from tools.mcp_tool import _discover_and_register_server, _connections, MCPConnection
|
||||
from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
|
||||
|
||||
mock_tools = [_make_mcp_tool("ping", "Ping")]
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.list_tools = AsyncMock(
|
||||
return_value=SimpleNamespace(tools=mock_tools)
|
||||
)
|
||||
|
||||
async def fake_connect(name, config):
|
||||
return MCPConnection(name, session=mock_session, stack=MagicMock())
|
||||
server = MCPServerTask(name)
|
||||
server.session = mock_session
|
||||
server._tools = mock_tools
|
||||
return server
|
||||
|
||||
mock_create = MagicMock()
|
||||
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
|
||||
@@ -306,24 +303,22 @@ class TestDiscoverAndRegister:
|
||||
call_kwargs = mock_create.call_args
|
||||
assert call_kwargs[1]["name"] == "mcp-myserver" or call_kwargs[0][0] == "mcp-myserver"
|
||||
|
||||
_connections.pop("myserver", None)
|
||||
_servers.pop("myserver", None)
|
||||
|
||||
def test_schema_format_correct(self):
|
||||
"""Registered schemas have the correct format."""
|
||||
from tools.registry import ToolRegistry, registry as real_registry
|
||||
from tools.mcp_tool import _discover_and_register_server, _connections, MCPConnection
|
||||
from tools.registry import ToolRegistry
|
||||
from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
|
||||
|
||||
mock_registry = ToolRegistry()
|
||||
mock_tools = [_make_mcp_tool("do_thing", "Do something")]
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.list_tools = AsyncMock(
|
||||
return_value=SimpleNamespace(tools=mock_tools)
|
||||
)
|
||||
|
||||
async def fake_connect(name, config):
|
||||
return MCPConnection(name, session=mock_session, stack=MagicMock())
|
||||
server = MCPServerTask(name)
|
||||
server.session = mock_session
|
||||
server._tools = mock_tools
|
||||
return server
|
||||
|
||||
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
|
||||
patch("tools.registry.registry", mock_registry):
|
||||
@@ -338,91 +333,125 @@ class TestDiscoverAndRegister:
|
||||
assert entry.is_async is False
|
||||
assert entry.toolset == "mcp-srv"
|
||||
|
||||
_connections.pop("srv", None)
|
||||
_servers.pop("srv", None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _connect_server (SDK interaction)
|
||||
# MCPServerTask (run / start / shutdown)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestConnectServer:
|
||||
def test_calls_sdk_with_correct_params(self):
|
||||
"""_connect_server creates StdioServerParameters and calls stdio_client."""
|
||||
from tools.mcp_tool import _connect_server, MCPConnection
|
||||
class TestMCPServerTask:
|
||||
"""Test the MCPServerTask lifecycle with mocked MCP SDK."""
|
||||
|
||||
def _mock_stdio_and_session(self, session):
|
||||
"""Return patches for stdio_client and ClientSession as async CMs."""
|
||||
mock_read, mock_write = MagicMock(), MagicMock()
|
||||
|
||||
mock_stdio_cm = MagicMock()
|
||||
mock_stdio_cm.__aenter__ = AsyncMock(return_value=(mock_read, mock_write))
|
||||
mock_stdio_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_cs_cm = MagicMock()
|
||||
mock_cs_cm.__aenter__ = AsyncMock(return_value=session)
|
||||
mock_cs_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
return (
|
||||
patch("tools.mcp_tool.stdio_client", return_value=mock_stdio_cm),
|
||||
patch("tools.mcp_tool.ClientSession", return_value=mock_cs_cm),
|
||||
mock_read, mock_write,
|
||||
)
|
||||
|
||||
def test_start_connects_and_discovers_tools(self):
|
||||
"""start() creates a Task that connects, discovers tools, and waits."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
mock_tools = [_make_mcp_tool("echo")]
|
||||
mock_session = MagicMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
|
||||
mock_read = MagicMock()
|
||||
mock_write = MagicMock()
|
||||
|
||||
with patch("tools.mcp_tool.StdioServerParameters") as mock_params, \
|
||||
patch("tools.mcp_tool.stdio_client") as mock_stdio, \
|
||||
patch("tools.mcp_tool.ClientSession") as mock_cs, \
|
||||
patch("tools.mcp_tool.AsyncExitStack") as mock_stack_cls:
|
||||
|
||||
mock_stack = MagicMock()
|
||||
mock_stack.enter_async_context = AsyncMock(
|
||||
side_effect=[(mock_read, mock_write), mock_session]
|
||||
)
|
||||
mock_stack_cls.return_value = mock_stack
|
||||
|
||||
conn = asyncio.run(_connect_server("test_srv", {
|
||||
"command": "npx",
|
||||
"args": ["-y", "some-server"],
|
||||
"env": {"MY_KEY": "secret"},
|
||||
}))
|
||||
|
||||
# StdioServerParameters called with correct values
|
||||
mock_params.assert_called_once_with(
|
||||
command="npx",
|
||||
args=["-y", "some-server"],
|
||||
env={"MY_KEY": "secret"},
|
||||
mock_session.list_tools = AsyncMock(
|
||||
return_value=SimpleNamespace(tools=mock_tools)
|
||||
)
|
||||
# ClientSession created with the streams
|
||||
mock_cs.assert_called_once_with(mock_read, mock_write)
|
||||
# initialize() was called
|
||||
mock_session.initialize.assert_called_once()
|
||||
# Returned connection is valid
|
||||
assert conn.server_name == "test_srv"
|
||||
assert conn.session is mock_session
|
||||
|
||||
p_stdio, p_cs, _, _ = self._mock_stdio_and_session(mock_session)
|
||||
|
||||
async def _test():
|
||||
with patch("tools.mcp_tool.StdioServerParameters"), p_stdio, p_cs:
|
||||
server = MCPServerTask("test_srv")
|
||||
await server.start({"command": "npx", "args": ["-y", "test"]})
|
||||
|
||||
assert server.session is mock_session
|
||||
assert len(server._tools) == 1
|
||||
assert server._tools[0].name == "echo"
|
||||
mock_session.initialize.assert_called_once()
|
||||
|
||||
await server.shutdown()
|
||||
assert server.session is None
|
||||
|
||||
asyncio.run(_test())
|
||||
|
||||
def test_no_command_raises(self):
|
||||
"""Missing 'command' in config raises ValueError."""
|
||||
from tools.mcp_tool import _connect_server
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
with pytest.raises(ValueError, match="no 'command'"):
|
||||
asyncio.run(_connect_server("bad", {"args": []}))
|
||||
async def _test():
|
||||
server = MCPServerTask("bad")
|
||||
with pytest.raises(ValueError, match="no 'command'"):
|
||||
await server.start({"args": []})
|
||||
|
||||
asyncio.run(_test())
|
||||
|
||||
def test_empty_env_passed_as_none(self):
|
||||
"""Empty env dict is passed as None to StdioServerParameters."""
|
||||
from tools.mcp_tool import _connect_server
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.list_tools = AsyncMock(
|
||||
return_value=SimpleNamespace(tools=[])
|
||||
)
|
||||
|
||||
with patch("tools.mcp_tool.StdioServerParameters") as mock_params, \
|
||||
patch("tools.mcp_tool.stdio_client"), \
|
||||
patch("tools.mcp_tool.ClientSession", return_value=mock_session), \
|
||||
patch("tools.mcp_tool.AsyncExitStack") as mock_stack_cls:
|
||||
p_stdio, p_cs, _, _ = self._mock_stdio_and_session(mock_session)
|
||||
|
||||
mock_stack = MagicMock()
|
||||
mock_stack.enter_async_context = AsyncMock(
|
||||
side_effect=[
|
||||
(MagicMock(), MagicMock()),
|
||||
mock_session,
|
||||
]
|
||||
)
|
||||
mock_stack_cls.return_value = mock_stack
|
||||
async def _test():
|
||||
with patch("tools.mcp_tool.StdioServerParameters") as mock_params, \
|
||||
p_stdio, p_cs:
|
||||
server = MCPServerTask("srv")
|
||||
await server.start({"command": "node", "env": {}})
|
||||
|
||||
asyncio.run(_connect_server("srv", {
|
||||
"command": "node",
|
||||
"env": {},
|
||||
}))
|
||||
# Empty dict -> None
|
||||
call_kwargs = mock_params.call_args
|
||||
assert call_kwargs.kwargs.get("env") is None
|
||||
|
||||
# Empty dict -> None
|
||||
assert mock_params.call_args[1]["env"] is None or \
|
||||
mock_params.call_args.kwargs.get("env") is None
|
||||
await server.shutdown()
|
||||
|
||||
asyncio.run(_test())
|
||||
|
||||
def test_shutdown_signals_task_exit(self):
|
||||
"""shutdown() signals the event and waits for task completion."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.list_tools = AsyncMock(
|
||||
return_value=SimpleNamespace(tools=[])
|
||||
)
|
||||
|
||||
p_stdio, p_cs, _, _ = self._mock_stdio_and_session(mock_session)
|
||||
|
||||
async def _test():
|
||||
with patch("tools.mcp_tool.StdioServerParameters"), p_stdio, p_cs:
|
||||
server = MCPServerTask("srv")
|
||||
await server.start({"command": "npx"})
|
||||
|
||||
assert server.session is not None
|
||||
assert not server._task.done()
|
||||
|
||||
await server.shutdown()
|
||||
|
||||
assert server.session is None
|
||||
assert server._task.done()
|
||||
|
||||
asyncio.run(_test())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -432,17 +461,16 @@ class TestConnectServer:
|
||||
class TestToolsetInjection:
|
||||
def test_mcp_tools_added_to_platform_toolsets(self):
|
||||
"""Discovered MCP tools are injected into hermes-cli and platform toolsets."""
|
||||
from tools.mcp_tool import _connections, MCPConnection
|
||||
from tools.mcp_tool import _servers, MCPServerTask
|
||||
|
||||
mock_tools = [_make_mcp_tool("list_files", "List files")]
|
||||
mock_session = MagicMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.list_tools = AsyncMock(
|
||||
return_value=SimpleNamespace(tools=mock_tools)
|
||||
)
|
||||
|
||||
async def fake_connect(name, config):
|
||||
return MCPConnection(name, session=mock_session, stack=MagicMock())
|
||||
server = MCPServerTask(name)
|
||||
server.session = mock_session
|
||||
server._tools = mock_tools
|
||||
return server
|
||||
|
||||
fake_toolsets = {
|
||||
"hermes-cli": {"tools": ["terminal", "web_search"], "description": "CLI", "includes": []},
|
||||
@@ -455,7 +483,6 @@ class TestToolsetInjection:
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
|
||||
patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
|
||||
patch("tools.mcp_tool.TOOLSETS", fake_toolsets, create=True), \
|
||||
patch("toolsets.TOOLSETS", fake_toolsets):
|
||||
from tools.mcp_tool import discover_mcp_tools
|
||||
result = discover_mcp_tools()
|
||||
@@ -466,18 +493,14 @@ class TestToolsetInjection:
|
||||
# Original tools preserved
|
||||
assert "terminal" in fake_toolsets["hermes-cli"]["tools"]
|
||||
|
||||
_connections.pop("fs", None)
|
||||
_servers.pop("fs", None)
|
||||
|
||||
def test_server_connection_failure_skipped(self):
|
||||
"""If one server fails to connect, others still proceed."""
|
||||
from tools.mcp_tool import _connections, MCPConnection
|
||||
from tools.mcp_tool import _servers, MCPServerTask
|
||||
|
||||
mock_tools = [_make_mcp_tool("ping", "Ping")]
|
||||
mock_session = MagicMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.list_tools = AsyncMock(
|
||||
return_value=SimpleNamespace(tools=mock_tools)
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
@@ -486,7 +509,10 @@ class TestToolsetInjection:
|
||||
call_count += 1
|
||||
if name == "broken":
|
||||
raise ConnectionError("cannot reach server")
|
||||
return MCPConnection(name, session=mock_session, stack=MagicMock())
|
||||
server = MCPServerTask(name)
|
||||
server.session = mock_session
|
||||
server._tools = mock_tools
|
||||
return server
|
||||
|
||||
fake_config = {
|
||||
"broken": {"command": "bad"},
|
||||
@@ -508,7 +534,7 @@ class TestToolsetInjection:
|
||||
assert "mcp_broken_ping" not in result
|
||||
assert call_count == 2 # Both were attempted
|
||||
|
||||
_connections.pop("good", None)
|
||||
_servers.pop("good", None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -533,50 +559,46 @@ class TestGracefulFallback:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shutdown
|
||||
# Shutdown (public API)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestShutdown:
|
||||
def test_no_connections_safe(self):
|
||||
"""shutdown_mcp_servers with no connections does nothing."""
|
||||
from tools.mcp_tool import shutdown_mcp_servers, _connections
|
||||
def test_no_servers_safe(self):
|
||||
"""shutdown_mcp_servers with no servers does nothing."""
|
||||
from tools.mcp_tool import shutdown_mcp_servers, _servers
|
||||
|
||||
_connections.clear()
|
||||
_servers.clear()
|
||||
shutdown_mcp_servers() # Should not raise
|
||||
|
||||
def test_shutdown_clears_connections(self):
|
||||
"""shutdown_mcp_servers closes stacks and clears the dict."""
|
||||
def test_shutdown_clears_servers(self):
|
||||
"""shutdown_mcp_servers calls shutdown() on each server and clears dict."""
|
||||
import tools.mcp_tool as mcp_mod
|
||||
from tools.mcp_tool import shutdown_mcp_servers, _connections, MCPConnection
|
||||
from tools.mcp_tool import shutdown_mcp_servers, _servers
|
||||
|
||||
_connections.clear()
|
||||
mock_stack = MagicMock()
|
||||
mock_stack.aclose = AsyncMock()
|
||||
conn = MCPConnection("test", session=MagicMock(), stack=mock_stack)
|
||||
_connections["test"] = conn
|
||||
_servers.clear()
|
||||
mock_server = MagicMock()
|
||||
mock_server.shutdown = AsyncMock()
|
||||
_servers["test"] = mock_server
|
||||
|
||||
# Start a real background loop so shutdown can schedule on it
|
||||
mcp_mod._ensure_mcp_loop()
|
||||
try:
|
||||
shutdown_mcp_servers()
|
||||
finally:
|
||||
# _stop_mcp_loop is called by shutdown, but ensure cleanup
|
||||
mcp_mod._mcp_loop = None
|
||||
mcp_mod._mcp_thread = None
|
||||
|
||||
assert len(_connections) == 0
|
||||
mock_stack.aclose.assert_called_once()
|
||||
assert len(_servers) == 0
|
||||
mock_server.shutdown.assert_called_once()
|
||||
|
||||
def test_shutdown_handles_errors(self):
|
||||
"""shutdown_mcp_servers handles errors during close gracefully."""
|
||||
import tools.mcp_tool as mcp_mod
|
||||
from tools.mcp_tool import shutdown_mcp_servers, _connections, MCPConnection
|
||||
from tools.mcp_tool import shutdown_mcp_servers, _servers
|
||||
|
||||
_connections.clear()
|
||||
mock_stack = MagicMock()
|
||||
mock_stack.aclose = AsyncMock(side_effect=RuntimeError("close failed"))
|
||||
conn = MCPConnection("broken", session=MagicMock(), stack=mock_stack)
|
||||
_connections["broken"] = conn
|
||||
_servers.clear()
|
||||
mock_server = MagicMock()
|
||||
mock_server.shutdown = AsyncMock(side_effect=RuntimeError("close failed"))
|
||||
_servers["broken"] = mock_server
|
||||
|
||||
mcp_mod._ensure_mcp_loop()
|
||||
try:
|
||||
@@ -585,4 +607,4 @@ class TestShutdown:
|
||||
mcp_mod._mcp_loop = None
|
||||
mcp_mod._mcp_thread = None
|
||||
|
||||
assert len(_connections) == 0
|
||||
assert len(_servers) == 0
|
||||
|
||||
@@ -25,8 +25,13 @@ Example config::
|
||||
|
||||
Architecture:
|
||||
A dedicated background event loop (_mcp_loop) runs in a daemon thread.
|
||||
All MCP connections live on this loop. Tool handlers schedule coroutines
|
||||
onto it via run_coroutine_threadsafe(), so they work from any thread.
|
||||
Each MCP server runs as a long-lived asyncio Task on this loop, keeping
|
||||
its ``async with stdio_client(...)`` context alive. Tool call coroutines
|
||||
are scheduled onto the loop via ``run_coroutine_threadsafe()``.
|
||||
|
||||
On shutdown, each server Task is signalled to exit its ``async with``
|
||||
block, ensuring the anyio cancel-scope cleanup happens in the *same*
|
||||
Task that opened the connection (required by anyio).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -45,31 +50,114 @@ _MCP_AVAILABLE = False
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
from contextlib import AsyncExitStack
|
||||
_MCP_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.debug("mcp package not installed -- MCP tool support disabled")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Connection tracking
|
||||
# Server task -- each MCP server lives in one long-lived asyncio Task
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MCPConnection:
|
||||
"""Holds a live MCP server connection and its async resource stack."""
|
||||
class MCPServerTask:
|
||||
"""Manages a single MCP server connection in a dedicated asyncio Task.
|
||||
|
||||
__slots__ = ("server_name", "session", "stack")
|
||||
The entire connection lifecycle (connect, discover, serve, disconnect)
|
||||
runs inside one asyncio Task so that anyio cancel-scopes created by
|
||||
``stdio_client`` are entered and exited in the same Task context.
|
||||
"""
|
||||
|
||||
def __init__(self, server_name: str, session: Any, stack: Any):
|
||||
self.server_name = server_name
|
||||
self.session: Optional[Any] = session
|
||||
self.stack: Optional[Any] = stack
|
||||
__slots__ = (
|
||||
"name", "session",
|
||||
"_task", "_ready", "_shutdown_event", "_tools", "_error",
|
||||
)
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.session: Optional[Any] = None
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._ready = asyncio.Event()
|
||||
self._shutdown_event = asyncio.Event()
|
||||
self._tools: list = []
|
||||
self._error: Optional[Exception] = None
|
||||
|
||||
async def run(self, config: dict):
|
||||
"""Long-lived coroutine: connect, discover tools, wait, disconnect."""
|
||||
command = config.get("command")
|
||||
args = config.get("args", [])
|
||||
env = config.get("env")
|
||||
|
||||
if not command:
|
||||
self._error = ValueError(
|
||||
f"MCP server '{self.name}' has no 'command' in config"
|
||||
)
|
||||
self._ready.set()
|
||||
return
|
||||
|
||||
server_params = StdioServerParameters(
|
||||
command=command,
|
||||
args=args,
|
||||
env=env if env else None,
|
||||
)
|
||||
|
||||
try:
|
||||
async with stdio_client(server_params) as (read_stream, write_stream):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
|
||||
tools_result = await session.list_tools()
|
||||
self._tools = (
|
||||
tools_result.tools
|
||||
if hasattr(tools_result, "tools")
|
||||
else []
|
||||
)
|
||||
|
||||
# Signal that connection is ready
|
||||
self._ready.set()
|
||||
|
||||
# Block until shutdown is requested -- this keeps the
|
||||
# async-with contexts alive on THIS Task.
|
||||
await self._shutdown_event.wait()
|
||||
except Exception as exc:
|
||||
self._error = exc
|
||||
self._ready.set()
|
||||
finally:
|
||||
self.session = None
|
||||
|
||||
async def start(self, config: dict):
|
||||
"""Create the background Task and wait until ready (or failed)."""
|
||||
self._task = asyncio.ensure_future(self.run(config))
|
||||
await self._ready.wait()
|
||||
if self._error:
|
||||
raise self._error
|
||||
|
||||
async def shutdown(self):
|
||||
"""Signal the Task to exit and wait for clean resource teardown."""
|
||||
self._shutdown_event.set()
|
||||
if self._task and not self._task.done():
|
||||
try:
|
||||
await asyncio.wait_for(self._task, timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"MCP server '%s' shutdown timed out, cancelling task",
|
||||
self.name,
|
||||
)
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.session = None
|
||||
|
||||
|
||||
_connections: Dict[str, MCPConnection] = {}
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_servers: Dict[str, MCPServerTask] = {}
|
||||
|
||||
# Dedicated event loop running in a background daemon thread.
|
||||
# All MCP async operations (connect, call_tool, shutdown) run here.
|
||||
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
_mcp_thread: Optional[threading.Thread] = None
|
||||
|
||||
@@ -118,42 +206,22 @@ def _load_mcp_config() -> Dict[str, dict]:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Server connection
|
||||
# Server connection helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _connect_server(name: str, config: dict) -> MCPConnection:
|
||||
"""Start an MCP server subprocess and initialize a ClientSession.
|
||||
async def _connect_server(name: str, config: dict) -> MCPServerTask:
|
||||
"""Create an MCPServerTask, start it, and return when ready.
|
||||
|
||||
Args:
|
||||
name: Logical server name (e.g. "filesystem").
|
||||
config: Dict with ``command``, ``args``, and optional ``env``.
|
||||
|
||||
Returns:
|
||||
An ``MCPConnection`` with a live session.
|
||||
The server Task keeps the subprocess alive in the background.
|
||||
Call ``server.shutdown()`` (on the same event loop) to tear it down.
|
||||
|
||||
Raises:
|
||||
Exception on connection or initialization failure.
|
||||
ValueError: if ``command`` is missing from *config*.
|
||||
Exception: on connection or initialization failure.
|
||||
"""
|
||||
command = config.get("command")
|
||||
args = config.get("args", [])
|
||||
env = config.get("env")
|
||||
|
||||
if not command:
|
||||
raise ValueError(f"MCP server '{name}' has no 'command' in config")
|
||||
|
||||
server_params = StdioServerParameters(
|
||||
command=command,
|
||||
args=args,
|
||||
env=env if env else None,
|
||||
)
|
||||
|
||||
stack = AsyncExitStack()
|
||||
stdio_transport = await stack.enter_async_context(stdio_client(server_params))
|
||||
read_stream, write_stream = stdio_transport
|
||||
session = await stack.enter_async_context(ClientSession(read_stream, write_stream))
|
||||
await session.initialize()
|
||||
|
||||
return MCPConnection(server_name=name, session=session, stack=stack)
|
||||
server = MCPServerTask(name)
|
||||
await server.start(config)
|
||||
return server
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -168,14 +236,14 @@ def _make_tool_handler(server_name: str, tool_name: str):
|
||||
"""
|
||||
|
||||
def _handler(args: dict, **kwargs) -> str:
|
||||
conn = _connections.get(server_name)
|
||||
if not conn or not conn.session:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
|
||||
async def _call():
|
||||
result = await conn.session.call_tool(tool_name, arguments=args)
|
||||
result = await server.session.call_tool(tool_name, arguments=args)
|
||||
# MCP CallToolResult has .content (list of content blocks) and .isError
|
||||
if result.isError:
|
||||
error_text = ""
|
||||
@@ -204,8 +272,8 @@ def _make_check_fn(server_name: str):
|
||||
"""Return a check function that verifies the MCP connection is alive."""
|
||||
|
||||
def _check() -> bool:
|
||||
conn = _connections.get(server_name)
|
||||
return conn is not None and conn.session is not None
|
||||
server = _servers.get(server_name)
|
||||
return server is not None and server.session is not None
|
||||
|
||||
return _check
|
||||
|
||||
@@ -247,17 +315,13 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
from tools.registry import registry
|
||||
from toolsets import create_custom_toolset
|
||||
|
||||
conn = await _connect_server(name, config)
|
||||
_connections[name] = conn
|
||||
|
||||
# Discover tools
|
||||
tools_result = await conn.session.list_tools()
|
||||
tools = tools_result.tools if hasattr(tools_result, "tools") else []
|
||||
server = await _connect_server(name, config)
|
||||
_servers[name] = server
|
||||
|
||||
registered_names: List[str] = []
|
||||
toolset_name = f"mcp-{name}"
|
||||
|
||||
for mcp_tool in tools:
|
||||
for mcp_tool in server._tools:
|
||||
schema = _convert_mcp_schema(name, mcp_tool)
|
||||
tool_name_prefixed = schema["name"]
|
||||
|
||||
@@ -339,29 +403,29 @@ def discover_mcp_tools() -> List[str]:
|
||||
|
||||
|
||||
def shutdown_mcp_servers():
|
||||
"""Close all MCP server connections and stop the background loop."""
|
||||
"""Close all MCP server connections and stop the background loop.
|
||||
|
||||
Each server Task is signalled to exit its ``async with`` block so that
|
||||
the anyio cancel-scope cleanup happens in the same Task that opened it.
|
||||
"""
|
||||
global _mcp_loop, _mcp_thread
|
||||
|
||||
if not _connections:
|
||||
if not _servers:
|
||||
_stop_mcp_loop()
|
||||
return
|
||||
|
||||
async def _shutdown():
|
||||
for name, conn in list(_connections.items()):
|
||||
for name, server in list(_servers.items()):
|
||||
try:
|
||||
if conn.stack:
|
||||
await conn.stack.aclose()
|
||||
await server.shutdown()
|
||||
except Exception as exc:
|
||||
logger.debug("Error closing MCP server '%s': %s", name, exc)
|
||||
finally:
|
||||
conn.session = None
|
||||
conn.stack = None
|
||||
_connections.clear()
|
||||
_servers.clear()
|
||||
|
||||
if _mcp_loop is not None and _mcp_loop.is_running():
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(_shutdown(), _mcp_loop)
|
||||
future.result(timeout=10)
|
||||
future.result(timeout=15)
|
||||
except Exception as exc:
|
||||
logger.debug("Error during MCP shutdown: %s", exc)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user