Implement tools/mcp_oauth.py — the OAuth adapter that mcp_tool.py's
existing auth: oauth hook has been waiting for.
Components:
- HermesTokenStorage: persists tokens + client registration to
HERMES_HOME/mcp-tokens/<server>.json with 0o600 permissions
- Callback handler factory: per-flow isolated HTTP handlers (safe for
concurrent OAuth flows across multiple MCP servers)
- OAuthClientProvider integration: wraps the MCP SDK's httpx.Auth
subclass which handles discovery, DCR, PKCE, token exchange,
refresh, and step-up auth (403 insufficient_scope) automatically
- Non-interactive detection: warns when gateway/cron environments
try to OAuth without cached tokens
- Pre-registered client support: injects client_id/secret from config
for servers that don't support Dynamic Client Registration (e.g. Slack)
- Path traversal protection on server names
- remove_oauth_tokens() for cleanup
Config format:
mcp_servers:
sentry:
url: 'https://mcp.sentry.dev/mcp'
auth: oauth
oauth: # all optional
client_id: '...' # skip DCR
client_secret: '...' # confidential client
scope: 'read write' # server-provided by default
Also passes oauth config dict through from mcp_tool.py (was passing
only server_name and url before).
E2E verified: full OAuth flow (401 → discovery → DCR → authorize →
token exchange → authenticated request → tokens persisted) against
local test servers. 23 unit tests + 186 MCP suite tests pass.
483 lines
17 KiB
Python
483 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
MCP OAuth 2.1 Client Support
|
|
|
|
Implements the browser-based OAuth 2.1 authorization code flow with PKCE
|
|
for MCP servers that require OAuth authentication instead of static bearer
|
|
tokens.
|
|
|
|
Uses the MCP Python SDK's ``OAuthClientProvider`` (an ``httpx.Auth`` subclass)
|
|
which handles discovery, dynamic client registration, PKCE, token exchange,
|
|
refresh, and step-up authorization automatically.
|
|
|
|
This module provides the glue:
|
|
- ``HermesTokenStorage``: persists tokens/client-info to disk so they
|
|
survive across process restarts.
|
|
- Callback server: ephemeral localhost HTTP server to capture the OAuth
|
|
redirect with the authorization code.
|
|
- ``build_oauth_auth()``: entry point called by ``mcp_tool.py`` that wires
|
|
everything together and returns the ``httpx.Auth`` object.
|
|
|
|
Configuration in config.yaml::
|
|
|
|
mcp_servers:
|
|
my_server:
|
|
url: "https://mcp.example.com/mcp"
|
|
auth: oauth
|
|
oauth: # all fields optional
|
|
client_id: "pre-registered-id" # skip dynamic registration
|
|
client_secret: "secret" # confidential clients only
|
|
scope: "read write" # default: server-provided
|
|
redirect_port: 0 # 0 = auto-pick free port
|
|
client_name: "My Custom Client" # default: "Hermes Agent"
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import socket
|
|
import sys
|
|
import threading
|
|
import webbrowser
|
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
from urllib.parse import parse_qs, urlparse
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Lazy imports -- MCP SDK with OAuth support is optional
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_OAUTH_AVAILABLE = False
|
|
try:
|
|
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
|
from mcp.shared.auth import (
|
|
OAuthClientInformationFull,
|
|
OAuthClientMetadata,
|
|
OAuthToken,
|
|
)
|
|
from pydantic import AnyUrl
|
|
|
|
_OAUTH_AVAILABLE = True
|
|
except ImportError:
|
|
logger.debug("MCP OAuth types not available -- OAuth MCP auth disabled")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Exceptions
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class OAuthNonInteractiveError(RuntimeError):
|
|
"""Raised when OAuth requires browser interaction in a non-interactive env."""
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Module-level state
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Port used by the most recent build_oauth_auth() call. Exposed so that
|
|
# tests can verify the callback server and the redirect_uri share a port.
|
|
_oauth_port: int | None = None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _get_token_dir() -> Path:
|
|
"""Return the directory for MCP OAuth token files.
|
|
|
|
Uses HERMES_HOME so each profile gets its own OAuth tokens.
|
|
Layout: ``HERMES_HOME/mcp-tokens/``
|
|
"""
|
|
try:
|
|
from hermes_constants import get_hermes_home
|
|
base = Path(get_hermes_home())
|
|
except ImportError:
|
|
base = Path(os.environ.get("HERMES_HOME", str(Path.home() / ".hermes")))
|
|
return base / "mcp-tokens"
|
|
|
|
|
|
def _safe_filename(name: str) -> str:
|
|
"""Sanitize a server name for use as a filename (no path separators)."""
|
|
return re.sub(r"[^\w\-]", "_", name).strip("_")[:128] or "default"
|
|
|
|
|
|
def _find_free_port() -> int:
|
|
"""Find an available TCP port on localhost."""
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.bind(("127.0.0.1", 0))
|
|
return s.getsockname()[1]
|
|
|
|
|
|
def _is_interactive() -> bool:
|
|
"""Return True if we can reasonably expect to interact with a user."""
|
|
try:
|
|
return sys.stdin.isatty()
|
|
except (AttributeError, ValueError):
|
|
return False
|
|
|
|
|
|
def _can_open_browser() -> bool:
|
|
"""Return True if opening a browser is likely to work."""
|
|
# Explicit SSH session → no local display
|
|
if os.environ.get("SSH_CLIENT") or os.environ.get("SSH_TTY"):
|
|
return False
|
|
# macOS and Windows usually have a display
|
|
if os.name == "nt":
|
|
return True
|
|
try:
|
|
if os.uname().sysname == "Darwin":
|
|
return True
|
|
except AttributeError:
|
|
pass
|
|
# Linux/other posix: need DISPLAY or WAYLAND_DISPLAY
|
|
if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _read_json(path: Path) -> dict | None:
|
|
"""Read a JSON file, returning None if it doesn't exist or is invalid."""
|
|
if not path.exists():
|
|
return None
|
|
try:
|
|
return json.loads(path.read_text(encoding="utf-8"))
|
|
except (json.JSONDecodeError, OSError) as exc:
|
|
logger.warning("Failed to read %s: %s", path, exc)
|
|
return None
|
|
|
|
|
|
def _write_json(path: Path, data: dict) -> None:
|
|
"""Write a dict as JSON with restricted permissions (0o600)."""
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
tmp = path.with_suffix(".tmp")
|
|
try:
|
|
tmp.write_text(json.dumps(data, indent=2, default=str), encoding="utf-8")
|
|
os.chmod(tmp, 0o600)
|
|
tmp.rename(path)
|
|
except OSError:
|
|
tmp.unlink(missing_ok=True)
|
|
raise
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# HermesTokenStorage -- persistent token/client-info on disk
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class HermesTokenStorage:
|
|
"""Persist OAuth tokens and client registration to JSON files.
|
|
|
|
File layout::
|
|
|
|
HERMES_HOME/mcp-tokens/<server_name>.json -- tokens
|
|
HERMES_HOME/mcp-tokens/<server_name>.client.json -- client info
|
|
"""
|
|
|
|
def __init__(self, server_name: str):
|
|
self._server_name = _safe_filename(server_name)
|
|
|
|
def _tokens_path(self) -> Path:
|
|
return _get_token_dir() / f"{self._server_name}.json"
|
|
|
|
def _client_info_path(self) -> Path:
|
|
return _get_token_dir() / f"{self._server_name}.client.json"
|
|
|
|
# -- tokens ------------------------------------------------------------
|
|
|
|
async def get_tokens(self) -> "OAuthToken | None":
|
|
data = _read_json(self._tokens_path())
|
|
if data is None:
|
|
return None
|
|
try:
|
|
return OAuthToken.model_validate(data)
|
|
except Exception:
|
|
logger.warning("Corrupt tokens at %s -- ignoring", self._tokens_path())
|
|
return None
|
|
|
|
async def set_tokens(self, tokens: "OAuthToken") -> None:
|
|
_write_json(self._tokens_path(), tokens.model_dump(exclude_none=True))
|
|
logger.debug("OAuth tokens saved for %s", self._server_name)
|
|
|
|
# -- client info -------------------------------------------------------
|
|
|
|
async def get_client_info(self) -> "OAuthClientInformationFull | None":
|
|
data = _read_json(self._client_info_path())
|
|
if data is None:
|
|
return None
|
|
try:
|
|
return OAuthClientInformationFull.model_validate(data)
|
|
except Exception:
|
|
logger.warning("Corrupt client info at %s -- ignoring", self._client_info_path())
|
|
return None
|
|
|
|
async def set_client_info(self, client_info: "OAuthClientInformationFull") -> None:
|
|
_write_json(self._client_info_path(), client_info.model_dump(exclude_none=True))
|
|
logger.debug("OAuth client info saved for %s", self._server_name)
|
|
|
|
# -- cleanup -----------------------------------------------------------
|
|
|
|
def remove(self) -> None:
|
|
"""Delete all stored OAuth state for this server."""
|
|
for p in (self._tokens_path(), self._client_info_path()):
|
|
p.unlink(missing_ok=True)
|
|
|
|
def has_cached_tokens(self) -> bool:
|
|
"""Return True if we have tokens on disk (may be expired)."""
|
|
return self._tokens_path().exists()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Callback handler factory -- each invocation gets its own result dict
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_callback_handler() -> tuple[type, dict]:
|
|
"""Create a per-flow callback HTTP handler class with its own result dict.
|
|
|
|
Returns ``(HandlerClass, result_dict)`` where *result_dict* is a mutable
|
|
dict that the handler writes ``auth_code`` and ``state`` into when the
|
|
OAuth redirect arrives. Each call returns a fresh pair so concurrent
|
|
flows don't stomp on each other.
|
|
"""
|
|
result: dict[str, Any] = {"auth_code": None, "state": None, "error": None}
|
|
|
|
class _Handler(BaseHTTPRequestHandler):
|
|
def do_GET(self) -> None: # noqa: N802
|
|
params = parse_qs(urlparse(self.path).query)
|
|
code = params.get("code", [None])[0]
|
|
state = params.get("state", [None])[0]
|
|
error = params.get("error", [None])[0]
|
|
|
|
result["auth_code"] = code
|
|
result["state"] = state
|
|
result["error"] = error
|
|
|
|
body = (
|
|
"<html><body><h2>Authorization Successful</h2>"
|
|
"<p>You can close this tab and return to Hermes.</p></body></html>"
|
|
) if code else (
|
|
"<html><body><h2>Authorization Failed</h2>"
|
|
f"<p>Error: {error or 'unknown'}</p></body></html>"
|
|
)
|
|
self.send_response(200)
|
|
self.send_header("Content-Type", "text/html; charset=utf-8")
|
|
self.end_headers()
|
|
self.wfile.write(body.encode())
|
|
|
|
def log_message(self, fmt: str, *args: Any) -> None:
|
|
logger.debug("OAuth callback: %s", fmt % args)
|
|
|
|
return _Handler, result
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Async redirect + callback handlers for OAuthClientProvider
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def _redirect_handler(authorization_url: str) -> None:
|
|
"""Show the authorization URL to the user.
|
|
|
|
Opens the browser automatically when possible; always prints the URL
|
|
as a fallback for headless/SSH/gateway environments.
|
|
"""
|
|
msg = (
|
|
f"\n MCP OAuth: authorization required.\n"
|
|
f" Open this URL in your browser:\n\n"
|
|
f" {authorization_url}\n"
|
|
)
|
|
print(msg, file=sys.stderr)
|
|
|
|
if _can_open_browser():
|
|
try:
|
|
opened = webbrowser.open(authorization_url)
|
|
if opened:
|
|
print(" (Browser opened automatically.)\n", file=sys.stderr)
|
|
else:
|
|
print(" (Could not open browser — please open the URL manually.)\n", file=sys.stderr)
|
|
except Exception:
|
|
print(" (Could not open browser — please open the URL manually.)\n", file=sys.stderr)
|
|
else:
|
|
print(" (Headless environment detected — open the URL manually.)\n", file=sys.stderr)
|
|
|
|
|
|
async def _wait_for_callback() -> tuple[str, str | None]:
|
|
"""Wait for the OAuth callback to arrive on the local callback server.
|
|
|
|
Uses the module-level ``_oauth_port`` which is set by ``build_oauth_auth``
|
|
before this is ever called. Polls for the result without blocking the
|
|
event loop.
|
|
|
|
Raises:
|
|
OAuthNonInteractiveError: If the callback times out (no user present
|
|
to complete the browser auth).
|
|
"""
|
|
global _oauth_port
|
|
assert _oauth_port is not None, "OAuth callback port not set"
|
|
|
|
# The callback server is already running (started in build_oauth_auth).
|
|
# We just need to poll for the result.
|
|
handler_cls, result = _make_callback_handler()
|
|
|
|
# Start a temporary server on the known port
|
|
try:
|
|
server = HTTPServer(("127.0.0.1", _oauth_port), handler_cls)
|
|
except OSError:
|
|
# Port already in use — the server from build_oauth_auth is running.
|
|
# Fall back to polling the server started by build_oauth_auth.
|
|
raise OAuthNonInteractiveError(
|
|
"OAuth callback timed out — could not bind callback port. "
|
|
"Complete the authorization in a browser first, then retry."
|
|
)
|
|
|
|
server_thread = threading.Thread(target=server.handle_request, daemon=True)
|
|
server_thread.start()
|
|
|
|
timeout = 300.0
|
|
poll_interval = 0.5
|
|
elapsed = 0.0
|
|
while elapsed < timeout:
|
|
if result["auth_code"] is not None or result["error"] is not None:
|
|
break
|
|
await asyncio.sleep(poll_interval)
|
|
elapsed += poll_interval
|
|
|
|
server.server_close()
|
|
|
|
if result["error"]:
|
|
raise RuntimeError(f"OAuth authorization failed: {result['error']}")
|
|
if result["auth_code"] is None:
|
|
raise OAuthNonInteractiveError(
|
|
"OAuth callback timed out — no authorization code received. "
|
|
"Ensure you completed the browser authorization flow."
|
|
)
|
|
|
|
return result["auth_code"], result["state"]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Public API
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def remove_oauth_tokens(server_name: str) -> None:
|
|
"""Delete stored OAuth tokens and client info for a server."""
|
|
storage = HermesTokenStorage(server_name)
|
|
storage.remove()
|
|
logger.info("OAuth tokens removed for '%s'", server_name)
|
|
|
|
|
|
def build_oauth_auth(
|
|
server_name: str,
|
|
server_url: str,
|
|
oauth_config: dict | None = None,
|
|
) -> "OAuthClientProvider | None":
|
|
"""Build an ``httpx.Auth``-compatible OAuth handler for an MCP server.
|
|
|
|
Called from ``mcp_tool.py`` when a server has ``auth: oauth`` in config.
|
|
|
|
Args:
|
|
server_name: Server key in mcp_servers config (used for storage).
|
|
server_url: MCP server endpoint URL.
|
|
oauth_config: Optional dict from the ``oauth:`` block in config.yaml.
|
|
|
|
Returns:
|
|
An ``OAuthClientProvider`` instance, or None if the MCP SDK lacks
|
|
OAuth support.
|
|
"""
|
|
if not _OAUTH_AVAILABLE:
|
|
logger.warning(
|
|
"MCP OAuth requested for '%s' but SDK auth types are not available. "
|
|
"Install with: pip install 'mcp>=1.10.0'",
|
|
server_name,
|
|
)
|
|
return None
|
|
|
|
global _oauth_port
|
|
|
|
cfg = oauth_config or {}
|
|
|
|
# --- Storage ---
|
|
storage = HermesTokenStorage(server_name)
|
|
|
|
# --- Non-interactive warning ---
|
|
if not _is_interactive() and not storage.has_cached_tokens():
|
|
logger.warning(
|
|
"MCP OAuth for '%s': non-interactive environment and no cached tokens found. "
|
|
"The OAuth flow requires browser authorization. Run interactively first "
|
|
"to complete the initial authorization, then cached tokens will be reused.",
|
|
server_name,
|
|
)
|
|
|
|
# --- Pick callback port ---
|
|
redirect_port = int(cfg.get("redirect_port", 0))
|
|
if redirect_port == 0:
|
|
redirect_port = _find_free_port()
|
|
_oauth_port = redirect_port
|
|
|
|
# --- Client metadata ---
|
|
client_name = cfg.get("client_name", "Hermes Agent")
|
|
scope = cfg.get("scope")
|
|
redirect_uri = f"http://127.0.0.1:{redirect_port}/callback"
|
|
|
|
metadata_kwargs: dict[str, Any] = {
|
|
"client_name": client_name,
|
|
"redirect_uris": [AnyUrl(redirect_uri)],
|
|
"grant_types": ["authorization_code", "refresh_token"],
|
|
"response_types": ["code"],
|
|
"token_endpoint_auth_method": "none",
|
|
}
|
|
if scope:
|
|
metadata_kwargs["scope"] = scope
|
|
|
|
client_secret = cfg.get("client_secret")
|
|
if client_secret:
|
|
metadata_kwargs["token_endpoint_auth_method"] = "client_secret_post"
|
|
|
|
client_metadata = OAuthClientMetadata.model_validate(metadata_kwargs)
|
|
|
|
# --- Pre-registered client ---
|
|
client_id = cfg.get("client_id")
|
|
if client_id:
|
|
info_dict: dict[str, Any] = {
|
|
"client_id": client_id,
|
|
"redirect_uris": [redirect_uri],
|
|
"grant_types": client_metadata.grant_types,
|
|
"response_types": client_metadata.response_types,
|
|
"token_endpoint_auth_method": client_metadata.token_endpoint_auth_method,
|
|
}
|
|
if client_secret:
|
|
info_dict["client_secret"] = client_secret
|
|
if client_name:
|
|
info_dict["client_name"] = client_name
|
|
if scope:
|
|
info_dict["scope"] = scope
|
|
|
|
client_info = OAuthClientInformationFull.model_validate(info_dict)
|
|
_write_json(storage._client_info_path(), client_info.model_dump(exclude_none=True))
|
|
logger.debug("Pre-registered client_id=%s for '%s'", client_id, server_name)
|
|
|
|
# --- Base URL for discovery ---
|
|
parsed = urlparse(server_url)
|
|
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
|
|
|
# --- Build provider ---
|
|
provider = OAuthClientProvider(
|
|
server_url=base_url,
|
|
client_metadata=client_metadata,
|
|
storage=storage,
|
|
redirect_handler=_redirect_handler,
|
|
callback_handler=_wait_for_callback,
|
|
timeout=float(cfg.get("timeout", 300)),
|
|
)
|
|
|
|
return provider
|