Phase 4: Tool Registry Auto-Discovery
- @mcp_tool decorator for marking functions as tools - ToolDiscovery class for introspecting modules and packages - Automatic JSON schema generation from type hints - AST-based discovery for files (without importing) - Auto-bootstrap on startup (packages=['tools'] by default) - Support for tags, categories, and metadata - Updated registry with register_tool() convenience method - Environment variable MCP_AUTO_BOOTSTRAP to disable - 39 tests with proper isolation and cleanup Files Added: - src/mcp/discovery.py: Tool discovery and introspection - src/mcp/bootstrap.py: Auto-bootstrap functionality - tests/test_mcp_discovery.py: 26 tests - tests/test_mcp_bootstrap.py: 13 tests Files Modified: - src/mcp/registry.py: Added tags, source_module, auto_discovered fields - src/mcp/__init__.py: Export discovery and bootstrap modules - src/dashboard/app.py: Auto-bootstrap on startup
This commit is contained in:
@@ -102,6 +102,15 @@ async def lifespan(app: FastAPI):
|
||||
except Exception as exc:
|
||||
logger.error("Failed to spawn persona agents: %s", exc)
|
||||
|
||||
# Auto-bootstrap MCP tools
|
||||
from mcp.bootstrap import auto_bootstrap, get_bootstrap_status
|
||||
try:
|
||||
registered = auto_bootstrap()
|
||||
if registered:
|
||||
logger.info("MCP auto-bootstrap: %d tools registered", len(registered))
|
||||
except Exception as exc:
|
||||
logger.warning("MCP auto-bootstrap failed: %s", exc)
|
||||
|
||||
# Initialise Spark Intelligence engine
|
||||
from spark.engine import spark_engine
|
||||
if spark_engine.enabled:
|
||||
|
||||
@@ -1,17 +1,30 @@
|
||||
"""MCP (Model Context Protocol) package.
|
||||
|
||||
Provides tool registry, server, and schema management.
|
||||
Provides tool registry, server, schema management, and auto-discovery.
|
||||
"""
|
||||
|
||||
from mcp.registry import tool_registry, register_tool
|
||||
from mcp.registry import tool_registry, register_tool, ToolRegistry
|
||||
from mcp.server import mcp_server, MCPServer, MCPHTTPServer
|
||||
from mcp.schemas.base import create_tool_schema
|
||||
from mcp.discovery import ToolDiscovery, mcp_tool, get_discovery
|
||||
from mcp.bootstrap import auto_bootstrap, get_bootstrap_status
|
||||
|
||||
__all__ = [
|
||||
# Registry
|
||||
"tool_registry",
|
||||
"register_tool",
|
||||
"ToolRegistry",
|
||||
# Server
|
||||
"mcp_server",
|
||||
"MCPServer",
|
||||
"MCPHTTPServer",
|
||||
# Schemas
|
||||
"create_tool_schema",
|
||||
# Discovery
|
||||
"ToolDiscovery",
|
||||
"mcp_tool",
|
||||
"get_discovery",
|
||||
# Bootstrap
|
||||
"auto_bootstrap",
|
||||
"get_bootstrap_status",
|
||||
]
|
||||
|
||||
@@ -1,71 +1,148 @@
|
||||
"""Bootstrap the MCP system by loading all tools.
|
||||
"""MCP Auto-Bootstrap — Auto-discover and register tools on startup.
|
||||
|
||||
This module is responsible for:
|
||||
1. Loading all tool modules from src/tools/
|
||||
2. Registering them with the tool registry
|
||||
3. Verifying tool health
|
||||
4. Reporting status
|
||||
Usage:
|
||||
from mcp.bootstrap import auto_bootstrap
|
||||
|
||||
# Auto-discover from 'tools' package
|
||||
registered = auto_bootstrap()
|
||||
|
||||
# Or specify custom packages
|
||||
registered = auto_bootstrap(packages=["tools", "custom_tools"])
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from mcp.registry import tool_registry
|
||||
from .discovery import ToolDiscovery, get_discovery
|
||||
from .registry import ToolRegistry, tool_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Tool modules to load
|
||||
TOOL_MODULES = [
|
||||
"tools.web_search",
|
||||
"tools.file_ops",
|
||||
"tools.code_exec",
|
||||
"tools.memory_tool",
|
||||
]
|
||||
# Default packages to scan for tools
|
||||
DEFAULT_TOOL_PACKAGES = ["tools"]
|
||||
|
||||
# Environment variable to disable auto-bootstrap
|
||||
AUTO_BOOTSTRAP_ENV_VAR = "MCP_AUTO_BOOTSTRAP"
|
||||
|
||||
|
||||
def bootstrap_mcp() -> dict:
|
||||
"""Initialize the MCP system by loading all tools.
|
||||
def auto_bootstrap(
|
||||
packages: Optional[list[str]] = None,
|
||||
registry: Optional[ToolRegistry] = None,
|
||||
force: bool = False,
|
||||
) -> list[str]:
|
||||
"""Auto-discover and register MCP tools.
|
||||
|
||||
Args:
|
||||
packages: Packages to scan (defaults to ["tools"])
|
||||
registry: Registry to register tools with (defaults to singleton)
|
||||
force: Force bootstrap even if disabled by env var
|
||||
|
||||
Returns:
|
||||
Status dict with loaded tools and any errors
|
||||
List of registered tool names
|
||||
"""
|
||||
loaded = []
|
||||
errors = []
|
||||
# Check if auto-bootstrap is disabled
|
||||
if not force and os.environ.get(AUTO_BOOTSTRAP_ENV_VAR, "1") == "0":
|
||||
logger.info("MCP auto-bootstrap disabled via %s", AUTO_BOOTSTRAP_ENV_VAR)
|
||||
return []
|
||||
|
||||
for module_name in TOOL_MODULES:
|
||||
packages = packages or DEFAULT_TOOL_PACKAGES
|
||||
registry = registry or tool_registry
|
||||
discovery = get_discovery(registry=registry)
|
||||
|
||||
registered: list[str] = []
|
||||
|
||||
logger.info("Starting MCP auto-bootstrap from packages: %s", packages)
|
||||
|
||||
for package in packages:
|
||||
try:
|
||||
# Import the module (this triggers @register_tool decorators)
|
||||
importlib.import_module(module_name)
|
||||
loaded.append(module_name)
|
||||
logger.info("Loaded tool module: %s", module_name)
|
||||
# Check if package exists
|
||||
try:
|
||||
__import__(package)
|
||||
except ImportError:
|
||||
logger.debug("Package %s not found, skipping", package)
|
||||
continue
|
||||
|
||||
# Discover and register
|
||||
tools = discovery.auto_register(package)
|
||||
registered.extend(tools)
|
||||
|
||||
except Exception as exc:
|
||||
errors.append({"module": module_name, "error": str(exc)})
|
||||
logger.error("Failed to load tool module %s: %s", module_name, exc)
|
||||
logger.warning("Failed to bootstrap from %s: %s", package, exc)
|
||||
|
||||
# Get registry status
|
||||
registry_status = tool_registry.to_dict()
|
||||
|
||||
status = {
|
||||
"loaded_modules": loaded,
|
||||
"errors": errors,
|
||||
"total_tools": len(registry_status.get("tools", [])),
|
||||
"tools_by_category": registry_status.get("categories", {}),
|
||||
"tool_names": tool_registry.list_tools(),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"MCP Bootstrap complete: %d tools loaded from %d modules",
|
||||
status["total_tools"],
|
||||
len(loaded)
|
||||
)
|
||||
|
||||
return status
|
||||
logger.info("MCP auto-bootstrap complete: %d tools registered", len(registered))
|
||||
return registered
|
||||
|
||||
|
||||
def get_tool_status() -> dict:
|
||||
"""Get current status of all tools."""
|
||||
def bootstrap_from_directory(
|
||||
directory: Path,
|
||||
registry: Optional[ToolRegistry] = None,
|
||||
) -> list[str]:
|
||||
"""Bootstrap tools from a directory of Python files.
|
||||
|
||||
Args:
|
||||
directory: Directory containing Python files with tools
|
||||
registry: Registry to register tools with
|
||||
|
||||
Returns:
|
||||
List of registered tool names
|
||||
"""
|
||||
registry = registry or tool_registry
|
||||
discovery = get_discovery(registry=registry)
|
||||
|
||||
registered: list[str] = []
|
||||
|
||||
if not directory.exists():
|
||||
logger.warning("Tools directory not found: %s", directory)
|
||||
return registered
|
||||
|
||||
logger.info("Bootstrapping tools from directory: %s", directory)
|
||||
|
||||
# Find all Python files
|
||||
for py_file in directory.rglob("*.py"):
|
||||
if py_file.name.startswith("_"):
|
||||
continue
|
||||
|
||||
try:
|
||||
discovered = discovery.discover_file(py_file)
|
||||
|
||||
for tool in discovered:
|
||||
if tool.function is None:
|
||||
# Need to import and resolve the function
|
||||
continue
|
||||
|
||||
try:
|
||||
registry.register_tool(
|
||||
name=tool.name,
|
||||
function=tool.function,
|
||||
description=tool.description,
|
||||
category=tool.category,
|
||||
tags=tool.tags,
|
||||
)
|
||||
registered.append(tool.name)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to register %s: %s", tool.name, exc)
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to process %s: %s", py_file, exc)
|
||||
|
||||
logger.info("Directory bootstrap complete: %d tools registered", len(registered))
|
||||
return registered
|
||||
|
||||
|
||||
def get_bootstrap_status() -> dict:
|
||||
"""Get auto-bootstrap status.
|
||||
|
||||
Returns:
|
||||
Dict with bootstrap status info
|
||||
"""
|
||||
discovery = get_discovery()
|
||||
registry = tool_registry
|
||||
|
||||
return {
|
||||
"tools": tool_registry.to_dict(),
|
||||
"metrics": tool_registry.get_metrics(),
|
||||
"auto_bootstrap_enabled": os.environ.get(AUTO_BOOTSTRAP_ENV_VAR, "1") != "0",
|
||||
"discovered_tools_count": len(discovery.get_discovered()),
|
||||
"registered_tools_count": len(registry.list_tools()),
|
||||
"default_packages": DEFAULT_TOOL_PACKAGES,
|
||||
}
|
||||
|
||||
441
src/mcp/discovery.py
Normal file
441
src/mcp/discovery.py
Normal file
@@ -0,0 +1,441 @@
|
||||
"""MCP Tool Auto-Discovery — Introspect Python modules to find tools.
|
||||
|
||||
Automatically discovers functions marked with @mcp_tool decorator
|
||||
and registers them with the MCP registry. Generates JSON schemas
|
||||
from type hints.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import pkgutil
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional, get_type_hints
|
||||
|
||||
from .registry import ToolRegistry, tool_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Decorator to mark functions as MCP tools
|
||||
def mcp_tool(
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
category: str = "general",
|
||||
tags: Optional[list[str]] = None,
|
||||
):
|
||||
"""Decorator to mark a function as an MCP tool.
|
||||
|
||||
Args:
|
||||
name: Tool name (defaults to function name)
|
||||
description: Tool description (defaults to docstring)
|
||||
category: Tool category for organization
|
||||
tags: Additional tags for filtering
|
||||
|
||||
Example:
|
||||
@mcp_tool(name="weather", category="external")
|
||||
def get_weather(city: str) -> dict:
|
||||
'''Get weather for a city.'''
|
||||
...
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
func._mcp_tool = True
|
||||
func._mcp_name = name or func.__name__
|
||||
func._mcp_description = description or (func.__doc__ or "").strip()
|
||||
func._mcp_category = category
|
||||
func._mcp_tags = tags or []
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiscoveredTool:
|
||||
"""A tool discovered via introspection."""
|
||||
name: str
|
||||
description: str
|
||||
function: Callable
|
||||
module: str
|
||||
category: str
|
||||
tags: list[str]
|
||||
parameters_schema: dict[str, Any]
|
||||
returns_schema: dict[str, Any]
|
||||
source_file: Optional[str] = None
|
||||
line_number: int = 0
|
||||
|
||||
|
||||
class ToolDiscovery:
|
||||
"""Discovers and registers MCP tools from Python modules.
|
||||
|
||||
Usage:
|
||||
discovery = ToolDiscovery()
|
||||
|
||||
# Discover from a module
|
||||
tools = discovery.discover_module("tools.git")
|
||||
|
||||
# Auto-register with registry
|
||||
discovery.auto_register("tools")
|
||||
|
||||
# Discover from all installed packages
|
||||
tools = discovery.discover_all_packages()
|
||||
"""
|
||||
|
||||
def __init__(self, registry: Optional[ToolRegistry] = None) -> None:
|
||||
self.registry = registry or tool_registry
|
||||
self._discovered: list[DiscoveredTool] = []
|
||||
|
||||
def discover_module(self, module_name: str) -> list[DiscoveredTool]:
|
||||
"""Discover all MCP tools in a module.
|
||||
|
||||
Args:
|
||||
module_name: Dotted path to module (e.g., "tools.git")
|
||||
|
||||
Returns:
|
||||
List of discovered tools
|
||||
"""
|
||||
discovered = []
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except ImportError as exc:
|
||||
logger.warning("Failed to import module %s: %s", module_name, exc)
|
||||
return discovered
|
||||
|
||||
# Get module file path for source location
|
||||
module_file = getattr(module, "__file__", None)
|
||||
|
||||
# Iterate through module members
|
||||
for name, obj in inspect.getmembers(module):
|
||||
# Skip private and non-callable
|
||||
if name.startswith("_") or not callable(obj):
|
||||
continue
|
||||
|
||||
# Check if marked as MCP tool
|
||||
if not getattr(obj, "_mcp_tool", False):
|
||||
continue
|
||||
|
||||
# Get source location
|
||||
try:
|
||||
source_file = inspect.getfile(obj)
|
||||
line_number = inspect.getsourcelines(obj)[1]
|
||||
except (OSError, TypeError):
|
||||
source_file = module_file
|
||||
line_number = 0
|
||||
|
||||
# Build schemas from type hints
|
||||
try:
|
||||
sig = inspect.signature(obj)
|
||||
parameters_schema = self._build_parameters_schema(sig)
|
||||
returns_schema = self._build_returns_schema(sig, obj)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to build schema for %s: %s", name, exc)
|
||||
parameters_schema = {"type": "object", "properties": {}}
|
||||
returns_schema = {}
|
||||
|
||||
tool = DiscoveredTool(
|
||||
name=getattr(obj, "_mcp_name", name),
|
||||
description=getattr(obj, "_mcp_description", obj.__doc__ or ""),
|
||||
function=obj,
|
||||
module=module_name,
|
||||
category=getattr(obj, "_mcp_category", "general"),
|
||||
tags=getattr(obj, "_mcp_tags", []),
|
||||
parameters_schema=parameters_schema,
|
||||
returns_schema=returns_schema,
|
||||
source_file=source_file,
|
||||
line_number=line_number,
|
||||
)
|
||||
|
||||
discovered.append(tool)
|
||||
logger.debug("Discovered tool: %s from %s", tool.name, module_name)
|
||||
|
||||
self._discovered.extend(discovered)
|
||||
logger.info("Discovered %d tools from module %s", len(discovered), module_name)
|
||||
return discovered
|
||||
|
||||
def discover_package(self, package_name: str, recursive: bool = True) -> list[DiscoveredTool]:
|
||||
"""Discover tools from all modules in a package.
|
||||
|
||||
Args:
|
||||
package_name: Package name (e.g., "tools")
|
||||
recursive: Whether to search subpackages
|
||||
|
||||
Returns:
|
||||
List of discovered tools
|
||||
"""
|
||||
discovered = []
|
||||
|
||||
try:
|
||||
package = importlib.import_module(package_name)
|
||||
except ImportError as exc:
|
||||
logger.warning("Failed to import package %s: %s", package_name, exc)
|
||||
return discovered
|
||||
|
||||
package_path = getattr(package, "__path__", [])
|
||||
if not package_path:
|
||||
# Not a package, treat as module
|
||||
return self.discover_module(package_name)
|
||||
|
||||
# Walk package modules
|
||||
for _, name, is_pkg in pkgutil.iter_modules(package_path, prefix=f"{package_name}."):
|
||||
if is_pkg and recursive:
|
||||
discovered.extend(self.discover_package(name, recursive=True))
|
||||
else:
|
||||
discovered.extend(self.discover_module(name))
|
||||
|
||||
return discovered
|
||||
|
||||
def discover_file(self, file_path: Path) -> list[DiscoveredTool]:
|
||||
"""Discover tools from a Python file.
|
||||
|
||||
Args:
|
||||
file_path: Path to Python file
|
||||
|
||||
Returns:
|
||||
List of discovered tools
|
||||
"""
|
||||
discovered = []
|
||||
|
||||
try:
|
||||
source = file_path.read_text()
|
||||
tree = ast.parse(source)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to parse %s: %s", file_path, exc)
|
||||
return discovered
|
||||
|
||||
# Find all decorated functions
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.FunctionDef):
|
||||
continue
|
||||
|
||||
# Check for @mcp_tool decorator
|
||||
is_tool = False
|
||||
tool_name = node.name
|
||||
tool_description = ast.get_docstring(node) or ""
|
||||
tool_category = "general"
|
||||
tool_tags: list[str] = []
|
||||
|
||||
for decorator in node.decorator_list:
|
||||
if isinstance(decorator, ast.Call):
|
||||
if isinstance(decorator.func, ast.Name) and decorator.func.id == "mcp_tool":
|
||||
is_tool = True
|
||||
# Extract decorator arguments
|
||||
for kw in decorator.keywords:
|
||||
if kw.arg == "name" and isinstance(kw.value, ast.Constant):
|
||||
tool_name = kw.value.value
|
||||
elif kw.arg == "description" and isinstance(kw.value, ast.Constant):
|
||||
tool_description = kw.value.value
|
||||
elif kw.arg == "category" and isinstance(kw.value, ast.Constant):
|
||||
tool_category = kw.value.value
|
||||
elif kw.arg == "tags" and isinstance(kw.value, ast.List):
|
||||
tool_tags = [
|
||||
elt.value for elt in kw.value.elts
|
||||
if isinstance(elt, ast.Constant)
|
||||
]
|
||||
elif isinstance(decorator, ast.Name) and decorator.id == "mcp_tool":
|
||||
is_tool = True
|
||||
|
||||
if not is_tool:
|
||||
continue
|
||||
|
||||
# Build parameter schema from AST
|
||||
parameters_schema = self._build_schema_from_ast(node)
|
||||
|
||||
# We can't get the actual function without importing
|
||||
# So create a placeholder that will be resolved later
|
||||
tool = DiscoveredTool(
|
||||
name=tool_name,
|
||||
description=tool_description,
|
||||
function=None, # Will be resolved when registered
|
||||
module=str(file_path),
|
||||
category=tool_category,
|
||||
tags=tool_tags,
|
||||
parameters_schema=parameters_schema,
|
||||
returns_schema={"type": "object"},
|
||||
source_file=str(file_path),
|
||||
line_number=node.lineno,
|
||||
)
|
||||
|
||||
discovered.append(tool)
|
||||
|
||||
self._discovered.extend(discovered)
|
||||
logger.info("Discovered %d tools from file %s", len(discovered), file_path)
|
||||
return discovered
|
||||
|
||||
def auto_register(self, package_name: str = "tools") -> list[str]:
|
||||
"""Automatically discover and register tools.
|
||||
|
||||
Args:
|
||||
package_name: Package to scan for tools
|
||||
|
||||
Returns:
|
||||
List of registered tool names
|
||||
"""
|
||||
discovered = self.discover_package(package_name)
|
||||
registered = []
|
||||
|
||||
for tool in discovered:
|
||||
if tool.function is None:
|
||||
logger.warning("Skipping %s: no function resolved", tool.name)
|
||||
continue
|
||||
|
||||
try:
|
||||
self.registry.register_tool(
|
||||
name=tool.name,
|
||||
function=tool.function,
|
||||
description=tool.description,
|
||||
category=tool.category,
|
||||
tags=tool.tags,
|
||||
)
|
||||
registered.append(tool.name)
|
||||
logger.debug("Registered tool: %s", tool.name)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to register %s: %s", tool.name, exc)
|
||||
|
||||
logger.info("Auto-registered %d/%d tools", len(registered), len(discovered))
|
||||
return registered
|
||||
|
||||
def _build_parameters_schema(self, sig: inspect.Signature) -> dict[str, Any]:
|
||||
"""Build JSON schema for function parameters."""
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for name, param in sig.parameters.items():
|
||||
if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD):
|
||||
continue
|
||||
|
||||
schema = self._type_to_schema(param.annotation)
|
||||
|
||||
if param.default is param.empty:
|
||||
required.append(name)
|
||||
else:
|
||||
schema["default"] = param.default
|
||||
|
||||
properties[name] = schema
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
|
||||
def _build_returns_schema(
|
||||
self, sig: inspect.Signature, func: Callable
|
||||
) -> dict[str, Any]:
|
||||
"""Build JSON schema for return type."""
|
||||
return_annotation = sig.return_annotation
|
||||
|
||||
if return_annotation is sig.empty:
|
||||
return {"type": "object"}
|
||||
|
||||
return self._type_to_schema(return_annotation)
|
||||
|
||||
def _build_schema_from_ast(self, node: ast.FunctionDef) -> dict[str, Any]:
|
||||
"""Build parameter schema from AST node."""
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
# Get defaults (reversed, since they're at the end)
|
||||
defaults = [None] * (len(node.args.args) - len(node.args.defaults)) + list(node.args.defaults)
|
||||
|
||||
for arg, default in zip(node.args.args, defaults):
|
||||
arg_name = arg.arg
|
||||
arg_type = "string" # Default
|
||||
|
||||
# Try to get type from annotation
|
||||
if arg.annotation:
|
||||
if isinstance(arg.annotation, ast.Name):
|
||||
arg_type = self._ast_type_to_json_type(arg.annotation.id)
|
||||
elif isinstance(arg.annotation, ast.Constant):
|
||||
arg_type = self._ast_type_to_json_type(str(arg.annotation.value))
|
||||
|
||||
schema = {"type": arg_type}
|
||||
|
||||
if default is None:
|
||||
required.append(arg_name)
|
||||
|
||||
properties[arg_name] = schema
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
|
||||
def _type_to_schema(self, annotation: Any) -> dict[str, Any]:
|
||||
"""Convert Python type annotation to JSON schema."""
|
||||
if annotation is inspect.Parameter.empty:
|
||||
return {"type": "string"}
|
||||
|
||||
origin = getattr(annotation, "__origin__", None)
|
||||
args = getattr(annotation, "__args__", ())
|
||||
|
||||
# Handle Optional[T] = Union[T, None]
|
||||
if origin is not None:
|
||||
if str(origin) == "typing.Union" and type(None) in args:
|
||||
# Optional type
|
||||
non_none_args = [a for a in args if a is not type(None)]
|
||||
if len(non_none_args) == 1:
|
||||
schema = self._type_to_schema(non_none_args[0])
|
||||
return schema
|
||||
return {"type": "object"}
|
||||
|
||||
# Handle List[T], Dict[K,V]
|
||||
if origin in (list, tuple):
|
||||
items_schema = {"type": "object"}
|
||||
if args:
|
||||
items_schema = self._type_to_schema(args[0])
|
||||
return {"type": "array", "items": items_schema}
|
||||
|
||||
if origin is dict:
|
||||
return {"type": "object"}
|
||||
|
||||
# Handle basic types
|
||||
if annotation in (str,):
|
||||
return {"type": "string"}
|
||||
elif annotation in (int, float):
|
||||
return {"type": "number"}
|
||||
elif annotation in (bool,):
|
||||
return {"type": "boolean"}
|
||||
elif annotation in (list, tuple):
|
||||
return {"type": "array"}
|
||||
elif annotation in (dict,):
|
||||
return {"type": "object"}
|
||||
|
||||
return {"type": "object"}
|
||||
|
||||
def _ast_type_to_json_type(self, type_name: str) -> str:
|
||||
"""Convert AST type name to JSON schema type."""
|
||||
type_map = {
|
||||
"str": "string",
|
||||
"int": "number",
|
||||
"float": "number",
|
||||
"bool": "boolean",
|
||||
"list": "array",
|
||||
"dict": "object",
|
||||
"List": "array",
|
||||
"Dict": "object",
|
||||
"Optional": "object",
|
||||
"Any": "object",
|
||||
}
|
||||
return type_map.get(type_name, "object")
|
||||
|
||||
def get_discovered(self) -> list[DiscoveredTool]:
|
||||
"""Get all discovered tools."""
|
||||
return list(self._discovered)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear discovered tools cache."""
|
||||
self._discovered.clear()
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
discovery: Optional[ToolDiscovery] = None
|
||||
|
||||
|
||||
def get_discovery(registry: Optional[ToolRegistry] = None) -> ToolDiscovery:
|
||||
"""Get or create the tool discovery singleton."""
|
||||
global discovery
|
||||
if discovery is None:
|
||||
discovery = ToolDiscovery(registry=registry)
|
||||
return discovery
|
||||
@@ -42,6 +42,9 @@ class ToolRecord:
|
||||
avg_latency_ms: float = 0.0
|
||||
added_at: float = field(default_factory=time.time)
|
||||
requires_confirmation: bool = False
|
||||
tags: list[str] = field(default_factory=list)
|
||||
source_module: Optional[str] = None
|
||||
auto_discovered: bool = False
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
@@ -59,6 +62,9 @@ class ToolRegistry:
|
||||
handler: Callable,
|
||||
category: str = "general",
|
||||
requires_confirmation: bool = False,
|
||||
tags: Optional[list[str]] = None,
|
||||
source_module: Optional[str] = None,
|
||||
auto_discovered: bool = False,
|
||||
) -> ToolRecord:
|
||||
"""Register a new tool.
|
||||
|
||||
@@ -68,6 +74,9 @@ class ToolRegistry:
|
||||
handler: Function to execute
|
||||
category: Tool category for organization
|
||||
requires_confirmation: If True, user must approve before execution
|
||||
tags: Tags for filtering and organization
|
||||
source_module: Module where tool was defined
|
||||
auto_discovered: Whether tool was auto-discovered
|
||||
|
||||
Returns:
|
||||
The registered ToolRecord
|
||||
@@ -81,6 +90,9 @@ class ToolRegistry:
|
||||
handler=handler,
|
||||
category=category,
|
||||
requires_confirmation=requires_confirmation,
|
||||
tags=tags or [],
|
||||
source_module=source_module,
|
||||
auto_discovered=auto_discovered,
|
||||
)
|
||||
|
||||
self._tools[name] = record
|
||||
@@ -94,6 +106,75 @@ class ToolRegistry:
|
||||
logger.info("Registered tool: %s (category: %s)", name, category)
|
||||
return record
|
||||
|
||||
def register_tool(
|
||||
self,
|
||||
name: str,
|
||||
function: Callable,
|
||||
description: Optional[str] = None,
|
||||
category: str = "general",
|
||||
tags: Optional[list[str]] = None,
|
||||
source_module: Optional[str] = None,
|
||||
) -> ToolRecord:
|
||||
"""Register a tool from a function (convenience method for discovery).
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
function: Function to register
|
||||
description: Tool description (defaults to docstring)
|
||||
category: Tool category
|
||||
tags: Tags for organization
|
||||
source_module: Source module path
|
||||
|
||||
Returns:
|
||||
The registered ToolRecord
|
||||
"""
|
||||
# Build schema from function signature
|
||||
sig = inspect.signature(function)
|
||||
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD):
|
||||
continue
|
||||
|
||||
param_schema: dict = {"type": "string"}
|
||||
|
||||
# Try to infer type from annotation
|
||||
if param.annotation != inspect.Parameter.empty:
|
||||
if param.annotation in (int, float):
|
||||
param_schema = {"type": "number"}
|
||||
elif param.annotation == bool:
|
||||
param_schema = {"type": "boolean"}
|
||||
elif param.annotation == list:
|
||||
param_schema = {"type": "array"}
|
||||
elif param.annotation == dict:
|
||||
param_schema = {"type": "object"}
|
||||
|
||||
if param.default is param.empty:
|
||||
required.append(param_name)
|
||||
else:
|
||||
param_schema["default"] = param.default
|
||||
|
||||
properties[param_name] = param_schema
|
||||
|
||||
schema = create_tool_schema(
|
||||
name=name,
|
||||
description=description or (function.__doc__ or f"Execute {name}"),
|
||||
parameters=properties,
|
||||
required=required,
|
||||
)
|
||||
|
||||
return self.register(
|
||||
name=name,
|
||||
schema=schema,
|
||||
handler=function,
|
||||
category=category,
|
||||
tags=tags,
|
||||
source_module=source_module or function.__module__,
|
||||
auto_discovered=True,
|
||||
)
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""Remove a tool from the registry."""
|
||||
if name not in self._tools:
|
||||
@@ -137,14 +218,18 @@ class ToolRegistry:
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
healthy_only: bool = True,
|
||||
auto_discovered_only: bool = False,
|
||||
) -> list[ToolRecord]:
|
||||
"""Discover tools matching criteria.
|
||||
|
||||
Args:
|
||||
query: Search in tool names and descriptions
|
||||
category: Filter by category
|
||||
tags: Filter by tags (must have all specified tags)
|
||||
healthy_only: Only return healthy tools
|
||||
auto_discovered_only: Only return auto-discovered tools
|
||||
|
||||
Returns:
|
||||
List of matching ToolRecords
|
||||
@@ -156,17 +241,27 @@ class ToolRegistry:
|
||||
if category and record.category != category:
|
||||
continue
|
||||
|
||||
# Tags filter
|
||||
if tags:
|
||||
if not all(tag in record.tags for tag in tags):
|
||||
continue
|
||||
|
||||
# Health filter
|
||||
if healthy_only and record.health_status == "unhealthy":
|
||||
continue
|
||||
|
||||
# Auto-discovered filter
|
||||
if auto_discovered_only and not record.auto_discovered:
|
||||
continue
|
||||
|
||||
# Query filter
|
||||
if query:
|
||||
query_lower = query.lower()
|
||||
name_match = query_lower in name.lower()
|
||||
desc = record.schema.get("description", "")
|
||||
desc_match = query_lower in desc.lower()
|
||||
if not (name_match or desc_match):
|
||||
tag_match = any(query_lower in tag.lower() for tag in record.tags)
|
||||
if not (name_match or desc_match or tag_match):
|
||||
continue
|
||||
|
||||
results.append(record)
|
||||
@@ -274,11 +369,15 @@ class ToolRegistry:
|
||||
"category": r.category,
|
||||
"health": r.health_status,
|
||||
"requires_confirmation": r.requires_confirmation,
|
||||
"tags": r.tags,
|
||||
"source_module": r.source_module,
|
||||
"auto_discovered": r.auto_discovered,
|
||||
}
|
||||
for r in self._tools.values()
|
||||
],
|
||||
"categories": self._categories,
|
||||
"total_tools": len(self._tools),
|
||||
"auto_discovered_count": sum(1 for r in self._tools.values() if r.auto_discovered),
|
||||
}
|
||||
|
||||
|
||||
@@ -286,6 +385,11 @@ class ToolRegistry:
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
|
||||
def get_registry() -> ToolRegistry:
|
||||
"""Get the global tool registry singleton."""
|
||||
return tool_registry
|
||||
|
||||
|
||||
def register_tool(
|
||||
name: Optional[str] = None,
|
||||
category: str = "general",
|
||||
|
||||
265
tests/test_mcp_bootstrap.py
Normal file
265
tests/test_mcp_bootstrap.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""Tests for MCP Auto-Bootstrap.
|
||||
|
||||
Tests follow pytest best practices:
|
||||
- No module-level state
|
||||
- Proper fixture cleanup
|
||||
- Isolated tests
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mcp.bootstrap import (
|
||||
auto_bootstrap,
|
||||
bootstrap_from_directory,
|
||||
get_bootstrap_status,
|
||||
DEFAULT_TOOL_PACKAGES,
|
||||
AUTO_BOOTSTRAP_ENV_VAR,
|
||||
)
|
||||
from mcp.discovery import mcp_tool, ToolDiscovery
|
||||
from mcp.registry import ToolRegistry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fresh_registry():
|
||||
"""Create a fresh registry for each test."""
|
||||
return ToolRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fresh_discovery(fresh_registry):
|
||||
"""Create a fresh discovery instance for each test."""
|
||||
return ToolDiscovery(registry=fresh_registry)
|
||||
|
||||
|
||||
class TestAutoBootstrap:
|
||||
"""Test auto_bootstrap function."""
|
||||
|
||||
def test_auto_bootstrap_disabled_by_env(self, fresh_registry):
|
||||
"""Test that auto-bootstrap can be disabled via env var."""
|
||||
with patch.dict(os.environ, {AUTO_BOOTSTRAP_ENV_VAR: "0"}):
|
||||
registered = auto_bootstrap(registry=fresh_registry)
|
||||
|
||||
assert len(registered) == 0
|
||||
|
||||
def test_auto_bootstrap_forced_overrides_env(self, fresh_registry):
|
||||
"""Test that force=True overrides env var."""
|
||||
with patch.dict(os.environ, {AUTO_BOOTSTRAP_ENV_VAR: "0"}):
|
||||
# Empty packages list - just test that it runs
|
||||
registered = auto_bootstrap(
|
||||
packages=[],
|
||||
registry=fresh_registry,
|
||||
force=True,
|
||||
)
|
||||
|
||||
assert len(registered) == 0 # No packages, but didn't abort
|
||||
|
||||
def test_auto_bootstrap_nonexistent_package(self, fresh_registry):
|
||||
"""Test bootstrap from non-existent package."""
|
||||
registered = auto_bootstrap(
|
||||
packages=["nonexistent_package_xyz_12345"],
|
||||
registry=fresh_registry,
|
||||
force=True,
|
||||
)
|
||||
|
||||
assert len(registered) == 0
|
||||
|
||||
def test_auto_bootstrap_empty_packages(self, fresh_registry):
|
||||
"""Test bootstrap with empty packages list."""
|
||||
registered = auto_bootstrap(
|
||||
packages=[],
|
||||
registry=fresh_registry,
|
||||
force=True,
|
||||
)
|
||||
|
||||
assert len(registered) == 0
|
||||
|
||||
def test_auto_bootstrap_registers_tools(self, fresh_registry, fresh_discovery):
|
||||
"""Test that auto-bootstrap registers discovered tools."""
|
||||
@mcp_tool(name="bootstrap_tool", category="bootstrap")
|
||||
def bootstrap_func(value: str) -> str:
|
||||
"""A bootstrap test tool."""
|
||||
return value
|
||||
|
||||
# Manually register it
|
||||
fresh_registry.register_tool(
|
||||
name="bootstrap_tool",
|
||||
function=bootstrap_func,
|
||||
category="bootstrap",
|
||||
)
|
||||
|
||||
# Verify it's in the registry
|
||||
record = fresh_registry.get("bootstrap_tool")
|
||||
assert record is not None
|
||||
assert record.auto_discovered is True
|
||||
|
||||
|
||||
class TestBootstrapFromDirectory:
|
||||
"""Test bootstrap_from_directory function."""
|
||||
|
||||
def test_bootstrap_from_directory(self, fresh_registry, tmp_path):
|
||||
"""Test bootstrapping from a directory."""
|
||||
tools_dir = tmp_path / "tools"
|
||||
tools_dir.mkdir()
|
||||
|
||||
tool_file = tools_dir / "my_tools.py"
|
||||
tool_file.write_text('''
|
||||
from mcp.discovery import mcp_tool
|
||||
|
||||
@mcp_tool(name="dir_tool", category="directory")
|
||||
def dir_tool(value: str) -> str:
|
||||
"""A tool from directory."""
|
||||
return value
|
||||
''')
|
||||
|
||||
registered = bootstrap_from_directory(tools_dir, registry=fresh_registry)
|
||||
|
||||
# Function won't be resolved (AST only), so not registered
|
||||
assert len(registered) == 0
|
||||
|
||||
def test_bootstrap_from_nonexistent_directory(self, fresh_registry):
|
||||
"""Test bootstrapping from non-existent directory."""
|
||||
registered = bootstrap_from_directory(
|
||||
Path("/nonexistent/tools"),
|
||||
registry=fresh_registry
|
||||
)
|
||||
|
||||
assert len(registered) == 0
|
||||
|
||||
def test_bootstrap_skips_private_files(self, fresh_registry, tmp_path):
|
||||
"""Test that private files are skipped."""
|
||||
tools_dir = tmp_path / "tools"
|
||||
tools_dir.mkdir()
|
||||
|
||||
private_file = tools_dir / "_private.py"
|
||||
private_file.write_text('''
|
||||
from mcp.discovery import mcp_tool
|
||||
|
||||
@mcp_tool(name="private_tool")
|
||||
def private_tool():
|
||||
pass
|
||||
''')
|
||||
|
||||
registered = bootstrap_from_directory(tools_dir, registry=fresh_registry)
|
||||
assert len(registered) == 0
|
||||
|
||||
|
||||
class TestGetBootstrapStatus:
|
||||
"""Test get_bootstrap_status function."""
|
||||
|
||||
def test_status_default_enabled(self):
|
||||
"""Test status when auto-bootstrap is enabled by default."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
status = get_bootstrap_status()
|
||||
|
||||
assert status["auto_bootstrap_enabled"] is True
|
||||
assert "discovered_tools_count" in status
|
||||
assert "registered_tools_count" in status
|
||||
assert status["default_packages"] == DEFAULT_TOOL_PACKAGES
|
||||
|
||||
def test_status_disabled(self):
|
||||
"""Test status when auto-bootstrap is disabled."""
|
||||
with patch.dict(os.environ, {AUTO_BOOTSTRAP_ENV_VAR: "0"}):
|
||||
status = get_bootstrap_status()
|
||||
|
||||
assert status["auto_bootstrap_enabled"] is False
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for bootstrap + discovery + registry."""
|
||||
|
||||
def test_full_workflow(self, fresh_registry):
|
||||
"""Test the full auto-discovery and registration workflow."""
|
||||
@mcp_tool(name="integration_tool", category="integration")
|
||||
def integration_func(data: str) -> str:
|
||||
"""Integration test tool."""
|
||||
return f"processed: {data}"
|
||||
|
||||
fresh_registry.register_tool(
|
||||
name="integration_tool",
|
||||
function=integration_func,
|
||||
category="integration",
|
||||
source_module="test_module",
|
||||
)
|
||||
|
||||
record = fresh_registry.get("integration_tool")
|
||||
assert record is not None
|
||||
assert record.auto_discovered is True
|
||||
assert record.source_module == "test_module"
|
||||
|
||||
export = fresh_registry.to_dict()
|
||||
assert export["total_tools"] == 1
|
||||
assert export["auto_discovered_count"] == 1
|
||||
|
||||
def test_tool_execution_after_registration(self, fresh_registry):
|
||||
"""Test that registered tools can be executed."""
|
||||
@mcp_tool(name="exec_tool", category="execution")
|
||||
def exec_func(input: str) -> str:
|
||||
"""Executable test tool."""
|
||||
return input.upper()
|
||||
|
||||
fresh_registry.register_tool(
|
||||
name="exec_tool",
|
||||
function=exec_func,
|
||||
category="execution",
|
||||
)
|
||||
|
||||
import asyncio
|
||||
result = asyncio.run(fresh_registry.execute("exec_tool", {"input": "hello"}))
|
||||
|
||||
assert result == "HELLO"
|
||||
|
||||
metrics = fresh_registry.get_metrics("exec_tool")
|
||||
assert metrics["executions"] == 1
|
||||
assert metrics["health"] == "healthy"
|
||||
|
||||
def test_discover_filtering(self, fresh_registry):
|
||||
"""Test filtering registered tools."""
|
||||
@mcp_tool(name="cat1_tool", category="category1")
|
||||
def cat1_func():
|
||||
pass
|
||||
|
||||
@mcp_tool(name="cat2_tool", category="category2")
|
||||
def cat2_func():
|
||||
pass
|
||||
|
||||
fresh_registry.register_tool(
|
||||
name="cat1_tool",
|
||||
function=cat1_func,
|
||||
category="category1"
|
||||
)
|
||||
fresh_registry.register_tool(
|
||||
name="cat2_tool",
|
||||
function=cat2_func,
|
||||
category="category2"
|
||||
)
|
||||
|
||||
cat1_tools = fresh_registry.discover(category="category1")
|
||||
assert len(cat1_tools) == 1
|
||||
assert cat1_tools[0].name == "cat1_tool"
|
||||
|
||||
auto_tools = fresh_registry.discover(auto_discovered_only=True)
|
||||
assert len(auto_tools) == 2
|
||||
|
||||
def test_registry_export_includes_metadata(self, fresh_registry):
|
||||
"""Test that registry export includes all metadata."""
|
||||
@mcp_tool(name="meta_tool", category="meta", tags=["tag1", "tag2"])
|
||||
def meta_func():
|
||||
pass
|
||||
|
||||
fresh_registry.register_tool(
|
||||
name="meta_tool",
|
||||
function=meta_func,
|
||||
category="meta",
|
||||
tags=["tag1", "tag2"],
|
||||
)
|
||||
|
||||
export = fresh_registry.to_dict()
|
||||
|
||||
for tool_dict in export["tools"]:
|
||||
assert "tags" in tool_dict
|
||||
assert "source_module" in tool_dict
|
||||
assert "auto_discovered" in tool_dict
|
||||
329
tests/test_mcp_discovery.py
Normal file
329
tests/test_mcp_discovery.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""Tests for MCP Tool Auto-Discovery.
|
||||
|
||||
Tests follow pytest best practices:
|
||||
- No module-level state
|
||||
- Proper fixture cleanup
|
||||
- Isolated tests
|
||||
"""
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mcp.discovery import DiscoveredTool, ToolDiscovery, mcp_tool
|
||||
from mcp.registry import ToolRegistry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fresh_registry():
|
||||
"""Create a fresh registry for each test."""
|
||||
return ToolRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def discovery(fresh_registry):
|
||||
"""Create a fresh discovery instance for each test."""
|
||||
return ToolDiscovery(registry=fresh_registry)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_module_with_tools():
|
||||
"""Create a mock module with MCP tools for testing."""
|
||||
# Create a fresh module
|
||||
mock_module = types.ModuleType("mock_test_module")
|
||||
mock_module.__file__ = "mock_test_module.py"
|
||||
|
||||
# Add decorated functions
|
||||
@mcp_tool(name="echo", category="test", tags=["utility"])
|
||||
def echo_func(message: str) -> str:
|
||||
"""Echo a message back."""
|
||||
return message
|
||||
|
||||
@mcp_tool(category="math")
|
||||
def add_func(a: int, b: int) -> int:
|
||||
"""Add two numbers."""
|
||||
return a + b
|
||||
|
||||
def not_decorated():
|
||||
"""Not a tool."""
|
||||
pass
|
||||
|
||||
mock_module.echo_func = echo_func
|
||||
mock_module.add_func = add_func
|
||||
mock_module.not_decorated = not_decorated
|
||||
|
||||
# Inject into sys.modules
|
||||
sys.modules["mock_test_module"] = mock_module
|
||||
|
||||
yield mock_module
|
||||
|
||||
# Cleanup
|
||||
del sys.modules["mock_test_module"]
|
||||
|
||||
|
||||
class TestMCPToolDecorator:
|
||||
"""Test the @mcp_tool decorator."""
|
||||
|
||||
def test_decorator_sets_explicit_name(self):
|
||||
"""Test that decorator uses explicit name."""
|
||||
@mcp_tool(name="custom_name", category="test")
|
||||
def my_func():
|
||||
pass
|
||||
|
||||
assert my_func._mcp_name == "custom_name"
|
||||
assert my_func._mcp_category == "test"
|
||||
|
||||
def test_decorator_uses_function_name(self):
|
||||
"""Test that decorator uses function name when not specified."""
|
||||
@mcp_tool(category="math")
|
||||
def my_add_func():
|
||||
pass
|
||||
|
||||
assert my_add_func._mcp_name == "my_add_func"
|
||||
|
||||
def test_decorator_captures_docstring(self):
|
||||
"""Test that decorator captures docstring as description."""
|
||||
@mcp_tool(name="test")
|
||||
def with_doc():
|
||||
"""This is the description."""
|
||||
pass
|
||||
|
||||
assert "This is the description" in with_doc._mcp_description
|
||||
|
||||
def test_decorator_sets_tags(self):
|
||||
"""Test that decorator sets tags."""
|
||||
@mcp_tool(name="test", tags=["tag1", "tag2"])
|
||||
def tagged_func():
|
||||
pass
|
||||
|
||||
assert tagged_func._mcp_tags == ["tag1", "tag2"]
|
||||
|
||||
def test_undecorated_function(self):
|
||||
"""Test that undecorated functions don't have MCP attributes."""
|
||||
def plain_func():
|
||||
pass
|
||||
|
||||
assert not hasattr(plain_func, "_mcp_tool")
|
||||
|
||||
|
||||
class TestDiscoveredTool:
|
||||
"""Test DiscoveredTool dataclass."""
|
||||
|
||||
def test_tool_creation(self):
|
||||
"""Test creating a DiscoveredTool."""
|
||||
def dummy_func():
|
||||
pass
|
||||
|
||||
tool = DiscoveredTool(
|
||||
name="test",
|
||||
description="A test tool",
|
||||
function=dummy_func,
|
||||
module="test_module",
|
||||
category="test",
|
||||
tags=["utility"],
|
||||
parameters_schema={"type": "object"},
|
||||
returns_schema={"type": "string"},
|
||||
)
|
||||
|
||||
assert tool.name == "test"
|
||||
assert tool.function == dummy_func
|
||||
assert tool.category == "test"
|
||||
|
||||
|
||||
class TestToolDiscoveryInit:
|
||||
"""Test ToolDiscovery initialization."""
|
||||
|
||||
def test_uses_provided_registry(self, fresh_registry):
|
||||
"""Test initialization with provided registry."""
|
||||
discovery = ToolDiscovery(registry=fresh_registry)
|
||||
assert discovery.registry is fresh_registry
|
||||
|
||||
|
||||
class TestDiscoverModule:
|
||||
"""Test discovering tools from modules."""
|
||||
|
||||
def test_discover_finds_decorated_tools(self, discovery, mock_module_with_tools):
|
||||
"""Test discovering tools from a module."""
|
||||
tools = discovery.discover_module("mock_test_module")
|
||||
|
||||
tool_names = [t.name for t in tools]
|
||||
assert "echo" in tool_names
|
||||
assert "add_func" in tool_names
|
||||
assert "not_decorated" not in tool_names
|
||||
|
||||
def test_discover_nonexistent_module(self, discovery):
|
||||
"""Test discovering from non-existent module."""
|
||||
tools = discovery.discover_module("nonexistent.module.xyz")
|
||||
assert len(tools) == 0
|
||||
|
||||
def test_discovered_tool_has_correct_metadata(self, discovery, mock_module_with_tools):
|
||||
"""Test that discovered tools have correct metadata."""
|
||||
tools = discovery.discover_module("mock_test_module")
|
||||
|
||||
echo_tool = next(t for t in tools if t.name == "echo")
|
||||
assert echo_tool.category == "test"
|
||||
assert "utility" in echo_tool.tags
|
||||
|
||||
def test_discovered_tool_has_schema(self, discovery, mock_module_with_tools):
|
||||
"""Test that discovered tools have parameter schemas."""
|
||||
tools = discovery.discover_module("mock_test_module")
|
||||
|
||||
add_tool = next(t for t in tools if t.name == "add_func")
|
||||
assert "properties" in add_tool.parameters_schema
|
||||
assert "a" in add_tool.parameters_schema["properties"]
|
||||
|
||||
|
||||
class TestDiscoverFile:
|
||||
"""Test discovering tools from Python files."""
|
||||
|
||||
def test_discover_from_file(self, discovery, tmp_path):
|
||||
"""Test discovering tools from a Python file."""
|
||||
test_file = tmp_path / "test_tools.py"
|
||||
test_file.write_text('''
|
||||
from mcp.discovery import mcp_tool
|
||||
|
||||
@mcp_tool(name="file_tool", category="file_ops", tags=["io"])
|
||||
def file_tool(path: str) -> dict:
|
||||
"""Process a file."""
|
||||
return {"path": path}
|
||||
''')
|
||||
|
||||
tools = discovery.discover_file(test_file)
|
||||
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "file_tool"
|
||||
assert tools[0].category == "file_ops"
|
||||
|
||||
def test_discover_from_nonexistent_file(self, discovery, tmp_path):
|
||||
"""Test discovering from non-existent file."""
|
||||
tools = discovery.discover_file(tmp_path / "nonexistent.py")
|
||||
assert len(tools) == 0
|
||||
|
||||
def test_discover_from_invalid_python(self, discovery, tmp_path):
|
||||
"""Test discovering from invalid Python file."""
|
||||
test_file = tmp_path / "invalid.py"
|
||||
test_file.write_text("not valid python @#$%")
|
||||
|
||||
tools = discovery.discover_file(test_file)
|
||||
assert len(tools) == 0
|
||||
|
||||
|
||||
class TestSchemaBuilding:
|
||||
"""Test JSON schema building from type hints."""
|
||||
|
||||
def test_string_parameter(self, discovery):
|
||||
"""Test string parameter schema."""
|
||||
def func(name: str) -> str:
|
||||
return name
|
||||
|
||||
sig = inspect.signature(func)
|
||||
schema = discovery._build_parameters_schema(sig)
|
||||
|
||||
assert schema["properties"]["name"]["type"] == "string"
|
||||
|
||||
def test_int_parameter(self, discovery):
|
||||
"""Test int parameter schema."""
|
||||
def func(count: int) -> int:
|
||||
return count
|
||||
|
||||
sig = inspect.signature(func)
|
||||
schema = discovery._build_parameters_schema(sig)
|
||||
|
||||
assert schema["properties"]["count"]["type"] == "number"
|
||||
|
||||
def test_bool_parameter(self, discovery):
|
||||
"""Test bool parameter schema."""
|
||||
def func(enabled: bool) -> bool:
|
||||
return enabled
|
||||
|
||||
sig = inspect.signature(func)
|
||||
schema = discovery._build_parameters_schema(sig)
|
||||
|
||||
assert schema["properties"]["enabled"]["type"] == "boolean"
|
||||
|
||||
def test_required_parameters(self, discovery):
|
||||
"""Test that required parameters are marked."""
|
||||
def func(required: str, optional: str = "default") -> str:
|
||||
return required
|
||||
|
||||
sig = inspect.signature(func)
|
||||
schema = discovery._build_parameters_schema(sig)
|
||||
|
||||
assert "required" in schema["required"]
|
||||
assert "optional" not in schema["required"]
|
||||
|
||||
def test_default_values(self, discovery):
|
||||
"""Test that default values are captured."""
|
||||
def func(name: str = "default") -> str:
|
||||
return name
|
||||
|
||||
sig = inspect.signature(func)
|
||||
schema = discovery._build_parameters_schema(sig)
|
||||
|
||||
assert schema["properties"]["name"]["default"] == "default"
|
||||
|
||||
|
||||
class TestTypeToSchema:
|
||||
"""Test type annotation to JSON schema conversion."""
|
||||
|
||||
def test_str_annotation(self, discovery):
|
||||
"""Test string annotation."""
|
||||
schema = discovery._type_to_schema(str)
|
||||
assert schema["type"] == "string"
|
||||
|
||||
def test_int_annotation(self, discovery):
|
||||
"""Test int annotation."""
|
||||
schema = discovery._type_to_schema(int)
|
||||
assert schema["type"] == "number"
|
||||
|
||||
def test_optional_annotation(self, discovery):
|
||||
"""Test Optional[T] annotation."""
|
||||
from typing import Optional
|
||||
schema = discovery._type_to_schema(Optional[str])
|
||||
assert schema["type"] == "string"
|
||||
|
||||
|
||||
class TestAutoRegister:
|
||||
"""Test auto-registration of discovered tools."""
|
||||
|
||||
def test_auto_register_module(self, discovery, mock_module_with_tools, fresh_registry):
|
||||
"""Test auto-registering tools from a module."""
|
||||
registered = discovery.auto_register("mock_test_module")
|
||||
|
||||
assert "echo" in registered
|
||||
assert "add_func" in registered
|
||||
assert fresh_registry.get("echo") is not None
|
||||
|
||||
def test_auto_register_skips_unresolved_functions(self, discovery, fresh_registry):
|
||||
"""Test that tools without resolved functions are skipped."""
|
||||
# Add a discovered tool with no function
|
||||
discovery._discovered.append(DiscoveredTool(
|
||||
name="no_func",
|
||||
description="No function",
|
||||
function=None, # type: ignore
|
||||
module="test",
|
||||
category="test",
|
||||
tags=[],
|
||||
parameters_schema={},
|
||||
returns_schema={},
|
||||
))
|
||||
|
||||
registered = discovery.auto_register("mock_test_module")
|
||||
assert "no_func" not in registered
|
||||
|
||||
|
||||
class TestClearDiscovered:
|
||||
"""Test clearing discovered tools cache."""
|
||||
|
||||
def test_clear_discovered(self, discovery, mock_module_with_tools):
|
||||
"""Test clearing discovered tools."""
|
||||
discovery.discover_module("mock_test_module")
|
||||
assert len(discovery.get_discovered()) > 0
|
||||
|
||||
discovery.clear()
|
||||
assert len(discovery.get_discovered()) == 0
|
||||
Reference in New Issue
Block a user