From 4a8f23eddff6fe0dbe01c4b0ee37efdb06e31f82 Mon Sep 17 00:00:00 2001 From: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com> Date: Tue, 10 Mar 2026 02:27:59 +0300 Subject: [PATCH] fix: correctly track failed MCP server connections in discovery _discover_one() caught all exceptions and returned [], making asyncio.gather(return_exceptions=True) redundant. The isinstance(result, Exception) branch in _discover_all() was dead code, so failed_count was always 0. This caused: - No summary printed when all servers fail (silent failure) - ok_servers always equaling total_servers (misleading count) - Unused variables transport_desc and transport_type Fix: let exceptions propagate to gather() so failed_count increments correctly. Move per-server failure logging to _discover_all(). Remove dead variables. --- tests/tools/test_mcp_tool.py | 124 +++++++++++++++++++++++++++++++++++ tools/mcp_tool.py | 20 ++---- 2 files changed, 131 insertions(+), 13 deletions(-) diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 446f80d3e..0f7fc18a5 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -2326,3 +2326,127 @@ class TestMCPServerTaskSamplingIntegration: kwargs = server._sampling.session_kwargs() assert "sampling_callback" in kwargs assert "sampling_capabilities" in kwargs + + +# --------------------------------------------------------------------------- +# Discovery failed_count tracking +# --------------------------------------------------------------------------- + +class TestDiscoveryFailedCount: + """Verify discover_mcp_tools() correctly tracks failed server connections.""" + + def test_failed_server_increments_failed_count(self): + """When _discover_and_register_server raises, failed_count increments.""" + from tools.mcp_tool import discover_mcp_tools, _servers, _ensure_mcp_loop + + fake_config = { + "good_server": {"command": "npx", "args": ["good"]}, + "bad_server": {"command": "npx", "args": ["bad"]}, + } + + async def fake_register(name, cfg): + if name == "bad_server": + raise ConnectionError("Connection refused") + # Simulate successful registration + from tools.mcp_tool import MCPServerTask + server = MCPServerTask(name) + server.session = MagicMock() + server._tools = [_make_mcp_tool("tool_a")] + _servers[name] = server + return [f"mcp_{name}_tool_a"] + + with patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ + patch("tools.mcp_tool._discover_and_register_server", side_effect=fake_register), \ + patch("tools.mcp_tool._MCP_AVAILABLE", True), \ + patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_good_server_tool_a"]): + _ensure_mcp_loop() + + # Capture the logger to verify failed_count in summary + with patch("tools.mcp_tool.logger") as mock_logger: + discover_mcp_tools() + + # Find the summary info call + info_calls = [ + str(call) + for call in mock_logger.info.call_args_list + if "failed" in str(call).lower() or "MCP:" in str(call) + ] + # The summary should mention the failure + assert any("1 failed" in str(c) for c in info_calls), ( + f"Summary should report 1 failed server, got: {info_calls}" + ) + + _servers.pop("good_server", None) + _servers.pop("bad_server", None) + + def test_all_servers_fail_still_prints_summary(self): + """When all servers fail, a summary with failure count is still printed.""" + from tools.mcp_tool import discover_mcp_tools, _servers, _ensure_mcp_loop + + fake_config = { + "srv1": {"command": "npx", "args": ["a"]}, + "srv2": {"command": "npx", "args": ["b"]}, + } + + async def always_fail(name, cfg): + raise ConnectionError(f"Server {name} refused") + + with patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ + patch("tools.mcp_tool._discover_and_register_server", side_effect=always_fail), \ + patch("tools.mcp_tool._MCP_AVAILABLE", True), \ + patch("tools.mcp_tool._existing_tool_names", return_value=[]): + _ensure_mcp_loop() + + with patch("tools.mcp_tool.logger") as mock_logger: + discover_mcp_tools() + + # Summary must be printed even when all servers fail + info_calls = [str(call) for call in mock_logger.info.call_args_list] + assert any("2 failed" in str(c) for c in info_calls), ( + f"Summary should report 2 failed servers, got: {info_calls}" + ) + + _servers.pop("srv1", None) + _servers.pop("srv2", None) + + def test_ok_servers_excludes_failures(self): + """ok_servers count correctly excludes failed servers.""" + from tools.mcp_tool import discover_mcp_tools, _servers, _ensure_mcp_loop + + fake_config = { + "ok1": {"command": "npx", "args": ["ok1"]}, + "ok2": {"command": "npx", "args": ["ok2"]}, + "fail1": {"command": "npx", "args": ["fail"]}, + } + + async def selective_register(name, cfg): + if name == "fail1": + raise ConnectionError("Refused") + from tools.mcp_tool import MCPServerTask + server = MCPServerTask(name) + server.session = MagicMock() + server._tools = [_make_mcp_tool("t")] + _servers[name] = server + return [f"mcp_{name}_t"] + + with patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ + patch("tools.mcp_tool._discover_and_register_server", side_effect=selective_register), \ + patch("tools.mcp_tool._MCP_AVAILABLE", True), \ + patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_ok1_t", "mcp_ok2_t"]): + _ensure_mcp_loop() + + with patch("tools.mcp_tool.logger") as mock_logger: + discover_mcp_tools() + + info_calls = [str(call) for call in mock_logger.info.call_args_list] + # Should say "2 server(s)" not "3 server(s)" + assert any("2 server" in str(c) for c in info_calls), ( + f"Summary should report 2 ok servers, got: {info_calls}" + ) + assert any("1 failed" in str(c) for c in info_calls), ( + f"Summary should report 1 failed, got: {info_calls}" + ) + + _servers.pop("ok1", None) + _servers.pop("ok2", None) + _servers.pop("fail1", None) diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index b0fc35f7f..94495430b 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -1331,29 +1331,23 @@ def discover_mcp_tools() -> List[str]: async def _discover_one(name: str, cfg: dict) -> List[str]: """Connect to a single server and return its registered tool names.""" - transport_desc = cfg.get("url", f'{cfg.get("command", "?")} {" ".join(cfg.get("args", [])[:2])}') - try: - registered = await _discover_and_register_server(name, cfg) - transport_type = "HTTP" if "url" in cfg else "stdio" - return registered - except Exception as exc: - logger.warning( - "Failed to connect to MCP server '%s': %s", - name, exc, - ) - return [] + return await _discover_and_register_server(name, cfg) async def _discover_all(): nonlocal failed_count + server_names = list(new_servers.keys()) # Connect to all servers in PARALLEL results = await asyncio.gather( *(_discover_one(name, cfg) for name, cfg in new_servers.items()), return_exceptions=True, ) - for result in results: + for name, result in zip(server_names, results): if isinstance(result, Exception): failed_count += 1 - logger.warning("MCP discovery error: %s", result) + logger.warning( + "Failed to connect to MCP server '%s': %s", + name, result, + ) elif isinstance(result, list): all_tools.extend(result) else: