feat(cli): MCP server management CLI + OAuth 2.1 PKCE auth
Add hermes mcp add/remove/list/test/configure CLI for managing MCP
server connections interactively. Discovery-first 'add' flow connects,
discovers tools, and lets users select which to enable via curses checklist.
Add OAuth 2.1 PKCE authentication for MCP HTTP servers (RFC 7636).
Supports browser-based and manual (headless) authorization, token
caching with 0600 permissions, automatic refresh. Zero external deps.
Add ${ENV_VAR} interpolation in MCP server config values, resolved
from os.environ + ~/.hermes/.env at load time.
Core OAuth module from PR #2021 by @imnotdev25. CLI and mcp_tool
wiring rewritten against current main. Closes #497, #690.
This commit is contained in:
@@ -2958,7 +2958,7 @@ def _coalesce_session_name_args(argv: list) -> list:
|
||||
_SUBCOMMANDS = {
|
||||
"chat", "model", "gateway", "setup", "whatsapp", "login", "logout",
|
||||
"status", "cron", "doctor", "config", "pairing", "skills", "tools",
|
||||
"sessions", "insights", "version", "update", "uninstall",
|
||||
"mcp", "sessions", "insights", "version", "update", "uninstall",
|
||||
}
|
||||
_SESSION_FLAGS = {"-c", "--continue", "-r", "--resume"}
|
||||
|
||||
@@ -3702,6 +3702,45 @@ For more help on a command:
|
||||
tools_command(args)
|
||||
|
||||
tools_parser.set_defaults(func=cmd_tools)
|
||||
# =========================================================================
|
||||
# mcp command — manage MCP server connections
|
||||
# =========================================================================
|
||||
mcp_parser = subparsers.add_parser(
|
||||
"mcp",
|
||||
help="Manage MCP server connections",
|
||||
description=(
|
||||
"Add, remove, list, test, and configure MCP server connections.\n\n"
|
||||
"MCP servers provide additional tools via the Model Context Protocol.\n"
|
||||
"Use 'hermes mcp add' to connect to a new server with interactive\n"
|
||||
"tool discovery. Run 'hermes mcp' with no subcommand to list servers."
|
||||
),
|
||||
)
|
||||
mcp_sub = mcp_parser.add_subparsers(dest="mcp_action")
|
||||
|
||||
mcp_add_p = mcp_sub.add_parser("add", help="Add an MCP server (discovery-first install)")
|
||||
mcp_add_p.add_argument("name", help="Server name (used as config key)")
|
||||
mcp_add_p.add_argument("--url", help="HTTP/SSE endpoint URL")
|
||||
mcp_add_p.add_argument("--command", help="Stdio command (e.g. npx)")
|
||||
mcp_add_p.add_argument("--args", nargs="*", default=[], help="Arguments for stdio command")
|
||||
mcp_add_p.add_argument("--auth", choices=["oauth", "header"], help="Auth method")
|
||||
|
||||
mcp_rm_p = mcp_sub.add_parser("remove", aliases=["rm"], help="Remove an MCP server")
|
||||
mcp_rm_p.add_argument("name", help="Server name to remove")
|
||||
|
||||
mcp_sub.add_parser("list", aliases=["ls"], help="List configured MCP servers")
|
||||
|
||||
mcp_test_p = mcp_sub.add_parser("test", help="Test MCP server connection")
|
||||
mcp_test_p.add_argument("name", help="Server name to test")
|
||||
|
||||
mcp_cfg_p = mcp_sub.add_parser("configure", aliases=["config"], help="Toggle tool selection")
|
||||
mcp_cfg_p.add_argument("name", help="Server name to configure")
|
||||
|
||||
def cmd_mcp(args):
|
||||
from hermes_cli.mcp_config import mcp_command
|
||||
mcp_command(args)
|
||||
|
||||
mcp_parser.set_defaults(func=cmd_mcp)
|
||||
|
||||
# =========================================================================
|
||||
# sessions command
|
||||
# =========================================================================
|
||||
|
||||
635
hermes_cli/mcp_config.py
Normal file
635
hermes_cli/mcp_config.py
Normal file
@@ -0,0 +1,635 @@
|
||||
"""
|
||||
MCP Server Management CLI — ``hermes mcp`` subcommand.
|
||||
|
||||
Implements ``hermes mcp add/remove/list/test/configure`` for interactive
|
||||
MCP server lifecycle management (issue #690 Phase 2).
|
||||
|
||||
Relies on tools/mcp_tool.py for connection/discovery and keeps
|
||||
configuration in ~/.hermes/config.yaml under the ``mcp_servers`` key.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import getpass
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from hermes_cli.config import (
|
||||
load_config,
|
||||
save_config,
|
||||
get_env_value,
|
||||
save_env_value,
|
||||
get_hermes_home,
|
||||
)
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ─── UI Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
def _info(text: str):
|
||||
print(color(f" {text}", Colors.DIM))
|
||||
|
||||
def _success(text: str):
|
||||
print(color(f" ✓ {text}", Colors.GREEN))
|
||||
|
||||
def _warning(text: str):
|
||||
print(color(f" ⚠ {text}", Colors.YELLOW))
|
||||
|
||||
def _error(text: str):
|
||||
print(color(f" ✗ {text}", Colors.RED))
|
||||
|
||||
|
||||
def _confirm(question: str, default: bool = True) -> bool:
|
||||
default_str = "Y/n" if default else "y/N"
|
||||
try:
|
||||
val = input(color(f" {question} [{default_str}]: ", Colors.YELLOW)).strip().lower()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
if not val:
|
||||
return default
|
||||
return val in ("y", "yes")
|
||||
|
||||
|
||||
def _prompt(question: str, *, password: bool = False, default: str = "") -> str:
|
||||
display = f" {question}"
|
||||
if default:
|
||||
display += f" [{default}]"
|
||||
display += ": "
|
||||
try:
|
||||
if password:
|
||||
value = getpass.getpass(color(display, Colors.YELLOW))
|
||||
else:
|
||||
value = input(color(display, Colors.YELLOW))
|
||||
return value.strip() or default
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
|
||||
|
||||
# ─── Config Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _get_mcp_servers(config: Optional[dict] = None) -> Dict[str, dict]:
|
||||
"""Return the ``mcp_servers`` dict from config, or empty dict."""
|
||||
if config is None:
|
||||
config = load_config()
|
||||
servers = config.get("mcp_servers")
|
||||
if not servers or not isinstance(servers, dict):
|
||||
return {}
|
||||
return servers
|
||||
|
||||
|
||||
def _save_mcp_server(name: str, server_config: dict):
|
||||
"""Add or update a server entry in config.yaml."""
|
||||
config = load_config()
|
||||
config.setdefault("mcp_servers", {})[name] = server_config
|
||||
save_config(config)
|
||||
|
||||
|
||||
def _remove_mcp_server(name: str) -> bool:
|
||||
"""Remove a server from config.yaml. Returns True if it existed."""
|
||||
config = load_config()
|
||||
servers = config.get("mcp_servers", {})
|
||||
if name not in servers:
|
||||
return False
|
||||
del servers[name]
|
||||
if not servers:
|
||||
config.pop("mcp_servers", None)
|
||||
save_config(config)
|
||||
return True
|
||||
|
||||
|
||||
def _env_key_for_server(name: str) -> str:
|
||||
"""Convert server name to an env-var key like ``MCP_MYSERVER_API_KEY``."""
|
||||
return f"MCP_{name.upper().replace('-', '_')}_API_KEY"
|
||||
|
||||
|
||||
# ─── Discovery (temporary connect) ───────────────────────────────────────────
|
||||
|
||||
def _probe_single_server(
|
||||
name: str, config: dict, connect_timeout: float = 30
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""Temporarily connect to one MCP server, list its tools, disconnect.
|
||||
|
||||
Returns list of ``(tool_name, description)`` tuples.
|
||||
Raises on connection failure.
|
||||
"""
|
||||
from tools.mcp_tool import (
|
||||
_ensure_mcp_loop,
|
||||
_run_on_mcp_loop,
|
||||
_connect_server,
|
||||
_stop_mcp_loop,
|
||||
)
|
||||
|
||||
_ensure_mcp_loop()
|
||||
|
||||
tools_found: List[Tuple[str, str]] = []
|
||||
|
||||
async def _probe():
|
||||
server = await asyncio.wait_for(
|
||||
_connect_server(name, config), timeout=connect_timeout
|
||||
)
|
||||
for t in server._tools:
|
||||
desc = getattr(t, "description", "") or ""
|
||||
# Truncate long descriptions for display
|
||||
if len(desc) > 80:
|
||||
desc = desc[:77] + "..."
|
||||
tools_found.append((t.name, desc))
|
||||
await server.shutdown()
|
||||
|
||||
try:
|
||||
_run_on_mcp_loop(_probe(), timeout=connect_timeout + 10)
|
||||
except BaseException as exc:
|
||||
raise _unwrap_exception_group(exc) from None
|
||||
finally:
|
||||
_stop_mcp_loop()
|
||||
|
||||
return tools_found
|
||||
|
||||
|
||||
def _unwrap_exception_group(exc: BaseException) -> Exception:
|
||||
"""Extract the root-cause exception from anyio TaskGroup wrappers.
|
||||
|
||||
The MCP SDK uses anyio task groups, which wrap errors in
|
||||
``BaseExceptionGroup`` / ``ExceptionGroup``. This makes error
|
||||
messages opaque ("unhandled errors in a TaskGroup"). We unwrap
|
||||
to surface the real cause (e.g. "401 Unauthorized").
|
||||
"""
|
||||
while isinstance(exc, BaseExceptionGroup) and exc.exceptions:
|
||||
exc = exc.exceptions[0]
|
||||
# Return a plain Exception so callers can catch normally
|
||||
if isinstance(exc, Exception):
|
||||
return exc
|
||||
return RuntimeError(str(exc))
|
||||
|
||||
|
||||
# ─── hermes mcp add ──────────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_add(args):
|
||||
"""Add a new MCP server with discovery-first tool selection."""
|
||||
name = args.name
|
||||
url = getattr(args, "url", None)
|
||||
command = getattr(args, "command", None)
|
||||
cmd_args = getattr(args, "args", None) or []
|
||||
auth_type = getattr(args, "auth", None)
|
||||
|
||||
# Validate transport
|
||||
if not url and not command:
|
||||
_error("Must specify --url <endpoint> or --command <cmd>")
|
||||
_info("Examples:")
|
||||
_info(' hermes mcp add ink --url "https://mcp.ml.ink/mcp"')
|
||||
_info(' hermes mcp add github --command npx --args @modelcontextprotocol/server-github')
|
||||
return
|
||||
|
||||
# Check if server already exists
|
||||
existing = _get_mcp_servers()
|
||||
if name in existing:
|
||||
if not _confirm(f"Server '{name}' already exists. Overwrite?", default=False):
|
||||
_info("Cancelled.")
|
||||
return
|
||||
|
||||
# Build initial config
|
||||
server_config: Dict[str, Any] = {}
|
||||
if url:
|
||||
server_config["url"] = url
|
||||
else:
|
||||
server_config["command"] = command
|
||||
if cmd_args:
|
||||
server_config["args"] = cmd_args
|
||||
|
||||
# ── Authentication ────────────────────────────────────────────────
|
||||
|
||||
if url and auth_type == "oauth":
|
||||
print()
|
||||
_info(f"Starting OAuth flow for '{name}'...")
|
||||
oauth_ok = False
|
||||
try:
|
||||
from tools.mcp_oauth import build_oauth_auth
|
||||
oauth_auth = build_oauth_auth(name, url)
|
||||
if oauth_auth:
|
||||
server_config["auth"] = "oauth"
|
||||
_success("OAuth configured (tokens will be acquired on first connection)")
|
||||
oauth_ok=True
|
||||
else:
|
||||
_warning("OAuth setup failed — MCP SDK auth module not available")
|
||||
except Exception as exc:
|
||||
_warning(f"OAuth error: {exc}")
|
||||
|
||||
if not oauth_ok:
|
||||
_info("This server may not support OAuth.")
|
||||
if _confirm("Continue without authentication?", default=True):
|
||||
# Don't store auth: oauth — server doesn't support it
|
||||
pass
|
||||
else:
|
||||
_info("Cancelled.")
|
||||
return
|
||||
|
||||
elif url:
|
||||
# Prompt for API key / Bearer token for HTTP servers
|
||||
print()
|
||||
_info(f"Connecting to {url}")
|
||||
needs_auth = _confirm("Does this server require authentication?", default=True)
|
||||
if needs_auth:
|
||||
if auth_type == "header" or not auth_type:
|
||||
env_key = _env_key_for_server(name)
|
||||
existing_key = get_env_value(env_key)
|
||||
if existing_key:
|
||||
_success(f"{env_key}: already configured")
|
||||
api_key = existing_key
|
||||
else:
|
||||
api_key = _prompt("API key / Bearer token", password=True)
|
||||
if api_key:
|
||||
save_env_value(env_key, api_key)
|
||||
_success(f"Saved to ~/.hermes/.env as {env_key}")
|
||||
|
||||
# Set header with env var interpolation
|
||||
if api_key or existing_key:
|
||||
server_config["headers"] = {
|
||||
"Authorization": f"Bearer ${{{env_key}}}"
|
||||
}
|
||||
|
||||
# ── Discovery: connect and list tools ─────────────────────────────
|
||||
|
||||
print()
|
||||
print(color(f" Connecting to '{name}'...", Colors.CYAN))
|
||||
|
||||
try:
|
||||
tools = _probe_single_server(name, server_config)
|
||||
except Exception as exc:
|
||||
_error(f"Failed to connect: {exc}")
|
||||
if _confirm("Save config anyway (you can test later)?", default=False):
|
||||
server_config["enabled"] = False
|
||||
_save_mcp_server(name, server_config)
|
||||
_success(f"Saved '{name}' to config (disabled)")
|
||||
_info("Fix the issue, then: hermes mcp test " + name)
|
||||
return
|
||||
|
||||
if not tools:
|
||||
_warning("Server connected but reported no tools.")
|
||||
if _confirm("Save config anyway?", default=True):
|
||||
_save_mcp_server(name, server_config)
|
||||
_success(f"Saved '{name}' to config")
|
||||
return
|
||||
|
||||
# ── Tool selection ────────────────────────────────────────────────
|
||||
|
||||
print()
|
||||
_success(f"Connected! Found {len(tools)} tool(s) from '{name}':")
|
||||
print()
|
||||
for tool_name, desc in tools:
|
||||
short = desc[:60] + "..." if len(desc) > 60 else desc
|
||||
print(f" {color(tool_name, Colors.GREEN):40s} {short}")
|
||||
print()
|
||||
|
||||
# Ask: enable all, select, or cancel
|
||||
try:
|
||||
choice = input(
|
||||
color(f" Enable all {len(tools)} tools? [Y/n/select]: ", Colors.YELLOW)
|
||||
).strip().lower()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
_info("Cancelled.")
|
||||
return
|
||||
|
||||
if choice in ("n", "no"):
|
||||
_info("Cancelled — server not saved.")
|
||||
return
|
||||
|
||||
if choice in ("s", "select"):
|
||||
# Interactive tool selection
|
||||
from hermes_cli.curses_ui import curses_checklist
|
||||
|
||||
labels = [f"{t[0]} — {t[1]}" for t in tools]
|
||||
pre_selected = set(range(len(tools)))
|
||||
|
||||
chosen = curses_checklist(
|
||||
f"Select tools for '{name}'",
|
||||
labels,
|
||||
pre_selected,
|
||||
)
|
||||
|
||||
if not chosen:
|
||||
_info("No tools selected — server not saved.")
|
||||
return
|
||||
|
||||
chosen_names = [tools[i][0] for i in sorted(chosen)]
|
||||
server_config.setdefault("tools", {})["include"] = chosen_names
|
||||
|
||||
tool_count = len(chosen_names)
|
||||
total = len(tools)
|
||||
else:
|
||||
# Enable all (no filter needed — default behaviour)
|
||||
tool_count = len(tools)
|
||||
total = len(tools)
|
||||
|
||||
# ── Save ──────────────────────────────────────────────────────────
|
||||
|
||||
server_config["enabled"] = True
|
||||
_save_mcp_server(name, server_config)
|
||||
|
||||
print()
|
||||
_success(f"Saved '{name}' to ~/.hermes/config.yaml ({tool_count}/{total} tools enabled)")
|
||||
_info("Start a new session to use these tools.")
|
||||
|
||||
|
||||
# ─── hermes mcp remove ───────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_remove(args):
|
||||
"""Remove an MCP server from config."""
|
||||
name = args.name
|
||||
existing = _get_mcp_servers()
|
||||
|
||||
if name not in existing:
|
||||
_error(f"Server '{name}' not found in config.")
|
||||
servers = list(existing.keys())
|
||||
if servers:
|
||||
_info(f"Available servers: {', '.join(servers)}")
|
||||
return
|
||||
|
||||
if not _confirm(f"Remove server '{name}'?", default=True):
|
||||
_info("Cancelled.")
|
||||
return
|
||||
|
||||
_remove_mcp_server(name)
|
||||
_success(f"Removed '{name}' from config")
|
||||
|
||||
# Clean up OAuth tokens if they exist
|
||||
try:
|
||||
from tools.mcp_oauth import remove_oauth_tokens
|
||||
remove_oauth_tokens(name)
|
||||
_success("Cleaned up OAuth tokens")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ─── hermes mcp list ──────────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_list(args=None):
|
||||
"""List all configured MCP servers."""
|
||||
servers = _get_mcp_servers()
|
||||
|
||||
if not servers:
|
||||
print()
|
||||
_info("No MCP servers configured.")
|
||||
print()
|
||||
_info("Add one with:")
|
||||
_info(' hermes mcp add <name> --url <endpoint>')
|
||||
_info(' hermes mcp add <name> --command <cmd> --args <args...>')
|
||||
print()
|
||||
return
|
||||
|
||||
print()
|
||||
print(color(" MCP Servers:", Colors.CYAN + Colors.BOLD))
|
||||
print()
|
||||
|
||||
# Table header
|
||||
print(f" {'Name':<16} {'Transport':<30} {'Tools':<12} {'Status':<10}")
|
||||
print(f" {'─' * 16} {'─' * 30} {'─' * 12} {'─' * 10}")
|
||||
|
||||
for name, cfg in servers.items():
|
||||
# Transport info
|
||||
if "url" in cfg:
|
||||
url = cfg["url"]
|
||||
# Truncate long URLs
|
||||
if len(url) > 28:
|
||||
url = url[:25] + "..."
|
||||
transport = url
|
||||
elif "command" in cfg:
|
||||
cmd = cfg["command"]
|
||||
cmd_args = cfg.get("args", [])
|
||||
if isinstance(cmd_args, list) and cmd_args:
|
||||
transport = f"{cmd} {' '.join(str(a) for a in cmd_args[:2])}"
|
||||
else:
|
||||
transport = cmd
|
||||
if len(transport) > 28:
|
||||
transport = transport[:25] + "..."
|
||||
else:
|
||||
transport = "?"
|
||||
|
||||
# Tool count
|
||||
tools_cfg = cfg.get("tools", {})
|
||||
if isinstance(tools_cfg, dict):
|
||||
include = tools_cfg.get("include")
|
||||
exclude = tools_cfg.get("exclude")
|
||||
if include and isinstance(include, list):
|
||||
tools_str = f"{len(include)} selected"
|
||||
elif exclude and isinstance(exclude, list):
|
||||
tools_str = f"-{len(exclude)} excluded"
|
||||
else:
|
||||
tools_str = "all"
|
||||
else:
|
||||
tools_str = "all"
|
||||
|
||||
# Enabled status
|
||||
enabled = cfg.get("enabled", True)
|
||||
if isinstance(enabled, str):
|
||||
enabled = enabled.lower() in ("true", "1", "yes")
|
||||
status = color("✓ enabled", Colors.GREEN) if enabled else color("✗ disabled", Colors.DIM)
|
||||
|
||||
print(f" {name:<16} {transport:<30} {tools_str:<12} {status}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# ─── hermes mcp test ──────────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_test(args):
|
||||
"""Test connection to an MCP server."""
|
||||
name = args.name
|
||||
servers = _get_mcp_servers()
|
||||
|
||||
if name not in servers:
|
||||
_error(f"Server '{name}' not found in config.")
|
||||
available = list(servers.keys())
|
||||
if available:
|
||||
_info(f"Available: {', '.join(available)}")
|
||||
return
|
||||
|
||||
cfg = servers[name]
|
||||
print()
|
||||
print(color(f" Testing '{name}'...", Colors.CYAN))
|
||||
|
||||
# Show transport info
|
||||
if "url" in cfg:
|
||||
_info(f"Transport: HTTP → {cfg['url']}")
|
||||
else:
|
||||
cmd = cfg.get("command", "?")
|
||||
_info(f"Transport: stdio → {cmd}")
|
||||
|
||||
# Show auth info (masked)
|
||||
auth_type = cfg.get("auth", "")
|
||||
headers = cfg.get("headers", {})
|
||||
if auth_type == "oauth":
|
||||
_info("Auth: OAuth 2.1 PKCE")
|
||||
elif headers:
|
||||
for k, v in headers.items():
|
||||
if isinstance(v, str) and ("key" in k.lower() or "auth" in k.lower()):
|
||||
# Mask the value
|
||||
resolved = _interpolate_value(v)
|
||||
if len(resolved) > 8:
|
||||
masked = resolved[:4] + "***" + resolved[-4:]
|
||||
else:
|
||||
masked = "***"
|
||||
print(f" {k}: {masked}")
|
||||
else:
|
||||
_info("Auth: none")
|
||||
|
||||
# Attempt connection
|
||||
start = time.monotonic()
|
||||
try:
|
||||
tools = _probe_single_server(name, cfg)
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
except Exception as exc:
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
_error(f"Connection failed ({elapsed_ms:.0f}ms): {exc}")
|
||||
return
|
||||
|
||||
_success(f"Connected ({elapsed_ms:.0f}ms)")
|
||||
_success(f"Tools discovered: {len(tools)}")
|
||||
|
||||
if tools:
|
||||
print()
|
||||
for tool_name, desc in tools:
|
||||
short = desc[:55] + "..." if len(desc) > 55 else desc
|
||||
print(f" {color(tool_name, Colors.GREEN):36s} {short}")
|
||||
print()
|
||||
|
||||
|
||||
def _interpolate_value(value: str) -> str:
|
||||
"""Resolve ``${ENV_VAR}`` references in a string."""
|
||||
def _replace(m):
|
||||
return os.getenv(m.group(1), "")
|
||||
return re.sub(r"\$\{(\w+)\}", _replace, value)
|
||||
|
||||
|
||||
# ─── hermes mcp configure ────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_configure(args):
|
||||
"""Reconfigure which tools are enabled for an existing MCP server."""
|
||||
name = args.name
|
||||
servers = _get_mcp_servers()
|
||||
|
||||
if name not in servers:
|
||||
_error(f"Server '{name}' not found in config.")
|
||||
available = list(servers.keys())
|
||||
if available:
|
||||
_info(f"Available: {', '.join(available)}")
|
||||
return
|
||||
|
||||
cfg = servers[name]
|
||||
|
||||
# Discover all available tools
|
||||
print()
|
||||
print(color(f" Connecting to '{name}' to discover tools...", Colors.CYAN))
|
||||
|
||||
try:
|
||||
all_tools = _probe_single_server(name, cfg)
|
||||
except Exception as exc:
|
||||
_error(f"Failed to connect: {exc}")
|
||||
return
|
||||
|
||||
if not all_tools:
|
||||
_warning("Server reports no tools.")
|
||||
return
|
||||
|
||||
# Determine which are currently enabled
|
||||
tools_cfg = cfg.get("tools", {})
|
||||
if isinstance(tools_cfg, dict):
|
||||
include = tools_cfg.get("include")
|
||||
exclude = tools_cfg.get("exclude")
|
||||
else:
|
||||
include = None
|
||||
exclude = None
|
||||
|
||||
tool_names = [t[0] for t in all_tools]
|
||||
|
||||
if include and isinstance(include, list):
|
||||
include_set = set(include)
|
||||
pre_selected = {
|
||||
i for i, tn in enumerate(tool_names) if tn in include_set
|
||||
}
|
||||
elif exclude and isinstance(exclude, list):
|
||||
exclude_set = set(exclude)
|
||||
pre_selected = {
|
||||
i for i, tn in enumerate(tool_names) if tn not in exclude_set
|
||||
}
|
||||
else:
|
||||
pre_selected = set(range(len(all_tools)))
|
||||
|
||||
currently = len(pre_selected)
|
||||
total = len(all_tools)
|
||||
_info(f"Currently {currently}/{total} tools enabled for '{name}'.")
|
||||
print()
|
||||
|
||||
# Interactive checklist
|
||||
from hermes_cli.curses_ui import curses_checklist
|
||||
|
||||
labels = [f"{t[0]} — {t[1]}" for t in all_tools]
|
||||
|
||||
chosen = curses_checklist(
|
||||
f"Select tools for '{name}'",
|
||||
labels,
|
||||
pre_selected,
|
||||
)
|
||||
|
||||
if chosen == pre_selected:
|
||||
_info("No changes made.")
|
||||
return
|
||||
|
||||
# Update config
|
||||
config = load_config()
|
||||
server_entry = config.get("mcp_servers", {}).get(name, {})
|
||||
|
||||
if len(chosen) == total:
|
||||
# All selected → remove include/exclude (register all)
|
||||
server_entry.pop("tools", None)
|
||||
else:
|
||||
chosen_names = [tool_names[i] for i in sorted(chosen)]
|
||||
server_entry.setdefault("tools", {})
|
||||
server_entry["tools"]["include"] = chosen_names
|
||||
server_entry["tools"].pop("exclude", None)
|
||||
|
||||
config.setdefault("mcp_servers", {})[name] = server_entry
|
||||
save_config(config)
|
||||
|
||||
new_count = len(chosen)
|
||||
_success(f"Updated config: {new_count}/{total} tools enabled")
|
||||
_info("Start a new session for changes to take effect.")
|
||||
|
||||
|
||||
# ─── Dispatcher ───────────────────────────────────────────────────────────────
|
||||
|
||||
def mcp_command(args):
|
||||
"""Main dispatcher for ``hermes mcp`` subcommands."""
|
||||
action = getattr(args, "mcp_action", None)
|
||||
|
||||
handlers = {
|
||||
"add": cmd_mcp_add,
|
||||
"remove": cmd_mcp_remove,
|
||||
"rm": cmd_mcp_remove,
|
||||
"list": cmd_mcp_list,
|
||||
"ls": cmd_mcp_list,
|
||||
"test": cmd_mcp_test,
|
||||
"configure": cmd_mcp_configure,
|
||||
"config": cmd_mcp_configure,
|
||||
}
|
||||
|
||||
handler = handlers.get(action)
|
||||
if handler:
|
||||
handler(args)
|
||||
else:
|
||||
# No subcommand — show list
|
||||
cmd_mcp_list()
|
||||
print(color(" Commands:", Colors.CYAN))
|
||||
_info("hermes mcp add <name> --url <endpoint> Add an MCP server")
|
||||
_info("hermes mcp add <name> --command <cmd> Add a stdio server")
|
||||
_info("hermes mcp remove <name> Remove a server")
|
||||
_info("hermes mcp list List servers")
|
||||
_info("hermes mcp test <name> Test connection")
|
||||
_info("hermes mcp configure <name> Toggle tools")
|
||||
print()
|
||||
400
tests/hermes_cli/test_mcp_config.py
Normal file
400
tests/hermes_cli/test_mcp_config.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""
|
||||
Tests for hermes_cli.mcp_config — ``hermes mcp`` subcommands.
|
||||
|
||||
These tests mock the MCP server connection layer so they run without
|
||||
any actual MCP servers or API keys.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import types
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_config(tmp_path, monkeypatch):
|
||||
"""Redirect all config I/O to a temp directory."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.config.get_hermes_home", lambda: tmp_path
|
||||
)
|
||||
config_path = tmp_path / "config.yaml"
|
||||
env_path = tmp_path / ".env"
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.config.get_config_path", lambda: config_path
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.config.get_env_path", lambda: env_path
|
||||
)
|
||||
return tmp_path
|
||||
|
||||
|
||||
def _make_args(**kwargs):
|
||||
"""Build a minimal argparse.Namespace."""
|
||||
defaults = {
|
||||
"name": "test-server",
|
||||
"url": None,
|
||||
"command": None,
|
||||
"args": None,
|
||||
"auth": None,
|
||||
"mcp_action": None,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return argparse.Namespace(**defaults)
|
||||
|
||||
|
||||
def _seed_config(tmp_path: Path, mcp_servers: dict):
|
||||
"""Write a config.yaml with the given mcp_servers."""
|
||||
import yaml
|
||||
|
||||
config = {"mcp_servers": mcp_servers, "_config_version": 9}
|
||||
config_path = tmp_path / "config.yaml"
|
||||
with open(config_path, "w") as f:
|
||||
yaml.safe_dump(config, f)
|
||||
|
||||
|
||||
class FakeTool:
|
||||
"""Mimics an MCP tool object returned by the SDK."""
|
||||
|
||||
def __init__(self, name: str, description: str = ""):
|
||||
self.name = name
|
||||
self.description = description
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: cmd_mcp_list
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMcpList:
|
||||
def test_list_empty_config(self, tmp_path, capsys):
|
||||
from hermes_cli.mcp_config import cmd_mcp_list
|
||||
|
||||
cmd_mcp_list()
|
||||
out = capsys.readouterr().out
|
||||
assert "No MCP servers configured" in out
|
||||
|
||||
def test_list_with_servers(self, tmp_path, capsys):
|
||||
_seed_config(tmp_path, {
|
||||
"ink": {
|
||||
"url": "https://mcp.ml.ink/mcp",
|
||||
"enabled": True,
|
||||
"tools": {"include": ["create_service", "get_service"]},
|
||||
},
|
||||
"github": {
|
||||
"command": "npx",
|
||||
"args": ["@mcp/github"],
|
||||
"enabled": False,
|
||||
},
|
||||
})
|
||||
from hermes_cli.mcp_config import cmd_mcp_list
|
||||
|
||||
cmd_mcp_list()
|
||||
out = capsys.readouterr().out
|
||||
assert "ink" in out
|
||||
assert "github" in out
|
||||
assert "2 selected" in out # ink has 2 in include
|
||||
assert "disabled" in out # github is disabled
|
||||
|
||||
def test_list_enabled_default_true(self, tmp_path, capsys):
|
||||
"""Server without explicit enabled key defaults to enabled."""
|
||||
_seed_config(tmp_path, {
|
||||
"myserver": {"url": "https://example.com/mcp"},
|
||||
})
|
||||
from hermes_cli.mcp_config import cmd_mcp_list
|
||||
|
||||
cmd_mcp_list()
|
||||
out = capsys.readouterr().out
|
||||
assert "myserver" in out
|
||||
assert "enabled" in out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: cmd_mcp_remove
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMcpRemove:
|
||||
def test_remove_existing_server(self, tmp_path, capsys, monkeypatch):
|
||||
_seed_config(tmp_path, {
|
||||
"myserver": {"url": "https://example.com/mcp"},
|
||||
})
|
||||
monkeypatch.setattr("builtins.input", lambda _: "y")
|
||||
from hermes_cli.mcp_config import cmd_mcp_remove
|
||||
|
||||
cmd_mcp_remove(_make_args(name="myserver"))
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Removed" in out
|
||||
|
||||
# Verify config updated
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
assert "myserver" not in config.get("mcp_servers", {})
|
||||
|
||||
def test_remove_nonexistent(self, tmp_path, capsys):
|
||||
_seed_config(tmp_path, {})
|
||||
from hermes_cli.mcp_config import cmd_mcp_remove
|
||||
|
||||
cmd_mcp_remove(_make_args(name="ghost"))
|
||||
out = capsys.readouterr().out
|
||||
assert "not found" in out
|
||||
|
||||
def test_remove_cleans_oauth_tokens(self, tmp_path, capsys, monkeypatch):
|
||||
_seed_config(tmp_path, {
|
||||
"oauth-srv": {"url": "https://example.com/mcp", "auth": "oauth"},
|
||||
})
|
||||
monkeypatch.setattr("builtins.input", lambda _: "y")
|
||||
# Also patch get_hermes_home in the mcp_config module namespace
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config.get_hermes_home", lambda: tmp_path
|
||||
)
|
||||
|
||||
# Create a fake token file
|
||||
token_dir = tmp_path / "mcp-tokens"
|
||||
token_dir.mkdir()
|
||||
token_file = token_dir / "oauth-srv.json"
|
||||
token_file.write_text("{}")
|
||||
|
||||
from hermes_cli.mcp_config import cmd_mcp_remove
|
||||
|
||||
cmd_mcp_remove(_make_args(name="oauth-srv"))
|
||||
assert not token_file.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: cmd_mcp_add
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMcpAdd:
|
||||
def test_add_no_transport(self, capsys):
|
||||
"""Must specify --url or --command."""
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
|
||||
cmd_mcp_add(_make_args(name="bad"))
|
||||
out = capsys.readouterr().out
|
||||
assert "Must specify" in out
|
||||
|
||||
def test_add_http_server_all_tools(self, tmp_path, capsys, monkeypatch):
|
||||
"""Add an HTTP server, accept all tools."""
|
||||
fake_tools = [
|
||||
FakeTool("create_service", "Deploy from repo"),
|
||||
FakeTool("list_services", "List all services"),
|
||||
]
|
||||
|
||||
def mock_probe(name, config, **kw):
|
||||
return [(t.name, t.description) for t in fake_tools]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._probe_single_server", mock_probe
|
||||
)
|
||||
# No auth, accept all tools
|
||||
inputs = iter(["n", ""]) # no auth needed, enable all
|
||||
monkeypatch.setattr("builtins.input", lambda _: next(inputs))
|
||||
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
|
||||
cmd_mcp_add(_make_args(name="ink", url="https://mcp.ml.ink/mcp"))
|
||||
out = capsys.readouterr().out
|
||||
assert "Saved" in out
|
||||
assert "2/2 tools" in out
|
||||
|
||||
# Verify config written
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
assert "ink" in config.get("mcp_servers", {})
|
||||
assert config["mcp_servers"]["ink"]["url"] == "https://mcp.ml.ink/mcp"
|
||||
|
||||
def test_add_stdio_server(self, tmp_path, capsys, monkeypatch):
|
||||
"""Add a stdio server."""
|
||||
fake_tools = [FakeTool("search", "Search repos")]
|
||||
|
||||
def mock_probe(name, config, **kw):
|
||||
return [(t.name, t.description) for t in fake_tools]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._probe_single_server", mock_probe
|
||||
)
|
||||
inputs = iter([""]) # accept all tools
|
||||
monkeypatch.setattr("builtins.input", lambda _: next(inputs))
|
||||
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
|
||||
cmd_mcp_add(_make_args(
|
||||
name="github",
|
||||
command="npx",
|
||||
args=["@mcp/github"],
|
||||
))
|
||||
out = capsys.readouterr().out
|
||||
assert "Saved" in out
|
||||
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
srv = config["mcp_servers"]["github"]
|
||||
assert srv["command"] == "npx"
|
||||
assert srv["args"] == ["@mcp/github"]
|
||||
|
||||
def test_add_connection_failure_save_disabled(
|
||||
self, tmp_path, capsys, monkeypatch
|
||||
):
|
||||
"""Failed connection → option to save as disabled."""
|
||||
|
||||
def mock_probe_fail(name, config, **kw):
|
||||
raise ConnectionError("Connection refused")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._probe_single_server", mock_probe_fail
|
||||
)
|
||||
inputs = iter(["n", "y"]) # no auth, yes save disabled
|
||||
monkeypatch.setattr("builtins.input", lambda _: next(inputs))
|
||||
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
|
||||
cmd_mcp_add(_make_args(name="broken", url="https://bad.host/mcp"))
|
||||
out = capsys.readouterr().out
|
||||
assert "disabled" in out
|
||||
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
assert config["mcp_servers"]["broken"]["enabled"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: cmd_mcp_test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMcpTest:
|
||||
def test_test_not_found(self, tmp_path, capsys):
|
||||
_seed_config(tmp_path, {})
|
||||
from hermes_cli.mcp_config import cmd_mcp_test
|
||||
|
||||
cmd_mcp_test(_make_args(name="ghost"))
|
||||
out = capsys.readouterr().out
|
||||
assert "not found" in out
|
||||
|
||||
def test_test_success(self, tmp_path, capsys, monkeypatch):
|
||||
_seed_config(tmp_path, {
|
||||
"ink": {"url": "https://mcp.ml.ink/mcp"},
|
||||
})
|
||||
|
||||
def mock_probe(name, config, **kw):
|
||||
return [("create_service", "Deploy"), ("list_services", "List all")]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._probe_single_server", mock_probe
|
||||
)
|
||||
from hermes_cli.mcp_config import cmd_mcp_test
|
||||
|
||||
cmd_mcp_test(_make_args(name="ink"))
|
||||
out = capsys.readouterr().out
|
||||
assert "Connected" in out
|
||||
assert "Tools discovered: 2" in out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: env var interpolation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEnvVarInterpolation:
|
||||
def test_interpolate_simple(self, monkeypatch):
|
||||
monkeypatch.setenv("MY_KEY", "secret123")
|
||||
from tools.mcp_tool import _interpolate_env_vars
|
||||
|
||||
result = _interpolate_env_vars("Bearer ${MY_KEY}")
|
||||
assert result == "Bearer secret123"
|
||||
|
||||
def test_interpolate_missing_var(self, monkeypatch):
|
||||
monkeypatch.delenv("MISSING_VAR", raising=False)
|
||||
from tools.mcp_tool import _interpolate_env_vars
|
||||
|
||||
result = _interpolate_env_vars("Bearer ${MISSING_VAR}")
|
||||
assert result == "Bearer ${MISSING_VAR}"
|
||||
|
||||
def test_interpolate_nested_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("API_KEY", "abc")
|
||||
from tools.mcp_tool import _interpolate_env_vars
|
||||
|
||||
result = _interpolate_env_vars({
|
||||
"url": "https://example.com",
|
||||
"headers": {"Authorization": "Bearer ${API_KEY}"},
|
||||
})
|
||||
assert result["headers"]["Authorization"] == "Bearer abc"
|
||||
assert result["url"] == "https://example.com"
|
||||
|
||||
def test_interpolate_list(self, monkeypatch):
|
||||
monkeypatch.setenv("ARG1", "hello")
|
||||
from tools.mcp_tool import _interpolate_env_vars
|
||||
|
||||
result = _interpolate_env_vars(["${ARG1}", "static"])
|
||||
assert result == ["hello", "static"]
|
||||
|
||||
def test_interpolate_non_string(self):
|
||||
from tools.mcp_tool import _interpolate_env_vars
|
||||
|
||||
assert _interpolate_env_vars(42) == 42
|
||||
assert _interpolate_env_vars(True) is True
|
||||
assert _interpolate_env_vars(None) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: config helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestConfigHelpers:
|
||||
def test_save_and_load_mcp_server(self, tmp_path):
|
||||
from hermes_cli.mcp_config import _save_mcp_server, _get_mcp_servers
|
||||
|
||||
_save_mcp_server("mysvr", {"url": "https://example.com/mcp"})
|
||||
servers = _get_mcp_servers()
|
||||
assert "mysvr" in servers
|
||||
assert servers["mysvr"]["url"] == "https://example.com/mcp"
|
||||
|
||||
def test_remove_mcp_server(self, tmp_path):
|
||||
from hermes_cli.mcp_config import (
|
||||
_save_mcp_server,
|
||||
_remove_mcp_server,
|
||||
_get_mcp_servers,
|
||||
)
|
||||
|
||||
_save_mcp_server("s1", {"command": "test"})
|
||||
_save_mcp_server("s2", {"command": "test2"})
|
||||
result = _remove_mcp_server("s1")
|
||||
assert result is True
|
||||
assert "s1" not in _get_mcp_servers()
|
||||
assert "s2" in _get_mcp_servers()
|
||||
|
||||
def test_remove_nonexistent(self, tmp_path):
|
||||
from hermes_cli.mcp_config import _remove_mcp_server
|
||||
|
||||
assert _remove_mcp_server("ghost") is False
|
||||
|
||||
def test_env_key_for_server(self):
|
||||
from hermes_cli.mcp_config import _env_key_for_server
|
||||
|
||||
assert _env_key_for_server("ink") == "MCP_INK_API_KEY"
|
||||
assert _env_key_for_server("my-server") == "MCP_MY_SERVER_API_KEY"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: dispatcher
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDispatcher:
|
||||
def test_no_action_shows_list(self, tmp_path, capsys):
|
||||
from hermes_cli.mcp_config import mcp_command
|
||||
|
||||
_seed_config(tmp_path, {})
|
||||
mcp_command(_make_args(mcp_action=None))
|
||||
out = capsys.readouterr().out
|
||||
assert "Commands:" in out or "No MCP servers" in out
|
||||
152
tests/tools/test_mcp_oauth.py
Normal file
152
tests/tools/test_mcp_oauth.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""Tests for tools/mcp_oauth.py — thin OAuth adapter over MCP SDK."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.mcp_oauth import (
|
||||
HermesTokenStorage,
|
||||
build_oauth_auth,
|
||||
remove_oauth_tokens,
|
||||
_find_free_port,
|
||||
_can_open_browser,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HermesTokenStorage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHermesTokenStorage:
|
||||
def test_roundtrip_tokens(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("test-server")
|
||||
|
||||
import asyncio
|
||||
|
||||
# Initially empty
|
||||
assert asyncio.run(storage.get_tokens()) is None
|
||||
|
||||
# Save and retrieve
|
||||
mock_token = MagicMock()
|
||||
mock_token.model_dump.return_value = {
|
||||
"access_token": "abc123",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "ref456",
|
||||
}
|
||||
asyncio.run(storage.set_tokens(mock_token))
|
||||
|
||||
# File exists with correct permissions
|
||||
token_path = tmp_path / "mcp-tokens" / "test-server.json"
|
||||
assert token_path.exists()
|
||||
data = json.loads(token_path.read_text())
|
||||
assert data["access_token"] == "abc123"
|
||||
|
||||
def test_roundtrip_client_info(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("test-server")
|
||||
import asyncio
|
||||
|
||||
assert asyncio.run(storage.get_client_info()) is None
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.model_dump.return_value = {
|
||||
"client_id": "hermes-123",
|
||||
"client_secret": "secret",
|
||||
}
|
||||
asyncio.run(storage.set_client_info(mock_client))
|
||||
|
||||
client_path = tmp_path / "mcp-tokens" / "test-server.client.json"
|
||||
assert client_path.exists()
|
||||
|
||||
def test_remove_cleans_up(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("test-server")
|
||||
|
||||
# Create files
|
||||
d = tmp_path / "mcp-tokens"
|
||||
d.mkdir(parents=True)
|
||||
(d / "test-server.json").write_text("{}")
|
||||
(d / "test-server.client.json").write_text("{}")
|
||||
|
||||
storage.remove()
|
||||
assert not (d / "test-server.json").exists()
|
||||
assert not (d / "test-server.client.json").exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_oauth_auth
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBuildOAuthAuth:
|
||||
def test_returns_oauth_provider(self):
|
||||
try:
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
except ImportError:
|
||||
pytest.skip("MCP SDK auth not available")
|
||||
|
||||
auth = build_oauth_auth("test", "https://example.com/mcp")
|
||||
assert isinstance(auth, OAuthClientProvider)
|
||||
|
||||
def test_returns_none_without_sdk(self, monkeypatch):
|
||||
import tools.mcp_oauth as mod
|
||||
orig_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__
|
||||
|
||||
def _block_import(name, *args, **kwargs):
|
||||
if "mcp.client.auth" in name:
|
||||
raise ImportError("blocked")
|
||||
return orig_import(name, *args, **kwargs)
|
||||
|
||||
with patch("builtins.__import__", side_effect=_block_import):
|
||||
result = build_oauth_auth("test", "https://example.com")
|
||||
# May or may not be None depending on import caching, but shouldn't crash
|
||||
assert result is None or result is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Utility functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUtilities:
|
||||
def test_find_free_port_returns_int(self):
|
||||
port = _find_free_port()
|
||||
assert isinstance(port, int)
|
||||
assert 1024 <= port <= 65535
|
||||
|
||||
def test_can_open_browser_false_in_ssh(self, monkeypatch):
|
||||
monkeypatch.setenv("SSH_CLIENT", "1.2.3.4 1234 22")
|
||||
assert _can_open_browser() is False
|
||||
|
||||
def test_can_open_browser_false_without_display(self, monkeypatch):
|
||||
monkeypatch.delenv("SSH_CLIENT", raising=False)
|
||||
monkeypatch.delenv("SSH_TTY", raising=False)
|
||||
monkeypatch.delenv("DISPLAY", raising=False)
|
||||
# Mock os.name and uname for non-macOS, non-Windows
|
||||
monkeypatch.setattr(os, "name", "posix")
|
||||
monkeypatch.setattr(os, "uname", lambda: type("", (), {"sysname": "Linux"})())
|
||||
assert _can_open_browser() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# remove_oauth_tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRemoveOAuthTokens:
|
||||
def test_removes_files(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
d = tmp_path / "mcp-tokens"
|
||||
d.mkdir()
|
||||
(d / "myserver.json").write_text("{}")
|
||||
(d / "myserver.client.json").write_text("{}")
|
||||
|
||||
remove_oauth_tokens("myserver")
|
||||
|
||||
assert not (d / "myserver.json").exists()
|
||||
assert not (d / "myserver.client.json").exists()
|
||||
|
||||
def test_no_error_when_files_missing(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
remove_oauth_tokens("nonexistent") # should not raise
|
||||
235
tools/mcp_oauth.py
Normal file
235
tools/mcp_oauth.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""Thin OAuth adapter for MCP HTTP servers.
|
||||
|
||||
Wraps the MCP SDK's built-in ``OAuthClientProvider`` (which implements
|
||||
``httpx.Auth``) with Hermes-specific token storage and browser-based
|
||||
authorization. The SDK handles all of the heavy lifting: PKCE generation,
|
||||
metadata discovery, dynamic client registration, token exchange, and refresh.
|
||||
|
||||
Usage in mcp_tool.py::
|
||||
|
||||
from tools.mcp_oauth import build_oauth_auth
|
||||
auth = build_oauth_auth(server_name, server_url)
|
||||
# pass ``auth`` as the httpx auth parameter
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import webbrowser
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TOKEN_DIR_NAME = "mcp-tokens"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token storage — persists tokens + client info to ~/.hermes/mcp-tokens/
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class HermesTokenStorage:
|
||||
"""File-backed token storage implementing the MCP SDK's TokenStorage protocol."""
|
||||
|
||||
def __init__(self, server_name: str):
|
||||
self._server_name = server_name
|
||||
|
||||
def _base_dir(self) -> Path:
|
||||
home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
d = home / _TOKEN_DIR_NAME
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
return d
|
||||
|
||||
def _tokens_path(self) -> Path:
|
||||
return self._base_dir() / f"{self._server_name}.json"
|
||||
|
||||
def _client_path(self) -> Path:
|
||||
return self._base_dir() / f"{self._server_name}.client.json"
|
||||
|
||||
# -- TokenStorage protocol (async) --
|
||||
|
||||
async def get_tokens(self):
|
||||
data = self._read_json(self._tokens_path())
|
||||
if not data:
|
||||
return None
|
||||
try:
|
||||
from mcp.shared.auth import OAuthToken
|
||||
return OAuthToken(**data)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def set_tokens(self, tokens) -> None:
|
||||
self._write_json(self._tokens_path(), tokens.model_dump(exclude_none=True))
|
||||
|
||||
async def get_client_info(self):
|
||||
data = self._read_json(self._client_path())
|
||||
if not data:
|
||||
return None
|
||||
try:
|
||||
from mcp.shared.auth import OAuthClientInformationFull
|
||||
return OAuthClientInformationFull(**data)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def set_client_info(self, client_info) -> None:
|
||||
self._write_json(self._client_path(), client_info.model_dump(exclude_none=True))
|
||||
|
||||
# -- helpers --
|
||||
|
||||
@staticmethod
|
||||
def _read_json(path: Path) -> dict | None:
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _write_json(path: Path, data: dict) -> None:
|
||||
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
try:
|
||||
path.chmod(0o600)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def remove(self) -> None:
|
||||
"""Delete stored tokens and client info for this server."""
|
||||
for p in (self._tokens_path(), self._client_path()):
|
||||
try:
|
||||
p.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Browser-based callback handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
class _CallbackHandler(BaseHTTPRequestHandler):
|
||||
auth_code: str | None = None
|
||||
state: str | None = None
|
||||
|
||||
def do_GET(self):
|
||||
qs = parse_qs(urlparse(self.path).query)
|
||||
_CallbackHandler.auth_code = (qs.get("code") or [None])[0]
|
||||
_CallbackHandler.state = (qs.get("state") or [None])[0]
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(b"<html><body><h3>Authorization complete. You can close this tab.</h3></body></html>")
|
||||
|
||||
def log_message(self, *_args: Any) -> None:
|
||||
pass # suppress HTTP log noise
|
||||
|
||||
|
||||
async def _redirect_to_browser(auth_url: str) -> None:
|
||||
"""Open the authorization URL in the user's browser."""
|
||||
try:
|
||||
if _can_open_browser():
|
||||
webbrowser.open(auth_url)
|
||||
print(f" Opened browser for authorization...")
|
||||
else:
|
||||
print(f"\n Open this URL to authorize:\n {auth_url}\n")
|
||||
except Exception:
|
||||
print(f"\n Open this URL to authorize:\n {auth_url}\n")
|
||||
|
||||
|
||||
async def _wait_for_callback() -> tuple[str, str | None]:
|
||||
"""Start a local HTTP server and wait for the OAuth redirect callback."""
|
||||
port = _find_free_port()
|
||||
server = HTTPServer(("127.0.0.1", port), _CallbackHandler)
|
||||
_CallbackHandler.auth_code = None
|
||||
_CallbackHandler.state = None
|
||||
|
||||
def _serve():
|
||||
server.timeout = 120
|
||||
server.handle_request()
|
||||
|
||||
thread = threading.Thread(target=_serve, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Wait for the callback
|
||||
for _ in range(1200): # 120 seconds
|
||||
await asyncio.sleep(0.1)
|
||||
if _CallbackHandler.auth_code is not None:
|
||||
break
|
||||
|
||||
server.server_close()
|
||||
code = _CallbackHandler.auth_code or ""
|
||||
state = _CallbackHandler.state
|
||||
if not code:
|
||||
# Fallback to manual entry
|
||||
print(" Browser callback timed out. Paste the authorization code manually:")
|
||||
code = input(" Code: ").strip()
|
||||
return code, state
|
||||
|
||||
|
||||
def _can_open_browser() -> bool:
|
||||
if os.environ.get("SSH_CLIENT") or os.environ.get("SSH_TTY"):
|
||||
return False
|
||||
if not os.environ.get("DISPLAY") and os.name != "nt" and "darwin" not in os.uname().sysname.lower():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_oauth_auth(server_name: str, server_url: str):
|
||||
"""Build an ``httpx.Auth`` handler for the given MCP server using OAuth 2.1 PKCE.
|
||||
|
||||
Uses the MCP SDK's ``OAuthClientProvider`` which handles discovery,
|
||||
registration, PKCE, token exchange, and refresh automatically.
|
||||
|
||||
Returns an ``OAuthClientProvider`` instance (implements ``httpx.Auth``),
|
||||
or ``None`` if the MCP SDK auth module is not available.
|
||||
"""
|
||||
try:
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
from mcp.shared.auth import OAuthClientMetadata
|
||||
except ImportError:
|
||||
logger.warning("MCP SDK auth module not available — OAuth disabled")
|
||||
return None
|
||||
|
||||
port = _find_free_port()
|
||||
redirect_uri = f"http://127.0.0.1:{port}/callback"
|
||||
|
||||
client_metadata = OAuthClientMetadata(
|
||||
client_name="Hermes Agent",
|
||||
redirect_uris=[redirect_uri],
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
scope="openid profile email offline_access",
|
||||
token_endpoint_auth_method="none",
|
||||
)
|
||||
|
||||
storage = HermesTokenStorage(server_name)
|
||||
|
||||
return OAuthClientProvider(
|
||||
server_url=server_url,
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
redirect_handler=_redirect_to_browser,
|
||||
callback_handler=_wait_for_callback,
|
||||
timeout=120.0,
|
||||
)
|
||||
|
||||
|
||||
def remove_oauth_tokens(server_name: str) -> None:
|
||||
"""Delete stored OAuth tokens and client info for a server."""
|
||||
HermesTokenStorage(server_name).remove()
|
||||
@@ -690,7 +690,7 @@ class MCPServerTask:
|
||||
__slots__ = (
|
||||
"name", "session", "tool_timeout",
|
||||
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
|
||||
"_sampling", "_registered_tool_names",
|
||||
"_sampling", "_registered_tool_names", "_auth_type",
|
||||
)
|
||||
|
||||
def __init__(self, name: str):
|
||||
@@ -705,6 +705,7 @@ class MCPServerTask:
|
||||
self._config: dict = {}
|
||||
self._sampling: Optional[SamplingHandler] = None
|
||||
self._registered_tool_names: list[str] = []
|
||||
self._auth_type: str = ""
|
||||
|
||||
def _is_http(self) -> bool:
|
||||
"""Check if this server uses HTTP transport."""
|
||||
@@ -748,15 +749,28 @@ class MCPServerTask:
|
||||
)
|
||||
|
||||
url = config["url"]
|
||||
headers = config.get("headers")
|
||||
headers = dict(config.get("headers") or {})
|
||||
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
||||
|
||||
# OAuth 2.1 PKCE: build httpx.Auth handler using the MCP SDK
|
||||
_oauth_auth = None
|
||||
if self._auth_type == "oauth":
|
||||
try:
|
||||
from tools.mcp_oauth import build_oauth_auth
|
||||
_oauth_auth = build_oauth_auth(self.name, url)
|
||||
except Exception as exc:
|
||||
logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc)
|
||||
|
||||
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
|
||||
async with streamablehttp_client(
|
||||
url,
|
||||
headers=headers,
|
||||
timeout=float(connect_timeout),
|
||||
) as (read_stream, write_stream, _get_session_id):
|
||||
_http_kwargs: dict = {
|
||||
"headers": headers,
|
||||
"timeout": float(connect_timeout),
|
||||
}
|
||||
if _oauth_auth is not None:
|
||||
_http_kwargs["auth"] = _oauth_auth
|
||||
async with streamablehttp_client(url, **_http_kwargs) as (
|
||||
read_stream, write_stream, _get_session_id,
|
||||
):
|
||||
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
@@ -783,6 +797,7 @@ class MCPServerTask:
|
||||
"""
|
||||
self._config = config
|
||||
self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT)
|
||||
self._auth_type = config.get("auth", "").lower().strip()
|
||||
|
||||
# Set up sampling handler if enabled and SDK types are available
|
||||
sampling_config = config.get("sampling", {})
|
||||
@@ -920,13 +935,30 @@ def _run_on_mcp_loop(coro, timeout: float = 30):
|
||||
# Config loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _interpolate_env_vars(value):
|
||||
"""Recursively resolve ``${VAR}`` placeholders from ``os.environ``."""
|
||||
if isinstance(value, str):
|
||||
import re
|
||||
def _replace(m):
|
||||
return os.environ.get(m.group(1), m.group(0))
|
||||
return re.sub(r"\$\{([^}]+)\}", _replace, value)
|
||||
if isinstance(value, dict):
|
||||
return {k: _interpolate_env_vars(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_interpolate_env_vars(v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
def _load_mcp_config() -> Dict[str, dict]:
|
||||
"""Read ``mcp_servers`` from the Hermes config file.
|
||||
|
||||
Returns a dict of ``{server_name: server_config}`` or empty dict.
|
||||
Server config can contain either ``command``/``args``/``env`` for stdio
|
||||
transport or ``url``/``headers`` for HTTP transport, plus optional
|
||||
``timeout`` and ``connect_timeout`` overrides.
|
||||
``timeout``, ``connect_timeout``, and ``auth`` overrides.
|
||||
|
||||
``${ENV_VAR}`` placeholders in string values are resolved from
|
||||
``os.environ`` (which includes ``~/.hermes/.env`` loaded at startup).
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
@@ -934,7 +966,13 @@ def _load_mcp_config() -> Dict[str, dict]:
|
||||
servers = config.get("mcp_servers")
|
||||
if not servers or not isinstance(servers, dict):
|
||||
return {}
|
||||
return servers
|
||||
# Ensure .env vars are available for interpolation
|
||||
try:
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
load_hermes_dotenv()
|
||||
except Exception:
|
||||
pass
|
||||
return {name: _interpolate_env_vars(cfg) for name, cfg in servers.items()}
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to load MCP config: %s", exc)
|
||||
return {}
|
||||
|
||||
Reference in New Issue
Block a user