fix(mcp): resolve npx stdio connection failures (#1291)

Salvaged from PR #977 onto current main.
Preserves the MCP stdio command resolution and improved error diagnostics,
with deterministic regression tests for the npx/node PATH cases.

Co-authored-by: kshitij <82637225+kshitijk4poor@users.noreply.github.com>
This commit is contained in:
Teknium
2026-03-14 05:44:00 -07:00
committed by GitHub
parent 1a857123b3
commit b646440ca0
2 changed files with 203 additions and 2 deletions

View File

@@ -0,0 +1,86 @@
import asyncio
import os
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from tools.mcp_tool import MCPServerTask, _format_connect_error, _resolve_stdio_command
def test_resolve_stdio_command_falls_back_to_hermes_node_bin(tmp_path):
node_bin = tmp_path / "node" / "bin"
node_bin.mkdir(parents=True)
npx_path = node_bin / "npx"
npx_path.write_text("#!/bin/sh\nexit 0\n", encoding="utf-8")
npx_path.chmod(0o755)
with patch("tools.mcp_tool.shutil.which", return_value=None), \
patch.dict("os.environ", {"HERMES_HOME": str(tmp_path)}, clear=False):
command, env = _resolve_stdio_command("npx", {"PATH": "/usr/bin"})
assert command == str(npx_path)
assert env["PATH"].split(os.pathsep)[0] == str(node_bin)
def test_resolve_stdio_command_respects_explicit_empty_path():
seen_paths = []
def _fake_which(_cmd, path=None):
seen_paths.append(path)
return None
with patch("tools.mcp_tool.shutil.which", side_effect=_fake_which):
command, env = _resolve_stdio_command("python", {"PATH": ""})
assert command == "python"
assert env["PATH"] == ""
assert seen_paths == [""]
def test_format_connect_error_unwraps_exception_group():
error = ExceptionGroup(
"unhandled errors in a TaskGroup",
[FileNotFoundError(2, "No such file or directory", "node")],
)
message = _format_connect_error(error)
assert "missing executable 'node'" in message
def test_run_stdio_uses_resolved_command_and_prepended_path(tmp_path):
node_bin = tmp_path / "node" / "bin"
node_bin.mkdir(parents=True)
npx_path = node_bin / "npx"
npx_path.write_text("#!/bin/sh\nexit 0\n", encoding="utf-8")
npx_path.chmod(0o755)
mock_session = MagicMock()
mock_session.initialize = AsyncMock()
mock_session.list_tools = AsyncMock(return_value=SimpleNamespace(tools=[]))
mock_stdio_cm = MagicMock()
mock_stdio_cm.__aenter__ = AsyncMock(return_value=(object(), object()))
mock_stdio_cm.__aexit__ = AsyncMock(return_value=False)
mock_session_cm = MagicMock()
mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_cm.__aexit__ = AsyncMock(return_value=False)
async def _test():
with patch("tools.mcp_tool.shutil.which", return_value=None), \
patch.dict("os.environ", {"HERMES_HOME": str(tmp_path), "PATH": "/usr/bin", "HOME": str(tmp_path)}, clear=False), \
patch("tools.mcp_tool.StdioServerParameters") as mock_params, \
patch("tools.mcp_tool.stdio_client", return_value=mock_stdio_cm), \
patch("tools.mcp_tool.ClientSession", return_value=mock_session_cm):
server = MCPServerTask("srv")
await server.start({"command": "npx", "args": ["-y", "pkg"], "env": {"PATH": "/usr/bin"}})
call_kwargs = mock_params.call_args.kwargs
assert call_kwargs["command"] == str(npx_path)
assert call_kwargs["env"]["PATH"].split(os.pathsep)[0] == str(node_bin)
await server.shutdown()
asyncio.run(_test())

View File

@@ -75,6 +75,7 @@ import logging
import math
import os
import re
import shutil
import threading
import time
from typing import Any, Dict, List, Optional
@@ -176,6 +177,116 @@ def _sanitize_error(text: str) -> str:
return _CREDENTIAL_PATTERN.sub("[REDACTED]", text)
def _prepend_path(env: dict, directory: str) -> dict:
"""Prepend *directory* to env PATH if it is not already present."""
updated = dict(env or {})
if not directory:
return updated
existing = updated.get("PATH", "")
parts = [part for part in existing.split(os.pathsep) if part]
if directory not in parts:
parts = [directory, *parts]
updated["PATH"] = os.pathsep.join(parts) if parts else directory
return updated
def _resolve_stdio_command(command: str, env: dict) -> tuple[str, dict]:
"""Resolve a stdio MCP command against the exact subprocess environment.
This primarily exists to make bare ``npx``/``npm``/``node`` commands work
reliably even when MCP subprocesses run under a filtered PATH.
"""
resolved_command = os.path.expanduser(str(command).strip())
resolved_env = dict(env or {})
if os.sep not in resolved_command:
path_arg = resolved_env["PATH"] if "PATH" in resolved_env else None
which_hit = shutil.which(resolved_command, path=path_arg)
if which_hit:
resolved_command = which_hit
elif resolved_command in {"npx", "npm", "node"}:
hermes_home = os.path.expanduser(
os.getenv(
"HERMES_HOME", os.path.join(os.path.expanduser("~"), ".hermes")
)
)
candidates = [
os.path.join(hermes_home, "node", "bin", resolved_command),
os.path.join(os.path.expanduser("~"), ".local", "bin", resolved_command),
]
for candidate in candidates:
if os.path.isfile(candidate) and os.access(candidate, os.X_OK):
resolved_command = candidate
break
command_dir = os.path.dirname(resolved_command)
if command_dir:
resolved_env = _prepend_path(resolved_env, command_dir)
return resolved_command, resolved_env
def _format_connect_error(exc: BaseException) -> str:
"""Render nested MCP connection errors into an actionable short message."""
def _find_missing(current: BaseException) -> Optional[str]:
nested = getattr(current, "exceptions", None)
if nested:
for child in nested:
missing = _find_missing(child)
if missing:
return missing
return None
if isinstance(current, FileNotFoundError):
if getattr(current, "filename", None):
return str(current.filename)
match = re.search(r"No such file or directory: '([^']+)'", str(current))
if match:
return match.group(1)
for attr in ("__cause__", "__context__"):
nested_exc = getattr(current, attr, None)
if isinstance(nested_exc, BaseException):
missing = _find_missing(nested_exc)
if missing:
return missing
return None
def _flatten_messages(current: BaseException) -> List[str]:
nested = getattr(current, "exceptions", None)
if nested:
flattened: List[str] = []
for child in nested:
flattened.extend(_flatten_messages(child))
return flattened
messages = []
text = str(current).strip()
if text:
messages.append(text)
for attr in ("__cause__", "__context__"):
nested_exc = getattr(current, attr, None)
if isinstance(nested_exc, BaseException):
messages.extend(_flatten_messages(nested_exc))
return messages or [current.__class__.__name__]
missing = _find_missing(exc)
if missing:
message = f"missing executable '{missing}'"
if os.path.basename(missing) in {"npx", "npm", "node"}:
message += (
" (ensure Node.js is installed and PATH includes its bin directory, "
"or set mcp_servers.<name>.command to an absolute path and include "
"that directory in mcp_servers.<name>.env.PATH)"
)
return _sanitize_error(message)
deduped: List[str] = []
for item in _flatten_messages(exc):
if item not in deduped:
deduped.append(item)
return _sanitize_error("; ".join(deduped[:3]))
# ---------------------------------------------------------------------------
# Sampling -- server-initiated LLM requests (MCP sampling/createMessage)
# ---------------------------------------------------------------------------
@@ -608,6 +719,7 @@ class MCPServerTask:
)
safe_env = _build_safe_env(user_env)
command, safe_env = _resolve_stdio_command(command, safe_env)
server_params = StdioServerParameters(
command=command,
args=args,
@@ -1340,9 +1452,12 @@ def discover_mcp_tools() -> List[str]:
for name, result in zip(server_names, results):
if isinstance(result, Exception):
failed_count += 1
command = new_servers.get(name, {}).get("command")
logger.warning(
"Failed to connect to MCP server '%s': %s",
name, result,
"Failed to connect to MCP server '%s'%s: %s",
name,
f" (command={command})" if command else "",
_format_connect_error(result),
)
elif isinstance(result, list):
all_tools.extend(result)