fix(mcp): resolve npx stdio connection failures (#1291)
Salvaged from PR #977 onto current main. Preserves the MCP stdio command resolution and improved error diagnostics, with deterministic regression tests for the npx/node PATH cases. Co-authored-by: kshitij <82637225+kshitijk4poor@users.noreply.github.com>
This commit is contained in:
86
tests/tools/test_mcp_tool_issue_948.py
Normal file
86
tests/tools/test_mcp_tool_issue_948.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import asyncio
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.mcp_tool import MCPServerTask, _format_connect_error, _resolve_stdio_command
|
||||
|
||||
|
||||
def test_resolve_stdio_command_falls_back_to_hermes_node_bin(tmp_path):
|
||||
node_bin = tmp_path / "node" / "bin"
|
||||
node_bin.mkdir(parents=True)
|
||||
npx_path = node_bin / "npx"
|
||||
npx_path.write_text("#!/bin/sh\nexit 0\n", encoding="utf-8")
|
||||
npx_path.chmod(0o755)
|
||||
|
||||
with patch("tools.mcp_tool.shutil.which", return_value=None), \
|
||||
patch.dict("os.environ", {"HERMES_HOME": str(tmp_path)}, clear=False):
|
||||
command, env = _resolve_stdio_command("npx", {"PATH": "/usr/bin"})
|
||||
|
||||
assert command == str(npx_path)
|
||||
assert env["PATH"].split(os.pathsep)[0] == str(node_bin)
|
||||
|
||||
|
||||
def test_resolve_stdio_command_respects_explicit_empty_path():
|
||||
seen_paths = []
|
||||
|
||||
def _fake_which(_cmd, path=None):
|
||||
seen_paths.append(path)
|
||||
return None
|
||||
|
||||
with patch("tools.mcp_tool.shutil.which", side_effect=_fake_which):
|
||||
command, env = _resolve_stdio_command("python", {"PATH": ""})
|
||||
|
||||
assert command == "python"
|
||||
assert env["PATH"] == ""
|
||||
assert seen_paths == [""]
|
||||
|
||||
|
||||
def test_format_connect_error_unwraps_exception_group():
|
||||
error = ExceptionGroup(
|
||||
"unhandled errors in a TaskGroup",
|
||||
[FileNotFoundError(2, "No such file or directory", "node")],
|
||||
)
|
||||
|
||||
message = _format_connect_error(error)
|
||||
|
||||
assert "missing executable 'node'" in message
|
||||
|
||||
|
||||
def test_run_stdio_uses_resolved_command_and_prepended_path(tmp_path):
|
||||
node_bin = tmp_path / "node" / "bin"
|
||||
node_bin.mkdir(parents=True)
|
||||
npx_path = node_bin / "npx"
|
||||
npx_path.write_text("#!/bin/sh\nexit 0\n", encoding="utf-8")
|
||||
npx_path.chmod(0o755)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.initialize = AsyncMock()
|
||||
mock_session.list_tools = AsyncMock(return_value=SimpleNamespace(tools=[]))
|
||||
|
||||
mock_stdio_cm = MagicMock()
|
||||
mock_stdio_cm.__aenter__ = AsyncMock(return_value=(object(), object()))
|
||||
mock_stdio_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_session_cm = MagicMock()
|
||||
mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def _test():
|
||||
with patch("tools.mcp_tool.shutil.which", return_value=None), \
|
||||
patch.dict("os.environ", {"HERMES_HOME": str(tmp_path), "PATH": "/usr/bin", "HOME": str(tmp_path)}, clear=False), \
|
||||
patch("tools.mcp_tool.StdioServerParameters") as mock_params, \
|
||||
patch("tools.mcp_tool.stdio_client", return_value=mock_stdio_cm), \
|
||||
patch("tools.mcp_tool.ClientSession", return_value=mock_session_cm):
|
||||
server = MCPServerTask("srv")
|
||||
await server.start({"command": "npx", "args": ["-y", "pkg"], "env": {"PATH": "/usr/bin"}})
|
||||
|
||||
call_kwargs = mock_params.call_args.kwargs
|
||||
assert call_kwargs["command"] == str(npx_path)
|
||||
assert call_kwargs["env"]["PATH"].split(os.pathsep)[0] == str(node_bin)
|
||||
|
||||
await server.shutdown()
|
||||
|
||||
asyncio.run(_test())
|
||||
@@ -75,6 +75,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -176,6 +177,116 @@ def _sanitize_error(text: str) -> str:
|
||||
return _CREDENTIAL_PATTERN.sub("[REDACTED]", text)
|
||||
|
||||
|
||||
def _prepend_path(env: dict, directory: str) -> dict:
|
||||
"""Prepend *directory* to env PATH if it is not already present."""
|
||||
updated = dict(env or {})
|
||||
if not directory:
|
||||
return updated
|
||||
|
||||
existing = updated.get("PATH", "")
|
||||
parts = [part for part in existing.split(os.pathsep) if part]
|
||||
if directory not in parts:
|
||||
parts = [directory, *parts]
|
||||
updated["PATH"] = os.pathsep.join(parts) if parts else directory
|
||||
return updated
|
||||
|
||||
|
||||
def _resolve_stdio_command(command: str, env: dict) -> tuple[str, dict]:
|
||||
"""Resolve a stdio MCP command against the exact subprocess environment.
|
||||
|
||||
This primarily exists to make bare ``npx``/``npm``/``node`` commands work
|
||||
reliably even when MCP subprocesses run under a filtered PATH.
|
||||
"""
|
||||
resolved_command = os.path.expanduser(str(command).strip())
|
||||
resolved_env = dict(env or {})
|
||||
|
||||
if os.sep not in resolved_command:
|
||||
path_arg = resolved_env["PATH"] if "PATH" in resolved_env else None
|
||||
which_hit = shutil.which(resolved_command, path=path_arg)
|
||||
if which_hit:
|
||||
resolved_command = which_hit
|
||||
elif resolved_command in {"npx", "npm", "node"}:
|
||||
hermes_home = os.path.expanduser(
|
||||
os.getenv(
|
||||
"HERMES_HOME", os.path.join(os.path.expanduser("~"), ".hermes")
|
||||
)
|
||||
)
|
||||
candidates = [
|
||||
os.path.join(hermes_home, "node", "bin", resolved_command),
|
||||
os.path.join(os.path.expanduser("~"), ".local", "bin", resolved_command),
|
||||
]
|
||||
for candidate in candidates:
|
||||
if os.path.isfile(candidate) and os.access(candidate, os.X_OK):
|
||||
resolved_command = candidate
|
||||
break
|
||||
|
||||
command_dir = os.path.dirname(resolved_command)
|
||||
if command_dir:
|
||||
resolved_env = _prepend_path(resolved_env, command_dir)
|
||||
|
||||
return resolved_command, resolved_env
|
||||
|
||||
|
||||
def _format_connect_error(exc: BaseException) -> str:
|
||||
"""Render nested MCP connection errors into an actionable short message."""
|
||||
|
||||
def _find_missing(current: BaseException) -> Optional[str]:
|
||||
nested = getattr(current, "exceptions", None)
|
||||
if nested:
|
||||
for child in nested:
|
||||
missing = _find_missing(child)
|
||||
if missing:
|
||||
return missing
|
||||
return None
|
||||
if isinstance(current, FileNotFoundError):
|
||||
if getattr(current, "filename", None):
|
||||
return str(current.filename)
|
||||
match = re.search(r"No such file or directory: '([^']+)'", str(current))
|
||||
if match:
|
||||
return match.group(1)
|
||||
for attr in ("__cause__", "__context__"):
|
||||
nested_exc = getattr(current, attr, None)
|
||||
if isinstance(nested_exc, BaseException):
|
||||
missing = _find_missing(nested_exc)
|
||||
if missing:
|
||||
return missing
|
||||
return None
|
||||
|
||||
def _flatten_messages(current: BaseException) -> List[str]:
|
||||
nested = getattr(current, "exceptions", None)
|
||||
if nested:
|
||||
flattened: List[str] = []
|
||||
for child in nested:
|
||||
flattened.extend(_flatten_messages(child))
|
||||
return flattened
|
||||
messages = []
|
||||
text = str(current).strip()
|
||||
if text:
|
||||
messages.append(text)
|
||||
for attr in ("__cause__", "__context__"):
|
||||
nested_exc = getattr(current, attr, None)
|
||||
if isinstance(nested_exc, BaseException):
|
||||
messages.extend(_flatten_messages(nested_exc))
|
||||
return messages or [current.__class__.__name__]
|
||||
|
||||
missing = _find_missing(exc)
|
||||
if missing:
|
||||
message = f"missing executable '{missing}'"
|
||||
if os.path.basename(missing) in {"npx", "npm", "node"}:
|
||||
message += (
|
||||
" (ensure Node.js is installed and PATH includes its bin directory, "
|
||||
"or set mcp_servers.<name>.command to an absolute path and include "
|
||||
"that directory in mcp_servers.<name>.env.PATH)"
|
||||
)
|
||||
return _sanitize_error(message)
|
||||
|
||||
deduped: List[str] = []
|
||||
for item in _flatten_messages(exc):
|
||||
if item not in deduped:
|
||||
deduped.append(item)
|
||||
return _sanitize_error("; ".join(deduped[:3]))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sampling -- server-initiated LLM requests (MCP sampling/createMessage)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -608,6 +719,7 @@ class MCPServerTask:
|
||||
)
|
||||
|
||||
safe_env = _build_safe_env(user_env)
|
||||
command, safe_env = _resolve_stdio_command(command, safe_env)
|
||||
server_params = StdioServerParameters(
|
||||
command=command,
|
||||
args=args,
|
||||
@@ -1340,9 +1452,12 @@ def discover_mcp_tools() -> List[str]:
|
||||
for name, result in zip(server_names, results):
|
||||
if isinstance(result, Exception):
|
||||
failed_count += 1
|
||||
command = new_servers.get(name, {}).get("command")
|
||||
logger.warning(
|
||||
"Failed to connect to MCP server '%s': %s",
|
||||
name, result,
|
||||
"Failed to connect to MCP server '%s'%s: %s",
|
||||
name,
|
||||
f" (command={command})" if command else "",
|
||||
_format_connect_error(result),
|
||||
)
|
||||
elif isinstance(result, list):
|
||||
all_tools.extend(result)
|
||||
|
||||
Reference in New Issue
Block a user