feat(mcp): dynamic tool discovery via notifications/tools/list_changed (#3812)

When a connected MCP server sends a ToolListChangedNotification (per the
MCP spec), Hermes now automatically re-fetches the tool list, deregisters
removed tools, and registers new ones — without requiring a restart.

This enables MCP servers with dynamic toolsets (e.g. GitHub MCP with
GITHUB_DYNAMIC_TOOLSETS=1) to add/remove tools at runtime.

Changes:
- registry.py: add ToolRegistry.deregister() for nuke-and-repave refresh
- mcp_tool.py: extract _register_server_tools() from
  _discover_and_register_server() as a shared helper for both initial
  discovery and dynamic refresh
- mcp_tool.py: add _make_message_handler() and _refresh_tools() on
  MCPServerTask, wired into all 3 ClientSession sites (stdio, new HTTP,
  deprecated HTTP)
- Graceful degradation: silently falls back to static discovery when the
  MCP SDK lacks notification types or message_handler support
- 8 new tests covering registration, refresh, handler dispatch, and
  deregister

Salvaged from PR #1794 by shivvor2.
This commit is contained in:
Teknium
2026-03-29 15:52:54 -07:00
committed by GitHub
parent bf84cdfa5e
commit d5d22fe7ba
3 changed files with 328 additions and 17 deletions

View File

@@ -0,0 +1,170 @@
"""Tests for MCP dynamic tool discovery (notifications/tools/list_changed)."""
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from tools.mcp_tool import MCPServerTask, _register_server_tools
from tools.registry import ToolRegistry
def _make_mcp_tool(name: str, desc: str = ""):
return SimpleNamespace(name=name, description=desc, inputSchema=None)
class TestRegisterServerTools:
"""Tests for the extracted _register_server_tools helper."""
@pytest.fixture
def mock_registry(self):
return ToolRegistry()
@pytest.fixture
def mock_toolsets(self):
return {
"hermes-cli": {"tools": ["terminal"], "description": "CLI", "includes": []},
"hermes-telegram": {"tools": ["terminal"], "description": "TG", "includes": []},
"custom-toolset": {"tools": [], "description": "Other", "includes": []},
}
def test_injects_hermes_toolsets(self, mock_registry, mock_toolsets):
"""Tools are injected into hermes-* toolsets but not custom ones."""
server = MCPServerTask("my_srv")
server._tools = [_make_mcp_tool("my_tool", "desc")]
server.session = MagicMock()
with patch("tools.registry.registry", mock_registry), \
patch("toolsets.create_custom_toolset"), \
patch.dict("toolsets.TOOLSETS", mock_toolsets, clear=True):
registered = _register_server_tools("my_srv", server, {})
assert "mcp_my_srv_my_tool" in registered
assert "mcp_my_srv_my_tool" in mock_registry.get_all_tool_names()
# Injected into hermes-* toolsets
assert "mcp_my_srv_my_tool" in mock_toolsets["hermes-cli"]["tools"]
assert "mcp_my_srv_my_tool" in mock_toolsets["hermes-telegram"]["tools"]
# NOT into non-hermes toolsets
assert "mcp_my_srv_my_tool" not in mock_toolsets["custom-toolset"]["tools"]
class TestRefreshTools:
"""Tests for MCPServerTask._refresh_tools nuke-and-repave cycle."""
@pytest.fixture
def mock_registry(self):
return ToolRegistry()
@pytest.fixture
def mock_toolsets(self):
return {
"hermes-cli": {"tools": ["terminal"], "description": "CLI", "includes": []},
"hermes-telegram": {"tools": ["terminal"], "description": "TG", "includes": []},
}
@pytest.mark.asyncio
async def test_nuke_and_repave(self, mock_registry, mock_toolsets):
"""Old tools are removed and new tools registered on refresh."""
server = MCPServerTask("live_srv")
server._refresh_lock = asyncio.Lock()
server._config = {}
# Seed initial state: one old tool registered
mock_registry.register(
name="mcp_live_srv_old_tool", toolset="mcp-live_srv", schema={},
handler=lambda x: x, check_fn=lambda: True, is_async=False,
description="", emoji="",
)
server._registered_tool_names = ["mcp_live_srv_old_tool"]
mock_toolsets["hermes-cli"]["tools"].append("mcp_live_srv_old_tool")
# New tool list from server
new_tool = _make_mcp_tool("new_tool", "new behavior")
server.session = SimpleNamespace(
list_tools=AsyncMock(
return_value=SimpleNamespace(tools=[new_tool])
)
)
with patch("tools.registry.registry", mock_registry), \
patch("toolsets.create_custom_toolset"), \
patch.dict("toolsets.TOOLSETS", mock_toolsets, clear=True):
await server._refresh_tools()
# Old tool completely gone
assert "mcp_live_srv_old_tool" not in mock_registry.get_all_tool_names()
assert "mcp_live_srv_old_tool" not in mock_toolsets["hermes-cli"]["tools"]
# New tool registered
assert "mcp_live_srv_new_tool" in mock_registry.get_all_tool_names()
assert "mcp_live_srv_new_tool" in mock_toolsets["hermes-cli"]["tools"]
assert server._registered_tool_names == ["mcp_live_srv_new_tool"]
class TestMessageHandler:
"""Tests for MCPServerTask._make_message_handler dispatch."""
@pytest.mark.asyncio
async def test_dispatches_tool_list_changed(self):
from tools.mcp_tool import _MCP_NOTIFICATION_TYPES
if not _MCP_NOTIFICATION_TYPES:
pytest.skip("MCP SDK ToolListChangedNotification not available")
from mcp.types import ServerNotification, ToolListChangedNotification
server = MCPServerTask("notif_srv")
with patch.object(MCPServerTask, "_refresh_tools", new_callable=AsyncMock) as mock_refresh:
handler = server._make_message_handler()
notification = ServerNotification(
root=ToolListChangedNotification(method="notifications/tools/list_changed")
)
await handler(notification)
mock_refresh.assert_awaited_once()
@pytest.mark.asyncio
async def test_ignores_exceptions_and_other_messages(self):
server = MCPServerTask("notif_srv")
with patch.object(MCPServerTask, "_refresh_tools", new_callable=AsyncMock) as mock_refresh:
handler = server._make_message_handler()
# Exceptions should not trigger refresh
await handler(RuntimeError("connection dead"))
# Unknown message types should not trigger refresh
await handler({"jsonrpc": "2.0", "result": "ok"})
mock_refresh.assert_not_awaited()
class TestDeregister:
"""Tests for ToolRegistry.deregister."""
def test_removes_tool(self):
reg = ToolRegistry()
reg.register(name="foo", toolset="ts1", schema={}, handler=lambda x: x)
assert "foo" in reg.get_all_tool_names()
reg.deregister("foo")
assert "foo" not in reg.get_all_tool_names()
def test_cleans_up_toolset_check(self):
reg = ToolRegistry()
check = lambda: True # noqa: E731
reg.register(name="foo", toolset="ts1", schema={}, handler=lambda x: x, check_fn=check)
assert reg.is_toolset_available("ts1")
reg.deregister("foo")
# Toolset check should be gone since no tools remain
assert "ts1" not in reg._toolset_checks
def test_preserves_toolset_check_if_other_tools_remain(self):
reg = ToolRegistry()
check = lambda: True # noqa: E731
reg.register(name="foo", toolset="ts1", schema={}, handler=lambda x: x, check_fn=check)
reg.register(name="bar", toolset="ts1", schema={}, handler=lambda x: x)
reg.deregister("foo")
# bar still in ts1, so check should remain
assert "ts1" in reg._toolset_checks
def test_noop_for_unknown_tool(self):
reg = ToolRegistry()
reg.deregister("nonexistent") # Should not raise

View File

@@ -70,6 +70,7 @@ Thread safety:
""" """
import asyncio import asyncio
import inspect
import json import json
import logging import logging
import math import math
@@ -89,6 +90,8 @@ logger = logging.getLogger(__name__)
_MCP_AVAILABLE = False _MCP_AVAILABLE = False
_MCP_HTTP_AVAILABLE = False _MCP_HTTP_AVAILABLE = False
_MCP_SAMPLING_TYPES = False _MCP_SAMPLING_TYPES = False
_MCP_NOTIFICATION_TYPES = False
_MCP_MESSAGE_HANDLER_SUPPORTED = False
try: try:
from mcp import ClientSession, StdioServerParameters from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client from mcp.client.stdio import stdio_client
@@ -119,9 +122,39 @@ try:
_MCP_SAMPLING_TYPES = True _MCP_SAMPLING_TYPES = True
except ImportError: except ImportError:
logger.debug("MCP sampling types not available -- sampling disabled") logger.debug("MCP sampling types not available -- sampling disabled")
# Notification types for dynamic tool discovery (tools/list_changed)
try:
from mcp.types import (
ServerNotification,
ToolListChangedNotification,
PromptListChangedNotification,
ResourceListChangedNotification,
)
_MCP_NOTIFICATION_TYPES = True
except ImportError:
logger.debug("MCP notification types not available -- dynamic tool discovery disabled")
except ImportError: except ImportError:
logger.debug("mcp package not installed -- MCP tool support disabled") logger.debug("mcp package not installed -- MCP tool support disabled")
def _check_message_handler_support() -> bool:
"""Check if ClientSession accepts ``message_handler`` kwarg.
Inspects the constructor signature for backward compatibility with older
MCP SDK versions that don't support notification handlers.
"""
if not _MCP_AVAILABLE:
return False
try:
return "message_handler" in inspect.signature(ClientSession).parameters
except (TypeError, ValueError):
return False
_MCP_MESSAGE_HANDLER_SUPPORTED = _check_message_handler_support()
if _MCP_AVAILABLE and not _MCP_MESSAGE_HANDLER_SUPPORTED:
logger.debug("MCP SDK does not support message_handler -- dynamic tool discovery disabled")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Constants # Constants
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -697,7 +730,7 @@ class MCPServerTask:
__slots__ = ( __slots__ = (
"name", "session", "tool_timeout", "name", "session", "tool_timeout",
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config", "_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
"_sampling", "_registered_tool_names", "_auth_type", "_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock",
) )
def __init__(self, name: str): def __init__(self, name: str):
@@ -713,11 +746,80 @@ class MCPServerTask:
self._sampling: Optional[SamplingHandler] = None self._sampling: Optional[SamplingHandler] = None
self._registered_tool_names: list[str] = [] self._registered_tool_names: list[str] = []
self._auth_type: str = "" self._auth_type: str = ""
self._refresh_lock = asyncio.Lock()
def _is_http(self) -> bool: def _is_http(self) -> bool:
"""Check if this server uses HTTP transport.""" """Check if this server uses HTTP transport."""
return "url" in self._config return "url" in self._config
# ----- Dynamic tool discovery (notifications/tools/list_changed) -----
def _make_message_handler(self):
"""Build a ``message_handler`` callback for ``ClientSession``.
Dispatches on notification type. Only ``ToolListChangedNotification``
triggers a refresh; prompt and resource change notifications are
logged as stubs for future work.
"""
async def _handler(message):
try:
if isinstance(message, Exception):
logger.debug("MCP message handler (%s): exception: %s", self.name, message)
return
if _MCP_NOTIFICATION_TYPES and isinstance(message, ServerNotification):
match message.root:
case ToolListChangedNotification():
logger.info(
"MCP server '%s': received tools/list_changed notification",
self.name,
)
await self._refresh_tools()
case PromptListChangedNotification():
logger.debug("MCP server '%s': prompts/list_changed (ignored)", self.name)
case ResourceListChangedNotification():
logger.debug("MCP server '%s': resources/list_changed (ignored)", self.name)
case _:
pass
except Exception:
logger.exception("Error in MCP message handler for '%s'", self.name)
return _handler
async def _refresh_tools(self):
"""Re-fetch tools from the server and update the registry.
Called when the server sends ``notifications/tools/list_changed``.
The lock prevents overlapping refreshes from rapid-fire notifications.
After the initial ``await`` (list_tools), all mutations are synchronous
— atomic from the event loop's perspective.
"""
from tools.registry import registry
from toolsets import TOOLSETS
async with self._refresh_lock:
# 1. Fetch current tool list from server
tools_result = await self.session.list_tools()
new_mcp_tools = tools_result.tools if hasattr(tools_result, "tools") else []
# 2. Remove old tools from hermes-* umbrella toolsets
for ts_name, ts in TOOLSETS.items():
if ts_name.startswith("hermes-"):
ts["tools"] = [t for t in ts["tools"] if t not in self._registered_tool_names]
# 3. Deregister old tools from the central registry
for prefixed_name in self._registered_tool_names:
registry.deregister(prefixed_name)
# 4. Re-register with fresh tool list
self._tools = new_mcp_tools
self._registered_tool_names = _register_server_tools(
self.name, self, self._config
)
logger.info(
"MCP server '%s': dynamically refreshed %d tool(s)",
self.name, len(self._registered_tool_names),
)
async def _run_stdio(self, config: dict): async def _run_stdio(self, config: dict):
"""Run the server using stdio transport.""" """Run the server using stdio transport."""
command = config.get("command") command = config.get("command")
@@ -738,6 +840,8 @@ class MCPServerTask:
) )
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {} sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED:
sampling_kwargs["message_handler"] = self._make_message_handler()
async with stdio_client(server_params) as (read_stream, write_stream): async with stdio_client(server_params) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session: async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
await session.initialize() await session.initialize()
@@ -769,6 +873,8 @@ class MCPServerTask:
logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc) logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc)
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {} sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED:
sampling_kwargs["message_handler"] = self._make_message_handler()
if _MCP_NEW_HTTP: if _MCP_NEW_HTTP:
# New API (mcp >= 1.24.0): build an explicit httpx.AsyncClient # New API (mcp >= 1.24.0): build an explicit httpx.AsyncClient
@@ -1522,24 +1628,19 @@ def _existing_tool_names() -> List[str]:
return names return names
async def _discover_and_register_server(name: str, config: dict) -> List[str]: def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> List[str]:
"""Connect to a single MCP server, discover tools, and register them. """Register tools from an already-connected server into the registry.
Also registers utility tools for MCP Resources and Prompts support Handles include/exclude filtering, utility tools, toolset creation,
(list_resources, read_resource, list_prompts, get_prompt). and hermes-* umbrella toolset injection.
Returns list of registered tool names. Used by both initial discovery and dynamic refresh (list_changed).
Returns:
List of registered prefixed tool names.
""" """
from tools.registry import registry from tools.registry import registry
from toolsets import create_custom_toolset from toolsets import create_custom_toolset, TOOLSETS
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
server = await asyncio.wait_for(
_connect_server(name, config),
timeout=connect_timeout,
)
with _lock:
_servers[name] = server
registered_names: List[str] = [] registered_names: List[str] = []
toolset_name = f"mcp-{name}" toolset_name = f"mcp-{name}"
@@ -1625,8 +1726,6 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
) )
registered_names.append(util_name) registered_names.append(util_name)
server._registered_tool_names = list(registered_names)
# Create a custom toolset so these tools are discoverable # Create a custom toolset so these tools are discoverable
if registered_names: if registered_names:
create_custom_toolset( create_custom_toolset(
@@ -1634,6 +1733,31 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
description=f"MCP tools from {name} server", description=f"MCP tools from {name} server",
tools=registered_names, tools=registered_names,
) )
# Inject into hermes-* umbrella toolsets for default behavior
for ts_name, ts in TOOLSETS.items():
if ts_name.startswith("hermes-"):
for tool_name in registered_names:
if tool_name not in ts["tools"]:
ts["tools"].append(tool_name)
return registered_names
async def _discover_and_register_server(name: str, config: dict) -> List[str]:
"""Connect to a single MCP server, discover tools, and register them.
Returns list of registered tool names.
"""
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
server = await asyncio.wait_for(
_connect_server(name, config),
timeout=connect_timeout,
)
with _lock:
_servers[name] = server
registered_names = _register_server_tools(name, server, config)
server._registered_tool_names = list(registered_names)
transport_type = "HTTP" if "url" in config else "stdio" transport_type = "HTTP" if "url" in config else "stdio"
logger.info( logger.info(

View File

@@ -87,6 +87,23 @@ class ToolRegistry:
if check_fn and toolset not in self._toolset_checks: if check_fn and toolset not in self._toolset_checks:
self._toolset_checks[toolset] = check_fn self._toolset_checks[toolset] = check_fn
def deregister(self, name: str) -> None:
"""Remove a tool from the registry.
Also cleans up the toolset check if no other tools remain in the
same toolset. Used by MCP dynamic tool discovery to nuke-and-repave
when a server sends ``notifications/tools/list_changed``.
"""
entry = self._tools.pop(name, None)
if entry is None:
return
# Drop the toolset check if this was the last tool in that toolset
if entry.toolset in self._toolset_checks and not any(
e.toolset == entry.toolset for e in self._tools.values()
):
self._toolset_checks.pop(entry.toolset, None)
logger.debug("Deregistered tool: %s", name)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Schema retrieval # Schema retrieval
# ------------------------------------------------------------------ # ------------------------------------------------------------------