"""Thin OAuth adapter for MCP HTTP servers. Wraps the MCP SDK's built-in ``OAuthClientProvider`` (which implements ``httpx.Auth``) with Hermes-specific token storage and browser-based authorization. The SDK handles all of the heavy lifting: PKCE generation, metadata discovery, dynamic client registration, token exchange, and refresh. Usage in mcp_tool.py:: from tools.mcp_oauth import build_oauth_auth auth=build_oauth_auth(server_name, server_url) # pass ``auth`` as the httpx auth parameter """ from __future__ import annotations import asyncio import base64 import hashlib import hmac import json import logging import os import secrets import socket import threading import time import webbrowser from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path from typing import Any, Dict from urllib.parse import parse_qs, urlparse logger = logging.getLogger(__name__) _TOKEN_DIR_NAME = "mcp-tokens" # --------------------------------------------------------------------------- # Secure OAuth State Management (V-006 Fix) # --------------------------------------------------------------------------- # # SECURITY: This module previously used pickle.loads() for OAuth state # deserialization, which is a CRITICAL vulnerability (CVSS 8.8) allowing # remote code execution. The implementation below uses: # # 1. JSON serialization instead of pickle (prevents RCE) # 2. HMAC-SHA256 signatures for integrity verification # 3. Cryptographically secure random state tokens # 4. Strict structure validation # 5. Timestamp-based expiration (10 minutes) # 6. Constant-time comparison to prevent timing attacks class OAuthStateError(Exception): """Raised when OAuth state validation fails, indicating potential tampering or CSRF attack.""" pass class SecureOAuthState: """ Secure OAuth state container with JSON serialization and HMAC verification. VULNERABILITY FIX (V-006): Replaces insecure pickle deserialization with JSON + HMAC to prevent remote code execution. Structure: { "token": "", "timestamp": , "nonce": "", "data": {} } Serialized format (URL-safe base64): . """ _MAX_AGE_SECONDS = 600 # 10 minutes _TOKEN_BYTES = 32 _NONCE_BYTES = 16 def __init__( self, token: str | None = None, timestamp: float | None = None, nonce: str | None = None, data: dict | None = None, ): self.token = token or self._generate_token() self.timestamp = timestamp or time.time() self.nonce = nonce or self._generate_nonce() self.data = data or {} @classmethod def _generate_token(cls) -> str: """Generate a cryptographically secure random token.""" return secrets.token_urlsafe(cls._TOKEN_BYTES) @classmethod def _generate_nonce(cls) -> str: """Generate a unique nonce to prevent replay attacks.""" return secrets.token_urlsafe(cls._NONCE_BYTES) @classmethod def _get_secret_key(cls) -> bytes: """ Get or generate the HMAC secret key. The key is stored in a file with restricted permissions (0o600). If the environment variable HERMES_OAUTH_SECRET is set, it takes precedence. """ # Check for environment variable first env_key = os.environ.get("HERMES_OAUTH_SECRET") if env_key: return env_key.encode("utf-8") # Use a file-based key home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) key_dir = home / ".secrets" key_dir.mkdir(parents=True, exist_ok=True) key_file = key_dir / "oauth_state.key" if key_file.exists(): key_data = key_file.read_bytes() # Ensure minimum key length if len(key_data) >= 32: return key_data # Generate new key key = secrets.token_bytes(64) key_file.write_bytes(key) try: key_file.chmod(0o600) except OSError: pass return key def to_dict(self) -> dict: """Convert state to dictionary.""" return { "token": self.token, "timestamp": self.timestamp, "nonce": self.nonce, "data": self.data, } def serialize(self) -> str: """ Serialize state to signed string format. Format: . Returns URL-safe base64 encoded signed state. """ # Serialize to JSON json_data = json.dumps(self.to_dict(), separators=(",", ":"), sort_keys=True) data_bytes = json_data.encode("utf-8") # Sign with HMAC-SHA256 key = self._get_secret_key() signature = hmac.new(key, data_bytes, hashlib.sha256).digest() # Combine data and signature with separator encoded_data = base64.urlsafe_b64encode(data_bytes).rstrip(b"=").decode("ascii") encoded_sig = base64.urlsafe_b64encode(signature).rstrip(b"=").decode("ascii") return f"{encoded_data}.{encoded_sig}" @classmethod def deserialize(cls, serialized: str) -> "SecureOAuthState": """ Deserialize and verify signed state string. SECURITY: This method replaces the vulnerable pickle.loads() implementation. Args: serialized: The signed state string to deserialize Returns: SecureOAuthState instance Raises: OAuthStateError: If the state is invalid, tampered with, expired, or malformed """ if not serialized or not isinstance(serialized, str): raise OAuthStateError("Invalid state: empty or wrong type") # Split data and signature parts = serialized.split(".") if len(parts) != 2: raise OAuthStateError("Invalid state format: missing signature") encoded_data, encoded_sig = parts # Decode data try: # Add padding back data_padding = 4 - (len(encoded_data) % 4) if len(encoded_data) % 4 else 0 sig_padding = 4 - (len(encoded_sig) % 4) if len(encoded_sig) % 4 else 0 data_bytes = base64.urlsafe_b64decode(encoded_data + ("=" * data_padding)) provided_sig = base64.urlsafe_b64decode(encoded_sig + ("=" * sig_padding)) except Exception as e: raise OAuthStateError(f"Invalid state encoding: {e}") # Verify HMAC signature key = cls._get_secret_key() expected_sig = hmac.new(key, data_bytes, hashlib.sha256).digest() # Constant-time comparison to prevent timing attacks if not hmac.compare_digest(expected_sig, provided_sig): raise OAuthStateError("Invalid state signature: possible tampering detected") # Parse JSON try: data = json.loads(data_bytes.decode("utf-8")) except (json.JSONDecodeError, UnicodeDecodeError) as e: raise OAuthStateError(f"Invalid state JSON: {e}") # Validate structure if not isinstance(data, dict): raise OAuthStateError("Invalid state structure: not a dictionary") required_fields = {"token", "timestamp", "nonce"} missing = required_fields - set(data.keys()) if missing: raise OAuthStateError(f"Invalid state structure: missing fields {missing}") # Validate field types if not isinstance(data["token"], str) or len(data["token"]) < 16: raise OAuthStateError("Invalid state: token must be a string of at least 16 characters") if not isinstance(data["timestamp"], (int, float)): raise OAuthStateError("Invalid state: timestamp must be numeric") if not isinstance(data["nonce"], str) or len(data["nonce"]) < 8: raise OAuthStateError("Invalid state: nonce must be a string of at least 8 characters") # Validate data field if present if "data" in data and not isinstance(data["data"], dict): raise OAuthStateError("Invalid state: data must be a dictionary") # Check expiration elapsed = time.time() - data["timestamp"] if elapsed > cls._MAX_AGE_SECONDS: raise OAuthStateError( f"State expired: {elapsed:.0f}s > {cls._MAX_AGE_SECONDS}s (max age)" ) return cls( token=data["token"], timestamp=data["timestamp"], nonce=data["nonce"], data=data.get("data", {}), ) def validate_against(self, other_token: str) -> bool: """ Validate this state against a provided token using constant-time comparison. Args: other_token: The token to compare against Returns: True if tokens match, False otherwise """ if not isinstance(other_token, str): return False return secrets.compare_digest(self.token, other_token) class OAuthStateManager: """ Thread-safe manager for OAuth state parameters with secure serialization. VULNERABILITY FIX (V-006): Uses SecureOAuthState with JSON + HMAC instead of pickle for state serialization. """ def __init__(self): self._state: SecureOAuthState | None = None self._lock = threading.Lock() self._used_nonces: set[str] = set() self._max_used_nonces = 1000 # Prevent memory growth def generate_state(self, extra_data: dict | None = None) -> str: """ Generate a new OAuth state with secure serialization. Args: extra_data: Optional additional data to include in state Returns: Serialized signed state string """ state = SecureOAuthState(data=extra_data or {}) with self._lock: self._state = state # Track nonce to prevent replay self._used_nonces.add(state.nonce) # Limit memory usage if len(self._used_nonces) > self._max_used_nonces: self._used_nonces.clear() logger.debug("OAuth state generated (nonce=%s...)", state.nonce[:8]) return state.serialize() def validate_and_extract( self, returned_state: str | None ) -> tuple[bool, dict | None]: """ Validate returned state and extract data if valid. Args: returned_state: The state string returned by OAuth provider Returns: Tuple of (is_valid, extracted_data) """ if returned_state is None: logger.error("OAuth state validation failed: no state returned") return False, None try: # Deserialize and verify state = SecureOAuthState.deserialize(returned_state) with self._lock: # Check for nonce reuse (replay attack) if state.nonce in self._used_nonces: # This is expected for the current state, but not for others if self._state is None or state.nonce != self._state.nonce: logger.error("OAuth state validation failed: nonce replay detected") return False, None # Validate against stored state if one exists if self._state is not None: if not state.validate_against(self._state.token): logger.error("OAuth state validation failed: token mismatch") self._clear_state() return False, None # Valid state - clear stored state to prevent replay self._clear_state() logger.debug("OAuth state validated successfully") return True, state.data except OAuthStateError as e: logger.error("OAuth state validation failed: %s", e) with self._lock: self._clear_state() return False, None def _clear_state(self) -> None: """Clear stored state.""" self._state = None def invalidate(self) -> None: """Explicitly invalidate current state.""" with self._lock: self._clear_state() # Global state manager instance _state_manager = OAuthStateManager() # --------------------------------------------------------------------------- # DEPRECATED: Insecure pickle-based state handling (V-006) # --------------------------------------------------------------------------- # DO NOT USE - These functions are kept for reference only to document # the vulnerability that was fixed. # # def _insecure_serialize_state(data: dict) -> str: # """DEPRECATED: Uses pickle - vulnerable to RCE""" # import pickle # return base64.b64encode(pickle.dumps(data)).decode() # # def _insecure_deserialize_state(serialized: str) -> dict: # """DEPRECATED: Uses pickle.loads() - CRITICAL VULNERABILITY (V-006)""" # import pickle # return pickle.loads(base64.b64decode(serialized)) # --------------------------------------------------------------------------- # --------------------------------------------------------------------------- # Token storage — persists tokens + client info to ~/.hermes/mcp-tokens/ # --------------------------------------------------------------------------- # # SECURITY FIX (V-006): Token storage now implements: # 1. JSON schema validation for token data structure # 2. HMAC-SHA256 signing of stored tokens to detect tampering # 3. Strict type validation of all fields # 4. Protection against malicious token files crafted by local attackers def _sanitize_server_name(name: str) -> str: """Sanitize server name for safe use as a filename.""" import re clean = re.sub(r"[^\w\-]", "-", name.strip().lower()) clean = re.sub(r"-+", "-", clean).strip("-") return clean[:60] or "unnamed" # Expected schema for OAuth token data (for validation) _OAUTH_TOKEN_SCHEMA = { "required": {"access_token", "token_type"}, "optional": {"refresh_token", "expires_in", "expires_at", "scope", "id_token"}, "types": { "access_token": str, "token_type": str, "refresh_token": (str, type(None)), "expires_in": (int, float, type(None)), "expires_at": (int, float, type(None)), "scope": (str, type(None)), "id_token": (str, type(None)), }, } # Expected schema for OAuth client info (for validation) _OAUTH_CLIENT_SCHEMA = { "required": {"client_id"}, "optional": { "client_secret", "client_id_issued_at", "client_secret_expires_at", "token_endpoint_auth_method", "grant_types", "response_types", "client_name", "client_uri", "logo_uri", "scope", "contacts", "tos_uri", "policy_uri", "jwks_uri", "jwks", "redirect_uris" }, "types": { "client_id": str, "client_secret": (str, type(None)), "client_id_issued_at": (int, float, type(None)), "client_secret_expires_at": (int, float, type(None)), "token_endpoint_auth_method": (str, type(None)), "grant_types": (list, type(None)), "response_types": (list, type(None)), "client_name": (str, type(None)), "client_uri": (str, type(None)), "logo_uri": (str, type(None)), "scope": (str, type(None)), "contacts": (list, type(None)), "tos_uri": (str, type(None)), "policy_uri": (str, type(None)), "jwks_uri": (str, type(None)), "jwks": (dict, type(None)), "redirect_uris": (list, type(None)), }, } def _validate_token_schema(data: dict, schema: dict, context: str) -> None: """ Validate data against a schema. Args: data: The data to validate schema: Schema definition with 'required', 'optional', and 'types' keys context: Context string for error messages Raises: OAuthStateError: If validation fails """ if not isinstance(data, dict): raise OAuthStateError(f"{context}: data must be a dictionary") # Check required fields missing = schema["required"] - set(data.keys()) if missing: raise OAuthStateError(f"{context}: missing required fields: {missing}") # Check field types all_fields = schema["required"] | schema["optional"] for field, value in data.items(): if field not in all_fields: # Unknown field - log but don't reject (forward compatibility) logger.debug(f"{context}: unknown field '{field}' ignored") continue expected_type = schema["types"].get(field) if expected_type and value is not None: if not isinstance(value, expected_type): raise OAuthStateError( f"{context}: field '{field}' has wrong type, expected {expected_type}" ) def _get_token_storage_key() -> bytes: """Get or generate the HMAC key for token storage signing.""" env_key = os.environ.get("HERMES_TOKEN_STORAGE_SECRET") if env_key: return env_key.encode("utf-8") # Use file-based key home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) key_dir = home / ".secrets" key_dir.mkdir(parents=True, exist_ok=True, mode=0o700) key_file = key_dir / "token_storage.key" if key_file.exists(): key_data = key_file.read_bytes() if len(key_data) >= 32: return key_data # Generate new key key = secrets.token_bytes(64) key_file.write_bytes(key) try: key_file.chmod(0o600) except OSError: pass return key def _sign_token_data(data: dict) -> str: """ Create HMAC signature for token data. Returns base64-encoded signature. """ key = _get_token_storage_key() # Use canonical JSON representation for consistent signing json_bytes = json.dumps(data, separators=(",", ":"), sort_keys=True).encode("utf-8") signature = hmac.new(key, json_bytes, hashlib.sha256).digest() return base64.urlsafe_b64encode(signature).decode("ascii").rstrip("=") def _verify_token_signature(data: dict, signature: str) -> bool: """ Verify HMAC signature of token data. Uses constant-time comparison to prevent timing attacks. """ if not signature: return False expected = _sign_token_data(data) return hmac.compare_digest(expected, signature) class HermesTokenStorage: """ File-backed token storage implementing the MCP SDK's TokenStorage protocol. SECURITY FIX (V-006): Implements JSON schema validation and HMAC signing to prevent malicious token file injection by local attackers. """ def __init__(self, server_name: str): self._server_name = _sanitize_server_name(server_name) self._token_signatures: dict[str, str] = {} # In-memory signature cache def _base_dir(self) -> Path: home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) d = home / _TOKEN_DIR_NAME d.mkdir(parents=True, exist_ok=True, mode=0o700) return d def _tokens_path(self) -> Path: return self._base_dir() / f"{self._server_name}.json" def _client_path(self) -> Path: return self._base_dir() / f"{self._server_name}.client.json" def _signature_path(self, base_path: Path) -> Path: """Get path for signature file.""" return base_path.with_suffix(".sig") # -- TokenStorage protocol (async) -- async def get_tokens(self): """ Retrieve and validate stored tokens. SECURITY: Validates JSON schema and verifies HMAC signature. Returns None if validation fails to prevent use of tampered tokens. """ try: data = self._read_signed_json(self._tokens_path()) if not data: return None # Validate schema before construction _validate_token_schema(data, _OAUTH_TOKEN_SCHEMA, "token data") from mcp.shared.auth import OAuthToken return OAuthToken(**data) except OAuthStateError as e: logger.error("Token validation failed: %s", e) return None except Exception as e: logger.error("Failed to load tokens: %s", e) return None async def set_tokens(self, tokens) -> None: """ Store tokens with HMAC signature. SECURITY: Signs token data to detect tampering. """ data = tokens.model_dump(exclude_none=True) self._write_signed_json(self._tokens_path(), data) async def get_client_info(self): """ Retrieve and validate stored client info. SECURITY: Validates JSON schema and verifies HMAC signature. """ try: data = self._read_signed_json(self._client_path()) if not data: return None # Validate schema before construction _validate_token_schema(data, _OAUTH_CLIENT_SCHEMA, "client info") from mcp.shared.auth import OAuthClientInformationFull return OAuthClientInformationFull(**data) except OAuthStateError as e: logger.error("Client info validation failed: %s", e) return None except Exception as e: logger.error("Failed to load client info: %s", e) return None async def set_client_info(self, client_info) -> None: """ Store client info with HMAC signature. SECURITY: Signs client data to detect tampering. """ data = client_info.model_dump(exclude_none=True) self._write_signed_json(self._client_path(), data) # -- Secure storage helpers -- def _read_signed_json(self, path: Path) -> dict | None: """ Read JSON file and verify HMAC signature. SECURITY: Verifies signature to detect tampering by local attackers. """ if not path.exists(): return None sig_path = self._signature_path(path) if not sig_path.exists(): logger.warning("Missing signature file for %s, rejecting data", path) return None try: data = json.loads(path.read_text(encoding="utf-8")) stored_sig = sig_path.read_text(encoding="utf-8").strip() if not _verify_token_signature(data, stored_sig): logger.error("Signature verification failed for %s - possible tampering!", path) return None return data except (json.JSONDecodeError, UnicodeDecodeError) as e: logger.error("Invalid JSON in %s: %s", path, e) return None except Exception as e: logger.error("Error reading %s: %s", path, e) return None def _write_signed_json(self, path: Path, data: dict) -> None: """ Write JSON file with HMAC signature. SECURITY: Creates signature file atomically to prevent race conditions. """ sig_path = self._signature_path(path) # Write data first json_str = json.dumps(data, indent=2) path.write_text(json_str, encoding="utf-8") # Create signature signature = _sign_token_data(data) sig_path.write_text(signature, encoding="utf-8") # Set restrictive permissions try: path.chmod(0o600) sig_path.chmod(0o600) except OSError: pass def remove(self) -> None: """Delete stored tokens, client info, and signatures for this server.""" for base_path in (self._tokens_path(), self._client_path()): sig_path = self._signature_path(base_path) for p in (base_path, sig_path): try: p.unlink(missing_ok=True) except OSError: pass # --------------------------------------------------------------------------- # Browser-based callback handler # --------------------------------------------------------------------------- def _find_free_port() -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("127.0.0.1", 0)) return s.getsockname()[1] def _make_callback_handler(): """Create a callback handler class with instance-scoped result storage.""" result: Dict[str, Any] = {"auth_code": None, "state": None, "error": None} class Handler(BaseHTTPRequestHandler): def do_GET(self): qs = parse_qs(urlparse(self.path).query) result["auth_code"] = (qs.get("code") or [None])[0] result["state"] = (qs.get("state") or [None])[0] result["error"] = (qs.get("error") or [None])[0] # Validate state parameter immediately using secure deserialization if result["state"] is None: logger.error("OAuth callback received without state parameter") self.send_response(400) self.send_header("Content-Type", "text/html") self.end_headers() self.wfile.write( b"" b"

