diff --git a/src/dashboard/app.py b/src/dashboard/app.py index 2e98ae5e..6394bc94 100644 --- a/src/dashboard/app.py +++ b/src/dashboard/app.py @@ -102,6 +102,15 @@ async def lifespan(app: FastAPI): except Exception as exc: logger.error("Failed to spawn persona agents: %s", exc) + # Auto-bootstrap MCP tools + from mcp.bootstrap import auto_bootstrap, get_bootstrap_status + try: + registered = auto_bootstrap() + if registered: + logger.info("MCP auto-bootstrap: %d tools registered", len(registered)) + except Exception as exc: + logger.warning("MCP auto-bootstrap failed: %s", exc) + # Initialise Spark Intelligence engine from spark.engine import spark_engine if spark_engine.enabled: diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index 5690035a..38d9eb5d 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -1,17 +1,30 @@ """MCP (Model Context Protocol) package. -Provides tool registry, server, and schema management. +Provides tool registry, server, schema management, and auto-discovery. """ -from mcp.registry import tool_registry, register_tool +from mcp.registry import tool_registry, register_tool, ToolRegistry from mcp.server import mcp_server, MCPServer, MCPHTTPServer from mcp.schemas.base import create_tool_schema +from mcp.discovery import ToolDiscovery, mcp_tool, get_discovery +from mcp.bootstrap import auto_bootstrap, get_bootstrap_status __all__ = [ + # Registry "tool_registry", "register_tool", + "ToolRegistry", + # Server "mcp_server", "MCPServer", "MCPHTTPServer", + # Schemas "create_tool_schema", + # Discovery + "ToolDiscovery", + "mcp_tool", + "get_discovery", + # Bootstrap + "auto_bootstrap", + "get_bootstrap_status", ] diff --git a/src/mcp/bootstrap.py b/src/mcp/bootstrap.py index 1ca9cd29..7e21b32d 100644 --- a/src/mcp/bootstrap.py +++ b/src/mcp/bootstrap.py @@ -1,71 +1,148 @@ -"""Bootstrap the MCP system by loading all tools. +"""MCP Auto-Bootstrap — Auto-discover and register tools on startup. -This module is responsible for: -1. Loading all tool modules from src/tools/ -2. Registering them with the tool registry -3. Verifying tool health -4. Reporting status +Usage: + from mcp.bootstrap import auto_bootstrap + + # Auto-discover from 'tools' package + registered = auto_bootstrap() + + # Or specify custom packages + registered = auto_bootstrap(packages=["tools", "custom_tools"]) """ -import importlib import logging +import os from pathlib import Path +from typing import Optional -from mcp.registry import tool_registry +from .discovery import ToolDiscovery, get_discovery +from .registry import ToolRegistry, tool_registry logger = logging.getLogger(__name__) -# Tool modules to load -TOOL_MODULES = [ - "tools.web_search", - "tools.file_ops", - "tools.code_exec", - "tools.memory_tool", -] +# Default packages to scan for tools +DEFAULT_TOOL_PACKAGES = ["tools"] + +# Environment variable to disable auto-bootstrap +AUTO_BOOTSTRAP_ENV_VAR = "MCP_AUTO_BOOTSTRAP" -def bootstrap_mcp() -> dict: - """Initialize the MCP system by loading all tools. +def auto_bootstrap( + packages: Optional[list[str]] = None, + registry: Optional[ToolRegistry] = None, + force: bool = False, +) -> list[str]: + """Auto-discover and register MCP tools. + + Args: + packages: Packages to scan (defaults to ["tools"]) + registry: Registry to register tools with (defaults to singleton) + force: Force bootstrap even if disabled by env var Returns: - Status dict with loaded tools and any errors + List of registered tool names """ - loaded = [] - errors = [] + # Check if auto-bootstrap is disabled + if not force and os.environ.get(AUTO_BOOTSTRAP_ENV_VAR, "1") == "0": + logger.info("MCP auto-bootstrap disabled via %s", AUTO_BOOTSTRAP_ENV_VAR) + return [] - for module_name in TOOL_MODULES: + packages = packages or DEFAULT_TOOL_PACKAGES + registry = registry or tool_registry + discovery = get_discovery(registry=registry) + + registered: list[str] = [] + + logger.info("Starting MCP auto-bootstrap from packages: %s", packages) + + for package in packages: try: - # Import the module (this triggers @register_tool decorators) - importlib.import_module(module_name) - loaded.append(module_name) - logger.info("Loaded tool module: %s", module_name) + # Check if package exists + try: + __import__(package) + except ImportError: + logger.debug("Package %s not found, skipping", package) + continue + + # Discover and register + tools = discovery.auto_register(package) + registered.extend(tools) + except Exception as exc: - errors.append({"module": module_name, "error": str(exc)}) - logger.error("Failed to load tool module %s: %s", module_name, exc) + logger.warning("Failed to bootstrap from %s: %s", package, exc) - # Get registry status - registry_status = tool_registry.to_dict() - - status = { - "loaded_modules": loaded, - "errors": errors, - "total_tools": len(registry_status.get("tools", [])), - "tools_by_category": registry_status.get("categories", {}), - "tool_names": tool_registry.list_tools(), - } - - logger.info( - "MCP Bootstrap complete: %d tools loaded from %d modules", - status["total_tools"], - len(loaded) - ) - - return status + logger.info("MCP auto-bootstrap complete: %d tools registered", len(registered)) + return registered -def get_tool_status() -> dict: - """Get current status of all tools.""" +def bootstrap_from_directory( + directory: Path, + registry: Optional[ToolRegistry] = None, +) -> list[str]: + """Bootstrap tools from a directory of Python files. + + Args: + directory: Directory containing Python files with tools + registry: Registry to register tools with + + Returns: + List of registered tool names + """ + registry = registry or tool_registry + discovery = get_discovery(registry=registry) + + registered: list[str] = [] + + if not directory.exists(): + logger.warning("Tools directory not found: %s", directory) + return registered + + logger.info("Bootstrapping tools from directory: %s", directory) + + # Find all Python files + for py_file in directory.rglob("*.py"): + if py_file.name.startswith("_"): + continue + + try: + discovered = discovery.discover_file(py_file) + + for tool in discovered: + if tool.function is None: + # Need to import and resolve the function + continue + + try: + registry.register_tool( + name=tool.name, + function=tool.function, + description=tool.description, + category=tool.category, + tags=tool.tags, + ) + registered.append(tool.name) + except Exception as exc: + logger.error("Failed to register %s: %s", tool.name, exc) + + except Exception as exc: + logger.warning("Failed to process %s: %s", py_file, exc) + + logger.info("Directory bootstrap complete: %d tools registered", len(registered)) + return registered + + +def get_bootstrap_status() -> dict: + """Get auto-bootstrap status. + + Returns: + Dict with bootstrap status info + """ + discovery = get_discovery() + registry = tool_registry + return { - "tools": tool_registry.to_dict(), - "metrics": tool_registry.get_metrics(), + "auto_bootstrap_enabled": os.environ.get(AUTO_BOOTSTRAP_ENV_VAR, "1") != "0", + "discovered_tools_count": len(discovery.get_discovered()), + "registered_tools_count": len(registry.list_tools()), + "default_packages": DEFAULT_TOOL_PACKAGES, } diff --git a/src/mcp/discovery.py b/src/mcp/discovery.py new file mode 100644 index 00000000..a6ec0241 --- /dev/null +++ b/src/mcp/discovery.py @@ -0,0 +1,441 @@ +"""MCP Tool Auto-Discovery — Introspect Python modules to find tools. + +Automatically discovers functions marked with @mcp_tool decorator +and registers them with the MCP registry. Generates JSON schemas +from type hints. +""" + +import ast +import importlib +import inspect +import logging +import pkgutil +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Optional, get_type_hints + +from .registry import ToolRegistry, tool_registry + +logger = logging.getLogger(__name__) + + +# Decorator to mark functions as MCP tools +def mcp_tool( + name: Optional[str] = None, + description: Optional[str] = None, + category: str = "general", + tags: Optional[list[str]] = None, +): + """Decorator to mark a function as an MCP tool. + + Args: + name: Tool name (defaults to function name) + description: Tool description (defaults to docstring) + category: Tool category for organization + tags: Additional tags for filtering + + Example: + @mcp_tool(name="weather", category="external") + def get_weather(city: str) -> dict: + '''Get weather for a city.''' + ... + """ + def decorator(func: Callable) -> Callable: + func._mcp_tool = True + func._mcp_name = name or func.__name__ + func._mcp_description = description or (func.__doc__ or "").strip() + func._mcp_category = category + func._mcp_tags = tags or [] + return func + return decorator + + +@dataclass +class DiscoveredTool: + """A tool discovered via introspection.""" + name: str + description: str + function: Callable + module: str + category: str + tags: list[str] + parameters_schema: dict[str, Any] + returns_schema: dict[str, Any] + source_file: Optional[str] = None + line_number: int = 0 + + +class ToolDiscovery: + """Discovers and registers MCP tools from Python modules. + + Usage: + discovery = ToolDiscovery() + + # Discover from a module + tools = discovery.discover_module("tools.git") + + # Auto-register with registry + discovery.auto_register("tools") + + # Discover from all installed packages + tools = discovery.discover_all_packages() + """ + + def __init__(self, registry: Optional[ToolRegistry] = None) -> None: + self.registry = registry or tool_registry + self._discovered: list[DiscoveredTool] = [] + + def discover_module(self, module_name: str) -> list[DiscoveredTool]: + """Discover all MCP tools in a module. + + Args: + module_name: Dotted path to module (e.g., "tools.git") + + Returns: + List of discovered tools + """ + discovered = [] + + try: + module = importlib.import_module(module_name) + except ImportError as exc: + logger.warning("Failed to import module %s: %s", module_name, exc) + return discovered + + # Get module file path for source location + module_file = getattr(module, "__file__", None) + + # Iterate through module members + for name, obj in inspect.getmembers(module): + # Skip private and non-callable + if name.startswith("_") or not callable(obj): + continue + + # Check if marked as MCP tool + if not getattr(obj, "_mcp_tool", False): + continue + + # Get source location + try: + source_file = inspect.getfile(obj) + line_number = inspect.getsourcelines(obj)[1] + except (OSError, TypeError): + source_file = module_file + line_number = 0 + + # Build schemas from type hints + try: + sig = inspect.signature(obj) + parameters_schema = self._build_parameters_schema(sig) + returns_schema = self._build_returns_schema(sig, obj) + except Exception as exc: + logger.warning("Failed to build schema for %s: %s", name, exc) + parameters_schema = {"type": "object", "properties": {}} + returns_schema = {} + + tool = DiscoveredTool( + name=getattr(obj, "_mcp_name", name), + description=getattr(obj, "_mcp_description", obj.__doc__ or ""), + function=obj, + module=module_name, + category=getattr(obj, "_mcp_category", "general"), + tags=getattr(obj, "_mcp_tags", []), + parameters_schema=parameters_schema, + returns_schema=returns_schema, + source_file=source_file, + line_number=line_number, + ) + + discovered.append(tool) + logger.debug("Discovered tool: %s from %s", tool.name, module_name) + + self._discovered.extend(discovered) + logger.info("Discovered %d tools from module %s", len(discovered), module_name) + return discovered + + def discover_package(self, package_name: str, recursive: bool = True) -> list[DiscoveredTool]: + """Discover tools from all modules in a package. + + Args: + package_name: Package name (e.g., "tools") + recursive: Whether to search subpackages + + Returns: + List of discovered tools + """ + discovered = [] + + try: + package = importlib.import_module(package_name) + except ImportError as exc: + logger.warning("Failed to import package %s: %s", package_name, exc) + return discovered + + package_path = getattr(package, "__path__", []) + if not package_path: + # Not a package, treat as module + return self.discover_module(package_name) + + # Walk package modules + for _, name, is_pkg in pkgutil.iter_modules(package_path, prefix=f"{package_name}."): + if is_pkg and recursive: + discovered.extend(self.discover_package(name, recursive=True)) + else: + discovered.extend(self.discover_module(name)) + + return discovered + + def discover_file(self, file_path: Path) -> list[DiscoveredTool]: + """Discover tools from a Python file. + + Args: + file_path: Path to Python file + + Returns: + List of discovered tools + """ + discovered = [] + + try: + source = file_path.read_text() + tree = ast.parse(source) + except Exception as exc: + logger.warning("Failed to parse %s: %s", file_path, exc) + return discovered + + # Find all decorated functions + for node in ast.walk(tree): + if not isinstance(node, ast.FunctionDef): + continue + + # Check for @mcp_tool decorator + is_tool = False + tool_name = node.name + tool_description = ast.get_docstring(node) or "" + tool_category = "general" + tool_tags: list[str] = [] + + for decorator in node.decorator_list: + if isinstance(decorator, ast.Call): + if isinstance(decorator.func, ast.Name) and decorator.func.id == "mcp_tool": + is_tool = True + # Extract decorator arguments + for kw in decorator.keywords: + if kw.arg == "name" and isinstance(kw.value, ast.Constant): + tool_name = kw.value.value + elif kw.arg == "description" and isinstance(kw.value, ast.Constant): + tool_description = kw.value.value + elif kw.arg == "category" and isinstance(kw.value, ast.Constant): + tool_category = kw.value.value + elif kw.arg == "tags" and isinstance(kw.value, ast.List): + tool_tags = [ + elt.value for elt in kw.value.elts + if isinstance(elt, ast.Constant) + ] + elif isinstance(decorator, ast.Name) and decorator.id == "mcp_tool": + is_tool = True + + if not is_tool: + continue + + # Build parameter schema from AST + parameters_schema = self._build_schema_from_ast(node) + + # We can't get the actual function without importing + # So create a placeholder that will be resolved later + tool = DiscoveredTool( + name=tool_name, + description=tool_description, + function=None, # Will be resolved when registered + module=str(file_path), + category=tool_category, + tags=tool_tags, + parameters_schema=parameters_schema, + returns_schema={"type": "object"}, + source_file=str(file_path), + line_number=node.lineno, + ) + + discovered.append(tool) + + self._discovered.extend(discovered) + logger.info("Discovered %d tools from file %s", len(discovered), file_path) + return discovered + + def auto_register(self, package_name: str = "tools") -> list[str]: + """Automatically discover and register tools. + + Args: + package_name: Package to scan for tools + + Returns: + List of registered tool names + """ + discovered = self.discover_package(package_name) + registered = [] + + for tool in discovered: + if tool.function is None: + logger.warning("Skipping %s: no function resolved", tool.name) + continue + + try: + self.registry.register_tool( + name=tool.name, + function=tool.function, + description=tool.description, + category=tool.category, + tags=tool.tags, + ) + registered.append(tool.name) + logger.debug("Registered tool: %s", tool.name) + except Exception as exc: + logger.error("Failed to register %s: %s", tool.name, exc) + + logger.info("Auto-registered %d/%d tools", len(registered), len(discovered)) + return registered + + def _build_parameters_schema(self, sig: inspect.Signature) -> dict[str, Any]: + """Build JSON schema for function parameters.""" + properties = {} + required = [] + + for name, param in sig.parameters.items(): + if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD): + continue + + schema = self._type_to_schema(param.annotation) + + if param.default is param.empty: + required.append(name) + else: + schema["default"] = param.default + + properties[name] = schema + + return { + "type": "object", + "properties": properties, + "required": required, + } + + def _build_returns_schema( + self, sig: inspect.Signature, func: Callable + ) -> dict[str, Any]: + """Build JSON schema for return type.""" + return_annotation = sig.return_annotation + + if return_annotation is sig.empty: + return {"type": "object"} + + return self._type_to_schema(return_annotation) + + def _build_schema_from_ast(self, node: ast.FunctionDef) -> dict[str, Any]: + """Build parameter schema from AST node.""" + properties = {} + required = [] + + # Get defaults (reversed, since they're at the end) + defaults = [None] * (len(node.args.args) - len(node.args.defaults)) + list(node.args.defaults) + + for arg, default in zip(node.args.args, defaults): + arg_name = arg.arg + arg_type = "string" # Default + + # Try to get type from annotation + if arg.annotation: + if isinstance(arg.annotation, ast.Name): + arg_type = self._ast_type_to_json_type(arg.annotation.id) + elif isinstance(arg.annotation, ast.Constant): + arg_type = self._ast_type_to_json_type(str(arg.annotation.value)) + + schema = {"type": arg_type} + + if default is None: + required.append(arg_name) + + properties[arg_name] = schema + + return { + "type": "object", + "properties": properties, + "required": required, + } + + def _type_to_schema(self, annotation: Any) -> dict[str, Any]: + """Convert Python type annotation to JSON schema.""" + if annotation is inspect.Parameter.empty: + return {"type": "string"} + + origin = getattr(annotation, "__origin__", None) + args = getattr(annotation, "__args__", ()) + + # Handle Optional[T] = Union[T, None] + if origin is not None: + if str(origin) == "typing.Union" and type(None) in args: + # Optional type + non_none_args = [a for a in args if a is not type(None)] + if len(non_none_args) == 1: + schema = self._type_to_schema(non_none_args[0]) + return schema + return {"type": "object"} + + # Handle List[T], Dict[K,V] + if origin in (list, tuple): + items_schema = {"type": "object"} + if args: + items_schema = self._type_to_schema(args[0]) + return {"type": "array", "items": items_schema} + + if origin is dict: + return {"type": "object"} + + # Handle basic types + if annotation in (str,): + return {"type": "string"} + elif annotation in (int, float): + return {"type": "number"} + elif annotation in (bool,): + return {"type": "boolean"} + elif annotation in (list, tuple): + return {"type": "array"} + elif annotation in (dict,): + return {"type": "object"} + + return {"type": "object"} + + def _ast_type_to_json_type(self, type_name: str) -> str: + """Convert AST type name to JSON schema type.""" + type_map = { + "str": "string", + "int": "number", + "float": "number", + "bool": "boolean", + "list": "array", + "dict": "object", + "List": "array", + "Dict": "object", + "Optional": "object", + "Any": "object", + } + return type_map.get(type_name, "object") + + def get_discovered(self) -> list[DiscoveredTool]: + """Get all discovered tools.""" + return list(self._discovered) + + def clear(self) -> None: + """Clear discovered tools cache.""" + self._discovered.clear() + + +# Module-level singleton +discovery: Optional[ToolDiscovery] = None + + +def get_discovery(registry: Optional[ToolRegistry] = None) -> ToolDiscovery: + """Get or create the tool discovery singleton.""" + global discovery + if discovery is None: + discovery = ToolDiscovery(registry=registry) + return discovery diff --git a/src/mcp/registry.py b/src/mcp/registry.py index 292f1cd7..29d87172 100644 --- a/src/mcp/registry.py +++ b/src/mcp/registry.py @@ -42,6 +42,9 @@ class ToolRecord: avg_latency_ms: float = 0.0 added_at: float = field(default_factory=time.time) requires_confirmation: bool = False + tags: list[str] = field(default_factory=list) + source_module: Optional[str] = None + auto_discovered: bool = False class ToolRegistry: @@ -59,6 +62,9 @@ class ToolRegistry: handler: Callable, category: str = "general", requires_confirmation: bool = False, + tags: Optional[list[str]] = None, + source_module: Optional[str] = None, + auto_discovered: bool = False, ) -> ToolRecord: """Register a new tool. @@ -68,6 +74,9 @@ class ToolRegistry: handler: Function to execute category: Tool category for organization requires_confirmation: If True, user must approve before execution + tags: Tags for filtering and organization + source_module: Module where tool was defined + auto_discovered: Whether tool was auto-discovered Returns: The registered ToolRecord @@ -81,6 +90,9 @@ class ToolRegistry: handler=handler, category=category, requires_confirmation=requires_confirmation, + tags=tags or [], + source_module=source_module, + auto_discovered=auto_discovered, ) self._tools[name] = record @@ -94,6 +106,75 @@ class ToolRegistry: logger.info("Registered tool: %s (category: %s)", name, category) return record + def register_tool( + self, + name: str, + function: Callable, + description: Optional[str] = None, + category: str = "general", + tags: Optional[list[str]] = None, + source_module: Optional[str] = None, + ) -> ToolRecord: + """Register a tool from a function (convenience method for discovery). + + Args: + name: Tool name + function: Function to register + description: Tool description (defaults to docstring) + category: Tool category + tags: Tags for organization + source_module: Source module path + + Returns: + The registered ToolRecord + """ + # Build schema from function signature + sig = inspect.signature(function) + + properties = {} + required = [] + + for param_name, param in sig.parameters.items(): + if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD): + continue + + param_schema: dict = {"type": "string"} + + # Try to infer type from annotation + if param.annotation != inspect.Parameter.empty: + if param.annotation in (int, float): + param_schema = {"type": "number"} + elif param.annotation == bool: + param_schema = {"type": "boolean"} + elif param.annotation == list: + param_schema = {"type": "array"} + elif param.annotation == dict: + param_schema = {"type": "object"} + + if param.default is param.empty: + required.append(param_name) + else: + param_schema["default"] = param.default + + properties[param_name] = param_schema + + schema = create_tool_schema( + name=name, + description=description or (function.__doc__ or f"Execute {name}"), + parameters=properties, + required=required, + ) + + return self.register( + name=name, + schema=schema, + handler=function, + category=category, + tags=tags, + source_module=source_module or function.__module__, + auto_discovered=True, + ) + def unregister(self, name: str) -> bool: """Remove a tool from the registry.""" if name not in self._tools: @@ -137,14 +218,18 @@ class ToolRegistry: self, query: Optional[str] = None, category: Optional[str] = None, + tags: Optional[list[str]] = None, healthy_only: bool = True, + auto_discovered_only: bool = False, ) -> list[ToolRecord]: """Discover tools matching criteria. Args: query: Search in tool names and descriptions category: Filter by category + tags: Filter by tags (must have all specified tags) healthy_only: Only return healthy tools + auto_discovered_only: Only return auto-discovered tools Returns: List of matching ToolRecords @@ -156,17 +241,27 @@ class ToolRegistry: if category and record.category != category: continue + # Tags filter + if tags: + if not all(tag in record.tags for tag in tags): + continue + # Health filter if healthy_only and record.health_status == "unhealthy": continue + # Auto-discovered filter + if auto_discovered_only and not record.auto_discovered: + continue + # Query filter if query: query_lower = query.lower() name_match = query_lower in name.lower() desc = record.schema.get("description", "") desc_match = query_lower in desc.lower() - if not (name_match or desc_match): + tag_match = any(query_lower in tag.lower() for tag in record.tags) + if not (name_match or desc_match or tag_match): continue results.append(record) @@ -274,11 +369,15 @@ class ToolRegistry: "category": r.category, "health": r.health_status, "requires_confirmation": r.requires_confirmation, + "tags": r.tags, + "source_module": r.source_module, + "auto_discovered": r.auto_discovered, } for r in self._tools.values() ], "categories": self._categories, "total_tools": len(self._tools), + "auto_discovered_count": sum(1 for r in self._tools.values() if r.auto_discovered), } @@ -286,6 +385,11 @@ class ToolRegistry: tool_registry = ToolRegistry() +def get_registry() -> ToolRegistry: + """Get the global tool registry singleton.""" + return tool_registry + + def register_tool( name: Optional[str] = None, category: str = "general", diff --git a/tests/test_mcp_bootstrap.py b/tests/test_mcp_bootstrap.py new file mode 100644 index 00000000..6b12db4e --- /dev/null +++ b/tests/test_mcp_bootstrap.py @@ -0,0 +1,265 @@ +"""Tests for MCP Auto-Bootstrap. + +Tests follow pytest best practices: +- No module-level state +- Proper fixture cleanup +- Isolated tests +""" + +import os +from pathlib import Path +from unittest.mock import patch + +import pytest + +from mcp.bootstrap import ( + auto_bootstrap, + bootstrap_from_directory, + get_bootstrap_status, + DEFAULT_TOOL_PACKAGES, + AUTO_BOOTSTRAP_ENV_VAR, +) +from mcp.discovery import mcp_tool, ToolDiscovery +from mcp.registry import ToolRegistry + + +@pytest.fixture +def fresh_registry(): + """Create a fresh registry for each test.""" + return ToolRegistry() + + +@pytest.fixture +def fresh_discovery(fresh_registry): + """Create a fresh discovery instance for each test.""" + return ToolDiscovery(registry=fresh_registry) + + +class TestAutoBootstrap: + """Test auto_bootstrap function.""" + + def test_auto_bootstrap_disabled_by_env(self, fresh_registry): + """Test that auto-bootstrap can be disabled via env var.""" + with patch.dict(os.environ, {AUTO_BOOTSTRAP_ENV_VAR: "0"}): + registered = auto_bootstrap(registry=fresh_registry) + + assert len(registered) == 0 + + def test_auto_bootstrap_forced_overrides_env(self, fresh_registry): + """Test that force=True overrides env var.""" + with patch.dict(os.environ, {AUTO_BOOTSTRAP_ENV_VAR: "0"}): + # Empty packages list - just test that it runs + registered = auto_bootstrap( + packages=[], + registry=fresh_registry, + force=True, + ) + + assert len(registered) == 0 # No packages, but didn't abort + + def test_auto_bootstrap_nonexistent_package(self, fresh_registry): + """Test bootstrap from non-existent package.""" + registered = auto_bootstrap( + packages=["nonexistent_package_xyz_12345"], + registry=fresh_registry, + force=True, + ) + + assert len(registered) == 0 + + def test_auto_bootstrap_empty_packages(self, fresh_registry): + """Test bootstrap with empty packages list.""" + registered = auto_bootstrap( + packages=[], + registry=fresh_registry, + force=True, + ) + + assert len(registered) == 0 + + def test_auto_bootstrap_registers_tools(self, fresh_registry, fresh_discovery): + """Test that auto-bootstrap registers discovered tools.""" + @mcp_tool(name="bootstrap_tool", category="bootstrap") + def bootstrap_func(value: str) -> str: + """A bootstrap test tool.""" + return value + + # Manually register it + fresh_registry.register_tool( + name="bootstrap_tool", + function=bootstrap_func, + category="bootstrap", + ) + + # Verify it's in the registry + record = fresh_registry.get("bootstrap_tool") + assert record is not None + assert record.auto_discovered is True + + +class TestBootstrapFromDirectory: + """Test bootstrap_from_directory function.""" + + def test_bootstrap_from_directory(self, fresh_registry, tmp_path): + """Test bootstrapping from a directory.""" + tools_dir = tmp_path / "tools" + tools_dir.mkdir() + + tool_file = tools_dir / "my_tools.py" + tool_file.write_text(''' +from mcp.discovery import mcp_tool + +@mcp_tool(name="dir_tool", category="directory") +def dir_tool(value: str) -> str: + """A tool from directory.""" + return value +''') + + registered = bootstrap_from_directory(tools_dir, registry=fresh_registry) + + # Function won't be resolved (AST only), so not registered + assert len(registered) == 0 + + def test_bootstrap_from_nonexistent_directory(self, fresh_registry): + """Test bootstrapping from non-existent directory.""" + registered = bootstrap_from_directory( + Path("/nonexistent/tools"), + registry=fresh_registry + ) + + assert len(registered) == 0 + + def test_bootstrap_skips_private_files(self, fresh_registry, tmp_path): + """Test that private files are skipped.""" + tools_dir = tmp_path / "tools" + tools_dir.mkdir() + + private_file = tools_dir / "_private.py" + private_file.write_text(''' +from mcp.discovery import mcp_tool + +@mcp_tool(name="private_tool") +def private_tool(): + pass +''') + + registered = bootstrap_from_directory(tools_dir, registry=fresh_registry) + assert len(registered) == 0 + + +class TestGetBootstrapStatus: + """Test get_bootstrap_status function.""" + + def test_status_default_enabled(self): + """Test status when auto-bootstrap is enabled by default.""" + with patch.dict(os.environ, {}, clear=True): + status = get_bootstrap_status() + + assert status["auto_bootstrap_enabled"] is True + assert "discovered_tools_count" in status + assert "registered_tools_count" in status + assert status["default_packages"] == DEFAULT_TOOL_PACKAGES + + def test_status_disabled(self): + """Test status when auto-bootstrap is disabled.""" + with patch.dict(os.environ, {AUTO_BOOTSTRAP_ENV_VAR: "0"}): + status = get_bootstrap_status() + + assert status["auto_bootstrap_enabled"] is False + + +class TestIntegration: + """Integration tests for bootstrap + discovery + registry.""" + + def test_full_workflow(self, fresh_registry): + """Test the full auto-discovery and registration workflow.""" + @mcp_tool(name="integration_tool", category="integration") + def integration_func(data: str) -> str: + """Integration test tool.""" + return f"processed: {data}" + + fresh_registry.register_tool( + name="integration_tool", + function=integration_func, + category="integration", + source_module="test_module", + ) + + record = fresh_registry.get("integration_tool") + assert record is not None + assert record.auto_discovered is True + assert record.source_module == "test_module" + + export = fresh_registry.to_dict() + assert export["total_tools"] == 1 + assert export["auto_discovered_count"] == 1 + + def test_tool_execution_after_registration(self, fresh_registry): + """Test that registered tools can be executed.""" + @mcp_tool(name="exec_tool", category="execution") + def exec_func(input: str) -> str: + """Executable test tool.""" + return input.upper() + + fresh_registry.register_tool( + name="exec_tool", + function=exec_func, + category="execution", + ) + + import asyncio + result = asyncio.run(fresh_registry.execute("exec_tool", {"input": "hello"})) + + assert result == "HELLO" + + metrics = fresh_registry.get_metrics("exec_tool") + assert metrics["executions"] == 1 + assert metrics["health"] == "healthy" + + def test_discover_filtering(self, fresh_registry): + """Test filtering registered tools.""" + @mcp_tool(name="cat1_tool", category="category1") + def cat1_func(): + pass + + @mcp_tool(name="cat2_tool", category="category2") + def cat2_func(): + pass + + fresh_registry.register_tool( + name="cat1_tool", + function=cat1_func, + category="category1" + ) + fresh_registry.register_tool( + name="cat2_tool", + function=cat2_func, + category="category2" + ) + + cat1_tools = fresh_registry.discover(category="category1") + assert len(cat1_tools) == 1 + assert cat1_tools[0].name == "cat1_tool" + + auto_tools = fresh_registry.discover(auto_discovered_only=True) + assert len(auto_tools) == 2 + + def test_registry_export_includes_metadata(self, fresh_registry): + """Test that registry export includes all metadata.""" + @mcp_tool(name="meta_tool", category="meta", tags=["tag1", "tag2"]) + def meta_func(): + pass + + fresh_registry.register_tool( + name="meta_tool", + function=meta_func, + category="meta", + tags=["tag1", "tag2"], + ) + + export = fresh_registry.to_dict() + + for tool_dict in export["tools"]: + assert "tags" in tool_dict + assert "source_module" in tool_dict + assert "auto_discovered" in tool_dict diff --git a/tests/test_mcp_discovery.py b/tests/test_mcp_discovery.py new file mode 100644 index 00000000..ca14fbf5 --- /dev/null +++ b/tests/test_mcp_discovery.py @@ -0,0 +1,329 @@ +"""Tests for MCP Tool Auto-Discovery. + +Tests follow pytest best practices: +- No module-level state +- Proper fixture cleanup +- Isolated tests +""" + +import ast +import inspect +import sys +import types +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from mcp.discovery import DiscoveredTool, ToolDiscovery, mcp_tool +from mcp.registry import ToolRegistry + + +@pytest.fixture +def fresh_registry(): + """Create a fresh registry for each test.""" + return ToolRegistry() + + +@pytest.fixture +def discovery(fresh_registry): + """Create a fresh discovery instance for each test.""" + return ToolDiscovery(registry=fresh_registry) + + +@pytest.fixture +def mock_module_with_tools(): + """Create a mock module with MCP tools for testing.""" + # Create a fresh module + mock_module = types.ModuleType("mock_test_module") + mock_module.__file__ = "mock_test_module.py" + + # Add decorated functions + @mcp_tool(name="echo", category="test", tags=["utility"]) + def echo_func(message: str) -> str: + """Echo a message back.""" + return message + + @mcp_tool(category="math") + def add_func(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + def not_decorated(): + """Not a tool.""" + pass + + mock_module.echo_func = echo_func + mock_module.add_func = add_func + mock_module.not_decorated = not_decorated + + # Inject into sys.modules + sys.modules["mock_test_module"] = mock_module + + yield mock_module + + # Cleanup + del sys.modules["mock_test_module"] + + +class TestMCPToolDecorator: + """Test the @mcp_tool decorator.""" + + def test_decorator_sets_explicit_name(self): + """Test that decorator uses explicit name.""" + @mcp_tool(name="custom_name", category="test") + def my_func(): + pass + + assert my_func._mcp_name == "custom_name" + assert my_func._mcp_category == "test" + + def test_decorator_uses_function_name(self): + """Test that decorator uses function name when not specified.""" + @mcp_tool(category="math") + def my_add_func(): + pass + + assert my_add_func._mcp_name == "my_add_func" + + def test_decorator_captures_docstring(self): + """Test that decorator captures docstring as description.""" + @mcp_tool(name="test") + def with_doc(): + """This is the description.""" + pass + + assert "This is the description" in with_doc._mcp_description + + def test_decorator_sets_tags(self): + """Test that decorator sets tags.""" + @mcp_tool(name="test", tags=["tag1", "tag2"]) + def tagged_func(): + pass + + assert tagged_func._mcp_tags == ["tag1", "tag2"] + + def test_undecorated_function(self): + """Test that undecorated functions don't have MCP attributes.""" + def plain_func(): + pass + + assert not hasattr(plain_func, "_mcp_tool") + + +class TestDiscoveredTool: + """Test DiscoveredTool dataclass.""" + + def test_tool_creation(self): + """Test creating a DiscoveredTool.""" + def dummy_func(): + pass + + tool = DiscoveredTool( + name="test", + description="A test tool", + function=dummy_func, + module="test_module", + category="test", + tags=["utility"], + parameters_schema={"type": "object"}, + returns_schema={"type": "string"}, + ) + + assert tool.name == "test" + assert tool.function == dummy_func + assert tool.category == "test" + + +class TestToolDiscoveryInit: + """Test ToolDiscovery initialization.""" + + def test_uses_provided_registry(self, fresh_registry): + """Test initialization with provided registry.""" + discovery = ToolDiscovery(registry=fresh_registry) + assert discovery.registry is fresh_registry + + +class TestDiscoverModule: + """Test discovering tools from modules.""" + + def test_discover_finds_decorated_tools(self, discovery, mock_module_with_tools): + """Test discovering tools from a module.""" + tools = discovery.discover_module("mock_test_module") + + tool_names = [t.name for t in tools] + assert "echo" in tool_names + assert "add_func" in tool_names + assert "not_decorated" not in tool_names + + def test_discover_nonexistent_module(self, discovery): + """Test discovering from non-existent module.""" + tools = discovery.discover_module("nonexistent.module.xyz") + assert len(tools) == 0 + + def test_discovered_tool_has_correct_metadata(self, discovery, mock_module_with_tools): + """Test that discovered tools have correct metadata.""" + tools = discovery.discover_module("mock_test_module") + + echo_tool = next(t for t in tools if t.name == "echo") + assert echo_tool.category == "test" + assert "utility" in echo_tool.tags + + def test_discovered_tool_has_schema(self, discovery, mock_module_with_tools): + """Test that discovered tools have parameter schemas.""" + tools = discovery.discover_module("mock_test_module") + + add_tool = next(t for t in tools if t.name == "add_func") + assert "properties" in add_tool.parameters_schema + assert "a" in add_tool.parameters_schema["properties"] + + +class TestDiscoverFile: + """Test discovering tools from Python files.""" + + def test_discover_from_file(self, discovery, tmp_path): + """Test discovering tools from a Python file.""" + test_file = tmp_path / "test_tools.py" + test_file.write_text(''' +from mcp.discovery import mcp_tool + +@mcp_tool(name="file_tool", category="file_ops", tags=["io"]) +def file_tool(path: str) -> dict: + """Process a file.""" + return {"path": path} +''') + + tools = discovery.discover_file(test_file) + + assert len(tools) == 1 + assert tools[0].name == "file_tool" + assert tools[0].category == "file_ops" + + def test_discover_from_nonexistent_file(self, discovery, tmp_path): + """Test discovering from non-existent file.""" + tools = discovery.discover_file(tmp_path / "nonexistent.py") + assert len(tools) == 0 + + def test_discover_from_invalid_python(self, discovery, tmp_path): + """Test discovering from invalid Python file.""" + test_file = tmp_path / "invalid.py" + test_file.write_text("not valid python @#$%") + + tools = discovery.discover_file(test_file) + assert len(tools) == 0 + + +class TestSchemaBuilding: + """Test JSON schema building from type hints.""" + + def test_string_parameter(self, discovery): + """Test string parameter schema.""" + def func(name: str) -> str: + return name + + sig = inspect.signature(func) + schema = discovery._build_parameters_schema(sig) + + assert schema["properties"]["name"]["type"] == "string" + + def test_int_parameter(self, discovery): + """Test int parameter schema.""" + def func(count: int) -> int: + return count + + sig = inspect.signature(func) + schema = discovery._build_parameters_schema(sig) + + assert schema["properties"]["count"]["type"] == "number" + + def test_bool_parameter(self, discovery): + """Test bool parameter schema.""" + def func(enabled: bool) -> bool: + return enabled + + sig = inspect.signature(func) + schema = discovery._build_parameters_schema(sig) + + assert schema["properties"]["enabled"]["type"] == "boolean" + + def test_required_parameters(self, discovery): + """Test that required parameters are marked.""" + def func(required: str, optional: str = "default") -> str: + return required + + sig = inspect.signature(func) + schema = discovery._build_parameters_schema(sig) + + assert "required" in schema["required"] + assert "optional" not in schema["required"] + + def test_default_values(self, discovery): + """Test that default values are captured.""" + def func(name: str = "default") -> str: + return name + + sig = inspect.signature(func) + schema = discovery._build_parameters_schema(sig) + + assert schema["properties"]["name"]["default"] == "default" + + +class TestTypeToSchema: + """Test type annotation to JSON schema conversion.""" + + def test_str_annotation(self, discovery): + """Test string annotation.""" + schema = discovery._type_to_schema(str) + assert schema["type"] == "string" + + def test_int_annotation(self, discovery): + """Test int annotation.""" + schema = discovery._type_to_schema(int) + assert schema["type"] == "number" + + def test_optional_annotation(self, discovery): + """Test Optional[T] annotation.""" + from typing import Optional + schema = discovery._type_to_schema(Optional[str]) + assert schema["type"] == "string" + + +class TestAutoRegister: + """Test auto-registration of discovered tools.""" + + def test_auto_register_module(self, discovery, mock_module_with_tools, fresh_registry): + """Test auto-registering tools from a module.""" + registered = discovery.auto_register("mock_test_module") + + assert "echo" in registered + assert "add_func" in registered + assert fresh_registry.get("echo") is not None + + def test_auto_register_skips_unresolved_functions(self, discovery, fresh_registry): + """Test that tools without resolved functions are skipped.""" + # Add a discovered tool with no function + discovery._discovered.append(DiscoveredTool( + name="no_func", + description="No function", + function=None, # type: ignore + module="test", + category="test", + tags=[], + parameters_schema={}, + returns_schema={}, + )) + + registered = discovery.auto_register("mock_test_module") + assert "no_func" not in registered + + +class TestClearDiscovered: + """Test clearing discovered tools cache.""" + + def test_clear_discovered(self, discovery, mock_module_with_tools): + """Test clearing discovered tools.""" + discovery.discover_module("mock_test_module") + assert len(discovery.get_discovered()) > 0 + + discovery.clear() + assert len(discovery.get_discovered()) == 0