From ed805f57ffba09adedb3b53acdc12672e0a63e08 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Sun, 22 Mar 2026 15:02:26 -0700 Subject: [PATCH] fix(mcp-oauth): port mismatch, path traversal, and shared handler state (salvage #2521) (#2552) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(mcp-oauth): port mismatch, path traversal, and shared state in OAuth flow Three bugs in the new MCP OAuth 2.1 PKCE implementation: 1. CRITICAL: OAuth redirect port mismatch — build_oauth_auth() calls _find_free_port() to register the redirect_uri, but _wait_for_callback() calls _find_free_port() again getting a DIFFERENT port. Browser redirects to port A, server listens on port B — callback never arrives, 120s timeout. Fix: share the port via module-level _oauth_port variable. 2. MEDIUM: Path traversal via unsanitized server_name — HermesTokenStorage uses server_name directly in filenames. A name like "../../.ssh/config" writes token files outside ~/.hermes/mcp-tokens/. Fix: sanitize server_name with the same regex pattern used elsewhere. 3. MEDIUM: Class-level auth_code/state on _CallbackHandler causes data races if concurrent OAuth flows run. Second callback overwrites first. Fix: factory function _make_callback_handler() returns a handler class with a closure-scoped result dict, isolating each flow. * test: add tests for MCP OAuth path traversal, handler isolation, and port sharing 7 new tests covering: - Path traversal blocked (../../.ssh/config stays in mcp-tokens/) - Dots/slashes sanitized and resolved within base dir - Normal server names preserved - Special characters sanitized (@, :, /) - Concurrent handler result dicts are independent - Handler writes to its own result dict, not class-level - build_oauth_auth stores port in module-level _oauth_port --------- Co-authored-by: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com> --- tests/tools/test_mcp_oauth.py | 86 +++++++++++++++++++++++++++++++++++ tools/mcp_oauth.py | 66 ++++++++++++++++----------- 2 files changed, 126 insertions(+), 26 deletions(-) diff --git a/tests/tools/test_mcp_oauth.py b/tests/tools/test_mcp_oauth.py index 34c85b23e..66ac3b616 100644 --- a/tests/tools/test_mcp_oauth.py +++ b/tests/tools/test_mcp_oauth.py @@ -134,6 +134,92 @@ class TestUtilities: # remove_oauth_tokens # --------------------------------------------------------------------------- +class TestPathTraversal: + """Verify server_name is sanitized to prevent path traversal.""" + + def test_path_traversal_blocked(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + storage = HermesTokenStorage("../../.ssh/config") + path = storage._tokens_path() + # Should stay within mcp-tokens directory + assert "mcp-tokens" in str(path) + assert ".ssh" not in str(path.resolve()) + + def test_dots_and_slashes_sanitized(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + storage = HermesTokenStorage("../../../etc/passwd") + path = storage._tokens_path() + resolved = path.resolve() + assert resolved.is_relative_to((tmp_path / "mcp-tokens").resolve()) + + def test_normal_name_unchanged(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + storage = HermesTokenStorage("my-mcp-server") + assert "my-mcp-server.json" in str(storage._tokens_path()) + + def test_special_chars_sanitized(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + storage = HermesTokenStorage("server@host:8080/path") + path = storage._tokens_path() + assert "@" not in path.name + assert ":" not in path.name + assert "/" not in path.stem + + +class TestCallbackHandlerIsolation: + """Verify concurrent OAuth flows don't share state.""" + + def test_independent_result_dicts(self): + from tools.mcp_oauth import _make_callback_handler + _, result_a = _make_callback_handler() + _, result_b = _make_callback_handler() + + result_a["auth_code"] = "code_A" + result_b["auth_code"] = "code_B" + + assert result_a["auth_code"] == "code_A" + assert result_b["auth_code"] == "code_B" + + def test_handler_writes_to_own_result(self): + from tools.mcp_oauth import _make_callback_handler + from io import BytesIO + from unittest.mock import MagicMock + + HandlerClass, result = _make_callback_handler() + assert result["auth_code"] is None + + # Simulate a GET request + handler = HandlerClass.__new__(HandlerClass) + handler.path = "/callback?code=test123&state=mystate" + handler.wfile = BytesIO() + handler.send_response = MagicMock() + handler.send_header = MagicMock() + handler.end_headers = MagicMock() + handler.do_GET() + + assert result["auth_code"] == "test123" + assert result["state"] == "mystate" + + +class TestOAuthPortSharing: + """Verify build_oauth_auth and _wait_for_callback use the same port.""" + + def test_port_stored_globally(self): + import tools.mcp_oauth as mod + # Reset + mod._oauth_port = None + + try: + from mcp.client.auth import OAuthClientProvider + except ImportError: + pytest.skip("MCP SDK auth not available") + + build_oauth_auth("test-port", "https://example.com/mcp") + assert mod._oauth_port is not None + assert isinstance(mod._oauth_port, int) + assert 1024 <= mod._oauth_port <= 65535 + + class TestRemoveOAuthTokens: def test_removes_files(self, tmp_path, monkeypatch): monkeypatch.setenv("HERMES_HOME", str(tmp_path)) diff --git a/tools/mcp_oauth.py b/tools/mcp_oauth.py index d8c86ef28..fe5e07d7e 100644 --- a/tools/mcp_oauth.py +++ b/tools/mcp_oauth.py @@ -35,11 +35,19 @@ _TOKEN_DIR_NAME = "mcp-tokens" # Token storage — persists tokens + client info to ~/.hermes/mcp-tokens/ # --------------------------------------------------------------------------- +def _sanitize_server_name(name: str) -> str: + """Sanitize server name for safe use as a filename.""" + import re + clean = re.sub(r"[^\w\-]", "-", name.strip().lower()) + clean = re.sub(r"-+", "-", clean).strip("-") + return clean[:60] or "unnamed" + + class HermesTokenStorage: """File-backed token storage implementing the MCP SDK's TokenStorage protocol.""" def __init__(self, server_name: str): - self._server_name = server_name + self._server_name = _sanitize_server_name(server_name) def _base_dir(self) -> Path: home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) @@ -119,21 +127,28 @@ def _find_free_port() -> int: return s.getsockname()[1] -class _CallbackHandler(BaseHTTPRequestHandler): - auth_code: str | None = None - state: str | None = None +def _make_callback_handler(): + """Create a callback handler class with instance-scoped result storage.""" + result = {"auth_code": None, "state": None} - def do_GET(self): - qs = parse_qs(urlparse(self.path).query) - _CallbackHandler.auth_code = (qs.get("code") or [None])[0] - _CallbackHandler.state = (qs.get("state") or [None])[0] - self.send_response(200) - self.send_header("Content-Type", "text/html") - self.end_headers() - self.wfile.write(b"