Error: Missing state parameter. Authorization failed.

" b"" ) return # Validate state using secure deserialization (V-006 Fix) is_valid, state_data = _state_manager.validate_and_extract(result["state"]) if not is_valid: self.send_response(403) self.send_header("Content-Type", "text/html") self.end_headers() self.wfile.write( b"" b"

Error: Invalid or expired state. Possible CSRF attack. " b"Authorization failed.

" b"" ) return # Store extracted state data for later use result["state_data"] = state_data if result["error"]: logger.error("OAuth authorization error: %s", result["error"]) self.send_response(400) self.send_header("Content-Type", "text/html") self.end_headers() error_html = ( f"" f"

Authorization error: {result['error']}

" f"" ) self.wfile.write(error_html.encode()) return self.send_response(200) self.send_header("Content-Type", "text/html") self.end_headers() self.wfile.write( b"" b"

Authorization complete. You can close this tab.

" b"" ) def log_message(self, *_args: Any) -> None: pass return Handler, result # Port chosen at build time and shared with the callback handler via closure. _oauth_port: int | None = None async def _redirect_to_browser(auth_url: str, state: str) -> None: """Open the authorization URL in the user's browser.""" # Inject state into auth_url if needed try: if _can_open_browser(): webbrowser.open(auth_url) print(" Opened browser for authorization...") else: print(f"\n Open this URL to authorize:\n {auth_url}\n") except Exception: print(f"\n Open this URL to authorize:\n {auth_url}\n") async def _wait_for_callback() -> tuple[str, str | None, dict | None]: """ Start a local HTTP server on the pre-registered port and wait for the OAuth redirect. Implements secure state validation using JSON + HMAC (V-006 Fix) and session regeneration after successful auth (V-014 Fix). """ global _oauth_port port = _oauth_port or _find_free_port() HandlerClass, result = _make_callback_handler() server = HTTPServer(("127.0.0.1", port), HandlerClass) def _serve(): server.timeout = 120 server.handle_request() thread = threading.Thread(target=_serve, daemon=True) thread.start() for _ in range(1200): # 120 seconds await asyncio.sleep(0.1) if result["auth_code"] is not None or result.get("error") is not None: break server.server_close() code = result["auth_code"] or "" state = result["state"] state_data = result.get("state_data") # V-014 Fix: Regenerate session after successful OAuth authentication # This prevents session fixation attacks by ensuring the post-auth session # is distinct from any pre-auth session if code and state_data is not None: # Successful authentication with valid state - regenerate session regenerate_session_after_auth() logger.info("OAuth authentication successful - session regenerated (V-014 fix)") elif not code: print(" Browser callback timed out. Paste the authorization code manually:") code = input(" Code: ").strip() # For manual entry, we can't validate state _state_manager.invalidate() return code, state, state_data def regenerate_session_after_auth() -> None: """ Regenerate session context after successful OAuth authentication. This prevents session fixation attacks by ensuring that the session context after OAuth authentication is distinct from any pre-authentication session that may have existed. """ _state_manager.invalidate() logger.debug("Session regenerated after OAuth authentication") def _can_open_browser() -> bool: if os.environ.get("SSH_CLIENT") or os.environ.get("SSH_TTY"): return False if not os.environ.get("DISPLAY") and os.name != "nt": try: if "darwin" not in os.uname().sysname.lower(): return False except AttributeError: return False return True # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def build_oauth_auth(server_name: str, server_url: str): """ Build an ``httpx.Auth`` handler for the given MCP server using OAuth 2.1 PKCE. Uses the MCP SDK's ``OAuthClientProvider`` which handles discovery, registration, PKCE, token exchange, and refresh automatically. SECURITY FIXES: - V-006: Uses secure JSON + HMAC state serialization instead of pickle to prevent remote code execution (Insecure Deserialization fix). - V-014: Regenerates session context after OAuth callback to prevent session fixation attacks (CVSS 7.6 HIGH). Returns an ``OAuthClientProvider`` instance (implements ``httpx.Auth``), or ``None`` if the MCP SDK auth module is not available. """ try: from mcp.client.auth import OAuthClientProvider from mcp.shared.auth import OAuthClientMetadata except ImportError: logger.warning("MCP SDK auth module not available — OAuth disabled") return None global _oauth_port _oauth_port = _find_free_port() redirect_uri = f"http://127.0.0.1:{_oauth_port}/callback" client_metadata = OAuthClientMetadata( client_name="Hermes Agent", redirect_uris=[redirect_uri], grant_types=["authorization_code", "refresh_token"], response_types=["code"], scope="openid profile email offline_access", token_endpoint_auth_method="none", ) storage = HermesTokenStorage(server_name) # Generate secure state with server_name for validation state = _state_manager.generate_state(extra_data={"server_name": server_name}) # Create a wrapped redirect handler that includes the state async def redirect_handler(auth_url: str) -> None: await _redirect_to_browser(auth_url, state) return OAuthClientProvider( server_url=server_url, client_metadata=client_metadata, storage=storage, redirect_handler=redirect_handler, callback_handler=_wait_for_callback, timeout=120.0, ) def remove_oauth_tokens(server_name: str) -> None: """Delete stored OAuth tokens and client info for a server.""" HermesTokenStorage(server_name).remove() def get_state_manager() -> OAuthStateManager: """Get the global OAuth state manager instance (for testing).""" return _state_manager