feat(mcp): banner integration, /reload-mcp command, resources & prompts
Banner integration: - MCP Servers section in CLI startup banner between Tools and Skills - Shows each server with transport type, tool count, connection status - Failed servers shown in red; section hidden when no MCP configured - Summary line includes MCP server count - Removed raw print() calls from discovery (banner handles display) /reload-mcp command: - New slash command in both CLI and gateway - Disconnects all MCP servers, re-reads config.yaml, reconnects - Reports what changed (added/removed/reconnected servers) - Allows adding/removing MCP servers without restarting Resources & Prompts support: - 4 utility tools registered per server: list_resources, read_resource, list_prompts, get_prompt - Exposes MCP Resources (data sources) and Prompts (templates) as tools - Proper parameter schemas (uri for read_resource, name for get_prompt) - Handles text and binary resource content - 23 new tests covering schemas, handlers, and registration Test coverage: 74 MCP tests total, 1186 tests pass overall.
This commit is contained in:
42
cli.py
42
cli.py
@@ -690,6 +690,7 @@ COMMANDS = {
|
||||
"/cron": "Manage scheduled tasks (list, add, remove)",
|
||||
"/skills": "Search, install, inspect, or manage skills from online registries",
|
||||
"/platforms": "Show gateway/messaging platform status",
|
||||
"/reload-mcp": "Reload MCP servers from config.yaml",
|
||||
"/quit": "Exit the CLI (also: /exit, /q)",
|
||||
}
|
||||
|
||||
@@ -1770,6 +1771,8 @@ class HermesCLI:
|
||||
self._manual_compress()
|
||||
elif cmd_lower == "/usage":
|
||||
self._show_usage()
|
||||
elif cmd_lower == "/reload-mcp":
|
||||
self._reload_mcp()
|
||||
else:
|
||||
# Check for skill slash commands (/gif-search, /axolotl, etc.)
|
||||
base_cmd = cmd_lower.split()[0]
|
||||
@@ -1891,6 +1894,45 @@ class HermesCLI:
|
||||
for quiet_logger in ('tools', 'minisweagent', 'run_agent', 'trajectory_compressor', 'cron', 'hermes_cli'):
|
||||
logging.getLogger(quiet_logger).setLevel(logging.ERROR)
|
||||
|
||||
def _reload_mcp(self):
|
||||
"""Reload MCP servers: disconnect all, re-read config.yaml, reconnect."""
|
||||
try:
|
||||
from tools.mcp_tool import shutdown_mcp_servers, discover_mcp_tools, _load_mcp_config, _servers, _lock
|
||||
|
||||
# Capture old server names
|
||||
with _lock:
|
||||
old_servers = set(_servers.keys())
|
||||
|
||||
print("🔄 Reloading MCP servers...")
|
||||
|
||||
# Shutdown existing connections
|
||||
shutdown_mcp_servers()
|
||||
|
||||
# Reconnect (reads config.yaml fresh)
|
||||
new_tools = discover_mcp_tools()
|
||||
|
||||
# Compute what changed
|
||||
with _lock:
|
||||
connected_servers = set(_servers.keys())
|
||||
|
||||
added = connected_servers - old_servers
|
||||
removed = old_servers - connected_servers
|
||||
reconnected = connected_servers & old_servers
|
||||
|
||||
if reconnected:
|
||||
print(f" ♻️ Reconnected: {', '.join(sorted(reconnected))}")
|
||||
if added:
|
||||
print(f" ➕ Added: {', '.join(sorted(added))}")
|
||||
if removed:
|
||||
print(f" ➖ Removed: {', '.join(sorted(removed))}")
|
||||
if not connected_servers:
|
||||
print(" (._.) No MCP servers connected.")
|
||||
else:
|
||||
print(f" 🔧 {len(new_tools)} tool(s) available from {len(connected_servers)} server(s)")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ MCP reload failed: {e}")
|
||||
|
||||
def _clarify_callback(self, question, choices):
|
||||
"""
|
||||
Platform callback for the clarify tool. Called from the agent thread.
|
||||
|
||||
@@ -645,7 +645,7 @@ class GatewayRunner:
|
||||
# Emit command:* hook for any recognized slash command
|
||||
_known_commands = {"new", "reset", "help", "status", "stop", "model",
|
||||
"personality", "retry", "undo", "sethome", "set-home",
|
||||
"compress", "usage"}
|
||||
"compress", "usage", "reload-mcp"}
|
||||
if command and command in _known_commands:
|
||||
await self.hooks.emit(f"command:{command}", {
|
||||
"platform": source.platform.value if source.platform else "",
|
||||
@@ -686,6 +686,9 @@ class GatewayRunner:
|
||||
|
||||
if command == "usage":
|
||||
return await self._handle_usage_command(event)
|
||||
|
||||
if command == "reload-mcp":
|
||||
return await self._handle_reload_mcp_command(event)
|
||||
|
||||
# Skill slash commands: /skill-name loads the skill and sends to agent
|
||||
if command:
|
||||
@@ -1086,6 +1089,7 @@ class GatewayRunner:
|
||||
"`/sethome` — Set this chat as the home channel",
|
||||
"`/compress` — Compress conversation context",
|
||||
"`/usage` — Show token usage for this session",
|
||||
"`/reload-mcp` — Reload MCP servers from config",
|
||||
"`/help` — Show this message",
|
||||
]
|
||||
try:
|
||||
@@ -1379,6 +1383,51 @@ class GatewayRunner:
|
||||
)
|
||||
return "No usage data available for this session."
|
||||
|
||||
async def _handle_reload_mcp_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /reload-mcp command -- disconnect and reconnect all MCP servers."""
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
from tools.mcp_tool import shutdown_mcp_servers, discover_mcp_tools, _load_mcp_config, _servers, _lock
|
||||
|
||||
# Capture old server names before shutdown
|
||||
with _lock:
|
||||
old_servers = set(_servers.keys())
|
||||
|
||||
# Read new config before shutting down, so we know what will be added/removed
|
||||
new_config = _load_mcp_config()
|
||||
new_server_names = set(new_config.keys())
|
||||
|
||||
# Shutdown existing connections
|
||||
await loop.run_in_executor(None, shutdown_mcp_servers)
|
||||
|
||||
# Reconnect by discovering tools (reads config.yaml fresh)
|
||||
new_tools = await loop.run_in_executor(None, discover_mcp_tools)
|
||||
|
||||
# Compute what changed
|
||||
with _lock:
|
||||
connected_servers = set(_servers.keys())
|
||||
|
||||
added = connected_servers - old_servers
|
||||
removed = old_servers - connected_servers
|
||||
reconnected = connected_servers & old_servers
|
||||
|
||||
lines = ["🔄 **MCP Servers Reloaded**\n"]
|
||||
if reconnected:
|
||||
lines.append(f"♻️ Reconnected: {', '.join(sorted(reconnected))}")
|
||||
if added:
|
||||
lines.append(f"➕ Added: {', '.join(sorted(added))}")
|
||||
if removed:
|
||||
lines.append(f"➖ Removed: {', '.join(sorted(removed))}")
|
||||
if not connected_servers:
|
||||
lines.append("No MCP servers connected.")
|
||||
else:
|
||||
lines.append(f"\n🔧 {len(new_tools)} tool(s) available from {len(connected_servers)} server(s)")
|
||||
return "\n".join(lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("MCP reload failed: %s", e)
|
||||
return f"❌ MCP reload failed: {e}"
|
||||
|
||||
def _set_session_env(self, context: SessionContext) -> None:
|
||||
"""Set environment variables for the current session."""
|
||||
os.environ["HERMES_SESSION_PLATFORM"] = context.source.platform.value
|
||||
|
||||
@@ -196,6 +196,28 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
if remaining_toolsets > 0:
|
||||
right_lines.append(f"[dim #B8860B](and {remaining_toolsets} more toolsets...)[/]")
|
||||
|
||||
# MCP Servers section (only if configured)
|
||||
try:
|
||||
from tools.mcp_tool import get_mcp_status
|
||||
mcp_status = get_mcp_status()
|
||||
except Exception:
|
||||
mcp_status = []
|
||||
|
||||
if mcp_status:
|
||||
right_lines.append("")
|
||||
right_lines.append("[bold #FFBF00]MCP Servers[/]")
|
||||
for srv in mcp_status:
|
||||
if srv["connected"]:
|
||||
right_lines.append(
|
||||
f"[dim #B8860B]{srv['name']}[/] [#FFF8DC]({srv['transport']})[/] "
|
||||
f"[dim #B8860B]—[/] [#FFF8DC]{srv['tools']} tool(s)[/]"
|
||||
)
|
||||
else:
|
||||
right_lines.append(
|
||||
f"[red]{srv['name']}[/] [dim]({srv['transport']})[/] "
|
||||
f"[red]— failed[/]"
|
||||
)
|
||||
|
||||
right_lines.append("")
|
||||
right_lines.append("[bold #FFBF00]Available Skills[/]")
|
||||
skills_by_category = get_available_skills()
|
||||
@@ -216,7 +238,12 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
right_lines.append("[dim #B8860B]No skills installed[/]")
|
||||
|
||||
right_lines.append("")
|
||||
right_lines.append(f"[dim #B8860B]{len(tools)} tools · {total_skills} skills · /help for commands[/]")
|
||||
mcp_connected = sum(1 for s in mcp_status if s["connected"]) if mcp_status else 0
|
||||
summary_parts = [f"{len(tools)} tools", f"{total_skills} skills"]
|
||||
if mcp_connected:
|
||||
summary_parts.append(f"{mcp_connected} MCP servers")
|
||||
summary_parts.append("/help for commands")
|
||||
right_lines.append(f"[dim #B8860B]{' · '.join(summary_parts)}[/]")
|
||||
|
||||
right_content = "\n".join(right_lines)
|
||||
layout_table.add_row(left_content, right_content)
|
||||
|
||||
@@ -1063,3 +1063,429 @@ class TestConfigurableTimeouts:
|
||||
call_kwargs[1].get("timeout") == 180
|
||||
finally:
|
||||
_servers.pop("test_srv", None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Utility tool schemas (Resources & Prompts)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUtilitySchemas:
|
||||
"""Tests for _build_utility_schemas() and the schema format of utility tools."""
|
||||
|
||||
def test_builds_four_utility_schemas(self):
|
||||
from tools.mcp_tool import _build_utility_schemas
|
||||
|
||||
schemas = _build_utility_schemas("myserver")
|
||||
assert len(schemas) == 4
|
||||
names = [s["schema"]["name"] for s in schemas]
|
||||
assert "mcp_myserver_list_resources" in names
|
||||
assert "mcp_myserver_read_resource" in names
|
||||
assert "mcp_myserver_list_prompts" in names
|
||||
assert "mcp_myserver_get_prompt" in names
|
||||
|
||||
def test_hyphens_sanitized_in_utility_names(self):
|
||||
from tools.mcp_tool import _build_utility_schemas
|
||||
|
||||
schemas = _build_utility_schemas("my-server")
|
||||
names = [s["schema"]["name"] for s in schemas]
|
||||
for name in names:
|
||||
assert "-" not in name
|
||||
assert "mcp_my_server_list_resources" in names
|
||||
|
||||
def test_list_resources_schema_no_required_params(self):
|
||||
from tools.mcp_tool import _build_utility_schemas
|
||||
|
||||
schemas = _build_utility_schemas("srv")
|
||||
lr = next(s for s in schemas if s["handler_key"] == "list_resources")
|
||||
params = lr["schema"]["parameters"]
|
||||
assert params["type"] == "object"
|
||||
assert params["properties"] == {}
|
||||
assert "required" not in params
|
||||
|
||||
def test_read_resource_schema_requires_uri(self):
|
||||
from tools.mcp_tool import _build_utility_schemas
|
||||
|
||||
schemas = _build_utility_schemas("srv")
|
||||
rr = next(s for s in schemas if s["handler_key"] == "read_resource")
|
||||
params = rr["schema"]["parameters"]
|
||||
assert "uri" in params["properties"]
|
||||
assert params["properties"]["uri"]["type"] == "string"
|
||||
assert params["required"] == ["uri"]
|
||||
|
||||
def test_list_prompts_schema_no_required_params(self):
|
||||
from tools.mcp_tool import _build_utility_schemas
|
||||
|
||||
schemas = _build_utility_schemas("srv")
|
||||
lp = next(s for s in schemas if s["handler_key"] == "list_prompts")
|
||||
params = lp["schema"]["parameters"]
|
||||
assert params["type"] == "object"
|
||||
assert params["properties"] == {}
|
||||
assert "required" not in params
|
||||
|
||||
def test_get_prompt_schema_requires_name(self):
|
||||
from tools.mcp_tool import _build_utility_schemas
|
||||
|
||||
schemas = _build_utility_schemas("srv")
|
||||
gp = next(s for s in schemas if s["handler_key"] == "get_prompt")
|
||||
params = gp["schema"]["parameters"]
|
||||
assert "name" in params["properties"]
|
||||
assert params["properties"]["name"]["type"] == "string"
|
||||
assert "arguments" in params["properties"]
|
||||
assert params["properties"]["arguments"]["type"] == "object"
|
||||
assert params["required"] == ["name"]
|
||||
|
||||
def test_schemas_have_descriptions(self):
|
||||
from tools.mcp_tool import _build_utility_schemas
|
||||
|
||||
schemas = _build_utility_schemas("test_srv")
|
||||
for entry in schemas:
|
||||
desc = entry["schema"]["description"]
|
||||
assert desc and len(desc) > 0
|
||||
assert "test_srv" in desc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Utility tool handlers (Resources & Prompts)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUtilityHandlers:
|
||||
"""Tests for the MCP Resources & Prompts handler functions."""
|
||||
|
||||
def _patch_mcp_loop(self):
|
||||
"""Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
|
||||
def fake_run(coro, timeout=30):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run)
|
||||
|
||||
# -- list_resources --
|
||||
|
||||
def test_list_resources_success(self):
|
||||
from tools.mcp_tool import _make_list_resources_handler, _servers
|
||||
|
||||
mock_resource = SimpleNamespace(
|
||||
uri="file:///tmp/test.txt", name="test.txt",
|
||||
description="A test file", mimeType="text/plain",
|
||||
)
|
||||
mock_session = MagicMock()
|
||||
mock_session.list_resources = AsyncMock(
|
||||
return_value=SimpleNamespace(resources=[mock_resource])
|
||||
)
|
||||
server = _make_mock_server("srv", session=mock_session)
|
||||
_servers["srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_list_resources_handler("srv", 120)
|
||||
with self._patch_mcp_loop():
|
||||
result = json.loads(handler({}))
|
||||
assert "resources" in result
|
||||
assert len(result["resources"]) == 1
|
||||
assert result["resources"][0]["uri"] == "file:///tmp/test.txt"
|
||||
assert result["resources"][0]["name"] == "test.txt"
|
||||
finally:
|
||||
_servers.pop("srv", None)
|
||||
|
||||
def test_list_resources_empty(self):
|
||||
from tools.mcp_tool import _make_list_resources_handler, _servers
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.list_resources = AsyncMock(
|
||||
return_value=SimpleNamespace(resources=[])
|
||||
)
|
||||
server = _make_mock_server("srv", session=mock_session)
|
||||
_servers["srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_list_resources_handler("srv", 120)
|
||||
with self._patch_mcp_loop():
|
||||
result = json.loads(handler({}))
|
||||
assert result["resources"] == []
|
||||
finally:
|
||||
_servers.pop("srv", None)
|
||||
|
||||
def test_list_resources_disconnected(self):
|
||||
from tools.mcp_tool import _make_list_resources_handler, _servers
|
||||
_servers.pop("ghost", None)
|
||||
handler = _make_list_resources_handler("ghost", 120)
|
||||
result = json.loads(handler({}))
|
||||
assert "error" in result
|
||||
assert "not connected" in result["error"]
|
||||
|
||||
# -- read_resource --
|
||||
|
||||
def test_read_resource_success(self):
|
||||
from tools.mcp_tool import _make_read_resource_handler, _servers
|
||||
|
||||
content_block = SimpleNamespace(text="Hello from resource")
|
||||
mock_session = MagicMock()
|
||||
mock_session.read_resource = AsyncMock(
|
||||
return_value=SimpleNamespace(contents=[content_block])
|
||||
)
|
||||
server = _make_mock_server("srv", session=mock_session)
|
||||
_servers["srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_read_resource_handler("srv", 120)
|
||||
with self._patch_mcp_loop():
|
||||
result = json.loads(handler({"uri": "file:///tmp/test.txt"}))
|
||||
assert result["result"] == "Hello from resource"
|
||||
mock_session.read_resource.assert_called_once_with("file:///tmp/test.txt")
|
||||
finally:
|
||||
_servers.pop("srv", None)
|
||||
|
||||
def test_read_resource_missing_uri(self):
|
||||
from tools.mcp_tool import _make_read_resource_handler, _servers
|
||||
|
||||
server = _make_mock_server("srv", session=MagicMock())
|
||||
_servers["srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_read_resource_handler("srv", 120)
|
||||
result = json.loads(handler({}))
|
||||
assert "error" in result
|
||||
assert "uri" in result["error"].lower()
|
||||
finally:
|
||||
_servers.pop("srv", None)
|
||||
|
||||
def test_read_resource_disconnected(self):
|
||||
from tools.mcp_tool import _make_read_resource_handler, _servers
|
||||
_servers.pop("ghost", None)
|
||||
handler = _make_read_resource_handler("ghost", 120)
|
||||
result = json.loads(handler({"uri": "test://x"}))
|
||||
assert "error" in result
|
||||
assert "not connected" in result["error"]
|
||||
|
||||
# -- list_prompts --
|
||||
|
||||
def test_list_prompts_success(self):
|
||||
from tools.mcp_tool import _make_list_prompts_handler, _servers
|
||||
|
||||
mock_prompt = SimpleNamespace(
|
||||
name="summarize", description="Summarize text",
|
||||
arguments=[
|
||||
SimpleNamespace(name="text", description="Text to summarize", required=True),
|
||||
],
|
||||
)
|
||||
mock_session = MagicMock()
|
||||
mock_session.list_prompts = AsyncMock(
|
||||
return_value=SimpleNamespace(prompts=[mock_prompt])
|
||||
)
|
||||
server = _make_mock_server("srv", session=mock_session)
|
||||
_servers["srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_list_prompts_handler("srv", 120)
|
||||
with self._patch_mcp_loop():
|
||||
result = json.loads(handler({}))
|
||||
assert "prompts" in result
|
||||
assert len(result["prompts"]) == 1
|
||||
assert result["prompts"][0]["name"] == "summarize"
|
||||
assert result["prompts"][0]["arguments"][0]["name"] == "text"
|
||||
finally:
|
||||
_servers.pop("srv", None)
|
||||
|
||||
def test_list_prompts_empty(self):
|
||||
from tools.mcp_tool import _make_list_prompts_handler, _servers
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.list_prompts = AsyncMock(
|
||||
return_value=SimpleNamespace(prompts=[])
|
||||
)
|
||||
server = _make_mock_server("srv", session=mock_session)
|
||||
_servers["srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_list_prompts_handler("srv", 120)
|
||||
with self._patch_mcp_loop():
|
||||
result = json.loads(handler({}))
|
||||
assert result["prompts"] == []
|
||||
finally:
|
||||
_servers.pop("srv", None)
|
||||
|
||||
def test_list_prompts_disconnected(self):
|
||||
from tools.mcp_tool import _make_list_prompts_handler, _servers
|
||||
_servers.pop("ghost", None)
|
||||
handler = _make_list_prompts_handler("ghost", 120)
|
||||
result = json.loads(handler({}))
|
||||
assert "error" in result
|
||||
assert "not connected" in result["error"]
|
||||
|
||||
# -- get_prompt --
|
||||
|
||||
def test_get_prompt_success(self):
|
||||
from tools.mcp_tool import _make_get_prompt_handler, _servers
|
||||
|
||||
mock_msg = SimpleNamespace(
|
||||
role="assistant",
|
||||
content=SimpleNamespace(text="Here is a summary of your text."),
|
||||
)
|
||||
mock_session = MagicMock()
|
||||
mock_session.get_prompt = AsyncMock(
|
||||
return_value=SimpleNamespace(messages=[mock_msg], description=None)
|
||||
)
|
||||
server = _make_mock_server("srv", session=mock_session)
|
||||
_servers["srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_get_prompt_handler("srv", 120)
|
||||
with self._patch_mcp_loop():
|
||||
result = json.loads(handler({"name": "summarize", "arguments": {"text": "hello"}}))
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) == 1
|
||||
assert result["messages"][0]["role"] == "assistant"
|
||||
assert "summary" in result["messages"][0]["content"].lower()
|
||||
mock_session.get_prompt.assert_called_once_with(
|
||||
"summarize", arguments={"text": "hello"}
|
||||
)
|
||||
finally:
|
||||
_servers.pop("srv", None)
|
||||
|
||||
def test_get_prompt_missing_name(self):
|
||||
from tools.mcp_tool import _make_get_prompt_handler, _servers
|
||||
|
||||
server = _make_mock_server("srv", session=MagicMock())
|
||||
_servers["srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_get_prompt_handler("srv", 120)
|
||||
result = json.loads(handler({}))
|
||||
assert "error" in result
|
||||
assert "name" in result["error"].lower()
|
||||
finally:
|
||||
_servers.pop("srv", None)
|
||||
|
||||
def test_get_prompt_disconnected(self):
|
||||
from tools.mcp_tool import _make_get_prompt_handler, _servers
|
||||
_servers.pop("ghost", None)
|
||||
handler = _make_get_prompt_handler("ghost", 120)
|
||||
result = json.loads(handler({"name": "test"}))
|
||||
assert "error" in result
|
||||
assert "not connected" in result["error"]
|
||||
|
||||
def test_get_prompt_default_arguments(self):
|
||||
from tools.mcp_tool import _make_get_prompt_handler, _servers
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get_prompt = AsyncMock(
|
||||
return_value=SimpleNamespace(messages=[], description=None)
|
||||
)
|
||||
server = _make_mock_server("srv", session=mock_session)
|
||||
_servers["srv"] = server
|
||||
|
||||
try:
|
||||
handler = _make_get_prompt_handler("srv", 120)
|
||||
with self._patch_mcp_loop():
|
||||
handler({"name": "test_prompt"})
|
||||
# arguments defaults to {} when not provided
|
||||
mock_session.get_prompt.assert_called_once_with(
|
||||
"test_prompt", arguments={}
|
||||
)
|
||||
finally:
|
||||
_servers.pop("srv", None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Utility tools registration in _discover_and_register_server
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUtilityToolRegistration:
|
||||
"""Verify utility tools are registered alongside regular MCP tools."""
|
||||
|
||||
def test_utility_tools_registered(self):
|
||||
"""_discover_and_register_server registers all 4 utility tools."""
|
||||
from tools.registry import ToolRegistry
|
||||
from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
|
||||
|
||||
mock_registry = ToolRegistry()
|
||||
mock_tools = [_make_mcp_tool("read_file", "Read a file")]
|
||||
mock_session = MagicMock()
|
||||
|
||||
async def fake_connect(name, config):
|
||||
server = MCPServerTask(name)
|
||||
server.session = mock_session
|
||||
server._tools = mock_tools
|
||||
return server
|
||||
|
||||
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
|
||||
patch("tools.registry.registry", mock_registry):
|
||||
registered = asyncio.run(
|
||||
_discover_and_register_server("fs", {"command": "npx", "args": []})
|
||||
)
|
||||
|
||||
# Regular tool + 4 utility tools
|
||||
assert "mcp_fs_read_file" in registered
|
||||
assert "mcp_fs_list_resources" in registered
|
||||
assert "mcp_fs_read_resource" in registered
|
||||
assert "mcp_fs_list_prompts" in registered
|
||||
assert "mcp_fs_get_prompt" in registered
|
||||
assert len(registered) == 5
|
||||
|
||||
# All in the registry
|
||||
all_names = mock_registry.get_all_tool_names()
|
||||
for name in registered:
|
||||
assert name in all_names
|
||||
|
||||
_servers.pop("fs", None)
|
||||
|
||||
def test_utility_tools_in_same_toolset(self):
|
||||
"""Utility tools belong to the same mcp-{server} toolset."""
|
||||
from tools.registry import ToolRegistry
|
||||
from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
|
||||
|
||||
mock_registry = ToolRegistry()
|
||||
mock_session = MagicMock()
|
||||
|
||||
async def fake_connect(name, config):
|
||||
server = MCPServerTask(name)
|
||||
server.session = mock_session
|
||||
server._tools = []
|
||||
return server
|
||||
|
||||
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
|
||||
patch("tools.registry.registry", mock_registry):
|
||||
asyncio.run(
|
||||
_discover_and_register_server("myserv", {"command": "test"})
|
||||
)
|
||||
|
||||
# Check that utility tools are in the right toolset
|
||||
for tool_name in ["mcp_myserv_list_resources", "mcp_myserv_read_resource",
|
||||
"mcp_myserv_list_prompts", "mcp_myserv_get_prompt"]:
|
||||
entry = mock_registry._tools.get(tool_name)
|
||||
assert entry is not None, f"{tool_name} not found in registry"
|
||||
assert entry.toolset == "mcp-myserv"
|
||||
|
||||
_servers.pop("myserv", None)
|
||||
|
||||
def test_utility_tools_have_check_fn(self):
|
||||
"""Utility tools have a working check_fn."""
|
||||
from tools.registry import ToolRegistry
|
||||
from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
|
||||
|
||||
mock_registry = ToolRegistry()
|
||||
mock_session = MagicMock()
|
||||
|
||||
async def fake_connect(name, config):
|
||||
server = MCPServerTask(name)
|
||||
server.session = mock_session
|
||||
server._tools = []
|
||||
return server
|
||||
|
||||
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
|
||||
patch("tools.registry.registry", mock_registry):
|
||||
asyncio.run(
|
||||
_discover_and_register_server("chk", {"command": "test"})
|
||||
)
|
||||
|
||||
entry = mock_registry._tools.get("mcp_chk_list_resources")
|
||||
assert entry is not None
|
||||
# Server is connected, check_fn should return True
|
||||
assert entry.check_fn() is True
|
||||
|
||||
# Disconnect the server
|
||||
_servers["chk"].session = None
|
||||
assert entry.check_fn() is False
|
||||
|
||||
_servers.pop("chk", None)
|
||||
|
||||
@@ -475,6 +475,190 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
||||
return _handler
|
||||
|
||||
|
||||
def _make_list_resources_handler(server_name: str, tool_timeout: float):
|
||||
"""Return a sync handler that lists resources from an MCP server."""
|
||||
|
||||
def _handler(args: dict, **kwargs) -> str:
|
||||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
|
||||
async def _call():
|
||||
result = await server.session.list_resources()
|
||||
resources = []
|
||||
for r in (result.resources if hasattr(result, "resources") else []):
|
||||
entry = {}
|
||||
if hasattr(r, "uri"):
|
||||
entry["uri"] = str(r.uri)
|
||||
if hasattr(r, "name"):
|
||||
entry["name"] = r.name
|
||||
if hasattr(r, "description") and r.description:
|
||||
entry["description"] = r.description
|
||||
if hasattr(r, "mimeType") and r.mimeType:
|
||||
entry["mimeType"] = r.mimeType
|
||||
resources.append(entry)
|
||||
return json.dumps({"resources": resources})
|
||||
|
||||
try:
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"MCP %s/list_resources failed: %s", server_name, exc,
|
||||
)
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
|
||||
return _handler
|
||||
|
||||
|
||||
def _make_read_resource_handler(server_name: str, tool_timeout: float):
|
||||
"""Return a sync handler that reads a resource by URI from an MCP server."""
|
||||
|
||||
def _handler(args: dict, **kwargs) -> str:
|
||||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
|
||||
uri = args.get("uri")
|
||||
if not uri:
|
||||
return json.dumps({"error": "Missing required parameter 'uri'"})
|
||||
|
||||
async def _call():
|
||||
result = await server.session.read_resource(uri)
|
||||
# read_resource returns ReadResourceResult with .contents list
|
||||
parts: List[str] = []
|
||||
contents = result.contents if hasattr(result, "contents") else []
|
||||
for block in contents:
|
||||
if hasattr(block, "text"):
|
||||
parts.append(block.text)
|
||||
elif hasattr(block, "blob"):
|
||||
parts.append(f"[binary data, {len(block.blob)} bytes]")
|
||||
return json.dumps({"result": "\n".join(parts) if parts else ""})
|
||||
|
||||
try:
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"MCP %s/read_resource failed: %s", server_name, exc,
|
||||
)
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
|
||||
return _handler
|
||||
|
||||
|
||||
def _make_list_prompts_handler(server_name: str, tool_timeout: float):
|
||||
"""Return a sync handler that lists prompts from an MCP server."""
|
||||
|
||||
def _handler(args: dict, **kwargs) -> str:
|
||||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
|
||||
async def _call():
|
||||
result = await server.session.list_prompts()
|
||||
prompts = []
|
||||
for p in (result.prompts if hasattr(result, "prompts") else []):
|
||||
entry = {}
|
||||
if hasattr(p, "name"):
|
||||
entry["name"] = p.name
|
||||
if hasattr(p, "description") and p.description:
|
||||
entry["description"] = p.description
|
||||
if hasattr(p, "arguments") and p.arguments:
|
||||
entry["arguments"] = [
|
||||
{
|
||||
"name": a.name,
|
||||
**({"description": a.description} if hasattr(a, "description") and a.description else {}),
|
||||
**({"required": a.required} if hasattr(a, "required") else {}),
|
||||
}
|
||||
for a in p.arguments
|
||||
]
|
||||
prompts.append(entry)
|
||||
return json.dumps({"prompts": prompts})
|
||||
|
||||
try:
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"MCP %s/list_prompts failed: %s", server_name, exc,
|
||||
)
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
|
||||
return _handler
|
||||
|
||||
|
||||
def _make_get_prompt_handler(server_name: str, tool_timeout: float):
|
||||
"""Return a sync handler that gets a prompt by name from an MCP server."""
|
||||
|
||||
def _handler(args: dict, **kwargs) -> str:
|
||||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
|
||||
name = args.get("name")
|
||||
if not name:
|
||||
return json.dumps({"error": "Missing required parameter 'name'"})
|
||||
arguments = args.get("arguments", {})
|
||||
|
||||
async def _call():
|
||||
result = await server.session.get_prompt(name, arguments=arguments)
|
||||
# GetPromptResult has .messages list
|
||||
messages = []
|
||||
for msg in (result.messages if hasattr(result, "messages") else []):
|
||||
entry = {}
|
||||
if hasattr(msg, "role"):
|
||||
entry["role"] = msg.role
|
||||
if hasattr(msg, "content"):
|
||||
content = msg.content
|
||||
if hasattr(content, "text"):
|
||||
entry["content"] = content.text
|
||||
elif isinstance(content, str):
|
||||
entry["content"] = content
|
||||
else:
|
||||
entry["content"] = str(content)
|
||||
messages.append(entry)
|
||||
resp = {"messages": messages}
|
||||
if hasattr(result, "description") and result.description:
|
||||
resp["description"] = result.description
|
||||
return json.dumps(resp)
|
||||
|
||||
try:
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"MCP %s/get_prompt failed: %s", server_name, exc,
|
||||
)
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
|
||||
return _handler
|
||||
|
||||
|
||||
def _make_check_fn(server_name: str):
|
||||
"""Return a check function that verifies the MCP connection is alive."""
|
||||
|
||||
@@ -515,6 +699,77 @@ def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def _build_utility_schemas(server_name: str) -> List[dict]:
|
||||
"""Build schemas for the MCP utility tools (resources & prompts).
|
||||
|
||||
Returns a list of (schema, handler_factory_name) tuples encoded as dicts
|
||||
with keys: schema, handler_key.
|
||||
"""
|
||||
safe_name = server_name.replace("-", "_").replace(".", "_")
|
||||
return [
|
||||
{
|
||||
"schema": {
|
||||
"name": f"mcp_{safe_name}_list_resources",
|
||||
"description": f"List available resources from MCP server '{server_name}'",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
},
|
||||
"handler_key": "list_resources",
|
||||
},
|
||||
{
|
||||
"schema": {
|
||||
"name": f"mcp_{safe_name}_read_resource",
|
||||
"description": f"Read a resource by URI from MCP server '{server_name}'",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"uri": {
|
||||
"type": "string",
|
||||
"description": "URI of the resource to read",
|
||||
},
|
||||
},
|
||||
"required": ["uri"],
|
||||
},
|
||||
},
|
||||
"handler_key": "read_resource",
|
||||
},
|
||||
{
|
||||
"schema": {
|
||||
"name": f"mcp_{safe_name}_list_prompts",
|
||||
"description": f"List available prompts from MCP server '{server_name}'",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
},
|
||||
"handler_key": "list_prompts",
|
||||
},
|
||||
{
|
||||
"schema": {
|
||||
"name": f"mcp_{safe_name}_get_prompt",
|
||||
"description": f"Get a prompt by name from MCP server '{server_name}'",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Name of the prompt to retrieve",
|
||||
},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"description": "Optional arguments to pass to the prompt",
|
||||
},
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
},
|
||||
"handler_key": "get_prompt",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _existing_tool_names() -> List[str]:
|
||||
"""Return tool names for all currently connected servers."""
|
||||
names: List[str] = []
|
||||
@@ -522,12 +777,18 @@ def _existing_tool_names() -> List[str]:
|
||||
for mcp_tool in server._tools:
|
||||
schema = _convert_mcp_schema(sname, mcp_tool)
|
||||
names.append(schema["name"])
|
||||
# Also include utility tool names
|
||||
for entry in _build_utility_schemas(sname):
|
||||
names.append(entry["schema"]["name"])
|
||||
return names
|
||||
|
||||
|
||||
async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
"""Connect to a single MCP server, discover tools, and register them.
|
||||
|
||||
Also registers utility tools for MCP Resources and Prompts support
|
||||
(list_resources, read_resource, list_prompts, get_prompt).
|
||||
|
||||
Returns list of registered tool names.
|
||||
"""
|
||||
from tools.registry import registry
|
||||
@@ -559,6 +820,30 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
)
|
||||
registered_names.append(tool_name_prefixed)
|
||||
|
||||
# Register MCP Resources & Prompts utility tools
|
||||
_handler_factories = {
|
||||
"list_resources": _make_list_resources_handler,
|
||||
"read_resource": _make_read_resource_handler,
|
||||
"list_prompts": _make_list_prompts_handler,
|
||||
"get_prompt": _make_get_prompt_handler,
|
||||
}
|
||||
check_fn = _make_check_fn(name)
|
||||
for entry in _build_utility_schemas(name):
|
||||
schema = entry["schema"]
|
||||
handler_key = entry["handler_key"]
|
||||
handler = _handler_factories[handler_key](name, server.tool_timeout)
|
||||
|
||||
registry.register(
|
||||
name=schema["name"],
|
||||
toolset=toolset_name,
|
||||
schema=schema,
|
||||
handler=handler,
|
||||
check_fn=check_fn,
|
||||
is_async=False,
|
||||
description=schema["description"],
|
||||
)
|
||||
registered_names.append(schema["name"])
|
||||
|
||||
# Create a custom toolset so these tools are discoverable
|
||||
if registered_names:
|
||||
create_custom_toolset(
|
||||
@@ -620,10 +905,8 @@ def discover_mcp_tools() -> List[str]:
|
||||
try:
|
||||
registered = await _discover_and_register_server(name, cfg)
|
||||
transport_type = "HTTP" if "url" in cfg else "stdio"
|
||||
print(f" MCP: '{name}' ({transport_type}) — {len(registered)} tool(s)")
|
||||
return registered
|
||||
except Exception as exc:
|
||||
print(f" MCP: '{name}' — FAILED: {exc}")
|
||||
logger.warning(
|
||||
"Failed to connect to MCP server '%s': %s",
|
||||
name, exc,
|
||||
@@ -666,12 +949,49 @@ def discover_mcp_tools() -> List[str]:
|
||||
summary = f" MCP: {len(all_tools)} tool(s) from {ok_servers} server(s)"
|
||||
if failed_count:
|
||||
summary += f" ({failed_count} failed)"
|
||||
print(summary)
|
||||
logger.info(summary)
|
||||
|
||||
# Return ALL registered tools (existing + newly discovered)
|
||||
return _existing_tool_names()
|
||||
|
||||
|
||||
def get_mcp_status() -> List[dict]:
|
||||
"""Return status of all configured MCP servers for banner display.
|
||||
|
||||
Returns a list of dicts with keys: name, transport, tools, connected.
|
||||
Includes both successfully connected servers and configured-but-failed ones.
|
||||
"""
|
||||
result: List[dict] = []
|
||||
|
||||
# Get configured servers from config
|
||||
configured = _load_mcp_config()
|
||||
if not configured:
|
||||
return result
|
||||
|
||||
with _lock:
|
||||
active_servers = dict(_servers)
|
||||
|
||||
for name, cfg in configured.items():
|
||||
transport = "http" if "url" in cfg else "stdio"
|
||||
server = active_servers.get(name)
|
||||
if server and server.session is not None:
|
||||
result.append({
|
||||
"name": name,
|
||||
"transport": transport,
|
||||
"tools": len(server._tools),
|
||||
"connected": True,
|
||||
})
|
||||
else:
|
||||
result.append({
|
||||
"name": name,
|
||||
"transport": transport,
|
||||
"tools": 0,
|
||||
"connected": False,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def shutdown_mcp_servers():
|
||||
"""Close all MCP server connections and stop the background loop.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user