feat(mcp): make selective tool loading capability-aware
Extend the salvaged MCP filtering work so utility tools are also governed by policy and server capabilities. Store the registered tool subset per server so rediscovery and status reporting stay accurate after filtering.
This commit is contained in:
@@ -2450,76 +2450,226 @@ class TestDiscoveryFailedCount:
|
||||
|
||||
|
||||
class TestMCPSelectiveToolLoading:
|
||||
"""Tests for tools.include / tools.exclude / enabled config keys."""
|
||||
"""Tests for per-server MCP filtering and utility tool policies."""
|
||||
|
||||
def _make_server(self, name, tool_names):
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
server = MCPServerTask(name)
|
||||
server.session = MagicMock()
|
||||
server._tools = [_make_mcp_tool(n, n) for n in tool_names]
|
||||
def _make_server(self, name, tool_names, session=None):
|
||||
server = _make_mock_server(
|
||||
name,
|
||||
session=session or SimpleNamespace(),
|
||||
tools=[_make_mcp_tool(n, n) for n in tool_names],
|
||||
)
|
||||
return server
|
||||
|
||||
def _run_discover(self, name, tool_names, config):
|
||||
"""Run _discover_and_register_server directly and return registered names."""
|
||||
import asyncio
|
||||
from tools.mcp_tool import _discover_and_register_server
|
||||
server = self._make_server(name, tool_names)
|
||||
def _run_discover(self, name, tool_names, config, session=None):
|
||||
from tools.registry import ToolRegistry
|
||||
from tools.mcp_tool import _discover_and_register_server, _servers
|
||||
|
||||
async def fake_connect(n, c):
|
||||
mock_registry = ToolRegistry()
|
||||
server = self._make_server(name, tool_names, session=session)
|
||||
|
||||
async def fake_connect(_name, _config):
|
||||
return server
|
||||
|
||||
async def run():
|
||||
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), patch("tools.mcp_tool._servers", {}):
|
||||
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
|
||||
patch("tools.registry.registry", mock_registry), \
|
||||
patch("toolsets.create_custom_toolset"):
|
||||
return await _discover_and_register_server(name, config)
|
||||
|
||||
return asyncio.run(run())
|
||||
try:
|
||||
registered = asyncio.run(run())
|
||||
finally:
|
||||
_servers.pop(name, None)
|
||||
return registered, mock_registry
|
||||
|
||||
def test_include_filter_registers_only_listed_tools(self):
|
||||
"""tools.include whitelist: only specified tools are registered."""
|
||||
tool_names = ["create_service", "delete_service", "list_services"]
|
||||
config = {"url": "https://mcp.example.com", "tools": {"include": ["create_service", "list_services"]}}
|
||||
result = self._run_discover("ink", tool_names, config)
|
||||
assert "mcp_ink_create_service" in result
|
||||
assert "mcp_ink_list_services" in result
|
||||
assert "mcp_ink_delete_service" not in result
|
||||
|
||||
def test_exclude_filter_skips_listed_tools(self):
|
||||
"""tools.exclude blacklist: all tools except specified are registered."""
|
||||
tool_names = ["create_service", "delete_service", "list_services"]
|
||||
config = {"url": "https://mcp.example.com", "tools": {"exclude": ["delete_service"]}}
|
||||
result = self._run_discover("ink2", tool_names, config)
|
||||
assert "mcp_ink2_create_service" in result
|
||||
assert "mcp_ink2_list_services" in result
|
||||
assert "mcp_ink2_delete_service" not in result
|
||||
|
||||
def test_no_filter_registers_all_tools(self):
|
||||
"""No tools filter: all tools registered (backward compatible)."""
|
||||
tool_names = ["create_service", "delete_service", "list_services"]
|
||||
config = {"url": "https://mcp.example.com"}
|
||||
result = self._run_discover("ink3", tool_names, config)
|
||||
assert "mcp_ink3_create_service" in result
|
||||
assert "mcp_ink3_delete_service" in result
|
||||
assert "mcp_ink3_list_services" in result
|
||||
|
||||
def test_enabled_false_skips_server(self):
|
||||
"""enabled: false skips the server entirely."""
|
||||
fresh_servers = {}
|
||||
fake_config = {
|
||||
"ink": {
|
||||
"url": "https://mcp.example.com",
|
||||
"enabled": False,
|
||||
}
|
||||
def test_include_takes_precedence_over_exclude(self):
|
||||
config = {
|
||||
"url": "https://mcp.example.com",
|
||||
"tools": {
|
||||
"include": ["create_service"],
|
||||
"exclude": ["create_service", "delete_service"],
|
||||
},
|
||||
}
|
||||
registered, _ = self._run_discover(
|
||||
"ink",
|
||||
["create_service", "delete_service", "list_services"],
|
||||
config,
|
||||
session=SimpleNamespace(),
|
||||
)
|
||||
assert registered == ["mcp_ink_create_service"]
|
||||
|
||||
def test_exclude_filter_registers_all_except_listed_tools(self):
|
||||
config = {
|
||||
"url": "https://mcp.example.com",
|
||||
"tools": {"exclude": ["delete_service"]},
|
||||
}
|
||||
registered, _ = self._run_discover(
|
||||
"ink_exclude",
|
||||
["create_service", "delete_service", "list_services"],
|
||||
config,
|
||||
session=SimpleNamespace(),
|
||||
)
|
||||
assert registered == [
|
||||
"mcp_ink_exclude_create_service",
|
||||
"mcp_ink_exclude_list_services",
|
||||
]
|
||||
|
||||
def test_include_filter_skips_utility_tools_without_capabilities(self):
|
||||
config = {
|
||||
"url": "https://mcp.example.com",
|
||||
"tools": {"include": ["create_service"]},
|
||||
}
|
||||
registered, mock_registry = self._run_discover(
|
||||
"ink_no_caps",
|
||||
["create_service", "delete_service"],
|
||||
config,
|
||||
session=SimpleNamespace(),
|
||||
)
|
||||
assert registered == ["mcp_ink_no_caps_create_service"]
|
||||
assert set(mock_registry.get_all_tool_names()) == {"mcp_ink_no_caps_create_service"}
|
||||
|
||||
def test_no_filter_registers_all_server_tools_when_no_utilities_supported(self):
|
||||
registered, _ = self._run_discover(
|
||||
"ink_no_filter",
|
||||
["create_service", "delete_service", "list_services"],
|
||||
{"url": "https://mcp.example.com"},
|
||||
session=SimpleNamespace(),
|
||||
)
|
||||
assert registered == [
|
||||
"mcp_ink_no_filter_create_service",
|
||||
"mcp_ink_no_filter_delete_service",
|
||||
"mcp_ink_no_filter_list_services",
|
||||
]
|
||||
|
||||
def test_resources_and_prompts_can_be_disabled_explicitly(self):
|
||||
session = SimpleNamespace(
|
||||
list_resources=AsyncMock(),
|
||||
read_resource=AsyncMock(),
|
||||
list_prompts=AsyncMock(),
|
||||
get_prompt=AsyncMock(),
|
||||
)
|
||||
config = {
|
||||
"url": "https://mcp.example.com",
|
||||
"tools": {
|
||||
"resources": False,
|
||||
"prompts": False,
|
||||
},
|
||||
}
|
||||
registered, _ = self._run_discover(
|
||||
"ink_disabled_utils",
|
||||
["create_service"],
|
||||
config,
|
||||
session=session,
|
||||
)
|
||||
assert registered == ["mcp_ink_disabled_utils_create_service"]
|
||||
|
||||
def test_registers_only_utility_tools_supported_by_server_capabilities(self):
|
||||
session = SimpleNamespace(
|
||||
list_resources=AsyncMock(return_value=SimpleNamespace(resources=[])),
|
||||
read_resource=AsyncMock(return_value=SimpleNamespace(contents=[])),
|
||||
)
|
||||
registered, _ = self._run_discover(
|
||||
"ink_resources_only",
|
||||
["create_service"],
|
||||
{"url": "https://mcp.example.com"},
|
||||
session=session,
|
||||
)
|
||||
assert "mcp_ink_resources_only_create_service" in registered
|
||||
assert "mcp_ink_resources_only_list_resources" in registered
|
||||
assert "mcp_ink_resources_only_read_resource" in registered
|
||||
assert "mcp_ink_resources_only_list_prompts" not in registered
|
||||
assert "mcp_ink_resources_only_get_prompt" not in registered
|
||||
|
||||
def test_existing_tool_names_reflect_registered_subset(self):
|
||||
from tools.mcp_tool import _existing_tool_names, _servers, _discover_and_register_server
|
||||
from tools.registry import ToolRegistry
|
||||
|
||||
mock_registry = ToolRegistry()
|
||||
server = self._make_server(
|
||||
"ink_existing",
|
||||
["create_service", "delete_service"],
|
||||
session=SimpleNamespace(),
|
||||
)
|
||||
|
||||
async def fake_connect(_name, _config):
|
||||
return server
|
||||
|
||||
async def run():
|
||||
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
|
||||
patch("tools.registry.registry", mock_registry), \
|
||||
patch("toolsets.create_custom_toolset"):
|
||||
return await _discover_and_register_server(
|
||||
"ink_existing",
|
||||
{"url": "https://mcp.example.com", "tools": {"include": ["create_service"]}},
|
||||
)
|
||||
|
||||
try:
|
||||
registered = asyncio.run(run())
|
||||
assert registered == ["mcp_ink_existing_create_service"]
|
||||
assert _existing_tool_names() == ["mcp_ink_existing_create_service"]
|
||||
finally:
|
||||
_servers.pop("ink_existing", None)
|
||||
|
||||
def test_no_toolset_created_when_everything_is_filtered_out(self):
|
||||
from tools.registry import ToolRegistry
|
||||
from tools.mcp_tool import _discover_and_register_server, _servers
|
||||
|
||||
mock_registry = ToolRegistry()
|
||||
server = self._make_server("ink_none", ["create_service"], session=SimpleNamespace())
|
||||
mock_create = MagicMock()
|
||||
|
||||
async def fake_connect(_name, _config):
|
||||
return server
|
||||
|
||||
async def run():
|
||||
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
|
||||
patch("tools.registry.registry", mock_registry), \
|
||||
patch("toolsets.create_custom_toolset", mock_create):
|
||||
return await _discover_and_register_server(
|
||||
"ink_none",
|
||||
{
|
||||
"url": "https://mcp.example.com",
|
||||
"tools": {
|
||||
"include": ["missing_tool"],
|
||||
"resources": False,
|
||||
"prompts": False,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
registered = asyncio.run(run())
|
||||
assert registered == []
|
||||
mock_create.assert_not_called()
|
||||
assert mock_registry.get_all_tool_names() == []
|
||||
finally:
|
||||
_servers.pop("ink_none", None)
|
||||
|
||||
def test_enabled_false_skips_connection_attempt(self):
|
||||
from tools.mcp_tool import discover_mcp_tools
|
||||
|
||||
connect_called = []
|
||||
|
||||
async def fake_connect(name, config):
|
||||
connect_called.append(name)
|
||||
return self._make_server(name, ["create_service"])
|
||||
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), patch("tools.mcp_tool._servers", fresh_servers), patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), patch("tools.mcp_tool._connect_server", side_effect=fake_connect):
|
||||
from tools.mcp_tool import discover_mcp_tools
|
||||
fake_config = {
|
||||
"ink": {
|
||||
"url": "https://mcp.example.com",
|
||||
"enabled": False,
|
||||
}
|
||||
}
|
||||
fake_toolsets = {
|
||||
"hermes-cli": {"tools": [], "description": "CLI", "includes": []},
|
||||
}
|
||||
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._servers", {}), \
|
||||
patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
|
||||
patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
|
||||
patch("toolsets.TOOLSETS", fake_toolsets):
|
||||
result = discover_mcp_tools()
|
||||
|
||||
assert connect_called == []
|
||||
assert "mcp_ink_create_service" not in result
|
||||
|
||||
assert result == []
|
||||
|
||||
@@ -688,7 +688,7 @@ class MCPServerTask:
|
||||
__slots__ = (
|
||||
"name", "session", "tool_timeout",
|
||||
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
|
||||
"_sampling",
|
||||
"_sampling", "_registered_tool_names",
|
||||
)
|
||||
|
||||
def __init__(self, name: str):
|
||||
@@ -702,6 +702,7 @@ class MCPServerTask:
|
||||
self._error: Optional[Exception] = None
|
||||
self._config: dict = {}
|
||||
self._sampling: Optional[SamplingHandler] = None
|
||||
self._registered_tool_names: list[str] = []
|
||||
|
||||
def _is_http(self) -> bool:
|
||||
"""Check if this server uses HTTP transport."""
|
||||
@@ -1308,16 +1309,81 @@ def _build_utility_schemas(server_name: str) -> List[dict]:
|
||||
]
|
||||
|
||||
|
||||
def _normalize_name_filter(value: Any, label: str) -> set[str]:
|
||||
"""Normalize include/exclude config to a set of tool names."""
|
||||
if value is None:
|
||||
return set()
|
||||
if isinstance(value, str):
|
||||
return {value}
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return {str(item) for item in value}
|
||||
logger.warning("MCP config %s must be a string or list of strings; ignoring %r", label, value)
|
||||
return set()
|
||||
|
||||
|
||||
def _parse_boolish(value: Any, default: bool = True) -> bool:
|
||||
"""Parse a bool-like config value with safe fallback."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
lowered = value.strip().lower()
|
||||
if lowered in {"true", "1", "yes", "on"}:
|
||||
return True
|
||||
if lowered in {"false", "0", "no", "off"}:
|
||||
return False
|
||||
logger.warning("MCP config expected a boolean-ish value, got %r; using default=%s", value, default)
|
||||
return default
|
||||
|
||||
|
||||
_UTILITY_CAPABILITY_METHODS = {
|
||||
"list_resources": "list_resources",
|
||||
"read_resource": "read_resource",
|
||||
"list_prompts": "list_prompts",
|
||||
"get_prompt": "get_prompt",
|
||||
}
|
||||
|
||||
|
||||
def _select_utility_schemas(server_name: str, server: MCPServerTask, config: dict) -> List[dict]:
|
||||
"""Select utility schemas based on config and server capabilities."""
|
||||
tools_filter = config.get("tools") or {}
|
||||
resources_enabled = _parse_boolish(tools_filter.get("resources"), default=True)
|
||||
prompts_enabled = _parse_boolish(tools_filter.get("prompts"), default=True)
|
||||
|
||||
selected: List[dict] = []
|
||||
for entry in _build_utility_schemas(server_name):
|
||||
handler_key = entry["handler_key"]
|
||||
if handler_key in {"list_resources", "read_resource"} and not resources_enabled:
|
||||
logger.debug("MCP server '%s': skipping utility '%s' (resources disabled)", server_name, handler_key)
|
||||
continue
|
||||
if handler_key in {"list_prompts", "get_prompt"} and not prompts_enabled:
|
||||
logger.debug("MCP server '%s': skipping utility '%s' (prompts disabled)", server_name, handler_key)
|
||||
continue
|
||||
|
||||
required_method = _UTILITY_CAPABILITY_METHODS[handler_key]
|
||||
if not hasattr(server.session, required_method):
|
||||
logger.debug(
|
||||
"MCP server '%s': skipping utility '%s' (session lacks %s)",
|
||||
server_name,
|
||||
handler_key,
|
||||
required_method,
|
||||
)
|
||||
continue
|
||||
selected.append(entry)
|
||||
return selected
|
||||
|
||||
|
||||
def _existing_tool_names() -> List[str]:
|
||||
"""Return tool names for all currently connected servers."""
|
||||
names: List[str] = []
|
||||
for sname, server in _servers.items():
|
||||
for _sname, server in _servers.items():
|
||||
if hasattr(server, "_registered_tool_names"):
|
||||
names.extend(server._registered_tool_names)
|
||||
continue
|
||||
for mcp_tool in server._tools:
|
||||
schema = _convert_mcp_schema(sname, mcp_tool)
|
||||
schema = _convert_mcp_schema(server.name, mcp_tool)
|
||||
names.append(schema["name"])
|
||||
# Also include utility tool names
|
||||
for entry in _build_utility_schemas(sname):
|
||||
names.append(entry["schema"]["name"])
|
||||
return names
|
||||
|
||||
|
||||
@@ -1347,11 +1413,11 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
# Rules (matching issue #690 spec):
|
||||
# tools.include — whitelist: only these tool names are registered
|
||||
# tools.exclude — blacklist: all tools EXCEPT these are registered
|
||||
# include and exclude are mutually exclusive; include takes precedence
|
||||
# include takes precedence over exclude
|
||||
# Neither set → register all tools (backward-compatible default)
|
||||
tools_filter = config.get("tools") or {}
|
||||
include_set = set(tools_filter.get("include") or [])
|
||||
exclude_set = set(tools_filter.get("exclude") or [])
|
||||
include_set = _normalize_name_filter(tools_filter.get("include"), f"mcp_servers.{name}.tools.include")
|
||||
exclude_set = _normalize_name_filter(tools_filter.get("exclude"), f"mcp_servers.{name}.tools.exclude")
|
||||
|
||||
def _should_register(tool_name: str) -> bool:
|
||||
if include_set:
|
||||
@@ -1378,7 +1444,8 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
)
|
||||
registered_names.append(tool_name_prefixed)
|
||||
|
||||
# Register MCP Resources & Prompts utility tools
|
||||
# Register MCP Resources & Prompts utility tools, filtered by config and
|
||||
# only when the server actually supports the corresponding capability.
|
||||
_handler_factories = {
|
||||
"list_resources": _make_list_resources_handler,
|
||||
"read_resource": _make_read_resource_handler,
|
||||
@@ -1386,7 +1453,7 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
"get_prompt": _make_get_prompt_handler,
|
||||
}
|
||||
check_fn = _make_check_fn(name)
|
||||
for entry in _build_utility_schemas(name):
|
||||
for entry in _select_utility_schemas(name, server, config):
|
||||
schema = entry["schema"]
|
||||
handler_key = entry["handler_key"]
|
||||
handler = _handler_factories[handler_key](name, server.tool_timeout)
|
||||
@@ -1402,6 +1469,8 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
)
|
||||
registered_names.append(schema["name"])
|
||||
|
||||
server._registered_tool_names = list(registered_names)
|
||||
|
||||
# Create a custom toolset so these tools are discoverable
|
||||
if registered_names:
|
||||
create_custom_toolset(
|
||||
@@ -1448,8 +1517,9 @@ def discover_mcp_tools() -> List[str]:
|
||||
# (enabled: false skips the server entirely without removing its config)
|
||||
with _lock:
|
||||
new_servers = {
|
||||
k: v for k, v in servers.items()
|
||||
if k not in _servers and v.get("enabled", True) is not False
|
||||
k: v
|
||||
for k, v in servers.items()
|
||||
if k not in _servers and _parse_boolish(v.get("enabled", True), default=True)
|
||||
}
|
||||
|
||||
if not new_servers:
|
||||
@@ -1537,7 +1607,7 @@ def get_mcp_status() -> List[dict]:
|
||||
entry = {
|
||||
"name": name,
|
||||
"transport": transport,
|
||||
"tools": len(server._tools),
|
||||
"tools": len(server._registered_tool_names) if hasattr(server, "_registered_tool_names") else len(server._tools),
|
||||
"connected": True,
|
||||
}
|
||||
if server._sampling:
|
||||
|
||||
Reference in New Issue
Block a user