diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index caaffd484..f12a6c937 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -36,6 +36,15 @@ def _make_call_result(text="file contents here", is_error=False): return SimpleNamespace(content=[block], isError=is_error) +def _make_mock_server(name, session=None, tools=None): + """Create an MCPServerTask with mock attributes for testing.""" + from tools.mcp_tool import MCPServerTask + server = MCPServerTask(name) + server.session = session + server._tools = tools or [] + return server + + # --------------------------------------------------------------------------- # Config loading # --------------------------------------------------------------------------- @@ -43,12 +52,10 @@ def _make_call_result(text="file contents here", is_error=False): class TestLoadMCPConfig: def test_no_config_returns_empty(self): """No mcp_servers key in config -> empty dict.""" - with patch("tools.mcp_tool.load_config", create=True) as mock_lc: - # Patch the actual import inside the function - with patch("hermes_cli.config.load_config", return_value={"model": "test"}): - from tools.mcp_tool import _load_mcp_config - result = _load_mcp_config() - assert result == {} + with patch("hermes_cli.config.load_config", return_value={"model": "test"}): + from tools.mcp_tool import _load_mcp_config + result = _load_mcp_config() + assert result == {} def test_valid_config_parsed(self): """Valid mcp_servers config is returned as-is.""" @@ -123,46 +130,37 @@ class TestSchemaConversion: class TestCheckFunction: def test_disconnected_returns_false(self): - from tools.mcp_tool import _make_check_fn, _connections + from tools.mcp_tool import _make_check_fn, _servers - # Ensure no connection exists - _connections.pop("test_server", None) + _servers.pop("test_server", None) check = _make_check_fn("test_server") assert check() is False def test_connected_returns_true(self): - from tools.mcp_tool import _make_check_fn, _connections, MCPConnection + from tools.mcp_tool import _make_check_fn, _servers - conn = MCPConnection( - server_name="test_server", - session=MagicMock(), - stack=MagicMock(), - ) - _connections["test_server"] = conn + server = _make_mock_server("test_server", session=MagicMock()) + _servers["test_server"] = server try: check = _make_check_fn("test_server") assert check() is True finally: - _connections.pop("test_server", None) + _servers.pop("test_server", None) def test_session_none_returns_false(self): - from tools.mcp_tool import _make_check_fn, _connections, MCPConnection + from tools.mcp_tool import _make_check_fn, _servers - conn = MCPConnection( - server_name="test_server", - session=None, - stack=MagicMock(), - ) - _connections["test_server"] = conn + server = _make_mock_server("test_server", session=None) + _servers["test_server"] = server try: check = _make_check_fn("test_server") assert check() is False finally: - _connections.pop("test_server", None) + _servers.pop("test_server", None) # --------------------------------------------------------------------------- -# Tool handler (async) +# Tool handler # --------------------------------------------------------------------------- class TestToolHandler: @@ -171,20 +169,24 @@ class TestToolHandler: def _patch_mcp_loop(self, coro_side_effect=None): """Return a patch for _run_on_mcp_loop that runs the coroutine directly.""" def fake_run(coro, timeout=30): - return asyncio.get_event_loop().run_until_complete(coro) + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() if coro_side_effect: return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=coro_side_effect) return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run) def test_successful_call(self): - from tools.mcp_tool import _make_tool_handler, _connections, MCPConnection + from tools.mcp_tool import _make_tool_handler, _servers mock_session = MagicMock() mock_session.call_tool = AsyncMock( return_value=_make_call_result("hello world", is_error=False) ) - conn = MCPConnection("test_srv", session=mock_session, stack=MagicMock()) - _connections["test_srv"] = conn + server = _make_mock_server("test_srv", session=mock_session) + _servers["test_srv"] = server try: handler = _make_tool_handler("test_srv", "greet") @@ -193,17 +195,17 @@ class TestToolHandler: assert result["result"] == "hello world" mock_session.call_tool.assert_called_once_with("greet", arguments={"name": "world"}) finally: - _connections.pop("test_srv", None) + _servers.pop("test_srv", None) def test_mcp_error_result(self): - from tools.mcp_tool import _make_tool_handler, _connections, MCPConnection + from tools.mcp_tool import _make_tool_handler, _servers mock_session = MagicMock() mock_session.call_tool = AsyncMock( return_value=_make_call_result("something went wrong", is_error=True) ) - conn = MCPConnection("test_srv", session=mock_session, stack=MagicMock()) - _connections["test_srv"] = conn + server = _make_mock_server("test_srv", session=mock_session) + _servers["test_srv"] = server try: handler = _make_tool_handler("test_srv", "fail_tool") @@ -212,25 +214,24 @@ class TestToolHandler: assert "error" in result assert "something went wrong" in result["error"] finally: - _connections.pop("test_srv", None) + _servers.pop("test_srv", None) def test_disconnected_server(self): - from tools.mcp_tool import _make_tool_handler, _connections + from tools.mcp_tool import _make_tool_handler, _servers - _connections.pop("ghost", None) + _servers.pop("ghost", None) handler = _make_tool_handler("ghost", "any_tool") - # Disconnected check happens before _run_on_mcp_loop, no patch needed result = json.loads(handler({})) assert "error" in result assert "not connected" in result["error"] def test_exception_during_call(self): - from tools.mcp_tool import _make_tool_handler, _connections, MCPConnection + from tools.mcp_tool import _make_tool_handler, _servers mock_session = MagicMock() mock_session.call_tool = AsyncMock(side_effect=RuntimeError("connection lost")) - conn = MCPConnection("test_srv", session=mock_session, stack=MagicMock()) - _connections["test_srv"] = conn + server = _make_mock_server("test_srv", session=mock_session) + _servers["test_srv"] = server try: handler = _make_tool_handler("test_srv", "broken_tool") @@ -239,7 +240,7 @@ class TestToolHandler: assert "error" in result assert "connection lost" in result["error"] finally: - _connections.pop("test_srv", None) + _servers.pop("test_srv", None) # --------------------------------------------------------------------------- @@ -249,23 +250,21 @@ class TestToolHandler: class TestDiscoverAndRegister: def test_tools_registered_in_registry(self): """_discover_and_register_server registers tools with correct names.""" - from tools.registry import ToolRegistry, registry as real_registry - from tools.mcp_tool import _discover_and_register_server, _connections, MCPConnection + from tools.registry import ToolRegistry + from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask mock_registry = ToolRegistry() mock_tools = [ _make_mcp_tool("read_file", "Read a file"), _make_mcp_tool("write_file", "Write a file"), ] - mock_session = MagicMock() - mock_session.initialize = AsyncMock() - mock_session.list_tools = AsyncMock( - return_value=SimpleNamespace(tools=mock_tools) - ) async def fake_connect(name, config): - return MCPConnection(name, session=mock_session, stack=MagicMock()) + server = MCPServerTask(name) + server.session = mock_session + server._tools = mock_tools + return server with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ patch("tools.registry.registry", mock_registry): @@ -278,22 +277,20 @@ class TestDiscoverAndRegister: assert "mcp_fs_read_file" in mock_registry.get_all_tool_names() assert "mcp_fs_write_file" in mock_registry.get_all_tool_names() - _connections.pop("fs", None) + _servers.pop("fs", None) def test_toolset_created(self): """A custom toolset is created for the MCP server.""" - from tools.mcp_tool import _discover_and_register_server, _connections, MCPConnection + from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask mock_tools = [_make_mcp_tool("ping", "Ping")] - mock_session = MagicMock() - mock_session.initialize = AsyncMock() - mock_session.list_tools = AsyncMock( - return_value=SimpleNamespace(tools=mock_tools) - ) async def fake_connect(name, config): - return MCPConnection(name, session=mock_session, stack=MagicMock()) + server = MCPServerTask(name) + server.session = mock_session + server._tools = mock_tools + return server mock_create = MagicMock() with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ @@ -306,24 +303,22 @@ class TestDiscoverAndRegister: call_kwargs = mock_create.call_args assert call_kwargs[1]["name"] == "mcp-myserver" or call_kwargs[0][0] == "mcp-myserver" - _connections.pop("myserver", None) + _servers.pop("myserver", None) def test_schema_format_correct(self): """Registered schemas have the correct format.""" - from tools.registry import ToolRegistry, registry as real_registry - from tools.mcp_tool import _discover_and_register_server, _connections, MCPConnection + from tools.registry import ToolRegistry + from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask mock_registry = ToolRegistry() mock_tools = [_make_mcp_tool("do_thing", "Do something")] - mock_session = MagicMock() - mock_session.initialize = AsyncMock() - mock_session.list_tools = AsyncMock( - return_value=SimpleNamespace(tools=mock_tools) - ) async def fake_connect(name, config): - return MCPConnection(name, session=mock_session, stack=MagicMock()) + server = MCPServerTask(name) + server.session = mock_session + server._tools = mock_tools + return server with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ patch("tools.registry.registry", mock_registry): @@ -338,91 +333,125 @@ class TestDiscoverAndRegister: assert entry.is_async is False assert entry.toolset == "mcp-srv" - _connections.pop("srv", None) + _servers.pop("srv", None) # --------------------------------------------------------------------------- -# _connect_server (SDK interaction) +# MCPServerTask (run / start / shutdown) # --------------------------------------------------------------------------- -class TestConnectServer: - def test_calls_sdk_with_correct_params(self): - """_connect_server creates StdioServerParameters and calls stdio_client.""" - from tools.mcp_tool import _connect_server, MCPConnection +class TestMCPServerTask: + """Test the MCPServerTask lifecycle with mocked MCP SDK.""" + def _mock_stdio_and_session(self, session): + """Return patches for stdio_client and ClientSession as async CMs.""" + mock_read, mock_write = MagicMock(), MagicMock() + + mock_stdio_cm = MagicMock() + mock_stdio_cm.__aenter__ = AsyncMock(return_value=(mock_read, mock_write)) + mock_stdio_cm.__aexit__ = AsyncMock(return_value=False) + + mock_cs_cm = MagicMock() + mock_cs_cm.__aenter__ = AsyncMock(return_value=session) + mock_cs_cm.__aexit__ = AsyncMock(return_value=False) + + return ( + patch("tools.mcp_tool.stdio_client", return_value=mock_stdio_cm), + patch("tools.mcp_tool.ClientSession", return_value=mock_cs_cm), + mock_read, mock_write, + ) + + def test_start_connects_and_discovers_tools(self): + """start() creates a Task that connects, discovers tools, and waits.""" + from tools.mcp_tool import MCPServerTask + + mock_tools = [_make_mcp_tool("echo")] mock_session = MagicMock() mock_session.initialize = AsyncMock() - - mock_read = MagicMock() - mock_write = MagicMock() - - with patch("tools.mcp_tool.StdioServerParameters") as mock_params, \ - patch("tools.mcp_tool.stdio_client") as mock_stdio, \ - patch("tools.mcp_tool.ClientSession") as mock_cs, \ - patch("tools.mcp_tool.AsyncExitStack") as mock_stack_cls: - - mock_stack = MagicMock() - mock_stack.enter_async_context = AsyncMock( - side_effect=[(mock_read, mock_write), mock_session] - ) - mock_stack_cls.return_value = mock_stack - - conn = asyncio.run(_connect_server("test_srv", { - "command": "npx", - "args": ["-y", "some-server"], - "env": {"MY_KEY": "secret"}, - })) - - # StdioServerParameters called with correct values - mock_params.assert_called_once_with( - command="npx", - args=["-y", "some-server"], - env={"MY_KEY": "secret"}, + mock_session.list_tools = AsyncMock( + return_value=SimpleNamespace(tools=mock_tools) ) - # ClientSession created with the streams - mock_cs.assert_called_once_with(mock_read, mock_write) - # initialize() was called - mock_session.initialize.assert_called_once() - # Returned connection is valid - assert conn.server_name == "test_srv" - assert conn.session is mock_session + + p_stdio, p_cs, _, _ = self._mock_stdio_and_session(mock_session) + + async def _test(): + with patch("tools.mcp_tool.StdioServerParameters"), p_stdio, p_cs: + server = MCPServerTask("test_srv") + await server.start({"command": "npx", "args": ["-y", "test"]}) + + assert server.session is mock_session + assert len(server._tools) == 1 + assert server._tools[0].name == "echo" + mock_session.initialize.assert_called_once() + + await server.shutdown() + assert server.session is None + + asyncio.run(_test()) def test_no_command_raises(self): """Missing 'command' in config raises ValueError.""" - from tools.mcp_tool import _connect_server + from tools.mcp_tool import MCPServerTask - with pytest.raises(ValueError, match="no 'command'"): - asyncio.run(_connect_server("bad", {"args": []})) + async def _test(): + server = MCPServerTask("bad") + with pytest.raises(ValueError, match="no 'command'"): + await server.start({"args": []}) + + asyncio.run(_test()) def test_empty_env_passed_as_none(self): """Empty env dict is passed as None to StdioServerParameters.""" - from tools.mcp_tool import _connect_server + from tools.mcp_tool import MCPServerTask mock_session = MagicMock() mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock( + return_value=SimpleNamespace(tools=[]) + ) - with patch("tools.mcp_tool.StdioServerParameters") as mock_params, \ - patch("tools.mcp_tool.stdio_client"), \ - patch("tools.mcp_tool.ClientSession", return_value=mock_session), \ - patch("tools.mcp_tool.AsyncExitStack") as mock_stack_cls: + p_stdio, p_cs, _, _ = self._mock_stdio_and_session(mock_session) - mock_stack = MagicMock() - mock_stack.enter_async_context = AsyncMock( - side_effect=[ - (MagicMock(), MagicMock()), - mock_session, - ] - ) - mock_stack_cls.return_value = mock_stack + async def _test(): + with patch("tools.mcp_tool.StdioServerParameters") as mock_params, \ + p_stdio, p_cs: + server = MCPServerTask("srv") + await server.start({"command": "node", "env": {}}) - asyncio.run(_connect_server("srv", { - "command": "node", - "env": {}, - })) + # Empty dict -> None + call_kwargs = mock_params.call_args + assert call_kwargs.kwargs.get("env") is None - # Empty dict -> None - assert mock_params.call_args[1]["env"] is None or \ - mock_params.call_args.kwargs.get("env") is None + await server.shutdown() + + asyncio.run(_test()) + + def test_shutdown_signals_task_exit(self): + """shutdown() signals the event and waits for task completion.""" + from tools.mcp_tool import MCPServerTask + + mock_session = MagicMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock( + return_value=SimpleNamespace(tools=[]) + ) + + p_stdio, p_cs, _, _ = self._mock_stdio_and_session(mock_session) + + async def _test(): + with patch("tools.mcp_tool.StdioServerParameters"), p_stdio, p_cs: + server = MCPServerTask("srv") + await server.start({"command": "npx"}) + + assert server.session is not None + assert not server._task.done() + + await server.shutdown() + + assert server.session is None + assert server._task.done() + + asyncio.run(_test()) # --------------------------------------------------------------------------- @@ -432,17 +461,16 @@ class TestConnectServer: class TestToolsetInjection: def test_mcp_tools_added_to_platform_toolsets(self): """Discovered MCP tools are injected into hermes-cli and platform toolsets.""" - from tools.mcp_tool import _connections, MCPConnection + from tools.mcp_tool import _servers, MCPServerTask mock_tools = [_make_mcp_tool("list_files", "List files")] mock_session = MagicMock() - mock_session.initialize = AsyncMock() - mock_session.list_tools = AsyncMock( - return_value=SimpleNamespace(tools=mock_tools) - ) async def fake_connect(name, config): - return MCPConnection(name, session=mock_session, stack=MagicMock()) + server = MCPServerTask(name) + server.session = mock_session + server._tools = mock_tools + return server fake_toolsets = { "hermes-cli": {"tools": ["terminal", "web_search"], "description": "CLI", "includes": []}, @@ -455,7 +483,6 @@ class TestToolsetInjection: with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ - patch("tools.mcp_tool.TOOLSETS", fake_toolsets, create=True), \ patch("toolsets.TOOLSETS", fake_toolsets): from tools.mcp_tool import discover_mcp_tools result = discover_mcp_tools() @@ -466,18 +493,14 @@ class TestToolsetInjection: # Original tools preserved assert "terminal" in fake_toolsets["hermes-cli"]["tools"] - _connections.pop("fs", None) + _servers.pop("fs", None) def test_server_connection_failure_skipped(self): """If one server fails to connect, others still proceed.""" - from tools.mcp_tool import _connections, MCPConnection + from tools.mcp_tool import _servers, MCPServerTask mock_tools = [_make_mcp_tool("ping", "Ping")] mock_session = MagicMock() - mock_session.initialize = AsyncMock() - mock_session.list_tools = AsyncMock( - return_value=SimpleNamespace(tools=mock_tools) - ) call_count = 0 @@ -486,7 +509,10 @@ class TestToolsetInjection: call_count += 1 if name == "broken": raise ConnectionError("cannot reach server") - return MCPConnection(name, session=mock_session, stack=MagicMock()) + server = MCPServerTask(name) + server.session = mock_session + server._tools = mock_tools + return server fake_config = { "broken": {"command": "bad"}, @@ -508,7 +534,7 @@ class TestToolsetInjection: assert "mcp_broken_ping" not in result assert call_count == 2 # Both were attempted - _connections.pop("good", None) + _servers.pop("good", None) # --------------------------------------------------------------------------- @@ -533,50 +559,46 @@ class TestGracefulFallback: # --------------------------------------------------------------------------- -# Shutdown +# Shutdown (public API) # --------------------------------------------------------------------------- class TestShutdown: - def test_no_connections_safe(self): - """shutdown_mcp_servers with no connections does nothing.""" - from tools.mcp_tool import shutdown_mcp_servers, _connections + def test_no_servers_safe(self): + """shutdown_mcp_servers with no servers does nothing.""" + from tools.mcp_tool import shutdown_mcp_servers, _servers - _connections.clear() + _servers.clear() shutdown_mcp_servers() # Should not raise - def test_shutdown_clears_connections(self): - """shutdown_mcp_servers closes stacks and clears the dict.""" + def test_shutdown_clears_servers(self): + """shutdown_mcp_servers calls shutdown() on each server and clears dict.""" import tools.mcp_tool as mcp_mod - from tools.mcp_tool import shutdown_mcp_servers, _connections, MCPConnection + from tools.mcp_tool import shutdown_mcp_servers, _servers - _connections.clear() - mock_stack = MagicMock() - mock_stack.aclose = AsyncMock() - conn = MCPConnection("test", session=MagicMock(), stack=mock_stack) - _connections["test"] = conn + _servers.clear() + mock_server = MagicMock() + mock_server.shutdown = AsyncMock() + _servers["test"] = mock_server - # Start a real background loop so shutdown can schedule on it mcp_mod._ensure_mcp_loop() try: shutdown_mcp_servers() finally: - # _stop_mcp_loop is called by shutdown, but ensure cleanup mcp_mod._mcp_loop = None mcp_mod._mcp_thread = None - assert len(_connections) == 0 - mock_stack.aclose.assert_called_once() + assert len(_servers) == 0 + mock_server.shutdown.assert_called_once() def test_shutdown_handles_errors(self): """shutdown_mcp_servers handles errors during close gracefully.""" import tools.mcp_tool as mcp_mod - from tools.mcp_tool import shutdown_mcp_servers, _connections, MCPConnection + from tools.mcp_tool import shutdown_mcp_servers, _servers - _connections.clear() - mock_stack = MagicMock() - mock_stack.aclose = AsyncMock(side_effect=RuntimeError("close failed")) - conn = MCPConnection("broken", session=MagicMock(), stack=mock_stack) - _connections["broken"] = conn + _servers.clear() + mock_server = MagicMock() + mock_server.shutdown = AsyncMock(side_effect=RuntimeError("close failed")) + _servers["broken"] = mock_server mcp_mod._ensure_mcp_loop() try: @@ -585,4 +607,4 @@ class TestShutdown: mcp_mod._mcp_loop = None mcp_mod._mcp_thread = None - assert len(_connections) == 0 + assert len(_servers) == 0 diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index eecbaa29f..5225d63f7 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -25,8 +25,13 @@ Example config:: Architecture: A dedicated background event loop (_mcp_loop) runs in a daemon thread. - All MCP connections live on this loop. Tool handlers schedule coroutines - onto it via run_coroutine_threadsafe(), so they work from any thread. + Each MCP server runs as a long-lived asyncio Task on this loop, keeping + its ``async with stdio_client(...)`` context alive. Tool call coroutines + are scheduled onto the loop via ``run_coroutine_threadsafe()``. + + On shutdown, each server Task is signalled to exit its ``async with`` + block, ensuring the anyio cancel-scope cleanup happens in the *same* + Task that opened the connection (required by anyio). """ import asyncio @@ -45,31 +50,114 @@ _MCP_AVAILABLE = False try: from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client - from contextlib import AsyncExitStack _MCP_AVAILABLE = True except ImportError: logger.debug("mcp package not installed -- MCP tool support disabled") # --------------------------------------------------------------------------- -# Connection tracking +# Server task -- each MCP server lives in one long-lived asyncio Task # --------------------------------------------------------------------------- -class MCPConnection: - """Holds a live MCP server connection and its async resource stack.""" +class MCPServerTask: + """Manages a single MCP server connection in a dedicated asyncio Task. - __slots__ = ("server_name", "session", "stack") + The entire connection lifecycle (connect, discover, serve, disconnect) + runs inside one asyncio Task so that anyio cancel-scopes created by + ``stdio_client`` are entered and exited in the same Task context. + """ - def __init__(self, server_name: str, session: Any, stack: Any): - self.server_name = server_name - self.session: Optional[Any] = session - self.stack: Optional[Any] = stack + __slots__ = ( + "name", "session", + "_task", "_ready", "_shutdown_event", "_tools", "_error", + ) + + def __init__(self, name: str): + self.name = name + self.session: Optional[Any] = None + self._task: Optional[asyncio.Task] = None + self._ready = asyncio.Event() + self._shutdown_event = asyncio.Event() + self._tools: list = [] + self._error: Optional[Exception] = None + + async def run(self, config: dict): + """Long-lived coroutine: connect, discover tools, wait, disconnect.""" + command = config.get("command") + args = config.get("args", []) + env = config.get("env") + + if not command: + self._error = ValueError( + f"MCP server '{self.name}' has no 'command' in config" + ) + self._ready.set() + return + + server_params = StdioServerParameters( + command=command, + args=args, + env=env if env else None, + ) + + try: + async with stdio_client(server_params) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + self.session = session + + tools_result = await session.list_tools() + self._tools = ( + tools_result.tools + if hasattr(tools_result, "tools") + else [] + ) + + # Signal that connection is ready + self._ready.set() + + # Block until shutdown is requested -- this keeps the + # async-with contexts alive on THIS Task. + await self._shutdown_event.wait() + except Exception as exc: + self._error = exc + self._ready.set() + finally: + self.session = None + + async def start(self, config: dict): + """Create the background Task and wait until ready (or failed).""" + self._task = asyncio.ensure_future(self.run(config)) + await self._ready.wait() + if self._error: + raise self._error + + async def shutdown(self): + """Signal the Task to exit and wait for clean resource teardown.""" + self._shutdown_event.set() + if self._task and not self._task.done(): + try: + await asyncio.wait_for(self._task, timeout=10) + except asyncio.TimeoutError: + logger.warning( + "MCP server '%s' shutdown timed out, cancelling task", + self.name, + ) + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + self.session = None -_connections: Dict[str, MCPConnection] = {} +# --------------------------------------------------------------------------- +# Module-level state +# --------------------------------------------------------------------------- + +_servers: Dict[str, MCPServerTask] = {} # Dedicated event loop running in a background daemon thread. -# All MCP async operations (connect, call_tool, shutdown) run here. _mcp_loop: Optional[asyncio.AbstractEventLoop] = None _mcp_thread: Optional[threading.Thread] = None @@ -118,42 +206,22 @@ def _load_mcp_config() -> Dict[str, dict]: # --------------------------------------------------------------------------- -# Server connection +# Server connection helper # --------------------------------------------------------------------------- -async def _connect_server(name: str, config: dict) -> MCPConnection: - """Start an MCP server subprocess and initialize a ClientSession. +async def _connect_server(name: str, config: dict) -> MCPServerTask: + """Create an MCPServerTask, start it, and return when ready. - Args: - name: Logical server name (e.g. "filesystem"). - config: Dict with ``command``, ``args``, and optional ``env``. - - Returns: - An ``MCPConnection`` with a live session. + The server Task keeps the subprocess alive in the background. + Call ``server.shutdown()`` (on the same event loop) to tear it down. Raises: - Exception on connection or initialization failure. + ValueError: if ``command`` is missing from *config*. + Exception: on connection or initialization failure. """ - command = config.get("command") - args = config.get("args", []) - env = config.get("env") - - if not command: - raise ValueError(f"MCP server '{name}' has no 'command' in config") - - server_params = StdioServerParameters( - command=command, - args=args, - env=env if env else None, - ) - - stack = AsyncExitStack() - stdio_transport = await stack.enter_async_context(stdio_client(server_params)) - read_stream, write_stream = stdio_transport - session = await stack.enter_async_context(ClientSession(read_stream, write_stream)) - await session.initialize() - - return MCPConnection(server_name=name, session=session, stack=stack) + server = MCPServerTask(name) + await server.start(config) + return server # --------------------------------------------------------------------------- @@ -168,14 +236,14 @@ def _make_tool_handler(server_name: str, tool_name: str): """ def _handler(args: dict, **kwargs) -> str: - conn = _connections.get(server_name) - if not conn or not conn.session: + server = _servers.get(server_name) + if not server or not server.session: return json.dumps({ "error": f"MCP server '{server_name}' is not connected" }) async def _call(): - result = await conn.session.call_tool(tool_name, arguments=args) + result = await server.session.call_tool(tool_name, arguments=args) # MCP CallToolResult has .content (list of content blocks) and .isError if result.isError: error_text = "" @@ -204,8 +272,8 @@ def _make_check_fn(server_name: str): """Return a check function that verifies the MCP connection is alive.""" def _check() -> bool: - conn = _connections.get(server_name) - return conn is not None and conn.session is not None + server = _servers.get(server_name) + return server is not None and server.session is not None return _check @@ -247,17 +315,13 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]: from tools.registry import registry from toolsets import create_custom_toolset - conn = await _connect_server(name, config) - _connections[name] = conn - - # Discover tools - tools_result = await conn.session.list_tools() - tools = tools_result.tools if hasattr(tools_result, "tools") else [] + server = await _connect_server(name, config) + _servers[name] = server registered_names: List[str] = [] toolset_name = f"mcp-{name}" - for mcp_tool in tools: + for mcp_tool in server._tools: schema = _convert_mcp_schema(name, mcp_tool) tool_name_prefixed = schema["name"] @@ -339,29 +403,29 @@ def discover_mcp_tools() -> List[str]: def shutdown_mcp_servers(): - """Close all MCP server connections and stop the background loop.""" + """Close all MCP server connections and stop the background loop. + + Each server Task is signalled to exit its ``async with`` block so that + the anyio cancel-scope cleanup happens in the same Task that opened it. + """ global _mcp_loop, _mcp_thread - if not _connections: + if not _servers: _stop_mcp_loop() return async def _shutdown(): - for name, conn in list(_connections.items()): + for name, server in list(_servers.items()): try: - if conn.stack: - await conn.stack.aclose() + await server.shutdown() except Exception as exc: logger.debug("Error closing MCP server '%s': %s", name, exc) - finally: - conn.session = None - conn.stack = None - _connections.clear() + _servers.clear() if _mcp_loop is not None and _mcp_loop.is_running(): try: future = asyncio.run_coroutine_threadsafe(_shutdown(), _mcp_loop) - future.result(timeout=10) + future.result(timeout=15) except Exception as exc: logger.debug("Error during MCP shutdown: %s", exc)