diff --git a/tests/tools/test_mcp_dynamic_discovery.py b/tests/tools/test_mcp_dynamic_discovery.py new file mode 100644 index 000000000..c7c4ae86c --- /dev/null +++ b/tests/tools/test_mcp_dynamic_discovery.py @@ -0,0 +1,170 @@ +"""Tests for MCP dynamic tool discovery (notifications/tools/list_changed).""" + +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from tools.mcp_tool import MCPServerTask, _register_server_tools +from tools.registry import ToolRegistry + + +def _make_mcp_tool(name: str, desc: str = ""): + return SimpleNamespace(name=name, description=desc, inputSchema=None) + + +class TestRegisterServerTools: + """Tests for the extracted _register_server_tools helper.""" + + @pytest.fixture + def mock_registry(self): + return ToolRegistry() + + @pytest.fixture + def mock_toolsets(self): + return { + "hermes-cli": {"tools": ["terminal"], "description": "CLI", "includes": []}, + "hermes-telegram": {"tools": ["terminal"], "description": "TG", "includes": []}, + "custom-toolset": {"tools": [], "description": "Other", "includes": []}, + } + + def test_injects_hermes_toolsets(self, mock_registry, mock_toolsets): + """Tools are injected into hermes-* toolsets but not custom ones.""" + server = MCPServerTask("my_srv") + server._tools = [_make_mcp_tool("my_tool", "desc")] + server.session = MagicMock() + + with patch("tools.registry.registry", mock_registry), \ + patch("toolsets.create_custom_toolset"), \ + patch.dict("toolsets.TOOLSETS", mock_toolsets, clear=True): + + registered = _register_server_tools("my_srv", server, {}) + + assert "mcp_my_srv_my_tool" in registered + assert "mcp_my_srv_my_tool" in mock_registry.get_all_tool_names() + + # Injected into hermes-* toolsets + assert "mcp_my_srv_my_tool" in mock_toolsets["hermes-cli"]["tools"] + assert "mcp_my_srv_my_tool" in mock_toolsets["hermes-telegram"]["tools"] + # NOT into non-hermes toolsets + assert "mcp_my_srv_my_tool" not in mock_toolsets["custom-toolset"]["tools"] + + +class TestRefreshTools: + """Tests for MCPServerTask._refresh_tools nuke-and-repave cycle.""" + + @pytest.fixture + def mock_registry(self): + return ToolRegistry() + + @pytest.fixture + def mock_toolsets(self): + return { + "hermes-cli": {"tools": ["terminal"], "description": "CLI", "includes": []}, + "hermes-telegram": {"tools": ["terminal"], "description": "TG", "includes": []}, + } + + @pytest.mark.asyncio + async def test_nuke_and_repave(self, mock_registry, mock_toolsets): + """Old tools are removed and new tools registered on refresh.""" + server = MCPServerTask("live_srv") + server._refresh_lock = asyncio.Lock() + server._config = {} + + # Seed initial state: one old tool registered + mock_registry.register( + name="mcp_live_srv_old_tool", toolset="mcp-live_srv", schema={}, + handler=lambda x: x, check_fn=lambda: True, is_async=False, + description="", emoji="", + ) + server._registered_tool_names = ["mcp_live_srv_old_tool"] + mock_toolsets["hermes-cli"]["tools"].append("mcp_live_srv_old_tool") + + # New tool list from server + new_tool = _make_mcp_tool("new_tool", "new behavior") + server.session = SimpleNamespace( + list_tools=AsyncMock( + return_value=SimpleNamespace(tools=[new_tool]) + ) + ) + + with patch("tools.registry.registry", mock_registry), \ + patch("toolsets.create_custom_toolset"), \ + patch.dict("toolsets.TOOLSETS", mock_toolsets, clear=True): + + await server._refresh_tools() + + # Old tool completely gone + assert "mcp_live_srv_old_tool" not in mock_registry.get_all_tool_names() + assert "mcp_live_srv_old_tool" not in mock_toolsets["hermes-cli"]["tools"] + + # New tool registered + assert "mcp_live_srv_new_tool" in mock_registry.get_all_tool_names() + assert "mcp_live_srv_new_tool" in mock_toolsets["hermes-cli"]["tools"] + assert server._registered_tool_names == ["mcp_live_srv_new_tool"] + + +class TestMessageHandler: + """Tests for MCPServerTask._make_message_handler dispatch.""" + + @pytest.mark.asyncio + async def test_dispatches_tool_list_changed(self): + from tools.mcp_tool import _MCP_NOTIFICATION_TYPES + if not _MCP_NOTIFICATION_TYPES: + pytest.skip("MCP SDK ToolListChangedNotification not available") + + from mcp.types import ServerNotification, ToolListChangedNotification + + server = MCPServerTask("notif_srv") + with patch.object(MCPServerTask, "_refresh_tools", new_callable=AsyncMock) as mock_refresh: + handler = server._make_message_handler() + notification = ServerNotification( + root=ToolListChangedNotification(method="notifications/tools/list_changed") + ) + await handler(notification) + mock_refresh.assert_awaited_once() + + @pytest.mark.asyncio + async def test_ignores_exceptions_and_other_messages(self): + server = MCPServerTask("notif_srv") + with patch.object(MCPServerTask, "_refresh_tools", new_callable=AsyncMock) as mock_refresh: + handler = server._make_message_handler() + # Exceptions should not trigger refresh + await handler(RuntimeError("connection dead")) + # Unknown message types should not trigger refresh + await handler({"jsonrpc": "2.0", "result": "ok"}) + mock_refresh.assert_not_awaited() + + +class TestDeregister: + """Tests for ToolRegistry.deregister.""" + + def test_removes_tool(self): + reg = ToolRegistry() + reg.register(name="foo", toolset="ts1", schema={}, handler=lambda x: x) + assert "foo" in reg.get_all_tool_names() + reg.deregister("foo") + assert "foo" not in reg.get_all_tool_names() + + def test_cleans_up_toolset_check(self): + reg = ToolRegistry() + check = lambda: True # noqa: E731 + reg.register(name="foo", toolset="ts1", schema={}, handler=lambda x: x, check_fn=check) + assert reg.is_toolset_available("ts1") + reg.deregister("foo") + # Toolset check should be gone since no tools remain + assert "ts1" not in reg._toolset_checks + + def test_preserves_toolset_check_if_other_tools_remain(self): + reg = ToolRegistry() + check = lambda: True # noqa: E731 + reg.register(name="foo", toolset="ts1", schema={}, handler=lambda x: x, check_fn=check) + reg.register(name="bar", toolset="ts1", schema={}, handler=lambda x: x) + reg.deregister("foo") + # bar still in ts1, so check should remain + assert "ts1" in reg._toolset_checks + + def test_noop_for_unknown_tool(self): + reg = ToolRegistry() + reg.deregister("nonexistent") # Should not raise diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 5ce1ee192..4c762150e 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -70,6 +70,7 @@ Thread safety: """ import asyncio +import inspect import json import logging import math @@ -89,6 +90,8 @@ logger = logging.getLogger(__name__) _MCP_AVAILABLE = False _MCP_HTTP_AVAILABLE = False _MCP_SAMPLING_TYPES = False +_MCP_NOTIFICATION_TYPES = False +_MCP_MESSAGE_HANDLER_SUPPORTED = False try: from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client @@ -119,9 +122,39 @@ try: _MCP_SAMPLING_TYPES = True except ImportError: logger.debug("MCP sampling types not available -- sampling disabled") + # Notification types for dynamic tool discovery (tools/list_changed) + try: + from mcp.types import ( + ServerNotification, + ToolListChangedNotification, + PromptListChangedNotification, + ResourceListChangedNotification, + ) + _MCP_NOTIFICATION_TYPES = True + except ImportError: + logger.debug("MCP notification types not available -- dynamic tool discovery disabled") except ImportError: logger.debug("mcp package not installed -- MCP tool support disabled") + +def _check_message_handler_support() -> bool: + """Check if ClientSession accepts ``message_handler`` kwarg. + + Inspects the constructor signature for backward compatibility with older + MCP SDK versions that don't support notification handlers. + """ + if not _MCP_AVAILABLE: + return False + try: + return "message_handler" in inspect.signature(ClientSession).parameters + except (TypeError, ValueError): + return False + + +_MCP_MESSAGE_HANDLER_SUPPORTED = _check_message_handler_support() +if _MCP_AVAILABLE and not _MCP_MESSAGE_HANDLER_SUPPORTED: + logger.debug("MCP SDK does not support message_handler -- dynamic tool discovery disabled") + # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- @@ -697,7 +730,7 @@ class MCPServerTask: __slots__ = ( "name", "session", "tool_timeout", "_task", "_ready", "_shutdown_event", "_tools", "_error", "_config", - "_sampling", "_registered_tool_names", "_auth_type", + "_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock", ) def __init__(self, name: str): @@ -713,11 +746,80 @@ class MCPServerTask: self._sampling: Optional[SamplingHandler] = None self._registered_tool_names: list[str] = [] self._auth_type: str = "" + self._refresh_lock = asyncio.Lock() def _is_http(self) -> bool: """Check if this server uses HTTP transport.""" return "url" in self._config + # ----- Dynamic tool discovery (notifications/tools/list_changed) ----- + + def _make_message_handler(self): + """Build a ``message_handler`` callback for ``ClientSession``. + + Dispatches on notification type. Only ``ToolListChangedNotification`` + triggers a refresh; prompt and resource change notifications are + logged as stubs for future work. + """ + async def _handler(message): + try: + if isinstance(message, Exception): + logger.debug("MCP message handler (%s): exception: %s", self.name, message) + return + if _MCP_NOTIFICATION_TYPES and isinstance(message, ServerNotification): + match message.root: + case ToolListChangedNotification(): + logger.info( + "MCP server '%s': received tools/list_changed notification", + self.name, + ) + await self._refresh_tools() + case PromptListChangedNotification(): + logger.debug("MCP server '%s': prompts/list_changed (ignored)", self.name) + case ResourceListChangedNotification(): + logger.debug("MCP server '%s': resources/list_changed (ignored)", self.name) + case _: + pass + except Exception: + logger.exception("Error in MCP message handler for '%s'", self.name) + return _handler + + async def _refresh_tools(self): + """Re-fetch tools from the server and update the registry. + + Called when the server sends ``notifications/tools/list_changed``. + The lock prevents overlapping refreshes from rapid-fire notifications. + After the initial ``await`` (list_tools), all mutations are synchronous + — atomic from the event loop's perspective. + """ + from tools.registry import registry + from toolsets import TOOLSETS + + async with self._refresh_lock: + # 1. Fetch current tool list from server + tools_result = await self.session.list_tools() + new_mcp_tools = tools_result.tools if hasattr(tools_result, "tools") else [] + + # 2. Remove old tools from hermes-* umbrella toolsets + for ts_name, ts in TOOLSETS.items(): + if ts_name.startswith("hermes-"): + ts["tools"] = [t for t in ts["tools"] if t not in self._registered_tool_names] + + # 3. Deregister old tools from the central registry + for prefixed_name in self._registered_tool_names: + registry.deregister(prefixed_name) + + # 4. Re-register with fresh tool list + self._tools = new_mcp_tools + self._registered_tool_names = _register_server_tools( + self.name, self, self._config + ) + + logger.info( + "MCP server '%s': dynamically refreshed %d tool(s)", + self.name, len(self._registered_tool_names), + ) + async def _run_stdio(self, config: dict): """Run the server using stdio transport.""" command = config.get("command") @@ -738,6 +840,8 @@ class MCPServerTask: ) sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {} + if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED: + sampling_kwargs["message_handler"] = self._make_message_handler() async with stdio_client(server_params) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session: await session.initialize() @@ -769,6 +873,8 @@ class MCPServerTask: logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc) sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {} + if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED: + sampling_kwargs["message_handler"] = self._make_message_handler() if _MCP_NEW_HTTP: # New API (mcp >= 1.24.0): build an explicit httpx.AsyncClient @@ -1522,24 +1628,19 @@ def _existing_tool_names() -> List[str]: return names -async def _discover_and_register_server(name: str, config: dict) -> List[str]: - """Connect to a single MCP server, discover tools, and register them. +def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> List[str]: + """Register tools from an already-connected server into the registry. - Also registers utility tools for MCP Resources and Prompts support - (list_resources, read_resource, list_prompts, get_prompt). + Handles include/exclude filtering, utility tools, toolset creation, + and hermes-* umbrella toolset injection. - Returns list of registered tool names. + Used by both initial discovery and dynamic refresh (list_changed). + + Returns: + List of registered prefixed tool names. """ from tools.registry import registry - from toolsets import create_custom_toolset - - connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT) - server = await asyncio.wait_for( - _connect_server(name, config), - timeout=connect_timeout, - ) - with _lock: - _servers[name] = server + from toolsets import create_custom_toolset, TOOLSETS registered_names: List[str] = [] toolset_name = f"mcp-{name}" @@ -1625,8 +1726,6 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]: ) registered_names.append(util_name) - server._registered_tool_names = list(registered_names) - # Create a custom toolset so these tools are discoverable if registered_names: create_custom_toolset( @@ -1634,6 +1733,31 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]: description=f"MCP tools from {name} server", tools=registered_names, ) + # Inject into hermes-* umbrella toolsets for default behavior + for ts_name, ts in TOOLSETS.items(): + if ts_name.startswith("hermes-"): + for tool_name in registered_names: + if tool_name not in ts["tools"]: + ts["tools"].append(tool_name) + + return registered_names + + +async def _discover_and_register_server(name: str, config: dict) -> List[str]: + """Connect to a single MCP server, discover tools, and register them. + + Returns list of registered tool names. + """ + connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT) + server = await asyncio.wait_for( + _connect_server(name, config), + timeout=connect_timeout, + ) + with _lock: + _servers[name] = server + + registered_names = _register_server_tools(name, server, config) + server._registered_tool_names = list(registered_names) transport_type = "HTTP" if "url" in config else "stdio" logger.info( diff --git a/tools/registry.py b/tools/registry.py index b124faf6b..432e1f074 100644 --- a/tools/registry.py +++ b/tools/registry.py @@ -87,6 +87,23 @@ class ToolRegistry: if check_fn and toolset not in self._toolset_checks: self._toolset_checks[toolset] = check_fn + def deregister(self, name: str) -> None: + """Remove a tool from the registry. + + Also cleans up the toolset check if no other tools remain in the + same toolset. Used by MCP dynamic tool discovery to nuke-and-repave + when a server sends ``notifications/tools/list_changed``. + """ + entry = self._tools.pop(name, None) + if entry is None: + return + # Drop the toolset check if this was the last tool in that toolset + if entry.toolset in self._toolset_checks and not any( + e.toolset == entry.toolset for e in self._tools.values() + ): + self._toolset_checks.pop(entry.toolset, None) + logger.debug("Deregistered tool: %s", name) + # ------------------------------------------------------------------ # Schema retrieval # ------------------------------------------------------------------