feat(mcp): dynamic tool discovery via notifications/tools/list_changed (#3812)
When a connected MCP server sends a ToolListChangedNotification (per the MCP spec), Hermes now automatically re-fetches the tool list, deregisters removed tools, and registers new ones — without requiring a restart. This enables MCP servers with dynamic toolsets (e.g. GitHub MCP with GITHUB_DYNAMIC_TOOLSETS=1) to add/remove tools at runtime. Changes: - registry.py: add ToolRegistry.deregister() for nuke-and-repave refresh - mcp_tool.py: extract _register_server_tools() from _discover_and_register_server() as a shared helper for both initial discovery and dynamic refresh - mcp_tool.py: add _make_message_handler() and _refresh_tools() on MCPServerTask, wired into all 3 ClientSession sites (stdio, new HTTP, deprecated HTTP) - Graceful degradation: silently falls back to static discovery when the MCP SDK lacks notification types or message_handler support - 8 new tests covering registration, refresh, handler dispatch, and deregister Salvaged from PR #1794 by shivvor2.
This commit is contained in:
@@ -70,6 +70,7 @@ Thread safety:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
@@ -89,6 +90,8 @@ logger = logging.getLogger(__name__)
|
||||
_MCP_AVAILABLE = False
|
||||
_MCP_HTTP_AVAILABLE = False
|
||||
_MCP_SAMPLING_TYPES = False
|
||||
_MCP_NOTIFICATION_TYPES = False
|
||||
_MCP_MESSAGE_HANDLER_SUPPORTED = False
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
@@ -119,9 +122,39 @@ try:
|
||||
_MCP_SAMPLING_TYPES = True
|
||||
except ImportError:
|
||||
logger.debug("MCP sampling types not available -- sampling disabled")
|
||||
# Notification types for dynamic tool discovery (tools/list_changed)
|
||||
try:
|
||||
from mcp.types import (
|
||||
ServerNotification,
|
||||
ToolListChangedNotification,
|
||||
PromptListChangedNotification,
|
||||
ResourceListChangedNotification,
|
||||
)
|
||||
_MCP_NOTIFICATION_TYPES = True
|
||||
except ImportError:
|
||||
logger.debug("MCP notification types not available -- dynamic tool discovery disabled")
|
||||
except ImportError:
|
||||
logger.debug("mcp package not installed -- MCP tool support disabled")
|
||||
|
||||
|
||||
def _check_message_handler_support() -> bool:
|
||||
"""Check if ClientSession accepts ``message_handler`` kwarg.
|
||||
|
||||
Inspects the constructor signature for backward compatibility with older
|
||||
MCP SDK versions that don't support notification handlers.
|
||||
"""
|
||||
if not _MCP_AVAILABLE:
|
||||
return False
|
||||
try:
|
||||
return "message_handler" in inspect.signature(ClientSession).parameters
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
_MCP_MESSAGE_HANDLER_SUPPORTED = _check_message_handler_support()
|
||||
if _MCP_AVAILABLE and not _MCP_MESSAGE_HANDLER_SUPPORTED:
|
||||
logger.debug("MCP SDK does not support message_handler -- dynamic tool discovery disabled")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -697,7 +730,7 @@ class MCPServerTask:
|
||||
__slots__ = (
|
||||
"name", "session", "tool_timeout",
|
||||
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
|
||||
"_sampling", "_registered_tool_names", "_auth_type",
|
||||
"_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock",
|
||||
)
|
||||
|
||||
def __init__(self, name: str):
|
||||
@@ -713,11 +746,80 @@ class MCPServerTask:
|
||||
self._sampling: Optional[SamplingHandler] = None
|
||||
self._registered_tool_names: list[str] = []
|
||||
self._auth_type: str = ""
|
||||
self._refresh_lock = asyncio.Lock()
|
||||
|
||||
def _is_http(self) -> bool:
|
||||
"""Check if this server uses HTTP transport."""
|
||||
return "url" in self._config
|
||||
|
||||
# ----- Dynamic tool discovery (notifications/tools/list_changed) -----
|
||||
|
||||
def _make_message_handler(self):
|
||||
"""Build a ``message_handler`` callback for ``ClientSession``.
|
||||
|
||||
Dispatches on notification type. Only ``ToolListChangedNotification``
|
||||
triggers a refresh; prompt and resource change notifications are
|
||||
logged as stubs for future work.
|
||||
"""
|
||||
async def _handler(message):
|
||||
try:
|
||||
if isinstance(message, Exception):
|
||||
logger.debug("MCP message handler (%s): exception: %s", self.name, message)
|
||||
return
|
||||
if _MCP_NOTIFICATION_TYPES and isinstance(message, ServerNotification):
|
||||
match message.root:
|
||||
case ToolListChangedNotification():
|
||||
logger.info(
|
||||
"MCP server '%s': received tools/list_changed notification",
|
||||
self.name,
|
||||
)
|
||||
await self._refresh_tools()
|
||||
case PromptListChangedNotification():
|
||||
logger.debug("MCP server '%s': prompts/list_changed (ignored)", self.name)
|
||||
case ResourceListChangedNotification():
|
||||
logger.debug("MCP server '%s': resources/list_changed (ignored)", self.name)
|
||||
case _:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("Error in MCP message handler for '%s'", self.name)
|
||||
return _handler
|
||||
|
||||
async def _refresh_tools(self):
|
||||
"""Re-fetch tools from the server and update the registry.
|
||||
|
||||
Called when the server sends ``notifications/tools/list_changed``.
|
||||
The lock prevents overlapping refreshes from rapid-fire notifications.
|
||||
After the initial ``await`` (list_tools), all mutations are synchronous
|
||||
— atomic from the event loop's perspective.
|
||||
"""
|
||||
from tools.registry import registry
|
||||
from toolsets import TOOLSETS
|
||||
|
||||
async with self._refresh_lock:
|
||||
# 1. Fetch current tool list from server
|
||||
tools_result = await self.session.list_tools()
|
||||
new_mcp_tools = tools_result.tools if hasattr(tools_result, "tools") else []
|
||||
|
||||
# 2. Remove old tools from hermes-* umbrella toolsets
|
||||
for ts_name, ts in TOOLSETS.items():
|
||||
if ts_name.startswith("hermes-"):
|
||||
ts["tools"] = [t for t in ts["tools"] if t not in self._registered_tool_names]
|
||||
|
||||
# 3. Deregister old tools from the central registry
|
||||
for prefixed_name in self._registered_tool_names:
|
||||
registry.deregister(prefixed_name)
|
||||
|
||||
# 4. Re-register with fresh tool list
|
||||
self._tools = new_mcp_tools
|
||||
self._registered_tool_names = _register_server_tools(
|
||||
self.name, self, self._config
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"MCP server '%s': dynamically refreshed %d tool(s)",
|
||||
self.name, len(self._registered_tool_names),
|
||||
)
|
||||
|
||||
async def _run_stdio(self, config: dict):
|
||||
"""Run the server using stdio transport."""
|
||||
command = config.get("command")
|
||||
@@ -738,6 +840,8 @@ class MCPServerTask:
|
||||
)
|
||||
|
||||
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
|
||||
if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED:
|
||||
sampling_kwargs["message_handler"] = self._make_message_handler()
|
||||
async with stdio_client(server_params) as (read_stream, write_stream):
|
||||
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
|
||||
await session.initialize()
|
||||
@@ -769,6 +873,8 @@ class MCPServerTask:
|
||||
logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc)
|
||||
|
||||
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
|
||||
if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED:
|
||||
sampling_kwargs["message_handler"] = self._make_message_handler()
|
||||
|
||||
if _MCP_NEW_HTTP:
|
||||
# New API (mcp >= 1.24.0): build an explicit httpx.AsyncClient
|
||||
@@ -1522,24 +1628,19 @@ def _existing_tool_names() -> List[str]:
|
||||
return names
|
||||
|
||||
|
||||
async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
"""Connect to a single MCP server, discover tools, and register them.
|
||||
def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> List[str]:
|
||||
"""Register tools from an already-connected server into the registry.
|
||||
|
||||
Also registers utility tools for MCP Resources and Prompts support
|
||||
(list_resources, read_resource, list_prompts, get_prompt).
|
||||
Handles include/exclude filtering, utility tools, toolset creation,
|
||||
and hermes-* umbrella toolset injection.
|
||||
|
||||
Returns list of registered tool names.
|
||||
Used by both initial discovery and dynamic refresh (list_changed).
|
||||
|
||||
Returns:
|
||||
List of registered prefixed tool names.
|
||||
"""
|
||||
from tools.registry import registry
|
||||
from toolsets import create_custom_toolset
|
||||
|
||||
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
||||
server = await asyncio.wait_for(
|
||||
_connect_server(name, config),
|
||||
timeout=connect_timeout,
|
||||
)
|
||||
with _lock:
|
||||
_servers[name] = server
|
||||
from toolsets import create_custom_toolset, TOOLSETS
|
||||
|
||||
registered_names: List[str] = []
|
||||
toolset_name = f"mcp-{name}"
|
||||
@@ -1625,8 +1726,6 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
)
|
||||
registered_names.append(util_name)
|
||||
|
||||
server._registered_tool_names = list(registered_names)
|
||||
|
||||
# Create a custom toolset so these tools are discoverable
|
||||
if registered_names:
|
||||
create_custom_toolset(
|
||||
@@ -1634,6 +1733,31 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
description=f"MCP tools from {name} server",
|
||||
tools=registered_names,
|
||||
)
|
||||
# Inject into hermes-* umbrella toolsets for default behavior
|
||||
for ts_name, ts in TOOLSETS.items():
|
||||
if ts_name.startswith("hermes-"):
|
||||
for tool_name in registered_names:
|
||||
if tool_name not in ts["tools"]:
|
||||
ts["tools"].append(tool_name)
|
||||
|
||||
return registered_names
|
||||
|
||||
|
||||
async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
"""Connect to a single MCP server, discover tools, and register them.
|
||||
|
||||
Returns list of registered tool names.
|
||||
"""
|
||||
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
||||
server = await asyncio.wait_for(
|
||||
_connect_server(name, config),
|
||||
timeout=connect_timeout,
|
||||
)
|
||||
with _lock:
|
||||
_servers[name] = server
|
||||
|
||||
registered_names = _register_server_tools(name, server, config)
|
||||
server._registered_tool_names = list(registered_names)
|
||||
|
||||
transport_type = "HTTP" if "url" in config else "stdio"
|
||||
logger.info(
|
||||
|
||||
@@ -87,6 +87,23 @@ class ToolRegistry:
|
||||
if check_fn and toolset not in self._toolset_checks:
|
||||
self._toolset_checks[toolset] = check_fn
|
||||
|
||||
def deregister(self, name: str) -> None:
|
||||
"""Remove a tool from the registry.
|
||||
|
||||
Also cleans up the toolset check if no other tools remain in the
|
||||
same toolset. Used by MCP dynamic tool discovery to nuke-and-repave
|
||||
when a server sends ``notifications/tools/list_changed``.
|
||||
"""
|
||||
entry = self._tools.pop(name, None)
|
||||
if entry is None:
|
||||
return
|
||||
# Drop the toolset check if this was the last tool in that toolset
|
||||
if entry.toolset in self._toolset_checks and not any(
|
||||
e.toolset == entry.toolset for e in self._tools.values()
|
||||
):
|
||||
self._toolset_checks.pop(entry.toolset, None)
|
||||
logger.debug("Deregistered tool: %s", name)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Schema retrieval
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user