diff --git a/tests/tools/test_mcp_tool_issue_948.py b/tests/tools/test_mcp_tool_issue_948.py new file mode 100644 index 000000000..df6423034 --- /dev/null +++ b/tests/tools/test_mcp_tool_issue_948.py @@ -0,0 +1,86 @@ +import asyncio +import os +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from tools.mcp_tool import MCPServerTask, _format_connect_error, _resolve_stdio_command + + +def test_resolve_stdio_command_falls_back_to_hermes_node_bin(tmp_path): + node_bin = tmp_path / "node" / "bin" + node_bin.mkdir(parents=True) + npx_path = node_bin / "npx" + npx_path.write_text("#!/bin/sh\nexit 0\n", encoding="utf-8") + npx_path.chmod(0o755) + + with patch("tools.mcp_tool.shutil.which", return_value=None), \ + patch.dict("os.environ", {"HERMES_HOME": str(tmp_path)}, clear=False): + command, env = _resolve_stdio_command("npx", {"PATH": "/usr/bin"}) + + assert command == str(npx_path) + assert env["PATH"].split(os.pathsep)[0] == str(node_bin) + + +def test_resolve_stdio_command_respects_explicit_empty_path(): + seen_paths = [] + + def _fake_which(_cmd, path=None): + seen_paths.append(path) + return None + + with patch("tools.mcp_tool.shutil.which", side_effect=_fake_which): + command, env = _resolve_stdio_command("python", {"PATH": ""}) + + assert command == "python" + assert env["PATH"] == "" + assert seen_paths == [""] + + +def test_format_connect_error_unwraps_exception_group(): + error = ExceptionGroup( + "unhandled errors in a TaskGroup", + [FileNotFoundError(2, "No such file or directory", "node")], + ) + + message = _format_connect_error(error) + + assert "missing executable 'node'" in message + + +def test_run_stdio_uses_resolved_command_and_prepended_path(tmp_path): + node_bin = tmp_path / "node" / "bin" + node_bin.mkdir(parents=True) + npx_path = node_bin / "npx" + npx_path.write_text("#!/bin/sh\nexit 0\n", encoding="utf-8") + npx_path.chmod(0o755) + + mock_session = MagicMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock(return_value=SimpleNamespace(tools=[])) + + mock_stdio_cm = MagicMock() + mock_stdio_cm.__aenter__ = AsyncMock(return_value=(object(), object())) + mock_stdio_cm.__aexit__ = AsyncMock(return_value=False) + + mock_session_cm = MagicMock() + mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_cm.__aexit__ = AsyncMock(return_value=False) + + async def _test(): + with patch("tools.mcp_tool.shutil.which", return_value=None), \ + patch.dict("os.environ", {"HERMES_HOME": str(tmp_path), "PATH": "/usr/bin", "HOME": str(tmp_path)}, clear=False), \ + patch("tools.mcp_tool.StdioServerParameters") as mock_params, \ + patch("tools.mcp_tool.stdio_client", return_value=mock_stdio_cm), \ + patch("tools.mcp_tool.ClientSession", return_value=mock_session_cm): + server = MCPServerTask("srv") + await server.start({"command": "npx", "args": ["-y", "pkg"], "env": {"PATH": "/usr/bin"}}) + + call_kwargs = mock_params.call_args.kwargs + assert call_kwargs["command"] == str(npx_path) + assert call_kwargs["env"]["PATH"].split(os.pathsep)[0] == str(node_bin) + + await server.shutdown() + + asyncio.run(_test()) diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 2a4f5be86..448af9202 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -75,6 +75,7 @@ import logging import math import os import re +import shutil import threading import time from typing import Any, Dict, List, Optional @@ -176,6 +177,116 @@ def _sanitize_error(text: str) -> str: return _CREDENTIAL_PATTERN.sub("[REDACTED]", text) +def _prepend_path(env: dict, directory: str) -> dict: + """Prepend *directory* to env PATH if it is not already present.""" + updated = dict(env or {}) + if not directory: + return updated + + existing = updated.get("PATH", "") + parts = [part for part in existing.split(os.pathsep) if part] + if directory not in parts: + parts = [directory, *parts] + updated["PATH"] = os.pathsep.join(parts) if parts else directory + return updated + + +def _resolve_stdio_command(command: str, env: dict) -> tuple[str, dict]: + """Resolve a stdio MCP command against the exact subprocess environment. + + This primarily exists to make bare ``npx``/``npm``/``node`` commands work + reliably even when MCP subprocesses run under a filtered PATH. + """ + resolved_command = os.path.expanduser(str(command).strip()) + resolved_env = dict(env or {}) + + if os.sep not in resolved_command: + path_arg = resolved_env["PATH"] if "PATH" in resolved_env else None + which_hit = shutil.which(resolved_command, path=path_arg) + if which_hit: + resolved_command = which_hit + elif resolved_command in {"npx", "npm", "node"}: + hermes_home = os.path.expanduser( + os.getenv( + "HERMES_HOME", os.path.join(os.path.expanduser("~"), ".hermes") + ) + ) + candidates = [ + os.path.join(hermes_home, "node", "bin", resolved_command), + os.path.join(os.path.expanduser("~"), ".local", "bin", resolved_command), + ] + for candidate in candidates: + if os.path.isfile(candidate) and os.access(candidate, os.X_OK): + resolved_command = candidate + break + + command_dir = os.path.dirname(resolved_command) + if command_dir: + resolved_env = _prepend_path(resolved_env, command_dir) + + return resolved_command, resolved_env + + +def _format_connect_error(exc: BaseException) -> str: + """Render nested MCP connection errors into an actionable short message.""" + + def _find_missing(current: BaseException) -> Optional[str]: + nested = getattr(current, "exceptions", None) + if nested: + for child in nested: + missing = _find_missing(child) + if missing: + return missing + return None + if isinstance(current, FileNotFoundError): + if getattr(current, "filename", None): + return str(current.filename) + match = re.search(r"No such file or directory: '([^']+)'", str(current)) + if match: + return match.group(1) + for attr in ("__cause__", "__context__"): + nested_exc = getattr(current, attr, None) + if isinstance(nested_exc, BaseException): + missing = _find_missing(nested_exc) + if missing: + return missing + return None + + def _flatten_messages(current: BaseException) -> List[str]: + nested = getattr(current, "exceptions", None) + if nested: + flattened: List[str] = [] + for child in nested: + flattened.extend(_flatten_messages(child)) + return flattened + messages = [] + text = str(current).strip() + if text: + messages.append(text) + for attr in ("__cause__", "__context__"): + nested_exc = getattr(current, attr, None) + if isinstance(nested_exc, BaseException): + messages.extend(_flatten_messages(nested_exc)) + return messages or [current.__class__.__name__] + + missing = _find_missing(exc) + if missing: + message = f"missing executable '{missing}'" + if os.path.basename(missing) in {"npx", "npm", "node"}: + message += ( + " (ensure Node.js is installed and PATH includes its bin directory, " + "or set mcp_servers..command to an absolute path and include " + "that directory in mcp_servers..env.PATH)" + ) + return _sanitize_error(message) + + deduped: List[str] = [] + for item in _flatten_messages(exc): + if item not in deduped: + deduped.append(item) + return _sanitize_error("; ".join(deduped[:3])) + + # --------------------------------------------------------------------------- # Sampling -- server-initiated LLM requests (MCP sampling/createMessage) # --------------------------------------------------------------------------- @@ -608,6 +719,7 @@ class MCPServerTask: ) safe_env = _build_safe_env(user_env) + command, safe_env = _resolve_stdio_command(command, safe_env) server_params = StdioServerParameters( command=command, args=args, @@ -1340,9 +1452,12 @@ def discover_mcp_tools() -> List[str]: for name, result in zip(server_names, results): if isinstance(result, Exception): failed_count += 1 + command = new_servers.get(name, {}).get("command") logger.warning( - "Failed to connect to MCP server '%s': %s", - name, result, + "Failed to connect to MCP server '%s'%s: %s", + name, + f" (command={command})" if command else "", + _format_connect_error(result), ) elif isinstance(result, list): all_tools.extend(result)