#!/usr/bin/env python3 """ MCP OAuth 2.1 Client Support Implements the browser-based OAuth 2.1 authorization code flow with PKCE for MCP servers that require OAuth authentication instead of static bearer tokens. Uses the MCP Python SDK's ``OAuthClientProvider`` (an ``httpx.Auth`` subclass) which handles discovery, dynamic client registration, PKCE, token exchange, refresh, and step-up authorization automatically. This module provides the glue: - ``HermesTokenStorage``: persists tokens/client-info to disk so they survive across process restarts. - Callback server: ephemeral localhost HTTP server to capture the OAuth redirect with the authorization code. - ``build_oauth_auth()``: entry point called by ``mcp_tool.py`` that wires everything together and returns the ``httpx.Auth`` object. Configuration in config.yaml:: mcp_servers: my_server: url: "https://mcp.example.com/mcp" auth: oauth oauth: # all fields optional client_id: "pre-registered-id" # skip dynamic registration client_secret: "secret" # confidential clients only scope: "read write" # default: server-provided redirect_port: 0 # 0 = auto-pick free port client_name: "My Custom Client" # default: "Hermes Agent" """ import asyncio import json import logging import os import re import socket import sys import threading import webbrowser from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path from typing import Any, Optional from urllib.parse import parse_qs, urlparse logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Lazy imports -- MCP SDK with OAuth support is optional # --------------------------------------------------------------------------- _OAUTH_AVAILABLE = False try: from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.shared.auth import ( OAuthClientInformationFull, OAuthClientMetadata, OAuthToken, ) from pydantic import AnyUrl _OAUTH_AVAILABLE = True except ImportError: logger.debug("MCP OAuth types not available -- OAuth MCP auth disabled") # --------------------------------------------------------------------------- # Exceptions # --------------------------------------------------------------------------- class OAuthNonInteractiveError(RuntimeError): """Raised when OAuth requires browser interaction in a non-interactive env.""" # --------------------------------------------------------------------------- # Module-level state # --------------------------------------------------------------------------- # Port used by the most recent build_oauth_auth() call. Exposed so that # tests can verify the callback server and the redirect_uri share a port. _oauth_port: int | None = None # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _get_token_dir() -> Path: """Return the directory for MCP OAuth token files. Uses HERMES_HOME so each profile gets its own OAuth tokens. Layout: ``HERMES_HOME/mcp-tokens/`` """ try: from hermes_constants import get_hermes_home base = Path(get_hermes_home()) except ImportError: base = Path(os.environ.get("HERMES_HOME", str(Path.home() / ".hermes"))) return base / "mcp-tokens" def _safe_filename(name: str) -> str: """Sanitize a server name for use as a filename (no path separators).""" return re.sub(r"[^\w\-]", "_", name).strip("_")[:128] or "default" def _find_free_port() -> int: """Find an available TCP port on localhost.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("127.0.0.1", 0)) return s.getsockname()[1] def _is_interactive() -> bool: """Return True if we can reasonably expect to interact with a user.""" try: return sys.stdin.isatty() except (AttributeError, ValueError): return False def _can_open_browser() -> bool: """Return True if opening a browser is likely to work.""" # Explicit SSH session → no local display if os.environ.get("SSH_CLIENT") or os.environ.get("SSH_TTY"): return False # macOS and Windows usually have a display if os.name == "nt": return True try: if os.uname().sysname == "Darwin": return True except AttributeError: pass # Linux/other posix: need DISPLAY or WAYLAND_DISPLAY if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"): return True return False def _read_json(path: Path) -> dict | None: """Read a JSON file, returning None if it doesn't exist or is invalid.""" if not path.exists(): return None try: return json.loads(path.read_text(encoding="utf-8")) except (json.JSONDecodeError, OSError) as exc: logger.warning("Failed to read %s: %s", path, exc) return None def _write_json(path: Path, data: dict) -> None: """Write a dict as JSON with restricted permissions (0o600).""" path.parent.mkdir(parents=True, exist_ok=True) tmp = path.with_suffix(".tmp") try: tmp.write_text(json.dumps(data, indent=2, default=str), encoding="utf-8") os.chmod(tmp, 0o600) tmp.rename(path) except OSError: tmp.unlink(missing_ok=True) raise # --------------------------------------------------------------------------- # HermesTokenStorage -- persistent token/client-info on disk # --------------------------------------------------------------------------- class HermesTokenStorage: """Persist OAuth tokens and client registration to JSON files. File layout:: HERMES_HOME/mcp-tokens/.json -- tokens HERMES_HOME/mcp-tokens/.client.json -- client info """ def __init__(self, server_name: str): self._server_name = _safe_filename(server_name) def _tokens_path(self) -> Path: return _get_token_dir() / f"{self._server_name}.json" def _client_info_path(self) -> Path: return _get_token_dir() / f"{self._server_name}.client.json" # -- tokens ------------------------------------------------------------ async def get_tokens(self) -> "OAuthToken | None": data = _read_json(self._tokens_path()) if data is None: return None try: return OAuthToken.model_validate(data) except Exception: logger.warning("Corrupt tokens at %s -- ignoring", self._tokens_path()) return None async def set_tokens(self, tokens: "OAuthToken") -> None: _write_json(self._tokens_path(), tokens.model_dump(exclude_none=True)) logger.debug("OAuth tokens saved for %s", self._server_name) # -- client info ------------------------------------------------------- async def get_client_info(self) -> "OAuthClientInformationFull | None": data = _read_json(self._client_info_path()) if data is None: return None try: return OAuthClientInformationFull.model_validate(data) except Exception: logger.warning("Corrupt client info at %s -- ignoring", self._client_info_path()) return None async def set_client_info(self, client_info: "OAuthClientInformationFull") -> None: _write_json(self._client_info_path(), client_info.model_dump(exclude_none=True)) logger.debug("OAuth client info saved for %s", self._server_name) # -- cleanup ----------------------------------------------------------- def remove(self) -> None: """Delete all stored OAuth state for this server.""" for p in (self._tokens_path(), self._client_info_path()): p.unlink(missing_ok=True) def has_cached_tokens(self) -> bool: """Return True if we have tokens on disk (may be expired).""" return self._tokens_path().exists() # --------------------------------------------------------------------------- # Callback handler factory -- each invocation gets its own result dict # --------------------------------------------------------------------------- def _make_callback_handler() -> tuple[type, dict]: """Create a per-flow callback HTTP handler class with its own result dict. Returns ``(HandlerClass, result_dict)`` where *result_dict* is a mutable dict that the handler writes ``auth_code`` and ``state`` into when the OAuth redirect arrives. Each call returns a fresh pair so concurrent flows don't stomp on each other. """ result: dict[str, Any] = {"auth_code": None, "state": None, "error": None} class _Handler(BaseHTTPRequestHandler): def do_GET(self) -> None: # noqa: N802 params = parse_qs(urlparse(self.path).query) code = params.get("code", [None])[0] state = params.get("state", [None])[0] error = params.get("error", [None])[0] result["auth_code"] = code result["state"] = state result["error"] = error body = ( "

