Files
hermes-agent/tests/tools/test_mcp_oauth.py
Teknium ed805f57ff fix(mcp-oauth): port mismatch, path traversal, and shared handler state (salvage #2521) (#2552)
* fix(mcp-oauth): port mismatch, path traversal, and shared state in OAuth flow

Three bugs in the new MCP OAuth 2.1 PKCE implementation:

1. CRITICAL: OAuth redirect port mismatch — build_oauth_auth() calls
   _find_free_port() to register the redirect_uri, but _wait_for_callback()
   calls _find_free_port() again getting a DIFFERENT port. Browser redirects
   to port A, server listens on port B — callback never arrives, 120s timeout.
   Fix: share the port via module-level _oauth_port variable.

2. MEDIUM: Path traversal via unsanitized server_name — HermesTokenStorage
   uses server_name directly in filenames. A name like "../../.ssh/config"
   writes token files outside ~/.hermes/mcp-tokens/.
   Fix: sanitize server_name with the same regex pattern used elsewhere.

3. MEDIUM: Class-level auth_code/state on _CallbackHandler causes data
   races if concurrent OAuth flows run. Second callback overwrites first.
   Fix: factory function _make_callback_handler() returns a handler class
   with a closure-scoped result dict, isolating each flow.

* test: add tests for MCP OAuth path traversal, handler isolation, and port sharing

7 new tests covering:
- Path traversal blocked (../../.ssh/config stays in mcp-tokens/)
- Dots/slashes sanitized and resolved within base dir
- Normal server names preserved
- Special characters sanitized (@, :, /)
- Concurrent handler result dicts are independent
- Handler writes to its own result dict, not class-level
- build_oauth_auth stores port in module-level _oauth_port

---------

Co-authored-by: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com>
2026-03-22 15:02:26 -07:00

239 lines
8.4 KiB
Python

"""Tests for tools/mcp_oauth.py — thin OAuth adapter over MCP SDK."""
import json
import os
from pathlib import Path
from unittest.mock import patch, MagicMock, AsyncMock
import pytest
from tools.mcp_oauth import (
HermesTokenStorage,
build_oauth_auth,
remove_oauth_tokens,
_find_free_port,
_can_open_browser,
)
# ---------------------------------------------------------------------------
# HermesTokenStorage
# ---------------------------------------------------------------------------
class TestHermesTokenStorage:
def test_roundtrip_tokens(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
storage = HermesTokenStorage("test-server")
import asyncio
# Initially empty
assert asyncio.run(storage.get_tokens()) is None
# Save and retrieve
mock_token = MagicMock()
mock_token.model_dump.return_value = {
"access_token": "abc123",
"token_type": "Bearer",
"refresh_token": "ref456",
}
asyncio.run(storage.set_tokens(mock_token))
# File exists with correct permissions
token_path = tmp_path / "mcp-tokens" / "test-server.json"
assert token_path.exists()
data = json.loads(token_path.read_text())
assert data["access_token"] == "abc123"
def test_roundtrip_client_info(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
storage = HermesTokenStorage("test-server")
import asyncio
assert asyncio.run(storage.get_client_info()) is None
mock_client = MagicMock()
mock_client.model_dump.return_value = {
"client_id": "hermes-123",
"client_secret": "secret",
}
asyncio.run(storage.set_client_info(mock_client))
client_path = tmp_path / "mcp-tokens" / "test-server.client.json"
assert client_path.exists()
def test_remove_cleans_up(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
storage = HermesTokenStorage("test-server")
# Create files
d = tmp_path / "mcp-tokens"
d.mkdir(parents=True)
(d / "test-server.json").write_text("{}")
(d / "test-server.client.json").write_text("{}")
storage.remove()
assert not (d / "test-server.json").exists()
assert not (d / "test-server.client.json").exists()
# ---------------------------------------------------------------------------
# build_oauth_auth
# ---------------------------------------------------------------------------
class TestBuildOAuthAuth:
def test_returns_oauth_provider(self):
try:
from mcp.client.auth import OAuthClientProvider
except ImportError:
pytest.skip("MCP SDK auth not available")
auth = build_oauth_auth("test", "https://example.com/mcp")
assert isinstance(auth, OAuthClientProvider)
def test_returns_none_without_sdk(self, monkeypatch):
import tools.mcp_oauth as mod
orig_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__
def _block_import(name, *args, **kwargs):
if "mcp.client.auth" in name:
raise ImportError("blocked")
return orig_import(name, *args, **kwargs)
with patch("builtins.__import__", side_effect=_block_import):
result = build_oauth_auth("test", "https://example.com")
# May or may not be None depending on import caching, but shouldn't crash
assert result is None or result is not None
# ---------------------------------------------------------------------------
# Utility functions
# ---------------------------------------------------------------------------
class TestUtilities:
def test_find_free_port_returns_int(self):
port = _find_free_port()
assert isinstance(port, int)
assert 1024 <= port <= 65535
def test_can_open_browser_false_in_ssh(self, monkeypatch):
monkeypatch.setenv("SSH_CLIENT", "1.2.3.4 1234 22")
assert _can_open_browser() is False
def test_can_open_browser_false_without_display(self, monkeypatch):
monkeypatch.delenv("SSH_CLIENT", raising=False)
monkeypatch.delenv("SSH_TTY", raising=False)
monkeypatch.delenv("DISPLAY", raising=False)
# Mock os.name and uname for non-macOS, non-Windows
monkeypatch.setattr(os, "name", "posix")
monkeypatch.setattr(os, "uname", lambda: type("", (), {"sysname": "Linux"})())
assert _can_open_browser() is False
# ---------------------------------------------------------------------------
# remove_oauth_tokens
# ---------------------------------------------------------------------------
class TestPathTraversal:
"""Verify server_name is sanitized to prevent path traversal."""
def test_path_traversal_blocked(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
storage = HermesTokenStorage("../../.ssh/config")
path = storage._tokens_path()
# Should stay within mcp-tokens directory
assert "mcp-tokens" in str(path)
assert ".ssh" not in str(path.resolve())
def test_dots_and_slashes_sanitized(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
storage = HermesTokenStorage("../../../etc/passwd")
path = storage._tokens_path()
resolved = path.resolve()
assert resolved.is_relative_to((tmp_path / "mcp-tokens").resolve())
def test_normal_name_unchanged(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
storage = HermesTokenStorage("my-mcp-server")
assert "my-mcp-server.json" in str(storage._tokens_path())
def test_special_chars_sanitized(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
storage = HermesTokenStorage("server@host:8080/path")
path = storage._tokens_path()
assert "@" not in path.name
assert ":" not in path.name
assert "/" not in path.stem
class TestCallbackHandlerIsolation:
"""Verify concurrent OAuth flows don't share state."""
def test_independent_result_dicts(self):
from tools.mcp_oauth import _make_callback_handler
_, result_a = _make_callback_handler()
_, result_b = _make_callback_handler()
result_a["auth_code"] = "code_A"
result_b["auth_code"] = "code_B"
assert result_a["auth_code"] == "code_A"
assert result_b["auth_code"] == "code_B"
def test_handler_writes_to_own_result(self):
from tools.mcp_oauth import _make_callback_handler
from io import BytesIO
from unittest.mock import MagicMock
HandlerClass, result = _make_callback_handler()
assert result["auth_code"] is None
# Simulate a GET request
handler = HandlerClass.__new__(HandlerClass)
handler.path = "/callback?code=test123&state=mystate"
handler.wfile = BytesIO()
handler.send_response = MagicMock()
handler.send_header = MagicMock()
handler.end_headers = MagicMock()
handler.do_GET()
assert result["auth_code"] == "test123"
assert result["state"] == "mystate"
class TestOAuthPortSharing:
"""Verify build_oauth_auth and _wait_for_callback use the same port."""
def test_port_stored_globally(self):
import tools.mcp_oauth as mod
# Reset
mod._oauth_port = None
try:
from mcp.client.auth import OAuthClientProvider
except ImportError:
pytest.skip("MCP SDK auth not available")
build_oauth_auth("test-port", "https://example.com/mcp")
assert mod._oauth_port is not None
assert isinstance(mod._oauth_port, int)
assert 1024 <= mod._oauth_port <= 65535
class TestRemoveOAuthTokens:
def test_removes_files(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
d = tmp_path / "mcp-tokens"
d.mkdir()
(d / "myserver.json").write_text("{}")
(d / "myserver.client.json").write_text("{}")
remove_oauth_tokens("myserver")
assert not (d / "myserver.json").exists()
assert not (d / "myserver.client.json").exists()
def test_no_error_when_files_missing(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
remove_oauth_tokens("nonexistent") # should not raise