forked from Rockachopa/Timmy-time-dashboard
Merge pull request #44 from AlexanderWhitestone/feature/memory-layers-and-conversational-ai
Phase 3-4: Cascade LLM Router + Tool Registry Auto-Discovery
This commit is contained in:
21
src/agents/__init__.py
Normal file
21
src/agents/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Agents package — Timmy and sub-agents.
|
||||
"""
|
||||
|
||||
from agents.timmy import TimmyOrchestrator, create_timmy_swarm
|
||||
from agents.base import BaseAgent
|
||||
from agents.seer import SeerAgent
|
||||
from agents.forge import ForgeAgent
|
||||
from agents.quill import QuillAgent
|
||||
from agents.echo import EchoAgent
|
||||
from agents.helm import HelmAgent
|
||||
|
||||
__all__ = [
|
||||
"BaseAgent",
|
||||
"TimmyOrchestrator",
|
||||
"create_timmy_swarm",
|
||||
"SeerAgent",
|
||||
"ForgeAgent",
|
||||
"QuillAgent",
|
||||
"EchoAgent",
|
||||
"HelmAgent",
|
||||
]
|
||||
139
src/agents/base.py
Normal file
139
src/agents/base.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Base agent class for all Timmy sub-agents.
|
||||
|
||||
All sub-agents inherit from BaseAgent and get:
|
||||
- MCP tool registry access
|
||||
- Event bus integration
|
||||
- Memory integration
|
||||
- Structured logging
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from agno.agent import Agent
|
||||
from agno.models.ollama import Ollama
|
||||
|
||||
from config import settings
|
||||
from events.bus import EventBus, Event
|
||||
from mcp.registry import tool_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""Base class for all Timmy sub-agents.
|
||||
|
||||
Sub-agents are specialized agents that handle specific tasks:
|
||||
- Seer: Research and information gathering
|
||||
- Mace: Security and validation
|
||||
- Quill: Writing and content
|
||||
- Forge: Code and tool building
|
||||
- Echo: Memory and context
|
||||
- Helm: Routing and orchestration
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
name: str,
|
||||
role: str,
|
||||
system_prompt: str,
|
||||
tools: list[str] | None = None,
|
||||
) -> None:
|
||||
self.agent_id = agent_id
|
||||
self.name = name
|
||||
self.role = role
|
||||
self.tools = tools or []
|
||||
|
||||
# Create Agno agent
|
||||
self.agent = self._create_agent(system_prompt)
|
||||
|
||||
# Event bus for communication
|
||||
self.event_bus: Optional[EventBus] = None
|
||||
|
||||
logger.info("%s agent initialized (id: %s)", name, agent_id)
|
||||
|
||||
def _create_agent(self, system_prompt: str) -> Agent:
|
||||
"""Create the underlying Agno agent."""
|
||||
# Get tools from registry
|
||||
tool_instances = []
|
||||
for tool_name in self.tools:
|
||||
handler = tool_registry.get_handler(tool_name)
|
||||
if handler:
|
||||
tool_instances.append(handler)
|
||||
|
||||
return Agent(
|
||||
name=self.name,
|
||||
model=Ollama(id=settings.ollama_model, host=settings.ollama_url),
|
||||
description=system_prompt,
|
||||
tools=tool_instances if tool_instances else None,
|
||||
add_history_to_context=True,
|
||||
num_history_runs=10,
|
||||
markdown=True,
|
||||
telemetry=settings.telemetry_enabled,
|
||||
)
|
||||
|
||||
def connect_event_bus(self, bus: EventBus) -> None:
|
||||
"""Connect to the event bus for inter-agent communication."""
|
||||
self.event_bus = bus
|
||||
|
||||
# Subscribe to relevant events
|
||||
bus.subscribe(f"agent.{self.agent_id}.*")(self._handle_direct_message)
|
||||
bus.subscribe("agent.task.assigned")(self._handle_task_assignment)
|
||||
|
||||
async def _handle_direct_message(self, event: Event) -> None:
|
||||
"""Handle direct messages to this agent."""
|
||||
logger.debug("%s received message: %s", self.name, event.type)
|
||||
|
||||
async def _handle_task_assignment(self, event: Event) -> None:
|
||||
"""Handle task assignment events."""
|
||||
assigned_agent = event.data.get("agent_id")
|
||||
if assigned_agent == self.agent_id:
|
||||
task_id = event.data.get("task_id")
|
||||
description = event.data.get("description", "")
|
||||
logger.info("%s assigned task %s: %s", self.name, task_id, description[:50])
|
||||
|
||||
# Execute the task
|
||||
await self.execute_task(task_id, description, event.data)
|
||||
|
||||
@abstractmethod
|
||||
async def execute_task(self, task_id: str, description: str, context: dict) -> Any:
|
||||
"""Execute a task assigned to this agent.
|
||||
|
||||
Must be implemented by subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def run(self, message: str) -> str:
|
||||
"""Run the agent with a message.
|
||||
|
||||
Returns:
|
||||
Agent response
|
||||
"""
|
||||
result = self.agent.run(message, stream=False)
|
||||
response = result.content if hasattr(result, "content") else str(result)
|
||||
|
||||
# Emit completion event
|
||||
if self.event_bus:
|
||||
await self.event_bus.publish(Event(
|
||||
type=f"agent.{self.agent_id}.response",
|
||||
source=self.agent_id,
|
||||
data={"input": message, "output": response},
|
||||
))
|
||||
|
||||
return response
|
||||
|
||||
def get_capabilities(self) -> list[str]:
|
||||
"""Get list of capabilities this agent provides."""
|
||||
return self.tools
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Get current agent status."""
|
||||
return {
|
||||
"agent_id": self.agent_id,
|
||||
"name": self.name,
|
||||
"role": self.role,
|
||||
"status": "ready",
|
||||
"tools": self.tools,
|
||||
}
|
||||
81
src/agents/echo.py
Normal file
81
src/agents/echo.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Echo Agent — Memory and context management.
|
||||
|
||||
Capabilities:
|
||||
- Memory retrieval
|
||||
- Context synthesis
|
||||
- User profile management
|
||||
- Conversation history
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from agents.base import BaseAgent
|
||||
|
||||
|
||||
ECHO_SYSTEM_PROMPT = """You are Echo, a memory and context management specialist.
|
||||
|
||||
Your role is to remember, retrieve, and synthesize information from the past.
|
||||
|
||||
## Capabilities
|
||||
|
||||
- Search past conversations
|
||||
- Retrieve user preferences
|
||||
- Synthesize context from multiple sources
|
||||
- Manage user profile
|
||||
|
||||
## Guidelines
|
||||
|
||||
1. **Be accurate** — Only state what we actually know
|
||||
2. **Be relevant** — Filter for context that matters now
|
||||
3. **Be concise** — Summarize, don't dump everything
|
||||
4. **Acknowledge uncertainty** — Say when memory is unclear
|
||||
|
||||
## Tool Usage
|
||||
|
||||
- Use memory_search to find relevant past context
|
||||
- Use read_file to access vault files
|
||||
- Use write_file to update user profile
|
||||
|
||||
## Response Format
|
||||
|
||||
Provide memory retrieval in this structure:
|
||||
- Direct answer (what we know)
|
||||
- Context (relevant past discussions)
|
||||
- Confidence (certain/likely/speculative)
|
||||
- Source (where this came from)
|
||||
|
||||
You work for Timmy, the sovereign AI orchestrator. Be the keeper of institutional knowledge.
|
||||
"""
|
||||
|
||||
|
||||
class EchoAgent(BaseAgent):
|
||||
"""Memory and context specialist."""
|
||||
|
||||
def __init__(self, agent_id: str = "echo") -> None:
|
||||
super().__init__(
|
||||
agent_id=agent_id,
|
||||
name="Echo",
|
||||
role="memory",
|
||||
system_prompt=ECHO_SYSTEM_PROMPT,
|
||||
tools=["memory_search", "read_file", "write_file"],
|
||||
)
|
||||
|
||||
async def execute_task(self, task_id: str, description: str, context: dict) -> Any:
|
||||
"""Execute a memory retrieval task."""
|
||||
# Extract what to search for
|
||||
prompt = f"Search memory and provide relevant context:\n\nTask: {description}\n\nSynthesize findings clearly."
|
||||
|
||||
result = await self.run(prompt)
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"agent": self.agent_id,
|
||||
"result": result,
|
||||
"status": "completed",
|
||||
}
|
||||
|
||||
async def recall(self, query: str, include_sources: bool = True) -> str:
|
||||
"""Quick memory recall."""
|
||||
sources = "with sources" if include_sources else ""
|
||||
prompt = f"Recall information about: {query} {sources}\n\nProvide relevant context from memory."
|
||||
return await self.run(prompt)
|
||||
92
src/agents/forge.py
Normal file
92
src/agents/forge.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Forge Agent — Code generation and tool building.
|
||||
|
||||
Capabilities:
|
||||
- Code generation
|
||||
- Tool/script creation
|
||||
- System modifications
|
||||
- Debugging assistance
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from agents.base import BaseAgent
|
||||
|
||||
|
||||
FORGE_SYSTEM_PROMPT = """You are Forge, a code generation and tool building specialist.
|
||||
|
||||
Your role is to write code, create tools, and modify systems.
|
||||
|
||||
## Capabilities
|
||||
|
||||
- Python code generation
|
||||
- Tool/script creation
|
||||
- File operations
|
||||
- Code explanation and debugging
|
||||
|
||||
## Guidelines
|
||||
|
||||
1. **Write clean code** — Follow PEP 8, add docstrings
|
||||
2. **Be safe** — Never execute destructive operations without confirmation
|
||||
3. **Explain your work** — Provide context for what the code does
|
||||
4. **Test mentally** — Walk through the logic before presenting
|
||||
|
||||
## Tool Usage
|
||||
|
||||
- Use python for code execution and testing
|
||||
- Use write_file to save code (requires confirmation)
|
||||
- Use read_file to examine existing code
|
||||
- Use shell for system operations (requires confirmation)
|
||||
|
||||
## Response Format
|
||||
|
||||
Provide code in this structure:
|
||||
- Purpose (what this code does)
|
||||
- Code block (with language tag)
|
||||
- Usage example
|
||||
- Notes (any important considerations)
|
||||
|
||||
You work for Timmy, the sovereign AI orchestrator. Build reliable, well-documented tools.
|
||||
"""
|
||||
|
||||
|
||||
class ForgeAgent(BaseAgent):
|
||||
"""Code and tool building specialist."""
|
||||
|
||||
def __init__(self, agent_id: str = "forge") -> None:
|
||||
super().__init__(
|
||||
agent_id=agent_id,
|
||||
name="Forge",
|
||||
role="code",
|
||||
system_prompt=FORGE_SYSTEM_PROMPT,
|
||||
tools=["python", "write_file", "read_file", "list_directory"],
|
||||
)
|
||||
|
||||
async def execute_task(self, task_id: str, description: str, context: dict) -> Any:
|
||||
"""Execute a code/task building task."""
|
||||
prompt = f"Create the requested code or tool:\n\nTask: {description}\n\nProvide complete, working code with documentation."
|
||||
|
||||
result = await self.run(prompt)
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"agent": self.agent_id,
|
||||
"result": result,
|
||||
"status": "completed",
|
||||
}
|
||||
|
||||
async def generate_tool(self, name: str, purpose: str, parameters: list) -> str:
|
||||
"""Generate a new MCP tool."""
|
||||
params_str = ", ".join(parameters)
|
||||
prompt = f"""Create a new MCP tool named '{name}'.
|
||||
|
||||
Purpose: {purpose}
|
||||
Parameters: {params_str}
|
||||
|
||||
Generate:
|
||||
1. The tool function with proper error handling
|
||||
2. The MCP schema
|
||||
3. Registration code
|
||||
|
||||
Follow the MCP pattern used in existing tools."""
|
||||
|
||||
return await self.run(prompt)
|
||||
106
src/agents/helm.py
Normal file
106
src/agents/helm.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Helm Agent — Routing and orchestration decisions.
|
||||
|
||||
Capabilities:
|
||||
- Task analysis
|
||||
- Agent selection
|
||||
- Workflow planning
|
||||
- Priority management
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from agents.base import BaseAgent
|
||||
|
||||
|
||||
HELM_SYSTEM_PROMPT = """You are Helm, a routing and orchestration specialist.
|
||||
|
||||
Your role is to analyze tasks and decide how to route them to other agents.
|
||||
|
||||
## Capabilities
|
||||
|
||||
- Task analysis and decomposition
|
||||
- Agent selection for tasks
|
||||
- Workflow planning
|
||||
- Priority assessment
|
||||
|
||||
## Guidelines
|
||||
|
||||
1. **Analyze carefully** — Understand what the task really needs
|
||||
2. **Route wisely** — Match tasks to agent strengths
|
||||
3. **Consider dependencies** — Some tasks need sequencing
|
||||
4. **Be efficient** — Don't over-complicate simple tasks
|
||||
|
||||
## Agent Roster
|
||||
|
||||
- Seer: Research, information gathering
|
||||
- Forge: Code, tools, system changes
|
||||
- Quill: Writing, documentation
|
||||
- Echo: Memory, context retrieval
|
||||
- Mace: Security, validation (use for sensitive operations)
|
||||
|
||||
## Response Format
|
||||
|
||||
Provide routing decisions as:
|
||||
- Task breakdown (subtasks if needed)
|
||||
- Agent assignment (who does what)
|
||||
- Execution order (sequence if relevant)
|
||||
- Rationale (why this routing)
|
||||
|
||||
You work for Timmy, the sovereign AI orchestrator. Be the dispatcher that keeps everything flowing.
|
||||
"""
|
||||
|
||||
|
||||
class HelmAgent(BaseAgent):
|
||||
"""Routing and orchestration specialist."""
|
||||
|
||||
def __init__(self, agent_id: str = "helm") -> None:
|
||||
super().__init__(
|
||||
agent_id=agent_id,
|
||||
name="Helm",
|
||||
role="routing",
|
||||
system_prompt=HELM_SYSTEM_PROMPT,
|
||||
tools=["memory_search"], # May need to check past routing decisions
|
||||
)
|
||||
|
||||
async def execute_task(self, task_id: str, description: str, context: dict) -> Any:
|
||||
"""Execute a routing task."""
|
||||
prompt = f"Analyze and route this task:\n\nTask: {description}\n\nProvide routing decision with rationale."
|
||||
|
||||
result = await self.run(prompt)
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"agent": self.agent_id,
|
||||
"result": result,
|
||||
"status": "completed",
|
||||
}
|
||||
|
||||
async def route_request(self, request: str) -> dict:
|
||||
"""Analyze a request and suggest routing."""
|
||||
prompt = f"""Analyze this request and determine the best agent(s) to handle it:
|
||||
|
||||
Request: {request}
|
||||
|
||||
Respond in this format:
|
||||
Primary Agent: [agent name]
|
||||
Reason: [why this agent]
|
||||
Secondary Agents: [if needed]
|
||||
Complexity: [simple/moderate/complex]
|
||||
"""
|
||||
result = await self.run(prompt)
|
||||
|
||||
# Parse result into structured format
|
||||
# This is simplified - in production, use structured output
|
||||
return {
|
||||
"analysis": result,
|
||||
"primary_agent": self._extract_agent(result),
|
||||
}
|
||||
|
||||
def _extract_agent(self, text: str) -> str:
|
||||
"""Extract agent name from routing text."""
|
||||
agents = ["seer", "forge", "quill", "echo", "mace", "helm"]
|
||||
text_lower = text.lower()
|
||||
for agent in agents:
|
||||
if agent in text_lower:
|
||||
return agent
|
||||
return "timmy" # Default to orchestrator
|
||||
80
src/agents/quill.py
Normal file
80
src/agents/quill.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Quill Agent — Writing and content generation.
|
||||
|
||||
Capabilities:
|
||||
- Documentation writing
|
||||
- Content creation
|
||||
- Text editing
|
||||
- Summarization
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from agents.base import BaseAgent
|
||||
|
||||
|
||||
QUILL_SYSTEM_PROMPT = """You are Quill, a writing and content generation specialist.
|
||||
|
||||
Your role is to create, edit, and improve written content.
|
||||
|
||||
## Capabilities
|
||||
|
||||
- Documentation writing
|
||||
- Content creation
|
||||
- Text editing and refinement
|
||||
- Summarization
|
||||
- Style adaptation
|
||||
|
||||
## Guidelines
|
||||
|
||||
1. **Write clearly** — Plain language, logical structure
|
||||
2. **Know your audience** — Adapt tone and complexity
|
||||
3. **Be concise** — Cut unnecessary words
|
||||
4. **Use formatting** — Headers, lists, emphasis for readability
|
||||
|
||||
## Tool Usage
|
||||
|
||||
- Use write_file to save documents
|
||||
- Use read_file to review existing content
|
||||
- Use memory_search to check style preferences
|
||||
|
||||
## Response Format
|
||||
|
||||
Provide written content with:
|
||||
- Clear structure (headers, sections)
|
||||
- Appropriate tone for the context
|
||||
- Proper formatting (markdown)
|
||||
- Brief explanation of choices made
|
||||
|
||||
You work for Timmy, the sovereign AI orchestrator. Create polished, professional content.
|
||||
"""
|
||||
|
||||
|
||||
class QuillAgent(BaseAgent):
|
||||
"""Writing and content specialist."""
|
||||
|
||||
def __init__(self, agent_id: str = "quill") -> None:
|
||||
super().__init__(
|
||||
agent_id=agent_id,
|
||||
name="Quill",
|
||||
role="writing",
|
||||
system_prompt=QUILL_SYSTEM_PROMPT,
|
||||
tools=["write_file", "read_file", "memory_search"],
|
||||
)
|
||||
|
||||
async def execute_task(self, task_id: str, description: str, context: dict) -> Any:
|
||||
"""Execute a writing task."""
|
||||
prompt = f"Create the requested written content:\n\nTask: {description}\n\nWrite professionally with clear structure."
|
||||
|
||||
result = await self.run(prompt)
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"agent": self.agent_id,
|
||||
"result": result,
|
||||
"status": "completed",
|
||||
}
|
||||
|
||||
async def write_documentation(self, topic: str, format: str = "markdown") -> str:
|
||||
"""Write documentation for a topic."""
|
||||
prompt = f"Write comprehensive documentation for: {topic}\n\nFormat: {format}\nInclude: Overview, Usage, Examples, Notes"
|
||||
return await self.run(prompt)
|
||||
91
src/agents/seer.py
Normal file
91
src/agents/seer.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Seer Agent — Research and information gathering.
|
||||
|
||||
Capabilities:
|
||||
- Web search
|
||||
- Information synthesis
|
||||
- Fact checking
|
||||
- Source evaluation
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from agents.base import BaseAgent
|
||||
from events.bus import Event
|
||||
|
||||
|
||||
SEER_SYSTEM_PROMPT = """You are Seer, a research and information gathering specialist.
|
||||
|
||||
Your role is to find, evaluate, and synthesize information from external sources.
|
||||
|
||||
## Capabilities
|
||||
|
||||
- Web search for current information
|
||||
- File reading for local documents
|
||||
- Information synthesis and summarization
|
||||
- Source evaluation (credibility assessment)
|
||||
|
||||
## Guidelines
|
||||
|
||||
1. **Be thorough** — Search multiple angles, verify facts
|
||||
2. **Be skeptical** — Evaluate source credibility
|
||||
3. **Be concise** — Summarize findings clearly
|
||||
4. **Cite sources** — Reference where information came from
|
||||
|
||||
## Tool Usage
|
||||
|
||||
- Use web_search for external information
|
||||
- Use read_file for local documents
|
||||
- Use memory_search to check if we already know this
|
||||
|
||||
## Response Format
|
||||
|
||||
Provide findings in structured format:
|
||||
- Summary (2-3 sentences)
|
||||
- Key facts (bullet points)
|
||||
- Sources (where information came from)
|
||||
- Confidence level (high/medium/low)
|
||||
|
||||
You work for Timmy, the sovereign AI orchestrator. Report findings clearly and objectively.
|
||||
"""
|
||||
|
||||
|
||||
class SeerAgent(BaseAgent):
|
||||
"""Research specialist agent."""
|
||||
|
||||
def __init__(self, agent_id: str = "seer") -> None:
|
||||
super().__init__(
|
||||
agent_id=agent_id,
|
||||
name="Seer",
|
||||
role="research",
|
||||
system_prompt=SEER_SYSTEM_PROMPT,
|
||||
tools=["web_search", "read_file", "memory_search"],
|
||||
)
|
||||
|
||||
async def execute_task(self, task_id: str, description: str, context: dict) -> Any:
|
||||
"""Execute a research task."""
|
||||
# Determine research approach
|
||||
if "file" in description.lower() or "document" in description.lower():
|
||||
# Local document research
|
||||
prompt = f"Read and analyze the referenced document. Provide key findings:\n\nTask: {description}"
|
||||
else:
|
||||
# Web research
|
||||
prompt = f"Research the following topic thoroughly. Search for current information, evaluate sources, and provide a comprehensive summary:\n\nTask: {description}"
|
||||
|
||||
result = await self.run(prompt)
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"agent": self.agent_id,
|
||||
"result": result,
|
||||
"status": "completed",
|
||||
}
|
||||
|
||||
async def research_topic(self, topic: str, depth: str = "standard") -> str:
|
||||
"""Quick research on a topic."""
|
||||
prompts = {
|
||||
"quick": f"Quick search on: {topic}. Provide 3-5 key facts.",
|
||||
"standard": f"Research: {topic}. Search, synthesize, and summarize findings.",
|
||||
"deep": f"Deep research on: {topic}. Multiple searches, fact-checking, comprehensive report.",
|
||||
}
|
||||
|
||||
return await self.run(prompts.get(depth, prompts["standard"]))
|
||||
184
src/agents/timmy.py
Normal file
184
src/agents/timmy.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Timmy — The orchestrator agent.
|
||||
|
||||
Coordinates all sub-agents and handles user interaction.
|
||||
Uses the three-tier memory system and MCP tools.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from agno.agent import Agent
|
||||
from agno.models.ollama import Ollama
|
||||
|
||||
from agents.base import BaseAgent
|
||||
from config import settings
|
||||
from events.bus import EventBus, event_bus
|
||||
from mcp.registry import tool_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
TIMMY_ORCHESTRATOR_PROMPT = """You are Timmy, a sovereign AI orchestrator running locally on this Mac.
|
||||
|
||||
## Your Role
|
||||
|
||||
You are the primary interface between the user and the agent swarm. You:
|
||||
1. Understand user requests
|
||||
2. Decide whether to handle directly or delegate to sub-agents
|
||||
3. Coordinate multi-agent workflows when needed
|
||||
4. Maintain continuity using the three-tier memory system
|
||||
|
||||
## Sub-Agent Roster
|
||||
|
||||
| Agent | Role | When to Use |
|
||||
|-------|------|-------------|
|
||||
| Seer | Research | External info, web search, facts |
|
||||
| Forge | Code | Programming, tools, file operations |
|
||||
| Quill | Writing | Documentation, content creation |
|
||||
| Echo | Memory | Past conversations, user profile |
|
||||
| Helm | Routing | Complex multi-step workflows |
|
||||
| Mace | Security | Validation, sensitive operations |
|
||||
|
||||
## Decision Framework
|
||||
|
||||
**Handle directly if:**
|
||||
- Simple question (identity, capabilities)
|
||||
- General knowledge
|
||||
- Social/conversational
|
||||
|
||||
**Delegate if:**
|
||||
- Requires specialized skills
|
||||
- Needs external research (Seer)
|
||||
- Involves code (Forge)
|
||||
- Needs past context (Echo)
|
||||
- Complex workflow (Helm)
|
||||
|
||||
## Memory System
|
||||
|
||||
You have three tiers of memory:
|
||||
1. **Hot Memory** — Always loaded (MEMORY.md)
|
||||
2. **Vault** — Structured storage (memory/)
|
||||
3. **Semantic** — Vector search for recall
|
||||
|
||||
Use `memory_search` when the user refers to past conversations.
|
||||
|
||||
## Principles
|
||||
|
||||
1. **Sovereignty** — Everything local, no cloud
|
||||
2. **Privacy** — User data stays on their Mac
|
||||
3. **Clarity** — Think clearly, speak plainly
|
||||
4. **Christian faith** — Grounded in biblical values
|
||||
5. **Bitcoin economics** — Sound money, self-custody
|
||||
|
||||
Sir, affirmative.
|
||||
"""
|
||||
|
||||
|
||||
class TimmyOrchestrator(BaseAgent):
|
||||
"""Main orchestrator agent that coordinates the swarm."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
agent_id="timmy",
|
||||
name="Timmy",
|
||||
role="orchestrator",
|
||||
system_prompt=TIMMY_ORCHESTRATOR_PROMPT,
|
||||
tools=["web_search", "read_file", "write_file", "python", "memory_search"],
|
||||
)
|
||||
|
||||
# Sub-agent registry
|
||||
self.sub_agents: dict[str, BaseAgent] = {}
|
||||
|
||||
# Connect to event bus
|
||||
self.connect_event_bus(event_bus)
|
||||
|
||||
logger.info("Timmy Orchestrator initialized")
|
||||
|
||||
def register_sub_agent(self, agent: BaseAgent) -> None:
|
||||
"""Register a sub-agent with the orchestrator."""
|
||||
self.sub_agents[agent.agent_id] = agent
|
||||
agent.connect_event_bus(event_bus)
|
||||
logger.info("Registered sub-agent: %s", agent.name)
|
||||
|
||||
async def orchestrate(self, user_request: str) -> str:
|
||||
"""Main entry point for user requests.
|
||||
|
||||
Analyzes the request and either handles directly or delegates.
|
||||
"""
|
||||
# Quick classification
|
||||
request_lower = user_request.lower()
|
||||
|
||||
# Direct response patterns (no delegation needed)
|
||||
direct_patterns = [
|
||||
"your name", "who are you", "what are you",
|
||||
"hello", "hi", "how are you",
|
||||
"help", "what can you do",
|
||||
]
|
||||
|
||||
for pattern in direct_patterns:
|
||||
if pattern in request_lower:
|
||||
return await self.run(user_request)
|
||||
|
||||
# Check for memory references
|
||||
memory_patterns = [
|
||||
"we talked about", "we discussed", "remember",
|
||||
"what did i say", "what did we decide",
|
||||
"remind me", "have we",
|
||||
]
|
||||
|
||||
for pattern in memory_patterns:
|
||||
if pattern in request_lower:
|
||||
# Use Echo agent for memory retrieval
|
||||
echo = self.sub_agents.get("echo")
|
||||
if echo:
|
||||
return await echo.recall(user_request)
|
||||
|
||||
# Complex requests - use Helm for routing
|
||||
helm = self.sub_agents.get("helm")
|
||||
if helm:
|
||||
routing = await helm.route_request(user_request)
|
||||
agent_id = routing.get("primary_agent", "timmy")
|
||||
|
||||
if agent_id in self.sub_agents and agent_id != "timmy":
|
||||
agent = self.sub_agents[agent_id]
|
||||
return await agent.run(user_request)
|
||||
|
||||
# Default: handle directly
|
||||
return await self.run(user_request)
|
||||
|
||||
async def execute_task(self, task_id: str, description: str, context: dict) -> Any:
|
||||
"""Execute a task (usually delegates to appropriate agent)."""
|
||||
return await self.orchestrate(description)
|
||||
|
||||
def get_swarm_status(self) -> dict:
|
||||
"""Get status of all agents in the swarm."""
|
||||
return {
|
||||
"orchestrator": self.get_status(),
|
||||
"sub_agents": {
|
||||
aid: agent.get_status()
|
||||
for aid, agent in self.sub_agents.items()
|
||||
},
|
||||
"total_agents": 1 + len(self.sub_agents),
|
||||
}
|
||||
|
||||
|
||||
# Factory function for creating fully configured Timmy
|
||||
def create_timmy_swarm() -> TimmyOrchestrator:
|
||||
"""Create Timmy orchestrator with all sub-agents registered."""
|
||||
from agents.seer import SeerAgent
|
||||
from agents.forge import ForgeAgent
|
||||
from agents.quill import QuillAgent
|
||||
from agents.echo import EchoAgent
|
||||
from agents.helm import HelmAgent
|
||||
|
||||
# Create orchestrator
|
||||
timmy = TimmyOrchestrator()
|
||||
|
||||
# Register sub-agents
|
||||
timmy.register_sub_agent(SeerAgent())
|
||||
timmy.register_sub_agent(ForgeAgent())
|
||||
timmy.register_sub_agent(QuillAgent())
|
||||
timmy.register_sub_agent(EchoAgent())
|
||||
timmy.register_sub_agent(HelmAgent())
|
||||
|
||||
return timmy
|
||||
@@ -27,6 +27,7 @@ from dashboard.routes.spark import router as spark_router
|
||||
from dashboard.routes.creative import router as creative_router
|
||||
from dashboard.routes.discord import router as discord_router
|
||||
from dashboard.routes.self_modify import router as self_modify_router
|
||||
from router.api import router as cascade_router
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@@ -101,6 +102,15 @@ async def lifespan(app: FastAPI):
|
||||
except Exception as exc:
|
||||
logger.error("Failed to spawn persona agents: %s", exc)
|
||||
|
||||
# Auto-bootstrap MCP tools
|
||||
from mcp.bootstrap import auto_bootstrap, get_bootstrap_status
|
||||
try:
|
||||
registered = auto_bootstrap()
|
||||
if registered:
|
||||
logger.info("MCP auto-bootstrap: %d tools registered", len(registered))
|
||||
except Exception as exc:
|
||||
logger.warning("MCP auto-bootstrap failed: %s", exc)
|
||||
|
||||
# Initialise Spark Intelligence engine
|
||||
from spark.engine import spark_engine
|
||||
if spark_engine.enabled:
|
||||
@@ -156,6 +166,7 @@ app.include_router(spark_router)
|
||||
app.include_router(creative_router)
|
||||
app.include_router(discord_router)
|
||||
app.include_router(self_modify_router)
|
||||
app.include_router(cascade_router)
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
|
||||
168
src/events/bus.py
Normal file
168
src/events/bus.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Async Event Bus for inter-agent communication.
|
||||
|
||||
Agents publish and subscribe to events for loose coupling.
|
||||
Events are typed and carry structured data.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Event:
|
||||
"""A typed event in the system."""
|
||||
type: str # e.g., "agent.task.assigned", "tool.execution.completed"
|
||||
source: str # Agent or component that emitted the event
|
||||
data: dict = field(default_factory=dict)
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
id: str = field(default_factory=lambda: f"evt_{datetime.now(timezone.utc).timestamp()}")
|
||||
|
||||
|
||||
# Type alias for event handlers
|
||||
EventHandler = Callable[[Event], Coroutine[Any, Any, None]]
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""Async event bus for publish/subscribe pattern.
|
||||
|
||||
Usage:
|
||||
bus = EventBus()
|
||||
|
||||
# Subscribe to events
|
||||
@bus.subscribe("agent.task.*")
|
||||
async def handle_task(event: Event):
|
||||
print(f"Task event: {event.data}")
|
||||
|
||||
# Publish events
|
||||
await bus.publish(Event(
|
||||
type="agent.task.assigned",
|
||||
source="timmy",
|
||||
data={"task_id": "123", "agent": "forge"}
|
||||
))
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._subscribers: dict[str, list[EventHandler]] = {}
|
||||
self._history: list[Event] = []
|
||||
self._max_history = 1000
|
||||
logger.info("EventBus initialized")
|
||||
|
||||
def subscribe(self, event_pattern: str) -> Callable[[EventHandler], EventHandler]:
|
||||
"""Decorator to subscribe to events matching a pattern.
|
||||
|
||||
Patterns support wildcards:
|
||||
- "agent.task.assigned" — exact match
|
||||
- "agent.task.*" — any task event
|
||||
- "agent.*" — any agent event
|
||||
- "*" — all events
|
||||
"""
|
||||
def decorator(handler: EventHandler) -> EventHandler:
|
||||
if event_pattern not in self._subscribers:
|
||||
self._subscribers[event_pattern] = []
|
||||
self._subscribers[event_pattern].append(handler)
|
||||
logger.debug("Subscribed handler to '%s'", event_pattern)
|
||||
return handler
|
||||
return decorator
|
||||
|
||||
def unsubscribe(self, event_pattern: str, handler: EventHandler) -> bool:
|
||||
"""Remove a handler from a subscription."""
|
||||
if event_pattern not in self._subscribers:
|
||||
return False
|
||||
|
||||
if handler in self._subscribers[event_pattern]:
|
||||
self._subscribers[event_pattern].remove(handler)
|
||||
logger.debug("Unsubscribed handler from '%s'", event_pattern)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def publish(self, event: Event) -> int:
|
||||
"""Publish an event to all matching subscribers.
|
||||
|
||||
Returns:
|
||||
Number of handlers invoked
|
||||
"""
|
||||
# Store in history
|
||||
self._history.append(event)
|
||||
if len(self._history) > self._max_history:
|
||||
self._history = self._history[-self._max_history:]
|
||||
|
||||
# Find matching handlers
|
||||
handlers: list[EventHandler] = []
|
||||
|
||||
for pattern, pattern_handlers in self._subscribers.items():
|
||||
if self._match_pattern(event.type, pattern):
|
||||
handlers.extend(pattern_handlers)
|
||||
|
||||
# Invoke handlers concurrently
|
||||
if handlers:
|
||||
await asyncio.gather(
|
||||
*[self._invoke_handler(h, event) for h in handlers],
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
logger.debug("Published event '%s' to %d handlers", event.type, len(handlers))
|
||||
return len(handlers)
|
||||
|
||||
async def _invoke_handler(self, handler: EventHandler, event: Event) -> None:
|
||||
"""Invoke a handler with error handling."""
|
||||
try:
|
||||
await handler(event)
|
||||
except Exception as exc:
|
||||
logger.error("Event handler failed for '%s': %s", event.type, exc)
|
||||
|
||||
def _match_pattern(self, event_type: str, pattern: str) -> bool:
|
||||
"""Check if event type matches a wildcard pattern."""
|
||||
if pattern == "*":
|
||||
return True
|
||||
|
||||
if pattern.endswith(".*"):
|
||||
prefix = pattern[:-2]
|
||||
return event_type.startswith(prefix + ".")
|
||||
|
||||
return event_type == pattern
|
||||
|
||||
def get_history(
|
||||
self,
|
||||
event_type: str | None = None,
|
||||
source: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[Event]:
|
||||
"""Get recent event history with optional filtering."""
|
||||
events = self._history
|
||||
|
||||
if event_type:
|
||||
events = [e for e in events if e.type == event_type]
|
||||
|
||||
if source:
|
||||
events = [e for e in events if e.source == source]
|
||||
|
||||
return events[-limit:]
|
||||
|
||||
def clear_history(self) -> None:
|
||||
"""Clear event history."""
|
||||
self._history.clear()
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
event_bus = EventBus()
|
||||
|
||||
|
||||
# Convenience functions
|
||||
async def emit(event_type: str, source: str, data: dict) -> int:
|
||||
"""Quick emit an event."""
|
||||
return await event_bus.publish(Event(
|
||||
type=event_type,
|
||||
source=source,
|
||||
data=data,
|
||||
))
|
||||
|
||||
|
||||
def on(event_pattern: str) -> Callable[[EventHandler], EventHandler]:
|
||||
"""Quick subscribe decorator."""
|
||||
return event_bus.subscribe(event_pattern)
|
||||
30
src/mcp/__init__.py
Normal file
30
src/mcp/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""MCP (Model Context Protocol) package.
|
||||
|
||||
Provides tool registry, server, schema management, and auto-discovery.
|
||||
"""
|
||||
|
||||
from mcp.registry import tool_registry, register_tool, ToolRegistry
|
||||
from mcp.server import mcp_server, MCPServer, MCPHTTPServer
|
||||
from mcp.schemas.base import create_tool_schema
|
||||
from mcp.discovery import ToolDiscovery, mcp_tool, get_discovery
|
||||
from mcp.bootstrap import auto_bootstrap, get_bootstrap_status
|
||||
|
||||
__all__ = [
|
||||
# Registry
|
||||
"tool_registry",
|
||||
"register_tool",
|
||||
"ToolRegistry",
|
||||
# Server
|
||||
"mcp_server",
|
||||
"MCPServer",
|
||||
"MCPHTTPServer",
|
||||
# Schemas
|
||||
"create_tool_schema",
|
||||
# Discovery
|
||||
"ToolDiscovery",
|
||||
"mcp_tool",
|
||||
"get_discovery",
|
||||
# Bootstrap
|
||||
"auto_bootstrap",
|
||||
"get_bootstrap_status",
|
||||
]
|
||||
148
src/mcp/bootstrap.py
Normal file
148
src/mcp/bootstrap.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""MCP Auto-Bootstrap — Auto-discover and register tools on startup.
|
||||
|
||||
Usage:
|
||||
from mcp.bootstrap import auto_bootstrap
|
||||
|
||||
# Auto-discover from 'tools' package
|
||||
registered = auto_bootstrap()
|
||||
|
||||
# Or specify custom packages
|
||||
registered = auto_bootstrap(packages=["tools", "custom_tools"])
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .discovery import ToolDiscovery, get_discovery
|
||||
from .registry import ToolRegistry, tool_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default packages to scan for tools
|
||||
DEFAULT_TOOL_PACKAGES = ["tools"]
|
||||
|
||||
# Environment variable to disable auto-bootstrap
|
||||
AUTO_BOOTSTRAP_ENV_VAR = "MCP_AUTO_BOOTSTRAP"
|
||||
|
||||
|
||||
def auto_bootstrap(
|
||||
packages: Optional[list[str]] = None,
|
||||
registry: Optional[ToolRegistry] = None,
|
||||
force: bool = False,
|
||||
) -> list[str]:
|
||||
"""Auto-discover and register MCP tools.
|
||||
|
||||
Args:
|
||||
packages: Packages to scan (defaults to ["tools"])
|
||||
registry: Registry to register tools with (defaults to singleton)
|
||||
force: Force bootstrap even if disabled by env var
|
||||
|
||||
Returns:
|
||||
List of registered tool names
|
||||
"""
|
||||
# Check if auto-bootstrap is disabled
|
||||
if not force and os.environ.get(AUTO_BOOTSTRAP_ENV_VAR, "1") == "0":
|
||||
logger.info("MCP auto-bootstrap disabled via %s", AUTO_BOOTSTRAP_ENV_VAR)
|
||||
return []
|
||||
|
||||
packages = packages or DEFAULT_TOOL_PACKAGES
|
||||
registry = registry or tool_registry
|
||||
discovery = get_discovery(registry=registry)
|
||||
|
||||
registered: list[str] = []
|
||||
|
||||
logger.info("Starting MCP auto-bootstrap from packages: %s", packages)
|
||||
|
||||
for package in packages:
|
||||
try:
|
||||
# Check if package exists
|
||||
try:
|
||||
__import__(package)
|
||||
except ImportError:
|
||||
logger.debug("Package %s not found, skipping", package)
|
||||
continue
|
||||
|
||||
# Discover and register
|
||||
tools = discovery.auto_register(package)
|
||||
registered.extend(tools)
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to bootstrap from %s: %s", package, exc)
|
||||
|
||||
logger.info("MCP auto-bootstrap complete: %d tools registered", len(registered))
|
||||
return registered
|
||||
|
||||
|
||||
def bootstrap_from_directory(
|
||||
directory: Path,
|
||||
registry: Optional[ToolRegistry] = None,
|
||||
) -> list[str]:
|
||||
"""Bootstrap tools from a directory of Python files.
|
||||
|
||||
Args:
|
||||
directory: Directory containing Python files with tools
|
||||
registry: Registry to register tools with
|
||||
|
||||
Returns:
|
||||
List of registered tool names
|
||||
"""
|
||||
registry = registry or tool_registry
|
||||
discovery = get_discovery(registry=registry)
|
||||
|
||||
registered: list[str] = []
|
||||
|
||||
if not directory.exists():
|
||||
logger.warning("Tools directory not found: %s", directory)
|
||||
return registered
|
||||
|
||||
logger.info("Bootstrapping tools from directory: %s", directory)
|
||||
|
||||
# Find all Python files
|
||||
for py_file in directory.rglob("*.py"):
|
||||
if py_file.name.startswith("_"):
|
||||
continue
|
||||
|
||||
try:
|
||||
discovered = discovery.discover_file(py_file)
|
||||
|
||||
for tool in discovered:
|
||||
if tool.function is None:
|
||||
# Need to import and resolve the function
|
||||
continue
|
||||
|
||||
try:
|
||||
registry.register_tool(
|
||||
name=tool.name,
|
||||
function=tool.function,
|
||||
description=tool.description,
|
||||
category=tool.category,
|
||||
tags=tool.tags,
|
||||
)
|
||||
registered.append(tool.name)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to register %s: %s", tool.name, exc)
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to process %s: %s", py_file, exc)
|
||||
|
||||
logger.info("Directory bootstrap complete: %d tools registered", len(registered))
|
||||
return registered
|
||||
|
||||
|
||||
def get_bootstrap_status() -> dict:
|
||||
"""Get auto-bootstrap status.
|
||||
|
||||
Returns:
|
||||
Dict with bootstrap status info
|
||||
"""
|
||||
discovery = get_discovery()
|
||||
registry = tool_registry
|
||||
|
||||
return {
|
||||
"auto_bootstrap_enabled": os.environ.get(AUTO_BOOTSTRAP_ENV_VAR, "1") != "0",
|
||||
"discovered_tools_count": len(discovery.get_discovered()),
|
||||
"registered_tools_count": len(registry.list_tools()),
|
||||
"default_packages": DEFAULT_TOOL_PACKAGES,
|
||||
}
|
||||
441
src/mcp/discovery.py
Normal file
441
src/mcp/discovery.py
Normal file
@@ -0,0 +1,441 @@
|
||||
"""MCP Tool Auto-Discovery — Introspect Python modules to find tools.
|
||||
|
||||
Automatically discovers functions marked with @mcp_tool decorator
|
||||
and registers them with the MCP registry. Generates JSON schemas
|
||||
from type hints.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import pkgutil
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional, get_type_hints
|
||||
|
||||
from .registry import ToolRegistry, tool_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Decorator to mark functions as MCP tools
|
||||
def mcp_tool(
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
category: str = "general",
|
||||
tags: Optional[list[str]] = None,
|
||||
):
|
||||
"""Decorator to mark a function as an MCP tool.
|
||||
|
||||
Args:
|
||||
name: Tool name (defaults to function name)
|
||||
description: Tool description (defaults to docstring)
|
||||
category: Tool category for organization
|
||||
tags: Additional tags for filtering
|
||||
|
||||
Example:
|
||||
@mcp_tool(name="weather", category="external")
|
||||
def get_weather(city: str) -> dict:
|
||||
'''Get weather for a city.'''
|
||||
...
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
func._mcp_tool = True
|
||||
func._mcp_name = name or func.__name__
|
||||
func._mcp_description = description or (func.__doc__ or "").strip()
|
||||
func._mcp_category = category
|
||||
func._mcp_tags = tags or []
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiscoveredTool:
|
||||
"""A tool discovered via introspection."""
|
||||
name: str
|
||||
description: str
|
||||
function: Callable
|
||||
module: str
|
||||
category: str
|
||||
tags: list[str]
|
||||
parameters_schema: dict[str, Any]
|
||||
returns_schema: dict[str, Any]
|
||||
source_file: Optional[str] = None
|
||||
line_number: int = 0
|
||||
|
||||
|
||||
class ToolDiscovery:
|
||||
"""Discovers and registers MCP tools from Python modules.
|
||||
|
||||
Usage:
|
||||
discovery = ToolDiscovery()
|
||||
|
||||
# Discover from a module
|
||||
tools = discovery.discover_module("tools.git")
|
||||
|
||||
# Auto-register with registry
|
||||
discovery.auto_register("tools")
|
||||
|
||||
# Discover from all installed packages
|
||||
tools = discovery.discover_all_packages()
|
||||
"""
|
||||
|
||||
def __init__(self, registry: Optional[ToolRegistry] = None) -> None:
|
||||
self.registry = registry or tool_registry
|
||||
self._discovered: list[DiscoveredTool] = []
|
||||
|
||||
def discover_module(self, module_name: str) -> list[DiscoveredTool]:
|
||||
"""Discover all MCP tools in a module.
|
||||
|
||||
Args:
|
||||
module_name: Dotted path to module (e.g., "tools.git")
|
||||
|
||||
Returns:
|
||||
List of discovered tools
|
||||
"""
|
||||
discovered = []
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except ImportError as exc:
|
||||
logger.warning("Failed to import module %s: %s", module_name, exc)
|
||||
return discovered
|
||||
|
||||
# Get module file path for source location
|
||||
module_file = getattr(module, "__file__", None)
|
||||
|
||||
# Iterate through module members
|
||||
for name, obj in inspect.getmembers(module):
|
||||
# Skip private and non-callable
|
||||
if name.startswith("_") or not callable(obj):
|
||||
continue
|
||||
|
||||
# Check if marked as MCP tool
|
||||
if not getattr(obj, "_mcp_tool", False):
|
||||
continue
|
||||
|
||||
# Get source location
|
||||
try:
|
||||
source_file = inspect.getfile(obj)
|
||||
line_number = inspect.getsourcelines(obj)[1]
|
||||
except (OSError, TypeError):
|
||||
source_file = module_file
|
||||
line_number = 0
|
||||
|
||||
# Build schemas from type hints
|
||||
try:
|
||||
sig = inspect.signature(obj)
|
||||
parameters_schema = self._build_parameters_schema(sig)
|
||||
returns_schema = self._build_returns_schema(sig, obj)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to build schema for %s: %s", name, exc)
|
||||
parameters_schema = {"type": "object", "properties": {}}
|
||||
returns_schema = {}
|
||||
|
||||
tool = DiscoveredTool(
|
||||
name=getattr(obj, "_mcp_name", name),
|
||||
description=getattr(obj, "_mcp_description", obj.__doc__ or ""),
|
||||
function=obj,
|
||||
module=module_name,
|
||||
category=getattr(obj, "_mcp_category", "general"),
|
||||
tags=getattr(obj, "_mcp_tags", []),
|
||||
parameters_schema=parameters_schema,
|
||||
returns_schema=returns_schema,
|
||||
source_file=source_file,
|
||||
line_number=line_number,
|
||||
)
|
||||
|
||||
discovered.append(tool)
|
||||
logger.debug("Discovered tool: %s from %s", tool.name, module_name)
|
||||
|
||||
self._discovered.extend(discovered)
|
||||
logger.info("Discovered %d tools from module %s", len(discovered), module_name)
|
||||
return discovered
|
||||
|
||||
def discover_package(self, package_name: str, recursive: bool = True) -> list[DiscoveredTool]:
|
||||
"""Discover tools from all modules in a package.
|
||||
|
||||
Args:
|
||||
package_name: Package name (e.g., "tools")
|
||||
recursive: Whether to search subpackages
|
||||
|
||||
Returns:
|
||||
List of discovered tools
|
||||
"""
|
||||
discovered = []
|
||||
|
||||
try:
|
||||
package = importlib.import_module(package_name)
|
||||
except ImportError as exc:
|
||||
logger.warning("Failed to import package %s: %s", package_name, exc)
|
||||
return discovered
|
||||
|
||||
package_path = getattr(package, "__path__", [])
|
||||
if not package_path:
|
||||
# Not a package, treat as module
|
||||
return self.discover_module(package_name)
|
||||
|
||||
# Walk package modules
|
||||
for _, name, is_pkg in pkgutil.iter_modules(package_path, prefix=f"{package_name}."):
|
||||
if is_pkg and recursive:
|
||||
discovered.extend(self.discover_package(name, recursive=True))
|
||||
else:
|
||||
discovered.extend(self.discover_module(name))
|
||||
|
||||
return discovered
|
||||
|
||||
def discover_file(self, file_path: Path) -> list[DiscoveredTool]:
|
||||
"""Discover tools from a Python file.
|
||||
|
||||
Args:
|
||||
file_path: Path to Python file
|
||||
|
||||
Returns:
|
||||
List of discovered tools
|
||||
"""
|
||||
discovered = []
|
||||
|
||||
try:
|
||||
source = file_path.read_text()
|
||||
tree = ast.parse(source)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to parse %s: %s", file_path, exc)
|
||||
return discovered
|
||||
|
||||
# Find all decorated functions
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.FunctionDef):
|
||||
continue
|
||||
|
||||
# Check for @mcp_tool decorator
|
||||
is_tool = False
|
||||
tool_name = node.name
|
||||
tool_description = ast.get_docstring(node) or ""
|
||||
tool_category = "general"
|
||||
tool_tags: list[str] = []
|
||||
|
||||
for decorator in node.decorator_list:
|
||||
if isinstance(decorator, ast.Call):
|
||||
if isinstance(decorator.func, ast.Name) and decorator.func.id == "mcp_tool":
|
||||
is_tool = True
|
||||
# Extract decorator arguments
|
||||
for kw in decorator.keywords:
|
||||
if kw.arg == "name" and isinstance(kw.value, ast.Constant):
|
||||
tool_name = kw.value.value
|
||||
elif kw.arg == "description" and isinstance(kw.value, ast.Constant):
|
||||
tool_description = kw.value.value
|
||||
elif kw.arg == "category" and isinstance(kw.value, ast.Constant):
|
||||
tool_category = kw.value.value
|
||||
elif kw.arg == "tags" and isinstance(kw.value, ast.List):
|
||||
tool_tags = [
|
||||
elt.value for elt in kw.value.elts
|
||||
if isinstance(elt, ast.Constant)
|
||||
]
|
||||
elif isinstance(decorator, ast.Name) and decorator.id == "mcp_tool":
|
||||
is_tool = True
|
||||
|
||||
if not is_tool:
|
||||
continue
|
||||
|
||||
# Build parameter schema from AST
|
||||
parameters_schema = self._build_schema_from_ast(node)
|
||||
|
||||
# We can't get the actual function without importing
|
||||
# So create a placeholder that will be resolved later
|
||||
tool = DiscoveredTool(
|
||||
name=tool_name,
|
||||
description=tool_description,
|
||||
function=None, # Will be resolved when registered
|
||||
module=str(file_path),
|
||||
category=tool_category,
|
||||
tags=tool_tags,
|
||||
parameters_schema=parameters_schema,
|
||||
returns_schema={"type": "object"},
|
||||
source_file=str(file_path),
|
||||
line_number=node.lineno,
|
||||
)
|
||||
|
||||
discovered.append(tool)
|
||||
|
||||
self._discovered.extend(discovered)
|
||||
logger.info("Discovered %d tools from file %s", len(discovered), file_path)
|
||||
return discovered
|
||||
|
||||
def auto_register(self, package_name: str = "tools") -> list[str]:
|
||||
"""Automatically discover and register tools.
|
||||
|
||||
Args:
|
||||
package_name: Package to scan for tools
|
||||
|
||||
Returns:
|
||||
List of registered tool names
|
||||
"""
|
||||
discovered = self.discover_package(package_name)
|
||||
registered = []
|
||||
|
||||
for tool in discovered:
|
||||
if tool.function is None:
|
||||
logger.warning("Skipping %s: no function resolved", tool.name)
|
||||
continue
|
||||
|
||||
try:
|
||||
self.registry.register_tool(
|
||||
name=tool.name,
|
||||
function=tool.function,
|
||||
description=tool.description,
|
||||
category=tool.category,
|
||||
tags=tool.tags,
|
||||
)
|
||||
registered.append(tool.name)
|
||||
logger.debug("Registered tool: %s", tool.name)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to register %s: %s", tool.name, exc)
|
||||
|
||||
logger.info("Auto-registered %d/%d tools", len(registered), len(discovered))
|
||||
return registered
|
||||
|
||||
def _build_parameters_schema(self, sig: inspect.Signature) -> dict[str, Any]:
|
||||
"""Build JSON schema for function parameters."""
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for name, param in sig.parameters.items():
|
||||
if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD):
|
||||
continue
|
||||
|
||||
schema = self._type_to_schema(param.annotation)
|
||||
|
||||
if param.default is param.empty:
|
||||
required.append(name)
|
||||
else:
|
||||
schema["default"] = param.default
|
||||
|
||||
properties[name] = schema
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
|
||||
def _build_returns_schema(
|
||||
self, sig: inspect.Signature, func: Callable
|
||||
) -> dict[str, Any]:
|
||||
"""Build JSON schema for return type."""
|
||||
return_annotation = sig.return_annotation
|
||||
|
||||
if return_annotation is sig.empty:
|
||||
return {"type": "object"}
|
||||
|
||||
return self._type_to_schema(return_annotation)
|
||||
|
||||
def _build_schema_from_ast(self, node: ast.FunctionDef) -> dict[str, Any]:
|
||||
"""Build parameter schema from AST node."""
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
# Get defaults (reversed, since they're at the end)
|
||||
defaults = [None] * (len(node.args.args) - len(node.args.defaults)) + list(node.args.defaults)
|
||||
|
||||
for arg, default in zip(node.args.args, defaults):
|
||||
arg_name = arg.arg
|
||||
arg_type = "string" # Default
|
||||
|
||||
# Try to get type from annotation
|
||||
if arg.annotation:
|
||||
if isinstance(arg.annotation, ast.Name):
|
||||
arg_type = self._ast_type_to_json_type(arg.annotation.id)
|
||||
elif isinstance(arg.annotation, ast.Constant):
|
||||
arg_type = self._ast_type_to_json_type(str(arg.annotation.value))
|
||||
|
||||
schema = {"type": arg_type}
|
||||
|
||||
if default is None:
|
||||
required.append(arg_name)
|
||||
|
||||
properties[arg_name] = schema
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
|
||||
def _type_to_schema(self, annotation: Any) -> dict[str, Any]:
|
||||
"""Convert Python type annotation to JSON schema."""
|
||||
if annotation is inspect.Parameter.empty:
|
||||
return {"type": "string"}
|
||||
|
||||
origin = getattr(annotation, "__origin__", None)
|
||||
args = getattr(annotation, "__args__", ())
|
||||
|
||||
# Handle Optional[T] = Union[T, None]
|
||||
if origin is not None:
|
||||
if str(origin) == "typing.Union" and type(None) in args:
|
||||
# Optional type
|
||||
non_none_args = [a for a in args if a is not type(None)]
|
||||
if len(non_none_args) == 1:
|
||||
schema = self._type_to_schema(non_none_args[0])
|
||||
return schema
|
||||
return {"type": "object"}
|
||||
|
||||
# Handle List[T], Dict[K,V]
|
||||
if origin in (list, tuple):
|
||||
items_schema = {"type": "object"}
|
||||
if args:
|
||||
items_schema = self._type_to_schema(args[0])
|
||||
return {"type": "array", "items": items_schema}
|
||||
|
||||
if origin is dict:
|
||||
return {"type": "object"}
|
||||
|
||||
# Handle basic types
|
||||
if annotation in (str,):
|
||||
return {"type": "string"}
|
||||
elif annotation in (int, float):
|
||||
return {"type": "number"}
|
||||
elif annotation in (bool,):
|
||||
return {"type": "boolean"}
|
||||
elif annotation in (list, tuple):
|
||||
return {"type": "array"}
|
||||
elif annotation in (dict,):
|
||||
return {"type": "object"}
|
||||
|
||||
return {"type": "object"}
|
||||
|
||||
def _ast_type_to_json_type(self, type_name: str) -> str:
|
||||
"""Convert AST type name to JSON schema type."""
|
||||
type_map = {
|
||||
"str": "string",
|
||||
"int": "number",
|
||||
"float": "number",
|
||||
"bool": "boolean",
|
||||
"list": "array",
|
||||
"dict": "object",
|
||||
"List": "array",
|
||||
"Dict": "object",
|
||||
"Optional": "object",
|
||||
"Any": "object",
|
||||
}
|
||||
return type_map.get(type_name, "object")
|
||||
|
||||
def get_discovered(self) -> list[DiscoveredTool]:
|
||||
"""Get all discovered tools."""
|
||||
return list(self._discovered)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear discovered tools cache."""
|
||||
self._discovered.clear()
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
discovery: Optional[ToolDiscovery] = None
|
||||
|
||||
|
||||
def get_discovery(registry: Optional[ToolRegistry] = None) -> ToolDiscovery:
|
||||
"""Get or create the tool discovery singleton."""
|
||||
global discovery
|
||||
if discovery is None:
|
||||
discovery = ToolDiscovery(registry=registry)
|
||||
return discovery
|
||||
444
src/mcp/registry.py
Normal file
444
src/mcp/registry.py
Normal file
@@ -0,0 +1,444 @@
|
||||
"""MCP Tool Registry — Dynamic tool discovery and management.
|
||||
|
||||
The registry maintains a catalog of all available tools, their schemas,
|
||||
and health status. Tools can be registered dynamically at runtime.
|
||||
|
||||
Usage:
|
||||
from mcp.registry import tool_registry
|
||||
|
||||
# Register a tool
|
||||
tool_registry.register("web_search", web_search_schema, web_search_func)
|
||||
|
||||
# Discover tools
|
||||
tools = tool_registry.discover(capabilities=["search"])
|
||||
|
||||
# Execute a tool
|
||||
result = tool_registry.execute("web_search", {"query": "Bitcoin"})
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from mcp.schemas.base import create_tool_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolRecord:
|
||||
"""A registered tool with metadata."""
|
||||
name: str
|
||||
schema: dict
|
||||
handler: Callable
|
||||
category: str = "general"
|
||||
health_status: str = "unknown" # healthy, degraded, unhealthy
|
||||
last_execution: Optional[float] = None
|
||||
execution_count: int = 0
|
||||
error_count: int = 0
|
||||
avg_latency_ms: float = 0.0
|
||||
added_at: float = field(default_factory=time.time)
|
||||
requires_confirmation: bool = False
|
||||
tags: list[str] = field(default_factory=list)
|
||||
source_module: Optional[str] = None
|
||||
auto_discovered: bool = False
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""Central registry for all MCP tools."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._tools: dict[str, ToolRecord] = {}
|
||||
self._categories: dict[str, list[str]] = {}
|
||||
logger.info("ToolRegistry initialized")
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
schema: dict,
|
||||
handler: Callable,
|
||||
category: str = "general",
|
||||
requires_confirmation: bool = False,
|
||||
tags: Optional[list[str]] = None,
|
||||
source_module: Optional[str] = None,
|
||||
auto_discovered: bool = False,
|
||||
) -> ToolRecord:
|
||||
"""Register a new tool.
|
||||
|
||||
Args:
|
||||
name: Unique tool name
|
||||
schema: JSON schema describing inputs/outputs
|
||||
handler: Function to execute
|
||||
category: Tool category for organization
|
||||
requires_confirmation: If True, user must approve before execution
|
||||
tags: Tags for filtering and organization
|
||||
source_module: Module where tool was defined
|
||||
auto_discovered: Whether tool was auto-discovered
|
||||
|
||||
Returns:
|
||||
The registered ToolRecord
|
||||
"""
|
||||
if name in self._tools:
|
||||
logger.warning("Tool '%s' already registered, replacing", name)
|
||||
|
||||
record = ToolRecord(
|
||||
name=name,
|
||||
schema=schema,
|
||||
handler=handler,
|
||||
category=category,
|
||||
requires_confirmation=requires_confirmation,
|
||||
tags=tags or [],
|
||||
source_module=source_module,
|
||||
auto_discovered=auto_discovered,
|
||||
)
|
||||
|
||||
self._tools[name] = record
|
||||
|
||||
# Add to category
|
||||
if category not in self._categories:
|
||||
self._categories[category] = []
|
||||
if name not in self._categories[category]:
|
||||
self._categories[category].append(name)
|
||||
|
||||
logger.info("Registered tool: %s (category: %s)", name, category)
|
||||
return record
|
||||
|
||||
def register_tool(
|
||||
self,
|
||||
name: str,
|
||||
function: Callable,
|
||||
description: Optional[str] = None,
|
||||
category: str = "general",
|
||||
tags: Optional[list[str]] = None,
|
||||
source_module: Optional[str] = None,
|
||||
) -> ToolRecord:
|
||||
"""Register a tool from a function (convenience method for discovery).
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
function: Function to register
|
||||
description: Tool description (defaults to docstring)
|
||||
category: Tool category
|
||||
tags: Tags for organization
|
||||
source_module: Source module path
|
||||
|
||||
Returns:
|
||||
The registered ToolRecord
|
||||
"""
|
||||
# Build schema from function signature
|
||||
sig = inspect.signature(function)
|
||||
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD):
|
||||
continue
|
||||
|
||||
param_schema: dict = {"type": "string"}
|
||||
|
||||
# Try to infer type from annotation
|
||||
if param.annotation != inspect.Parameter.empty:
|
||||
if param.annotation in (int, float):
|
||||
param_schema = {"type": "number"}
|
||||
elif param.annotation == bool:
|
||||
param_schema = {"type": "boolean"}
|
||||
elif param.annotation == list:
|
||||
param_schema = {"type": "array"}
|
||||
elif param.annotation == dict:
|
||||
param_schema = {"type": "object"}
|
||||
|
||||
if param.default is param.empty:
|
||||
required.append(param_name)
|
||||
else:
|
||||
param_schema["default"] = param.default
|
||||
|
||||
properties[param_name] = param_schema
|
||||
|
||||
schema = create_tool_schema(
|
||||
name=name,
|
||||
description=description or (function.__doc__ or f"Execute {name}"),
|
||||
parameters=properties,
|
||||
required=required,
|
||||
)
|
||||
|
||||
return self.register(
|
||||
name=name,
|
||||
schema=schema,
|
||||
handler=function,
|
||||
category=category,
|
||||
tags=tags,
|
||||
source_module=source_module or function.__module__,
|
||||
auto_discovered=True,
|
||||
)
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""Remove a tool from the registry."""
|
||||
if name not in self._tools:
|
||||
return False
|
||||
|
||||
record = self._tools.pop(name)
|
||||
|
||||
# Remove from category
|
||||
if record.category in self._categories:
|
||||
if name in self._categories[record.category]:
|
||||
self._categories[record.category].remove(name)
|
||||
|
||||
logger.info("Unregistered tool: %s", name)
|
||||
return True
|
||||
|
||||
def get(self, name: str) -> Optional[ToolRecord]:
|
||||
"""Get a tool record by name."""
|
||||
return self._tools.get(name)
|
||||
|
||||
def get_handler(self, name: str) -> Optional[Callable]:
|
||||
"""Get just the handler function for a tool."""
|
||||
record = self._tools.get(name)
|
||||
return record.handler if record else None
|
||||
|
||||
def get_schema(self, name: str) -> Optional[dict]:
|
||||
"""Get the JSON schema for a tool."""
|
||||
record = self._tools.get(name)
|
||||
return record.schema if record else None
|
||||
|
||||
def list_tools(self, category: Optional[str] = None) -> list[str]:
|
||||
"""List all tool names, optionally filtered by category."""
|
||||
if category:
|
||||
return self._categories.get(category, [])
|
||||
return list(self._tools.keys())
|
||||
|
||||
def list_categories(self) -> list[str]:
|
||||
"""List all tool categories."""
|
||||
return list(self._categories.keys())
|
||||
|
||||
def discover(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
healthy_only: bool = True,
|
||||
auto_discovered_only: bool = False,
|
||||
) -> list[ToolRecord]:
|
||||
"""Discover tools matching criteria.
|
||||
|
||||
Args:
|
||||
query: Search in tool names and descriptions
|
||||
category: Filter by category
|
||||
tags: Filter by tags (must have all specified tags)
|
||||
healthy_only: Only return healthy tools
|
||||
auto_discovered_only: Only return auto-discovered tools
|
||||
|
||||
Returns:
|
||||
List of matching ToolRecords
|
||||
"""
|
||||
results = []
|
||||
|
||||
for name, record in self._tools.items():
|
||||
# Category filter
|
||||
if category and record.category != category:
|
||||
continue
|
||||
|
||||
# Tags filter
|
||||
if tags:
|
||||
if not all(tag in record.tags for tag in tags):
|
||||
continue
|
||||
|
||||
# Health filter
|
||||
if healthy_only and record.health_status == "unhealthy":
|
||||
continue
|
||||
|
||||
# Auto-discovered filter
|
||||
if auto_discovered_only and not record.auto_discovered:
|
||||
continue
|
||||
|
||||
# Query filter
|
||||
if query:
|
||||
query_lower = query.lower()
|
||||
name_match = query_lower in name.lower()
|
||||
desc = record.schema.get("description", "")
|
||||
desc_match = query_lower in desc.lower()
|
||||
tag_match = any(query_lower in tag.lower() for tag in record.tags)
|
||||
if not (name_match or desc_match or tag_match):
|
||||
continue
|
||||
|
||||
results.append(record)
|
||||
|
||||
return results
|
||||
|
||||
async def execute(self, name: str, params: dict) -> Any:
|
||||
"""Execute a tool by name with given parameters.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
params: Parameters to pass to the tool
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
|
||||
Raises:
|
||||
ValueError: If tool not found
|
||||
RuntimeError: If tool execution fails
|
||||
"""
|
||||
record = self._tools.get(name)
|
||||
if not record:
|
||||
raise ValueError(f"Tool '{name}' not found in registry")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Check if handler is async
|
||||
if inspect.iscoroutinefunction(record.handler):
|
||||
result = await record.handler(**params)
|
||||
else:
|
||||
result = record.handler(**params)
|
||||
|
||||
# Update metrics
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
record.last_execution = time.time()
|
||||
record.execution_count += 1
|
||||
|
||||
# Update rolling average latency
|
||||
if record.execution_count == 1:
|
||||
record.avg_latency_ms = latency_ms
|
||||
else:
|
||||
record.avg_latency_ms = (
|
||||
record.avg_latency_ms * 0.9 + latency_ms * 0.1
|
||||
)
|
||||
|
||||
# Mark healthy on success
|
||||
record.health_status = "healthy"
|
||||
|
||||
logger.debug("Tool '%s' executed in %.2fms", name, latency_ms)
|
||||
return result
|
||||
|
||||
except Exception as exc:
|
||||
record.error_count += 1
|
||||
record.execution_count += 1
|
||||
|
||||
# Degrade health on repeated errors
|
||||
error_rate = record.error_count / record.execution_count
|
||||
if error_rate > 0.5:
|
||||
record.health_status = "unhealthy"
|
||||
logger.error("Tool '%s' marked unhealthy (error rate: %.1f%%)",
|
||||
name, error_rate * 100)
|
||||
elif error_rate > 0.2:
|
||||
record.health_status = "degraded"
|
||||
logger.warning("Tool '%s' degraded (error rate: %.1f%%)",
|
||||
name, error_rate * 100)
|
||||
|
||||
raise RuntimeError(f"Tool '{name}' execution failed: {exc}") from exc
|
||||
|
||||
def check_health(self, name: str) -> str:
|
||||
"""Check health status of a tool."""
|
||||
record = self._tools.get(name)
|
||||
if not record:
|
||||
return "not_found"
|
||||
return record.health_status
|
||||
|
||||
def get_metrics(self, name: Optional[str] = None) -> dict:
|
||||
"""Get metrics for a tool or all tools."""
|
||||
if name:
|
||||
record = self._tools.get(name)
|
||||
if not record:
|
||||
return {}
|
||||
return {
|
||||
"name": record.name,
|
||||
"category": record.category,
|
||||
"health": record.health_status,
|
||||
"executions": record.execution_count,
|
||||
"errors": record.error_count,
|
||||
"avg_latency_ms": round(record.avg_latency_ms, 2),
|
||||
}
|
||||
|
||||
# Return metrics for all tools
|
||||
return {
|
||||
name: self.get_metrics(name)
|
||||
for name in self._tools.keys()
|
||||
}
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Export registry as dictionary (for API/dashboard)."""
|
||||
return {
|
||||
"tools": [
|
||||
{
|
||||
"name": r.name,
|
||||
"schema": r.schema,
|
||||
"category": r.category,
|
||||
"health": r.health_status,
|
||||
"requires_confirmation": r.requires_confirmation,
|
||||
"tags": r.tags,
|
||||
"source_module": r.source_module,
|
||||
"auto_discovered": r.auto_discovered,
|
||||
}
|
||||
for r in self._tools.values()
|
||||
],
|
||||
"categories": self._categories,
|
||||
"total_tools": len(self._tools),
|
||||
"auto_discovered_count": sum(1 for r in self._tools.values() if r.auto_discovered),
|
||||
}
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
|
||||
def get_registry() -> ToolRegistry:
|
||||
"""Get the global tool registry singleton."""
|
||||
return tool_registry
|
||||
|
||||
|
||||
def register_tool(
|
||||
name: Optional[str] = None,
|
||||
category: str = "general",
|
||||
schema: Optional[dict] = None,
|
||||
requires_confirmation: bool = False,
|
||||
):
|
||||
"""Decorator for registering a function as an MCP tool.
|
||||
|
||||
Usage:
|
||||
@register_tool(name="web_search", category="research")
|
||||
def web_search(query: str, max_results: int = 5) -> str:
|
||||
...
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
tool_name = name or func.__name__
|
||||
|
||||
# Auto-generate schema if not provided
|
||||
if schema is None:
|
||||
# Try to infer from type hints
|
||||
sig = inspect.signature(func)
|
||||
params = {}
|
||||
required = []
|
||||
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param.default == inspect.Parameter.empty:
|
||||
required.append(param_name)
|
||||
params[param_name] = {"type": "string"}
|
||||
else:
|
||||
params[param_name] = {
|
||||
"type": "string",
|
||||
"default": str(param.default),
|
||||
}
|
||||
|
||||
tool_schema = create_tool_schema(
|
||||
name=tool_name,
|
||||
description=func.__doc__ or f"Execute {tool_name}",
|
||||
parameters=params,
|
||||
required=required,
|
||||
)
|
||||
else:
|
||||
tool_schema = schema
|
||||
|
||||
tool_registry.register(
|
||||
name=tool_name,
|
||||
schema=tool_schema,
|
||||
handler=func,
|
||||
category=category,
|
||||
requires_confirmation=requires_confirmation,
|
||||
)
|
||||
|
||||
return func
|
||||
return decorator
|
||||
52
src/mcp/schemas/base.py
Normal file
52
src/mcp/schemas/base.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Base schemas for MCP (Model Context Protocol) tools.
|
||||
|
||||
All tools must provide a JSON schema describing their interface.
|
||||
This enables dynamic discovery and type-safe invocation.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def create_tool_schema(
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: dict[str, Any],
|
||||
required: list[str] | None = None,
|
||||
returns: dict[str, Any] | None = None,
|
||||
) -> dict:
|
||||
"""Create a standard MCP tool schema.
|
||||
|
||||
Args:
|
||||
name: Tool name (must be unique)
|
||||
description: Human-readable description
|
||||
parameters: JSON schema for input parameters
|
||||
required: List of required parameter names
|
||||
returns: JSON schema for return value
|
||||
|
||||
Returns:
|
||||
Complete tool schema dict
|
||||
"""
|
||||
return {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": parameters,
|
||||
"required": required or [],
|
||||
},
|
||||
"returns": returns or {"type": "string"},
|
||||
}
|
||||
|
||||
|
||||
# Common parameter schemas
|
||||
PARAM_STRING = {"type": "string"}
|
||||
PARAM_INTEGER = {"type": "integer"}
|
||||
PARAM_BOOLEAN = {"type": "boolean"}
|
||||
PARAM_ARRAY_STRINGS = {"type": "array", "items": {"type": "string"}}
|
||||
PARAM_OBJECT = {"type": "object"}
|
||||
|
||||
# Common return schemas
|
||||
RETURN_STRING = {"type": "string"}
|
||||
RETURN_OBJECT = {"type": "object"}
|
||||
RETURN_ARRAY = {"type": "array"}
|
||||
RETURN_BOOLEAN = {"type": "boolean"}
|
||||
210
src/mcp/server.py
Normal file
210
src/mcp/server.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""MCP (Model Context Protocol) Server.
|
||||
|
||||
Implements the MCP protocol for tool discovery and execution.
|
||||
Agents communicate with this server to discover and invoke tools.
|
||||
|
||||
The server can run:
|
||||
1. In-process (direct method calls) — fastest, for local agents
|
||||
2. HTTP API — for external clients
|
||||
3. Stdio — for subprocess-based agents
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from mcp.registry import tool_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPServer:
|
||||
"""Model Context Protocol server for tool management.
|
||||
|
||||
Provides standard MCP endpoints:
|
||||
- list_tools: Discover available tools
|
||||
- call_tool: Execute a tool
|
||||
- get_schema: Get tool input/output schemas
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.registry = tool_registry
|
||||
logger.info("MCP Server initialized")
|
||||
|
||||
def list_tools(
|
||||
self,
|
||||
category: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
) -> list[dict]:
|
||||
"""List available tools.
|
||||
|
||||
MCP Protocol: tools/list
|
||||
"""
|
||||
tools = self.registry.discover(
|
||||
query=query,
|
||||
category=category,
|
||||
healthy_only=True,
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"name": t.name,
|
||||
"description": t.schema.get("description", ""),
|
||||
"parameters": t.schema.get("parameters", {}),
|
||||
"category": t.category,
|
||||
}
|
||||
for t in tools
|
||||
]
|
||||
|
||||
async def call_tool(self, name: str, arguments: dict) -> dict:
|
||||
"""Execute a tool with given arguments.
|
||||
|
||||
MCP Protocol: tools/call
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
arguments: Tool parameters
|
||||
|
||||
Returns:
|
||||
Result dict with content or error
|
||||
"""
|
||||
try:
|
||||
result = await self.registry.execute(name, arguments)
|
||||
return {
|
||||
"content": [
|
||||
{"type": "text", "text": str(result)}
|
||||
],
|
||||
"isError": False,
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.error("Tool execution failed: %s", exc)
|
||||
return {
|
||||
"content": [
|
||||
{"type": "text", "text": f"Error: {exc}"}
|
||||
],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
def get_schema(self, name: str) -> Optional[dict]:
|
||||
"""Get the JSON schema for a tool.
|
||||
|
||||
MCP Protocol: tools/schema
|
||||
"""
|
||||
return self.registry.get_schema(name)
|
||||
|
||||
def get_tool_info(self, name: str) -> Optional[dict]:
|
||||
"""Get detailed info about a tool including health metrics."""
|
||||
record = self.registry.get(name)
|
||||
if not record:
|
||||
return None
|
||||
|
||||
return {
|
||||
"name": record.name,
|
||||
"schema": record.schema,
|
||||
"category": record.category,
|
||||
"health": record.health_status,
|
||||
"metrics": {
|
||||
"executions": record.execution_count,
|
||||
"errors": record.error_count,
|
||||
"avg_latency_ms": round(record.avg_latency_ms, 2),
|
||||
},
|
||||
"requires_confirmation": record.requires_confirmation,
|
||||
}
|
||||
|
||||
def health_check(self) -> dict:
|
||||
"""Server health status."""
|
||||
tools = self.registry.list_tools()
|
||||
healthy = sum(
|
||||
1 for t in tools
|
||||
if self.registry.check_health(t) == "healthy"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"total_tools": len(tools),
|
||||
"healthy_tools": healthy,
|
||||
"degraded_tools": sum(
|
||||
1 for t in tools
|
||||
if self.registry.check_health(t) == "degraded"
|
||||
),
|
||||
"unhealthy_tools": sum(
|
||||
1 for t in tools
|
||||
if self.registry.check_health(t) == "unhealthy"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class MCPHTTPServer:
|
||||
"""HTTP API wrapper for MCP Server."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.mcp = MCPServer()
|
||||
|
||||
def get_routes(self) -> dict:
|
||||
"""Get FastAPI route handlers."""
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/mcp", tags=["mcp"])
|
||||
|
||||
class ToolCallRequest(BaseModel):
|
||||
name: str
|
||||
arguments: dict = {}
|
||||
|
||||
@router.get("/tools")
|
||||
async def list_tools(
|
||||
category: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
):
|
||||
"""List available tools."""
|
||||
return {"tools": self.mcp.list_tools(category, query)}
|
||||
|
||||
@router.post("/tools/call")
|
||||
async def call_tool(request: ToolCallRequest):
|
||||
"""Execute a tool."""
|
||||
result = await self.mcp.call_tool(request.name, request.arguments)
|
||||
return result
|
||||
|
||||
@router.get("/tools/{name}")
|
||||
async def get_tool(name: str):
|
||||
"""Get tool info."""
|
||||
info = self.mcp.get_tool_info(name)
|
||||
if not info:
|
||||
raise HTTPException(404, f"Tool '{name}' not found")
|
||||
return info
|
||||
|
||||
@router.get("/tools/{name}/schema")
|
||||
async def get_schema(name: str):
|
||||
"""Get tool schema."""
|
||||
schema = self.mcp.get_schema(name)
|
||||
if not schema:
|
||||
raise HTTPException(404, f"Tool '{name}' not found")
|
||||
return schema
|
||||
|
||||
@router.get("/health")
|
||||
async def health():
|
||||
"""Server health check."""
|
||||
return self.mcp.health_check()
|
||||
|
||||
return router
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
mcp_server = MCPServer()
|
||||
|
||||
|
||||
# Convenience functions for agents
|
||||
def discover_tools(query: Optional[str] = None) -> list[dict]:
|
||||
"""Quick tool discovery."""
|
||||
return mcp_server.list_tools(query=query)
|
||||
|
||||
|
||||
async def use_tool(name: str, **kwargs) -> str:
|
||||
"""Execute a tool and return result text."""
|
||||
result = await mcp_server.call_tool(name, kwargs)
|
||||
|
||||
if result.get("isError"):
|
||||
raise RuntimeError(result["content"][0]["text"])
|
||||
|
||||
return result["content"][0]["text"]
|
||||
12
src/router/__init__.py
Normal file
12
src/router/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Cascade LLM Router — Automatic failover between providers."""
|
||||
|
||||
from .cascade import CascadeRouter, Provider, ProviderStatus, get_router
|
||||
from .api import router
|
||||
|
||||
__all__ = [
|
||||
"CascadeRouter",
|
||||
"Provider",
|
||||
"ProviderStatus",
|
||||
"get_router",
|
||||
"router",
|
||||
]
|
||||
199
src/router/api.py
Normal file
199
src/router/api.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""API endpoints for Cascade Router monitoring and control."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .cascade import CascadeRouter, get_router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/v1/router", tags=["router"])
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
"""Request body for completions."""
|
||||
messages: list[dict[str, str]]
|
||||
model: str | None = None
|
||||
temperature: float = 0.7
|
||||
max_tokens: int | None = None
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
"""Response from completion endpoint."""
|
||||
content: str
|
||||
provider: str
|
||||
model: str
|
||||
latency_ms: float
|
||||
|
||||
|
||||
class ProviderControl(BaseModel):
|
||||
"""Control a provider's status."""
|
||||
action: str # "enable", "disable", "reset_circuit"
|
||||
|
||||
|
||||
async def get_cascade_router() -> CascadeRouter:
|
||||
"""Dependency to get the cascade router."""
|
||||
return get_router()
|
||||
|
||||
|
||||
@router.post("/complete", response_model=CompletionResponse)
|
||||
async def complete(
|
||||
request: CompletionRequest,
|
||||
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],
|
||||
) -> dict[str, Any]:
|
||||
"""Complete a conversation with automatic failover.
|
||||
|
||||
Routes through providers in priority order until one succeeds.
|
||||
"""
|
||||
try:
|
||||
result = await cascade.complete(
|
||||
messages=request.messages,
|
||||
model=request.model,
|
||||
temperature=request.temperature,
|
||||
max_tokens=request.max_tokens,
|
||||
)
|
||||
return result
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(status_code=503, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_status(
|
||||
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],
|
||||
) -> dict[str, Any]:
|
||||
"""Get router status and provider health."""
|
||||
return cascade.get_status()
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
async def get_metrics(
|
||||
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],
|
||||
) -> dict[str, Any]:
|
||||
"""Get detailed metrics for all providers."""
|
||||
return cascade.get_metrics()
|
||||
|
||||
|
||||
@router.get("/providers")
|
||||
async def list_providers(
|
||||
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List all configured providers."""
|
||||
return [
|
||||
{
|
||||
"name": p.name,
|
||||
"type": p.type,
|
||||
"enabled": p.enabled,
|
||||
"priority": p.priority,
|
||||
"status": p.status.value,
|
||||
"circuit_state": p.circuit_state.value,
|
||||
"default_model": p.get_default_model(),
|
||||
"models": [m["name"] for m in p.models],
|
||||
}
|
||||
for p in cascade.providers
|
||||
]
|
||||
|
||||
|
||||
@router.post("/providers/{provider_name}/control")
|
||||
async def control_provider(
|
||||
provider_name: str,
|
||||
control: ProviderControl,
|
||||
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],
|
||||
) -> dict[str, str]:
|
||||
"""Control a provider (enable/disable/reset)."""
|
||||
provider = None
|
||||
for p in cascade.providers:
|
||||
if p.name == provider_name:
|
||||
provider = p
|
||||
break
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"Provider {provider_name} not found")
|
||||
|
||||
if control.action == "enable":
|
||||
provider.enabled = True
|
||||
provider.status = provider.status.__class__.HEALTHY
|
||||
return {"message": f"Provider {provider_name} enabled"}
|
||||
|
||||
elif control.action == "disable":
|
||||
provider.enabled = False
|
||||
from .cascade import ProviderStatus
|
||||
provider.status = ProviderStatus.DISABLED
|
||||
return {"message": f"Provider {provider_name} disabled"}
|
||||
|
||||
elif control.action == "reset_circuit":
|
||||
from .cascade import CircuitState, ProviderStatus
|
||||
provider.circuit_state = CircuitState.CLOSED
|
||||
provider.circuit_opened_at = None
|
||||
provider.half_open_calls = 0
|
||||
provider.metrics.consecutive_failures = 0
|
||||
provider.status = ProviderStatus.HEALTHY
|
||||
return {"message": f"Circuit breaker reset for {provider_name}"}
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown action: {control.action}")
|
||||
|
||||
|
||||
@router.post("/health-check")
|
||||
async def run_health_check(
|
||||
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],
|
||||
) -> dict[str, Any]:
|
||||
"""Run health checks on all providers."""
|
||||
results = []
|
||||
|
||||
for provider in cascade.providers:
|
||||
# Quick ping to check availability
|
||||
is_healthy = cascade._check_provider_available(provider)
|
||||
|
||||
from .cascade import ProviderStatus
|
||||
if is_healthy:
|
||||
if provider.status == ProviderStatus.UNHEALTHY:
|
||||
# Reset circuit if it was open but now healthy
|
||||
provider.circuit_state = provider.circuit_state.__class__.CLOSED
|
||||
provider.circuit_opened_at = None
|
||||
provider.status = ProviderStatus.HEALTHY if provider.metrics.error_rate < 0.1 else ProviderStatus.DEGRADED
|
||||
else:
|
||||
provider.status = ProviderStatus.UNHEALTHY
|
||||
|
||||
results.append({
|
||||
"name": provider.name,
|
||||
"type": provider.type,
|
||||
"healthy": is_healthy,
|
||||
"status": provider.status.value,
|
||||
})
|
||||
|
||||
return {
|
||||
"checked_at": asyncio.get_event_loop().time(),
|
||||
"providers": results,
|
||||
"healthy_count": sum(1 for r in results if r["healthy"]),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_config(
|
||||
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],
|
||||
) -> dict[str, Any]:
|
||||
"""Get router configuration (without secrets)."""
|
||||
cfg = cascade.config
|
||||
|
||||
return {
|
||||
"timeout_seconds": cfg.timeout_seconds,
|
||||
"max_retries_per_provider": cfg.max_retries_per_provider,
|
||||
"retry_delay_seconds": cfg.retry_delay_seconds,
|
||||
"circuit_breaker": {
|
||||
"failure_threshold": cfg.circuit_breaker_failure_threshold,
|
||||
"recovery_timeout": cfg.circuit_breaker_recovery_timeout,
|
||||
"half_open_max_calls": cfg.circuit_breaker_half_open_max_calls,
|
||||
},
|
||||
"providers": [
|
||||
{
|
||||
"name": p.name,
|
||||
"type": p.type,
|
||||
"priority": p.priority,
|
||||
"enabled": p.enabled,
|
||||
}
|
||||
for p in cascade.providers
|
||||
],
|
||||
}
|
||||
566
src/router/cascade.py
Normal file
566
src/router/cascade.py
Normal file
@@ -0,0 +1,566 @@
|
||||
"""Cascade LLM Router — Automatic failover between providers.
|
||||
|
||||
Routes requests through an ordered list of LLM providers,
|
||||
automatically failing over on rate limits or errors.
|
||||
Tracks metrics for latency, errors, and cost.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
yaml = None # type: ignore
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
requests = None # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProviderStatus(Enum):
|
||||
"""Health status of a provider."""
|
||||
HEALTHY = "healthy"
|
||||
DEGRADED = "degraded" # Working but slow or occasional errors
|
||||
UNHEALTHY = "unhealthy" # Circuit breaker open
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
"""Circuit breaker state."""
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Failing, rejecting requests
|
||||
HALF_OPEN = "half_open" # Testing if recovered
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMetrics:
|
||||
"""Metrics for a single provider."""
|
||||
total_requests: int = 0
|
||||
successful_requests: int = 0
|
||||
failed_requests: int = 0
|
||||
total_latency_ms: float = 0.0
|
||||
last_request_time: Optional[str] = None
|
||||
last_error_time: Optional[str] = None
|
||||
consecutive_failures: int = 0
|
||||
|
||||
@property
|
||||
def avg_latency_ms(self) -> float:
|
||||
if self.total_requests == 0:
|
||||
return 0.0
|
||||
return self.total_latency_ms / self.total_requests
|
||||
|
||||
@property
|
||||
def error_rate(self) -> float:
|
||||
if self.total_requests == 0:
|
||||
return 0.0
|
||||
return self.failed_requests / self.total_requests
|
||||
|
||||
|
||||
@dataclass
|
||||
class Provider:
|
||||
"""LLM provider configuration and state."""
|
||||
name: str
|
||||
type: str # ollama, openai, anthropic, airllm
|
||||
enabled: bool
|
||||
priority: int
|
||||
url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
models: list[dict] = field(default_factory=list)
|
||||
|
||||
# Runtime state
|
||||
status: ProviderStatus = ProviderStatus.HEALTHY
|
||||
metrics: ProviderMetrics = field(default_factory=ProviderMetrics)
|
||||
circuit_state: CircuitState = CircuitState.CLOSED
|
||||
circuit_opened_at: Optional[float] = None
|
||||
half_open_calls: int = 0
|
||||
|
||||
def get_default_model(self) -> Optional[str]:
|
||||
"""Get the default model for this provider."""
|
||||
for model in self.models:
|
||||
if model.get("default"):
|
||||
return model["name"]
|
||||
if self.models:
|
||||
return self.models[0]["name"]
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouterConfig:
|
||||
"""Cascade router configuration."""
|
||||
timeout_seconds: int = 30
|
||||
max_retries_per_provider: int = 2
|
||||
retry_delay_seconds: int = 1
|
||||
circuit_breaker_failure_threshold: int = 5
|
||||
circuit_breaker_recovery_timeout: int = 60
|
||||
circuit_breaker_half_open_max_calls: int = 2
|
||||
cost_tracking_enabled: bool = True
|
||||
budget_daily_usd: float = 10.0
|
||||
|
||||
|
||||
class CascadeRouter:
|
||||
"""Routes LLM requests with automatic failover.
|
||||
|
||||
Usage:
|
||||
router = CascadeRouter()
|
||||
|
||||
response = await router.complete(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
model="llama3.2"
|
||||
)
|
||||
|
||||
# Check metrics
|
||||
metrics = router.get_metrics()
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: Optional[Path] = None) -> None:
|
||||
self.config_path = config_path or Path("config/providers.yaml")
|
||||
self.providers: list[Provider] = []
|
||||
self.config: RouterConfig = RouterConfig()
|
||||
self._load_config()
|
||||
|
||||
logger.info("CascadeRouter initialized with %d providers", len(self.providers))
|
||||
|
||||
def _load_config(self) -> None:
|
||||
"""Load configuration from YAML."""
|
||||
if not self.config_path.exists():
|
||||
logger.warning("Config not found: %s, using defaults", self.config_path)
|
||||
return
|
||||
|
||||
try:
|
||||
if yaml is None:
|
||||
raise RuntimeError("PyYAML not installed")
|
||||
|
||||
content = self.config_path.read_text()
|
||||
# Expand environment variables
|
||||
content = self._expand_env_vars(content)
|
||||
data = yaml.safe_load(content)
|
||||
|
||||
# Load cascade settings
|
||||
cascade = data.get("cascade", {})
|
||||
self.config = RouterConfig(
|
||||
timeout_seconds=cascade.get("timeout_seconds", 30),
|
||||
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
|
||||
retry_delay_seconds=cascade.get("retry_delay_seconds", 1),
|
||||
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get("failure_threshold", 5),
|
||||
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get("recovery_timeout", 60),
|
||||
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get("half_open_max_calls", 2),
|
||||
)
|
||||
|
||||
# Load providers
|
||||
for p_data in data.get("providers", []):
|
||||
# Skip disabled providers
|
||||
if not p_data.get("enabled", False):
|
||||
continue
|
||||
|
||||
provider = Provider(
|
||||
name=p_data["name"],
|
||||
type=p_data["type"],
|
||||
enabled=p_data.get("enabled", True),
|
||||
priority=p_data.get("priority", 99),
|
||||
url=p_data.get("url"),
|
||||
api_key=p_data.get("api_key"),
|
||||
base_url=p_data.get("base_url"),
|
||||
models=p_data.get("models", []),
|
||||
)
|
||||
|
||||
# Check if provider is actually available
|
||||
if self._check_provider_available(provider):
|
||||
self.providers.append(provider)
|
||||
else:
|
||||
logger.warning("Provider %s not available, skipping", provider.name)
|
||||
|
||||
# Sort by priority
|
||||
self.providers.sort(key=lambda p: p.priority)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Failed to load config: %s", exc)
|
||||
|
||||
def _expand_env_vars(self, content: str) -> str:
|
||||
"""Expand ${VAR} syntax in YAML content."""
|
||||
import os
|
||||
import re
|
||||
|
||||
def replace_var(match):
|
||||
var_name = match.group(1)
|
||||
return os.environ.get(var_name, match.group(0))
|
||||
|
||||
return re.sub(r"\$\{(\w+)\}", replace_var, content)
|
||||
|
||||
def _check_provider_available(self, provider: Provider) -> bool:
|
||||
"""Check if a provider is actually available."""
|
||||
if provider.type == "ollama":
|
||||
# Check if Ollama is running
|
||||
if requests is None:
|
||||
# Can't check without requests, assume available
|
||||
return True
|
||||
try:
|
||||
url = provider.url or "http://localhost:11434"
|
||||
response = requests.get(f"{url}/api/tags", timeout=5)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
elif provider.type == "airllm":
|
||||
# Check if airllm is installed
|
||||
try:
|
||||
import airllm
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
elif provider.type in ("openai", "anthropic"):
|
||||
# Check if API key is set
|
||||
return provider.api_key is not None and provider.api_key != ""
|
||||
|
||||
return True
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict],
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""Complete a chat conversation with automatic failover.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with role and content
|
||||
model: Preferred model (tries this first, then provider defaults)
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Returns:
|
||||
Dict with content, provider_used, and metrics
|
||||
|
||||
Raises:
|
||||
RuntimeError: If all providers fail
|
||||
"""
|
||||
errors = []
|
||||
|
||||
for provider in self.providers:
|
||||
# Skip unhealthy providers (circuit breaker)
|
||||
if provider.status == ProviderStatus.UNHEALTHY:
|
||||
# Check if circuit breaker can close
|
||||
if self._can_close_circuit(provider):
|
||||
provider.circuit_state = CircuitState.HALF_OPEN
|
||||
provider.half_open_calls = 0
|
||||
logger.info("Circuit breaker half-open for %s", provider.name)
|
||||
else:
|
||||
logger.debug("Skipping %s (circuit open)", provider.name)
|
||||
continue
|
||||
|
||||
# Try this provider
|
||||
for attempt in range(self.config.max_retries_per_provider):
|
||||
try:
|
||||
result = await self._try_provider(
|
||||
provider=provider,
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# Success! Update metrics and return
|
||||
self._record_success(provider, result.get("latency_ms", 0))
|
||||
return {
|
||||
"content": result["content"],
|
||||
"provider": provider.name,
|
||||
"model": result.get("model", model or provider.get_default_model()),
|
||||
"latency_ms": result.get("latency_ms", 0),
|
||||
}
|
||||
|
||||
except Exception as exc:
|
||||
error_msg = str(exc)
|
||||
logger.warning(
|
||||
"Provider %s attempt %d failed: %s",
|
||||
provider.name, attempt + 1, error_msg
|
||||
)
|
||||
errors.append(f"{provider.name}: {error_msg}")
|
||||
|
||||
if attempt < self.config.max_retries_per_provider - 1:
|
||||
await asyncio.sleep(self.config.retry_delay_seconds)
|
||||
|
||||
# All retries failed for this provider
|
||||
self._record_failure(provider)
|
||||
|
||||
# All providers failed
|
||||
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
|
||||
|
||||
async def _try_provider(
|
||||
self,
|
||||
provider: Provider,
|
||||
messages: list[dict],
|
||||
model: Optional[str],
|
||||
temperature: float,
|
||||
max_tokens: Optional[int],
|
||||
) -> dict:
|
||||
"""Try a single provider request."""
|
||||
start_time = time.time()
|
||||
|
||||
if provider.type == "ollama":
|
||||
result = await self._call_ollama(
|
||||
provider=provider,
|
||||
messages=messages,
|
||||
model=model or provider.get_default_model(),
|
||||
temperature=temperature,
|
||||
)
|
||||
elif provider.type == "openai":
|
||||
result = await self._call_openai(
|
||||
provider=provider,
|
||||
messages=messages,
|
||||
model=model or provider.get_default_model(),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
elif provider.type == "anthropic":
|
||||
result = await self._call_anthropic(
|
||||
provider=provider,
|
||||
messages=messages,
|
||||
model=model or provider.get_default_model(),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider.type}")
|
||||
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
result["latency_ms"] = latency_ms
|
||||
|
||||
return result
|
||||
|
||||
async def _call_ollama(
|
||||
self,
|
||||
provider: Provider,
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
temperature: float,
|
||||
) -> dict:
|
||||
"""Call Ollama API."""
|
||||
import aiohttp
|
||||
|
||||
url = f"{provider.url}/api/chat"
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=self.config.timeout_seconds)
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(url, json=payload) as response:
|
||||
if response.status != 200:
|
||||
text = await response.text()
|
||||
raise RuntimeError(f"Ollama error {response.status}: {text}")
|
||||
|
||||
data = await response.json()
|
||||
return {
|
||||
"content": data["message"]["content"],
|
||||
"model": model,
|
||||
}
|
||||
|
||||
async def _call_openai(
|
||||
self,
|
||||
provider: Provider,
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: Optional[int],
|
||||
) -> dict:
|
||||
"""Call OpenAI API."""
|
||||
import openai
|
||||
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=provider.api_key,
|
||||
base_url=provider.base_url,
|
||||
timeout=self.config.timeout_seconds,
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if max_tokens:
|
||||
kwargs["max_tokens"] = max_tokens
|
||||
|
||||
response = await client.chat.completions.create(**kwargs)
|
||||
|
||||
return {
|
||||
"content": response.choices[0].message.content,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
async def _call_anthropic(
|
||||
self,
|
||||
provider: Provider,
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: Optional[int],
|
||||
) -> dict:
|
||||
"""Call Anthropic API."""
|
||||
import anthropic
|
||||
|
||||
client = anthropic.AsyncAnthropic(
|
||||
api_key=provider.api_key,
|
||||
timeout=self.config.timeout_seconds,
|
||||
)
|
||||
|
||||
# Convert messages to Anthropic format
|
||||
system_msg = None
|
||||
conversation = []
|
||||
for msg in messages:
|
||||
if msg["role"] == "system":
|
||||
system_msg = msg["content"]
|
||||
else:
|
||||
conversation.append({
|
||||
"role": msg["role"],
|
||||
"content": msg["content"],
|
||||
})
|
||||
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"messages": conversation,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens or 1024,
|
||||
}
|
||||
if system_msg:
|
||||
kwargs["system"] = system_msg
|
||||
|
||||
response = await client.messages.create(**kwargs)
|
||||
|
||||
return {
|
||||
"content": response.content[0].text,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
def _record_success(self, provider: Provider, latency_ms: float) -> None:
|
||||
"""Record a successful request."""
|
||||
provider.metrics.total_requests += 1
|
||||
provider.metrics.successful_requests += 1
|
||||
provider.metrics.total_latency_ms += latency_ms
|
||||
provider.metrics.last_request_time = datetime.now(timezone.utc).isoformat()
|
||||
provider.metrics.consecutive_failures = 0
|
||||
|
||||
# Close circuit breaker if half-open
|
||||
if provider.circuit_state == CircuitState.HALF_OPEN:
|
||||
provider.half_open_calls += 1
|
||||
if provider.half_open_calls >= self.config.circuit_breaker_half_open_max_calls:
|
||||
self._close_circuit(provider)
|
||||
|
||||
# Update status based on error rate
|
||||
if provider.metrics.error_rate < 0.1:
|
||||
provider.status = ProviderStatus.HEALTHY
|
||||
elif provider.metrics.error_rate < 0.3:
|
||||
provider.status = ProviderStatus.DEGRADED
|
||||
|
||||
def _record_failure(self, provider: Provider) -> None:
|
||||
"""Record a failed request."""
|
||||
provider.metrics.total_requests += 1
|
||||
provider.metrics.failed_requests += 1
|
||||
provider.metrics.last_error_time = datetime.now(timezone.utc).isoformat()
|
||||
provider.metrics.consecutive_failures += 1
|
||||
|
||||
# Check if we should open circuit breaker
|
||||
if provider.metrics.consecutive_failures >= self.config.circuit_breaker_failure_threshold:
|
||||
self._open_circuit(provider)
|
||||
|
||||
# Update status
|
||||
if provider.metrics.error_rate > 0.3:
|
||||
provider.status = ProviderStatus.DEGRADED
|
||||
if provider.metrics.error_rate > 0.5:
|
||||
provider.status = ProviderStatus.UNHEALTHY
|
||||
|
||||
def _open_circuit(self, provider: Provider) -> None:
|
||||
"""Open the circuit breaker for a provider."""
|
||||
provider.circuit_state = CircuitState.OPEN
|
||||
provider.circuit_opened_at = time.time()
|
||||
provider.status = ProviderStatus.UNHEALTHY
|
||||
logger.warning("Circuit breaker OPEN for %s", provider.name)
|
||||
|
||||
def _can_close_circuit(self, provider: Provider) -> bool:
|
||||
"""Check if circuit breaker can transition to half-open."""
|
||||
if provider.circuit_opened_at is None:
|
||||
return False
|
||||
elapsed = time.time() - provider.circuit_opened_at
|
||||
return elapsed >= self.config.circuit_breaker_recovery_timeout
|
||||
|
||||
def _close_circuit(self, provider: Provider) -> None:
|
||||
"""Close the circuit breaker (provider healthy again)."""
|
||||
provider.circuit_state = CircuitState.CLOSED
|
||||
provider.circuit_opened_at = None
|
||||
provider.half_open_calls = 0
|
||||
provider.metrics.consecutive_failures = 0
|
||||
provider.status = ProviderStatus.HEALTHY
|
||||
logger.info("Circuit breaker CLOSED for %s", provider.name)
|
||||
|
||||
def get_metrics(self) -> dict:
|
||||
"""Get metrics for all providers."""
|
||||
return {
|
||||
"providers": [
|
||||
{
|
||||
"name": p.name,
|
||||
"type": p.type,
|
||||
"status": p.status.value,
|
||||
"circuit_state": p.circuit_state.value,
|
||||
"metrics": {
|
||||
"total_requests": p.metrics.total_requests,
|
||||
"successful": p.metrics.successful_requests,
|
||||
"failed": p.metrics.failed_requests,
|
||||
"error_rate": round(p.metrics.error_rate, 3),
|
||||
"avg_latency_ms": round(p.metrics.avg_latency_ms, 2),
|
||||
},
|
||||
}
|
||||
for p in self.providers
|
||||
]
|
||||
}
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Get current router status."""
|
||||
healthy = sum(1 for p in self.providers if p.status == ProviderStatus.HEALTHY)
|
||||
|
||||
return {
|
||||
"total_providers": len(self.providers),
|
||||
"healthy_providers": healthy,
|
||||
"degraded_providers": sum(1 for p in self.providers if p.status == ProviderStatus.DEGRADED),
|
||||
"unhealthy_providers": sum(1 for p in self.providers if p.status == ProviderStatus.UNHEALTHY),
|
||||
"providers": [
|
||||
{
|
||||
"name": p.name,
|
||||
"type": p.type,
|
||||
"status": p.status.value,
|
||||
"priority": p.priority,
|
||||
"default_model": p.get_default_model(),
|
||||
}
|
||||
for p in self.providers
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
cascade_router: Optional[CascadeRouter] = None
|
||||
|
||||
|
||||
def get_router() -> CascadeRouter:
|
||||
"""Get or create the cascade router singleton."""
|
||||
global cascade_router
|
||||
if cascade_router is None:
|
||||
cascade_router = CascadeRouter()
|
||||
return cascade_router
|
||||
124
src/tools/code_exec.py
Normal file
124
src/tools/code_exec.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Code execution tool.
|
||||
|
||||
MCP-compliant tool for executing Python code.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from mcp.registry import register_tool
|
||||
from mcp.schemas.base import create_tool_schema, PARAM_STRING, PARAM_BOOLEAN, RETURN_STRING
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
PYTHON_SCHEMA = create_tool_schema(
|
||||
name="python",
|
||||
description="Execute Python code. Use for calculations, data processing, or when precise computation is needed. Code runs in a restricted environment.",
|
||||
parameters={
|
||||
"code": {
|
||||
**PARAM_STRING,
|
||||
"description": "Python code to execute",
|
||||
},
|
||||
"return_output": {
|
||||
**PARAM_BOOLEAN,
|
||||
"description": "Return the value of the last expression",
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
required=["code"],
|
||||
returns=RETURN_STRING,
|
||||
)
|
||||
|
||||
|
||||
def python(code: str, return_output: bool = True) -> str:
|
||||
"""Execute Python code in restricted environment.
|
||||
|
||||
Args:
|
||||
code: Python code to execute
|
||||
return_output: Whether to return last expression value
|
||||
|
||||
Returns:
|
||||
Execution result or error message
|
||||
"""
|
||||
# Safe globals for code execution
|
||||
safe_globals = {
|
||||
"__builtins__": {
|
||||
"abs": abs,
|
||||
"all": all,
|
||||
"any": any,
|
||||
"bin": bin,
|
||||
"bool": bool,
|
||||
"dict": dict,
|
||||
"enumerate": enumerate,
|
||||
"filter": filter,
|
||||
"float": float,
|
||||
"format": format,
|
||||
"hex": hex,
|
||||
"int": int,
|
||||
"isinstance": isinstance,
|
||||
"issubclass": issubclass,
|
||||
"len": len,
|
||||
"list": list,
|
||||
"map": map,
|
||||
"max": max,
|
||||
"min": min,
|
||||
"next": next,
|
||||
"oct": oct,
|
||||
"ord": ord,
|
||||
"pow": pow,
|
||||
"print": lambda *args, **kwargs: None, # Disabled
|
||||
"range": range,
|
||||
"repr": repr,
|
||||
"reversed": reversed,
|
||||
"round": round,
|
||||
"set": set,
|
||||
"slice": slice,
|
||||
"sorted": sorted,
|
||||
"str": str,
|
||||
"sum": sum,
|
||||
"tuple": tuple,
|
||||
"type": type,
|
||||
"zip": zip,
|
||||
}
|
||||
}
|
||||
|
||||
# Allowed modules
|
||||
allowed_modules = ["math", "random", "statistics", "datetime", "json"]
|
||||
|
||||
for mod_name in allowed_modules:
|
||||
try:
|
||||
safe_globals[mod_name] = __import__(mod_name)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Compile and execute
|
||||
compiled = compile(code, "<string>", "eval" if return_output else "exec")
|
||||
|
||||
if return_output:
|
||||
result = eval(compiled, safe_globals, {})
|
||||
return f"Result: {result}"
|
||||
else:
|
||||
exec(compiled, safe_globals, {})
|
||||
return "Code executed successfully."
|
||||
|
||||
except SyntaxError:
|
||||
# Try as exec if eval fails
|
||||
try:
|
||||
compiled = compile(code, "<string>", "exec")
|
||||
exec(compiled, safe_globals, {})
|
||||
return "Code executed successfully."
|
||||
except Exception as exc:
|
||||
error_msg = traceback.format_exc()
|
||||
logger.error("Python execution failed: %s", exc)
|
||||
return f"Error: {exc}\n\n{error_msg}"
|
||||
except Exception as exc:
|
||||
error_msg = traceback.format_exc()
|
||||
logger.error("Python execution failed: %s", exc)
|
||||
return f"Error: {exc}\n\n{error_msg}"
|
||||
|
||||
|
||||
# Register with MCP
|
||||
register_tool(name="python", schema=PYTHON_SCHEMA, category="code")(python)
|
||||
179
src/tools/file_ops.py
Normal file
179
src/tools/file_ops.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""File operations tool.
|
||||
|
||||
MCP-compliant tool for reading, writing, and listing files.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from mcp.registry import register_tool
|
||||
from mcp.schemas.base import create_tool_schema, PARAM_STRING, PARAM_BOOLEAN, RETURN_STRING
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Read File Schema
|
||||
READ_FILE_SCHEMA = create_tool_schema(
|
||||
name="read_file",
|
||||
description="Read contents of a file. Use when user explicitly asks to read a file.",
|
||||
parameters={
|
||||
"path": {
|
||||
**PARAM_STRING,
|
||||
"description": "Path to file (relative to project root or absolute)",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum lines to read (0 = all)",
|
||||
"default": 0,
|
||||
},
|
||||
},
|
||||
required=["path"],
|
||||
returns=RETURN_STRING,
|
||||
)
|
||||
|
||||
# Write File Schema
|
||||
WRITE_FILE_SCHEMA = create_tool_schema(
|
||||
name="write_file",
|
||||
description="Write content to a file. Use when user explicitly asks to save content.",
|
||||
parameters={
|
||||
"path": {
|
||||
**PARAM_STRING,
|
||||
"description": "Path to file",
|
||||
},
|
||||
"content": {
|
||||
**PARAM_STRING,
|
||||
"description": "Content to write",
|
||||
},
|
||||
"append": {
|
||||
**PARAM_BOOLEAN,
|
||||
"description": "Append to file instead of overwrite",
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
required=["path", "content"],
|
||||
returns=RETURN_STRING,
|
||||
)
|
||||
|
||||
# List Directory Schema
|
||||
LIST_DIR_SCHEMA = create_tool_schema(
|
||||
name="list_directory",
|
||||
description="List files in a directory.",
|
||||
parameters={
|
||||
"path": {
|
||||
**PARAM_STRING,
|
||||
"description": "Directory path (default: current)",
|
||||
"default": ".",
|
||||
},
|
||||
"pattern": {
|
||||
**PARAM_STRING,
|
||||
"description": "File pattern filter (e.g., '*.py')",
|
||||
"default": "*",
|
||||
},
|
||||
},
|
||||
returns=RETURN_STRING,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_path(path: str) -> Path:
|
||||
"""Resolve path relative to project root."""
|
||||
from config import settings
|
||||
|
||||
p = Path(path)
|
||||
if p.is_absolute():
|
||||
return p
|
||||
|
||||
# Try relative to project root
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
return project_root / p
|
||||
|
||||
|
||||
def read_file(path: str, limit: int = 0) -> str:
|
||||
"""Read file contents."""
|
||||
try:
|
||||
filepath = _resolve_path(path)
|
||||
|
||||
if not filepath.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
|
||||
if not filepath.is_file():
|
||||
return f"Error: Path is not a file: {path}"
|
||||
|
||||
content = filepath.read_text()
|
||||
|
||||
if limit > 0:
|
||||
lines = content.split('\n')[:limit]
|
||||
content = '\n'.join(lines)
|
||||
if len(content.split('\n')) == limit:
|
||||
content += f"\n\n... [{limit} lines shown]"
|
||||
|
||||
return content
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Read file failed: %s", exc)
|
||||
return f"Error reading file: {exc}"
|
||||
|
||||
|
||||
def write_file(path: str, content: str, append: bool = False) -> str:
|
||||
"""Write content to file."""
|
||||
try:
|
||||
filepath = _resolve_path(path)
|
||||
|
||||
# Ensure directory exists
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
mode = "a" if append else "w"
|
||||
filepath.write_text(content)
|
||||
|
||||
action = "appended to" if append else "wrote"
|
||||
return f"Successfully {action} {filepath}"
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Write file failed: %s", exc)
|
||||
return f"Error writing file: {exc}"
|
||||
|
||||
|
||||
def list_directory(path: str = ".", pattern: str = "*") -> str:
|
||||
"""List directory contents."""
|
||||
try:
|
||||
dirpath = _resolve_path(path)
|
||||
|
||||
if not dirpath.exists():
|
||||
return f"Error: Directory not found: {path}"
|
||||
|
||||
if not dirpath.is_dir():
|
||||
return f"Error: Path is not a directory: {path}"
|
||||
|
||||
items = list(dirpath.glob(pattern))
|
||||
|
||||
files = []
|
||||
dirs = []
|
||||
|
||||
for item in items:
|
||||
if item.is_dir():
|
||||
dirs.append(f"📁 {item.name}/")
|
||||
else:
|
||||
size = item.stat().st_size
|
||||
size_str = f"{size}B" if size < 1024 else f"{size//1024}KB"
|
||||
files.append(f"📄 {item.name} ({size_str})")
|
||||
|
||||
result = [f"Contents of {dirpath}:", ""]
|
||||
result.extend(sorted(dirs))
|
||||
result.extend(sorted(files))
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("List directory failed: %s", exc)
|
||||
return f"Error listing directory: {exc}"
|
||||
|
||||
|
||||
# Register with MCP
|
||||
register_tool(name="read_file", schema=READ_FILE_SCHEMA, category="files")(read_file)
|
||||
register_tool(
|
||||
name="write_file",
|
||||
schema=WRITE_FILE_SCHEMA,
|
||||
category="files",
|
||||
requires_confirmation=True,
|
||||
)(write_file)
|
||||
register_tool(name="list_directory", schema=LIST_DIR_SCHEMA, category="files")(list_directory)
|
||||
70
src/tools/memory_tool.py
Normal file
70
src/tools/memory_tool.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Memory search tool.
|
||||
|
||||
MCP-compliant tool for searching Timmy's memory.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from mcp.registry import register_tool
|
||||
from mcp.schemas.base import create_tool_schema, PARAM_STRING, PARAM_INTEGER, RETURN_STRING
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MEMORY_SEARCH_SCHEMA = create_tool_schema(
|
||||
name="memory_search",
|
||||
description="Search Timmy's memory for past conversations, facts, and context. Use when user asks about previous discussions or when you need to recall something from memory.",
|
||||
parameters={
|
||||
"query": {
|
||||
**PARAM_STRING,
|
||||
"description": "What to search for in memory",
|
||||
},
|
||||
"top_k": {
|
||||
**PARAM_INTEGER,
|
||||
"description": "Number of results to return (1-10)",
|
||||
"default": 5,
|
||||
"minimum": 1,
|
||||
"maximum": 10,
|
||||
},
|
||||
},
|
||||
required=["query"],
|
||||
returns=RETURN_STRING,
|
||||
)
|
||||
|
||||
|
||||
def memory_search(query: str, top_k: int = 5) -> str:
|
||||
"""Search Timmy's memory.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
top_k: Number of results
|
||||
|
||||
Returns:
|
||||
Relevant memories from past conversations
|
||||
"""
|
||||
try:
|
||||
from timmy.semantic_memory import memory_search as semantic_search
|
||||
|
||||
results = semantic_search(query, top_k=top_k)
|
||||
|
||||
if not results:
|
||||
return "No relevant memories found."
|
||||
|
||||
formatted = ["Relevant memories from past conversations:", ""]
|
||||
|
||||
for i, (content, score) in enumerate(results, 1):
|
||||
relevance = "🔥" if score > 0.8 else "⭐" if score > 0.5 else "📄"
|
||||
formatted.append(f"{relevance} [{i}] (score: {score:.2f})")
|
||||
formatted.append(f" {content[:300]}...")
|
||||
formatted.append("")
|
||||
|
||||
return "\n".join(formatted)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Memory search failed: %s", exc)
|
||||
return f"Memory search error: {exc}"
|
||||
|
||||
|
||||
# Register with MCP
|
||||
register_tool(name="memory_search", schema=MEMORY_SEARCH_SCHEMA, category="memory")(memory_search)
|
||||
74
src/tools/web_search.py
Normal file
74
src/tools/web_search.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Web search tool using DuckDuckGo.
|
||||
|
||||
MCP-compliant tool for searching the web.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from mcp.registry import register_tool
|
||||
from mcp.schemas.base import create_tool_schema, PARAM_STRING, PARAM_INTEGER, RETURN_STRING
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
WEB_SEARCH_SCHEMA = create_tool_schema(
|
||||
name="web_search",
|
||||
description="Search the web using DuckDuckGo. Use for current events, news, real-time data, and information not in your training data.",
|
||||
parameters={
|
||||
"query": {
|
||||
**PARAM_STRING,
|
||||
"description": "Search query string",
|
||||
},
|
||||
"max_results": {
|
||||
**PARAM_INTEGER,
|
||||
"description": "Maximum number of results (1-10)",
|
||||
"default": 5,
|
||||
"minimum": 1,
|
||||
"maximum": 10,
|
||||
},
|
||||
},
|
||||
required=["query"],
|
||||
returns=RETURN_STRING,
|
||||
)
|
||||
|
||||
|
||||
def web_search(query: str, max_results: int = 5) -> str:
|
||||
"""Search the web using DuckDuckGo.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
max_results: Maximum results to return
|
||||
|
||||
Returns:
|
||||
Formatted search results
|
||||
"""
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
with DDGS() as ddgs:
|
||||
results = list(ddgs.text(query, max_results=max_results))
|
||||
|
||||
if not results:
|
||||
return "No results found."
|
||||
|
||||
formatted = []
|
||||
for i, r in enumerate(results, 1):
|
||||
title = r.get("title", "No title")
|
||||
body = r.get("body", "No description")
|
||||
href = r.get("href", "")
|
||||
formatted.append(f"{i}. {title}\n {body[:150]}...\n {href}")
|
||||
|
||||
return "\n\n".join(formatted)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Web search failed: %s", exc)
|
||||
return f"Search error: {exc}"
|
||||
|
||||
|
||||
# Register with MCP
|
||||
register_tool(
|
||||
name="web_search",
|
||||
schema=WEB_SEARCH_SCHEMA,
|
||||
category="research",
|
||||
)(web_search)
|
||||
Reference in New Issue
Block a user