Authorization Successful

" "

You can close this tab and return to Hermes.

" ) if code else ( "

Authorization Failed

" f"

Error: {error or 'unknown'}

" ) self.send_response(200) self.send_header("Content-Type", "text/html; charset=utf-8") self.end_headers() self.wfile.write(body.encode()) def log_message(self, fmt: str, *args: Any) -> None: logger.debug("OAuth callback: %s", fmt % args) return _Handler, result # --------------------------------------------------------------------------- # Async redirect + callback handlers for OAuthClientProvider # --------------------------------------------------------------------------- async def _redirect_handler(authorization_url: str) -> None: """Show the authorization URL to the user. Opens the browser automatically when possible; always prints the URL as a fallback for headless/SSH/gateway environments. """ msg = ( f"\n MCP OAuth: authorization required.\n" f" Open this URL in your browser:\n\n" f" {authorization_url}\n" ) print(msg, file=sys.stderr) if _can_open_browser(): try: opened = webbrowser.open(authorization_url) if opened: print(" (Browser opened automatically.)\n", file=sys.stderr) else: print(" (Could not open browser — please open the URL manually.)\n", file=sys.stderr) except Exception: print(" (Could not open browser — please open the URL manually.)\n", file=sys.stderr) else: print(" (Headless environment detected — open the URL manually.)\n", file=sys.stderr) async def _wait_for_callback() -> tuple[str, str | None]: """Wait for the OAuth callback to arrive on the local callback server. Uses the module-level ``_oauth_port`` which is set by ``build_oauth_auth`` before this is ever called. Polls for the result without blocking the event loop. Raises: OAuthNonInteractiveError: If the callback times out (no user present to complete the browser auth). """ global _oauth_port assert _oauth_port is not None, "OAuth callback port not set" # The callback server is already running (started in build_oauth_auth). # We just need to poll for the result. handler_cls, result = _make_callback_handler() # Start a temporary server on the known port try: server = HTTPServer(("127.0.0.1", _oauth_port), handler_cls) except OSError: # Port already in use — the server from build_oauth_auth is running. # Fall back to polling the server started by build_oauth_auth. raise OAuthNonInteractiveError( "OAuth callback timed out — could not bind callback port. " "Complete the authorization in a browser first, then retry." ) server_thread = threading.Thread(target=server.handle_request, daemon=True) server_thread.start() timeout = 300.0 poll_interval = 0.5 elapsed = 0.0 while elapsed < timeout: if result["auth_code"] is not None or result["error"] is not None: break await asyncio.sleep(poll_interval) elapsed += poll_interval server.server_close() if result["error"]: raise RuntimeError(f"OAuth authorization failed: {result['error']}") if result["auth_code"] is None: raise OAuthNonInteractiveError( "OAuth callback timed out — no authorization code received. " "Ensure you completed the browser authorization flow." ) return result["auth_code"], result["state"] # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def remove_oauth_tokens(server_name: str) -> None: """Delete stored OAuth tokens and client info for a server.""" storage = HermesTokenStorage(server_name) storage.remove() logger.info("OAuth tokens removed for '%s'", server_name) def build_oauth_auth( server_name: str, server_url: str, oauth_config: dict | None = None, ) -> "OAuthClientProvider | None": """Build an ``httpx.Auth``-compatible OAuth handler for an MCP server. Called from ``mcp_tool.py`` when a server has ``auth: oauth`` in config. Args: server_name: Server key in mcp_servers config (used for storage). server_url: MCP server endpoint URL. oauth_config: Optional dict from the ``oauth:`` block in config.yaml. Returns: An ``OAuthClientProvider`` instance, or None if the MCP SDK lacks OAuth support. """ if not _OAUTH_AVAILABLE: logger.warning( "MCP OAuth requested for '%s' but SDK auth types are not available. " "Install with: pip install 'mcp>=1.10.0'", server_name, ) return None global _oauth_port cfg = oauth_config or {} # --- Storage --- storage = HermesTokenStorage(server_name) # --- Non-interactive warning --- if not _is_interactive() and not storage.has_cached_tokens(): logger.warning( "MCP OAuth for '%s': non-interactive environment and no cached tokens found. " "The OAuth flow requires browser authorization. Run interactively first " "to complete the initial authorization, then cached tokens will be reused.", server_name, ) # --- Pick callback port --- redirect_port = int(cfg.get("redirect_port", 0)) if redirect_port == 0: redirect_port = _find_free_port() _oauth_port = redirect_port # --- Client metadata --- client_name = cfg.get("client_name", "Hermes Agent") scope = cfg.get("scope") redirect_uri = f"http://127.0.0.1:{redirect_port}/callback" metadata_kwargs: dict[str, Any] = { "client_name": client_name, "redirect_uris": [AnyUrl(redirect_uri)], "grant_types": ["authorization_code", "refresh_token"], "response_types": ["code"], "token_endpoint_auth_method": "none", } if scope: metadata_kwargs["scope"] = scope client_secret = cfg.get("client_secret") if client_secret: metadata_kwargs["token_endpoint_auth_method"] = "client_secret_post" client_metadata = OAuthClientMetadata.model_validate(metadata_kwargs) # --- Pre-registered client --- client_id = cfg.get("client_id") if client_id: info_dict: dict[str, Any] = { "client_id": client_id, "redirect_uris": [redirect_uri], "grant_types": client_metadata.grant_types, "response_types": client_metadata.response_types, "token_endpoint_auth_method": client_metadata.token_endpoint_auth_method, } if client_secret: info_dict["client_secret"] = client_secret if client_name: info_dict["client_name"] = client_name if scope: info_dict["scope"] = scope client_info = OAuthClientInformationFull.model_validate(info_dict) _write_json(storage._client_info_path(), client_info.model_dump(exclude_none=True)) logger.debug("Pre-registered client_id=%s for '%s'", client_id, server_name) # --- Base URL for discovery --- parsed = urlparse(server_url) base_url = f"{parsed.scheme}://{parsed.netloc}" # --- Build provider --- provider = OAuthClientProvider( server_url=base_url, client_metadata=client_metadata, storage=storage, redirect_handler=_redirect_handler, callback_handler=_wait_for_callback, timeout=float(cfg.get("timeout", 300)), ) return provider