feat: code quality audit + autoresearch integration + infra hardening (#150)

This commit is contained in:
Alexander Whitestone
2026-03-08 12:50:44 -04:00
committed by GitHub
parent fd0ede0d51
commit ae3bb1cc21
186 changed files with 5129 additions and 3289 deletions

View File

@@ -119,9 +119,7 @@ def capture_error(
return None
# Format the stack trace
tb_str = "".join(
traceback.format_exception(type(exc), exc, exc.__traceback__)
)
tb_str = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
# Extract file/line from traceback
tb_obj = exc.__traceback__

View File

@@ -19,38 +19,39 @@ logger = logging.getLogger(__name__)
class EventBroadcaster:
"""Broadcasts events to WebSocket clients.
Usage:
from infrastructure.events.broadcaster import event_broadcaster
event_broadcaster.broadcast(event)
"""
def __init__(self) -> None:
self._ws_manager: Optional = None
def _get_ws_manager(self):
"""Lazy import to avoid circular deps."""
if self._ws_manager is None:
try:
from infrastructure.ws_manager.handler import ws_manager
self._ws_manager = ws_manager
except Exception as exc:
logger.debug("WebSocket manager not available: %s", exc)
return self._ws_manager
async def broadcast(self, event: EventLogEntry) -> int:
"""Broadcast an event to all connected WebSocket clients.
Args:
event: The event to broadcast
Returns:
Number of clients notified
"""
ws_manager = self._get_ws_manager()
if not ws_manager:
return 0
# Build message payload
payload = {
"type": "event",
@@ -62,9 +63,9 @@ class EventBroadcaster:
"agent_id": event.agent_id,
"timestamp": event.timestamp,
"data": event.data,
}
},
}
try:
# Broadcast to all connected clients
count = await ws_manager.broadcast_json(payload)
@@ -73,10 +74,10 @@ class EventBroadcaster:
except Exception as exc:
logger.error("Failed to broadcast event: %s", exc)
return 0
def broadcast_sync(self, event: EventLogEntry) -> None:
"""Synchronous wrapper for broadcast.
Use this from synchronous code - it schedules the async broadcast
in the event loop if one is running.
"""
@@ -151,11 +152,11 @@ def get_event_label(event_type: str) -> str:
def format_event_for_display(event: EventLogEntry) -> dict:
"""Format event for display in activity feed.
Returns dict with display-friendly fields.
"""
data = event.data or {}
# Build description based on event type
description = ""
if event.event_type.value == "task.created":
@@ -178,7 +179,7 @@ def format_event_for_display(event: EventLogEntry) -> dict:
val = str(data[key])
description = val[:60] + "..." if len(val) > 60 else val
break
return {
"id": event.id,
"icon": get_event_icon(event.event_type.value),

View File

@@ -16,6 +16,7 @@ logger = logging.getLogger(__name__)
@dataclass
class Event:
"""A typed event in the system."""
type: str # e.g., "agent.task.assigned", "tool.execution.completed"
source: str # Agent or component that emitted the event
data: dict = field(default_factory=dict)
@@ -29,15 +30,15 @@ EventHandler = Callable[[Event], Coroutine[Any, Any, None]]
class EventBus:
"""Async event bus for publish/subscribe pattern.
Usage:
bus = EventBus()
# Subscribe to events
@bus.subscribe("agent.task.*")
async def handle_task(event: Event):
print(f"Task event: {event.data}")
# Publish events
await bus.publish(Event(
type="agent.task.assigned",
@@ -45,88 +46,89 @@ class EventBus:
data={"task_id": "123", "agent": "forge"}
))
"""
def __init__(self) -> None:
self._subscribers: dict[str, list[EventHandler]] = {}
self._history: list[Event] = []
self._max_history = 1000
logger.info("EventBus initialized")
def subscribe(self, event_pattern: str) -> Callable[[EventHandler], EventHandler]:
"""Decorator to subscribe to events matching a pattern.
Patterns support wildcards:
- "agent.task.assigned" — exact match
- "agent.task.*" — any task event
- "agent.*" — any agent event
- "*" — all events
"""
def decorator(handler: EventHandler) -> EventHandler:
if event_pattern not in self._subscribers:
self._subscribers[event_pattern] = []
self._subscribers[event_pattern].append(handler)
logger.debug("Subscribed handler to '%s'", event_pattern)
return handler
return decorator
def unsubscribe(self, event_pattern: str, handler: EventHandler) -> bool:
"""Remove a handler from a subscription."""
if event_pattern not in self._subscribers:
return False
if handler in self._subscribers[event_pattern]:
self._subscribers[event_pattern].remove(handler)
logger.debug("Unsubscribed handler from '%s'", event_pattern)
return True
return False
async def publish(self, event: Event) -> int:
"""Publish an event to all matching subscribers.
Returns:
Number of handlers invoked
"""
# Store in history
self._history.append(event)
if len(self._history) > self._max_history:
self._history = self._history[-self._max_history:]
self._history = self._history[-self._max_history :]
# Find matching handlers
handlers: list[EventHandler] = []
for pattern, pattern_handlers in self._subscribers.items():
if self._match_pattern(event.type, pattern):
handlers.extend(pattern_handlers)
# Invoke handlers concurrently
if handlers:
await asyncio.gather(
*[self._invoke_handler(h, event) for h in handlers],
return_exceptions=True
*[self._invoke_handler(h, event) for h in handlers], return_exceptions=True
)
logger.debug("Published event '%s' to %d handlers", event.type, len(handlers))
return len(handlers)
async def _invoke_handler(self, handler: EventHandler, event: Event) -> None:
"""Invoke a handler with error handling."""
try:
await handler(event)
except Exception as exc:
logger.error("Event handler failed for '%s': %s", event.type, exc)
def _match_pattern(self, event_type: str, pattern: str) -> bool:
"""Check if event type matches a wildcard pattern."""
if pattern == "*":
return True
if pattern.endswith(".*"):
prefix = pattern[:-2]
return event_type.startswith(prefix + ".")
return event_type == pattern
def get_history(
self,
event_type: str | None = None,
@@ -135,15 +137,15 @@ class EventBus:
) -> list[Event]:
"""Get recent event history with optional filtering."""
events = self._history
if event_type:
events = [e for e in events if e.type == event_type]
if source:
events = [e for e in events if e.source == source]
return events[-limit:]
def clear_history(self) -> None:
"""Clear event history."""
self._history.clear()
@@ -156,11 +158,13 @@ event_bus = EventBus()
# Convenience functions
async def emit(event_type: str, source: str, data: dict) -> int:
"""Quick emit an event."""
return await event_bus.publish(Event(
type=event_type,
source=source,
data=data,
))
return await event_bus.publish(
Event(
type=event_type,
source=source,
data=data,
)
)
def on(event_pattern: str) -> Callable[[EventHandler], EventHandler]:

View File

@@ -11,7 +11,7 @@ Usage:
result = await git_hand.run("status")
"""
from infrastructure.hands.shell import shell_hand
from infrastructure.hands.git import git_hand
from infrastructure.hands.shell import shell_hand
__all__ = ["shell_hand", "git_hand"]

View File

@@ -25,16 +25,18 @@ from config import settings
logger = logging.getLogger(__name__)
# Operations that require explicit confirmation before execution
DESTRUCTIVE_OPS = frozenset({
"push --force",
"push -f",
"reset --hard",
"clean -fd",
"clean -f",
"branch -D",
"checkout -- .",
"restore .",
})
DESTRUCTIVE_OPS = frozenset(
{
"push --force",
"push -f",
"reset --hard",
"clean -fd",
"clean -f",
"branch -D",
"checkout -- .",
"restore .",
}
)
@dataclass
@@ -190,7 +192,9 @@ class GitHand:
flag = "-b" if create else ""
return await self.run(f"checkout {flag} {branch}".strip())
async def push(self, remote: str = "origin", branch: str = "", force: bool = False) -> GitResult:
async def push(
self, remote: str = "origin", branch: str = "", force: bool = False
) -> GitResult:
"""Push to remote. Force-push requires explicit opt-in."""
args = f"push -u {remote} {branch}".strip()
if force:

View File

@@ -26,15 +26,17 @@ from config import settings
logger = logging.getLogger(__name__)
# Commands that are always blocked regardless of allow-list
_BLOCKED_COMMANDS = frozenset({
"rm -rf /",
"rm -rf /*",
"mkfs",
"dd if=/dev/zero",
":(){ :|:& };:", # fork bomb
"> /dev/sda",
"chmod -R 777 /",
})
_BLOCKED_COMMANDS = frozenset(
{
"rm -rf /",
"rm -rf /*",
"mkfs",
"dd if=/dev/zero",
":(){ :|:& };:", # fork bomb
"> /dev/sda",
"chmod -R 777 /",
}
)
# Default allow-list: safe build/dev commands
DEFAULT_ALLOWED_PREFIXES = (
@@ -199,9 +201,7 @@ class ShellHand:
proc.kill()
await proc.wait()
latency = (time.time() - start) * 1000
logger.warning(
"Shell command timed out after %ds: %s", effective_timeout, command
)
logger.warning("Shell command timed out after %ds: %s", effective_timeout, command)
return ShellResult(
command=command,
success=False,

View File

@@ -11,15 +11,17 @@ the tool registry.
import logging
from typing import Any
from infrastructure.hands.shell import shell_hand
from infrastructure.hands.git import git_hand
from infrastructure.hands.shell import shell_hand
try:
from mcp.schemas.base import create_tool_schema
except ImportError:
def create_tool_schema(**kwargs):
return kwargs
logger = logging.getLogger(__name__)
# ── Tool schemas ─────────────────────────────────────────────────────────────
@@ -83,6 +85,7 @@ PERSONA_LOCAL_HAND_MAP: dict[str, list[str]] = {
# ── Handlers ─────────────────────────────────────────────────────────────────
async def _handle_shell(**kwargs: Any) -> str:
"""Handler for the shell MCP tool."""
command = kwargs.get("command", "")

View File

@@ -1,12 +1,5 @@
"""Infrastructure models package."""
from infrastructure.models.registry import (
CustomModel,
ModelFormat,
ModelRegistry,
ModelRole,
model_registry,
)
from infrastructure.models.multimodal import (
ModelCapability,
ModelInfo,
@@ -17,6 +10,13 @@ from infrastructure.models.multimodal import (
model_supports_vision,
pull_model_with_fallback,
)
from infrastructure.models.registry import (
CustomModel,
ModelFormat,
ModelRegistry,
ModelRole,
model_registry,
)
__all__ = [
# Registry

View File

@@ -21,39 +21,130 @@ logger = logging.getLogger(__name__)
class ModelCapability(Enum):
"""Capabilities a model can have."""
TEXT = auto() # Standard text completion
VISION = auto() # Image understanding
AUDIO = auto() # Audio/speech processing
TOOLS = auto() # Function calling / tool use
JSON = auto() # Structured output / JSON mode
STREAMING = auto() # Streaming responses
TEXT = auto() # Standard text completion
VISION = auto() # Image understanding
AUDIO = auto() # Audio/speech processing
TOOLS = auto() # Function calling / tool use
JSON = auto() # Structured output / JSON mode
STREAMING = auto() # Streaming responses
# Known model capabilities (local Ollama models)
# These are used when we can't query the model directly
KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
# Llama 3.x series
"llama3.1": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"llama3.1:8b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"llama3.1:8b-instruct": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"llama3.1:70b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"llama3.1:405b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"llama3.2": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
"llama3.1": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"llama3.1:8b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"llama3.1:8b-instruct": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"llama3.1:70b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"llama3.1:405b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"llama3.2": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
ModelCapability.VISION,
},
"llama3.2:1b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"llama3.2:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
"llama3.2-vision": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
"llama3.2-vision:11b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
"llama3.2:3b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
ModelCapability.VISION,
},
"llama3.2-vision": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
ModelCapability.VISION,
},
"llama3.2-vision:11b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
ModelCapability.VISION,
},
# Qwen series
"qwen2.5": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"qwen2.5:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"qwen2.5:14b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"qwen2.5:32b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"qwen2.5:72b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"qwen2.5-vl": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
"qwen2.5-vl:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
"qwen2.5-vl:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
"qwen2.5": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"qwen2.5:7b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"qwen2.5:14b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"qwen2.5:32b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"qwen2.5:72b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"qwen2.5-vl": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
ModelCapability.VISION,
},
"qwen2.5-vl:3b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
ModelCapability.VISION,
},
"qwen2.5-vl:7b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
ModelCapability.VISION,
},
# DeepSeek series
"deepseek-r1": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"deepseek-r1:1.5b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
@@ -61,21 +152,48 @@ KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
"deepseek-r1:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"deepseek-r1:32b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"deepseek-r1:70b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"deepseek-v3": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"deepseek-v3": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
# Gemma series
"gemma2": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"gemma2:2b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"gemma2:9b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"gemma2:27b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
# Mistral series
"mistral": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"mistral:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"mistral-nemo": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"mistral-small": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"mistral-large": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"mistral": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"mistral:7b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"mistral-nemo": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"mistral-small": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"mistral-large": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
# Vision-specific models
"llava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
"llava:7b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
@@ -86,21 +204,48 @@ KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
"bakllava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
"moondream": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
"moondream:1.8b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
# Phi series
"phi3": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"phi3:3.8b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"phi3:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"phi4": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"phi4": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
# Command R
"command-r": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"command-r:35b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"command-r-plus": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"command-r": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"command-r:35b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"command-r-plus": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
# Granite (IBM)
"granite3-dense": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"granite3-moe": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"granite3-dense": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"granite3-moe": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
}
@@ -108,15 +253,15 @@ KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
# These are tried in order when the primary model doesn't support a capability
DEFAULT_FALLBACK_CHAINS: dict[ModelCapability, list[str]] = {
ModelCapability.VISION: [
"llama3.2:3b", # Fast vision model
"llava:7b", # Classic vision model
"qwen2.5-vl:3b", # Qwen vision
"moondream:1.8b", # Tiny vision model (last resort)
"llama3.2:3b", # Fast vision model
"llava:7b", # Classic vision model
"qwen2.5-vl:3b", # Qwen vision
"moondream:1.8b", # Tiny vision model (last resort)
],
ModelCapability.TOOLS: [
"llama3.1:8b-instruct", # Best tool use
"llama3.2:3b", # Smaller but capable
"qwen2.5:7b", # Reliable fallback
"llama3.2:3b", # Smaller but capable
"qwen2.5:7b", # Reliable fallback
],
ModelCapability.AUDIO: [
# Audio models are less common in Ollama
@@ -128,13 +273,14 @@ DEFAULT_FALLBACK_CHAINS: dict[ModelCapability, list[str]] = {
@dataclass
class ModelInfo:
"""Information about a model's capabilities and availability."""
name: str
capabilities: set[ModelCapability] = field(default_factory=set)
is_available: bool = False
is_pulled: bool = False
size_mb: Optional[int] = None
description: str = ""
def supports(self, capability: ModelCapability) -> bool:
"""Check if model supports a specific capability."""
return capability in self.capabilities
@@ -142,26 +288,26 @@ class ModelInfo:
class MultiModalManager:
"""Manages multi-modal model capabilities and fallback chains.
This class:
1. Detects what capabilities each model has
2. Maintains fallback chains for different capabilities
3. Pulls models on-demand with automatic fallback
4. Routes requests to appropriate models based on content type
"""
def __init__(self, ollama_url: Optional[str] = None) -> None:
self.ollama_url = ollama_url or settings.ollama_url
self._available_models: dict[str, ModelInfo] = {}
self._fallback_chains: dict[ModelCapability, list[str]] = dict(DEFAULT_FALLBACK_CHAINS)
self._refresh_available_models()
def _refresh_available_models(self) -> None:
"""Query Ollama for available models."""
try:
import urllib.request
import json
import urllib.request
url = self.ollama_url.replace("localhost", "127.0.0.1")
req = urllib.request.Request(
f"{url}/api/tags",
@@ -170,7 +316,7 @@ class MultiModalManager:
)
with urllib.request.urlopen(req, timeout=5) as response:
data = json.loads(response.read().decode())
for model_data in data.get("models", []):
name = model_data.get("name", "")
self._available_models[name] = ModelInfo(
@@ -181,58 +327,53 @@ class MultiModalManager:
size_mb=model_data.get("size", 0) // (1024 * 1024),
description=model_data.get("details", {}).get("family", ""),
)
logger.info("Found %d models in Ollama", len(self._available_models))
except Exception as exc:
logger.warning("Could not refresh available models: %s", exc)
def _detect_capabilities(self, model_name: str) -> set[ModelCapability]:
"""Detect capabilities for a model based on known data."""
# Normalize model name (strip tags for lookup)
base_name = model_name.split(":")[0]
# Try exact match first
if model_name in KNOWN_MODEL_CAPABILITIES:
return set(KNOWN_MODEL_CAPABILITIES[model_name])
# Try base name match
if base_name in KNOWN_MODEL_CAPABILITIES:
return set(KNOWN_MODEL_CAPABILITIES[base_name])
# Default to text-only for unknown models
logger.debug("Unknown model %s, defaulting to TEXT only", model_name)
return {ModelCapability.TEXT, ModelCapability.STREAMING}
def get_model_capabilities(self, model_name: str) -> set[ModelCapability]:
"""Get capabilities for a specific model."""
if model_name in self._available_models:
return self._available_models[model_name].capabilities
return self._detect_capabilities(model_name)
def model_supports(self, model_name: str, capability: ModelCapability) -> bool:
"""Check if a model supports a specific capability."""
capabilities = self.get_model_capabilities(model_name)
return capability in capabilities
def get_models_with_capability(self, capability: ModelCapability) -> list[ModelInfo]:
"""Get all available models that support a capability."""
return [
info for info in self._available_models.values()
if capability in info.capabilities
]
return [info for info in self._available_models.values() if capability in info.capabilities]
def get_best_model_for(
self,
capability: ModelCapability,
preferred_model: Optional[str] = None
self, capability: ModelCapability, preferred_model: Optional[str] = None
) -> Optional[str]:
"""Get the best available model for a specific capability.
Args:
capability: The required capability
preferred_model: Preferred model to use if available and capable
Returns:
Model name or None if no suitable model found
"""
@@ -243,25 +384,26 @@ class MultiModalManager:
return preferred_model
logger.debug(
"Preferred model %s doesn't support %s, checking fallbacks",
preferred_model, capability.name
preferred_model,
capability.name,
)
# Check fallback chain for this capability
fallback_chain = self._fallback_chains.get(capability, [])
for model_name in fallback_chain:
if model_name in self._available_models:
logger.debug("Using fallback model %s for %s", model_name, capability.name)
return model_name
# Find any available model with this capability
capable_models = self.get_models_with_capability(capability)
if capable_models:
# Sort by size (prefer smaller/faster models as fallback)
capable_models.sort(key=lambda m: m.size_mb or float('inf'))
capable_models.sort(key=lambda m: m.size_mb or float("inf"))
return capable_models[0].name
return None
def pull_model_with_fallback(
self,
primary_model: str,
@@ -269,58 +411,58 @@ class MultiModalManager:
auto_pull: bool = True,
) -> tuple[str, bool]:
"""Pull a model with automatic fallback if unavailable.
Args:
primary_model: The desired model to use
capability: Required capability (for finding fallback)
auto_pull: Whether to attempt pulling missing models
Returns:
Tuple of (model_name, is_fallback)
"""
# Check if primary model is already available
if primary_model in self._available_models:
return primary_model, False
# Try to pull the primary model
if auto_pull:
if self._pull_model(primary_model):
return primary_model, False
# Need to find a fallback
if capability:
fallback = self.get_best_model_for(capability, primary_model)
if fallback:
logger.info(
"Primary model %s unavailable, using fallback %s",
primary_model, fallback
"Primary model %s unavailable, using fallback %s", primary_model, fallback
)
return fallback, True
# Last resort: use the configured default model
default_model = settings.ollama_model
if default_model in self._available_models:
logger.warning(
"Falling back to default model %s (primary: %s unavailable)",
default_model, primary_model
default_model,
primary_model,
)
return default_model, True
# Absolute last resort
return primary_model, False
def _pull_model(self, model_name: str) -> bool:
"""Attempt to pull a model from Ollama.
Returns:
True if successful or model already exists
"""
try:
import urllib.request
import json
import urllib.request
logger.info("Pulling model: %s", model_name)
url = self.ollama_url.replace("localhost", "127.0.0.1")
req = urllib.request.Request(
f"{url}/api/pull",
@@ -328,7 +470,7 @@ class MultiModalManager:
headers={"Content-Type": "application/json"},
data=json.dumps({"name": model_name, "stream": False}).encode(),
)
with urllib.request.urlopen(req, timeout=300) as response:
if response.status == 200:
logger.info("Successfully pulled model: %s", model_name)
@@ -338,55 +480,51 @@ class MultiModalManager:
else:
logger.error("Failed to pull %s: HTTP %s", model_name, response.status)
return False
except Exception as exc:
logger.error("Error pulling model %s: %s", model_name, exc)
return False
def configure_fallback_chain(
self,
capability: ModelCapability,
models: list[str]
) -> None:
def configure_fallback_chain(self, capability: ModelCapability, models: list[str]) -> None:
"""Configure a custom fallback chain for a capability."""
self._fallback_chains[capability] = models
logger.info("Configured fallback chain for %s: %s", capability.name, models)
def get_fallback_chain(self, capability: ModelCapability) -> list[str]:
"""Get the fallback chain for a capability."""
return list(self._fallback_chains.get(capability, []))
def list_available_models(self) -> list[ModelInfo]:
"""List all available models with their capabilities."""
return list(self._available_models.values())
def refresh(self) -> None:
"""Refresh the list of available models."""
self._refresh_available_models()
def get_model_for_content(
self,
content_type: str, # "text", "image", "audio", "multimodal"
preferred_model: Optional[str] = None,
) -> tuple[str, bool]:
"""Get appropriate model based on content type.
Args:
content_type: Type of content (text, image, audio, multimodal)
preferred_model: User's preferred model
Returns:
Tuple of (model_name, is_fallback)
"""
content_type = content_type.lower()
if content_type in ("image", "vision", "multimodal"):
# For vision content, we need a vision-capable model
return self.pull_model_with_fallback(
preferred_model or "llava:7b",
capability=ModelCapability.VISION,
)
elif content_type == "audio":
# Audio support is limited in Ollama
# Would need specific audio models
@@ -395,7 +533,7 @@ class MultiModalManager:
preferred_model or settings.ollama_model,
capability=ModelCapability.TEXT,
)
else:
# Standard text content
return self.pull_model_with_fallback(
@@ -417,8 +555,7 @@ def get_multimodal_manager() -> MultiModalManager:
def get_model_for_capability(
capability: ModelCapability,
preferred_model: Optional[str] = None
capability: ModelCapability, preferred_model: Optional[str] = None
) -> Optional[str]:
"""Convenience function to get best model for a capability."""
return get_multimodal_manager().get_best_model_for(capability, preferred_model)
@@ -430,9 +567,7 @@ def pull_model_with_fallback(
auto_pull: bool = True,
) -> tuple[str, bool]:
"""Convenience function to pull model with fallback."""
return get_multimodal_manager().pull_model_with_fallback(
primary_model, capability, auto_pull
)
return get_multimodal_manager().pull_model_with_fallback(primary_model, capability, auto_pull)
def model_supports_vision(model_name: str) -> bool:

View File

@@ -26,26 +26,29 @@ DB_PATH = Path("data/swarm.db")
class ModelFormat(str, Enum):
"""Supported model weight formats."""
GGUF = "gguf" # Ollama-compatible quantised weights
SAFETENSORS = "safetensors" # HuggingFace safetensors
HF_CHECKPOINT = "hf" # Full HuggingFace checkpoint directory
OLLAMA = "ollama" # Already loaded in Ollama by name
GGUF = "gguf" # Ollama-compatible quantised weights
SAFETENSORS = "safetensors" # HuggingFace safetensors
HF_CHECKPOINT = "hf" # Full HuggingFace checkpoint directory
OLLAMA = "ollama" # Already loaded in Ollama by name
class ModelRole(str, Enum):
"""Role a model can play in the system (OpenClaw-RL style)."""
GENERAL = "general" # Default agent inference
REWARD = "reward" # Process Reward Model (PRM) scoring
TEACHER = "teacher" # On-policy distillation teacher
JUDGE = "judge" # Output quality evaluation
GENERAL = "general" # Default agent inference
REWARD = "reward" # Process Reward Model (PRM) scoring
TEACHER = "teacher" # On-policy distillation teacher
JUDGE = "judge" # Output quality evaluation
@dataclass
class CustomModel:
"""A registered custom model."""
name: str
format: ModelFormat
path: str # Absolute path or Ollama model name
path: str # Absolute path or Ollama model name
role: ModelRole = ModelRole.GENERAL
context_window: int = 4096
description: str = ""
@@ -141,10 +144,16 @@ class ModelRegistry:
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
model.name, model.format.value, model.path,
model.role.value, model.context_window, model.description,
model.registered_at, int(model.active),
model.default_temperature, model.max_tokens,
model.name,
model.format.value,
model.path,
model.role.value,
model.context_window,
model.description,
model.registered_at,
int(model.active),
model.default_temperature,
model.max_tokens,
),
)
conn.commit()
@@ -160,9 +169,7 @@ class ModelRegistry:
return False
conn = _get_conn()
conn.execute("DELETE FROM custom_models WHERE name = ?", (name,))
conn.execute(
"DELETE FROM agent_model_assignments WHERE model_name = ?", (name,)
)
conn.execute("DELETE FROM agent_model_assignments WHERE model_name = ?", (name,))
conn.commit()
conn.close()
del self._models[name]

View File

@@ -9,8 +9,8 @@ No cloud push services — everything stays local.
"""
import logging
import subprocess
import platform
import subprocess
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime, timezone
@@ -25,9 +25,7 @@ class Notification:
title: str
message: str
category: str # swarm | task | agent | system | payment
timestamp: str = field(
default_factory=lambda: datetime.now(timezone.utc).isoformat()
)
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
read: bool = False
@@ -74,9 +72,11 @@ class PushNotifier:
def _native_notify(self, title: str, message: str) -> None:
"""Send a native macOS notification via osascript."""
try:
safe_message = message.replace("\\", "\\\\").replace('"', '\\"')
safe_title = title.replace("\\", "\\\\").replace('"', '\\"')
script = (
f'display notification "{message}" '
f'with title "Agent Dashboard" subtitle "{title}"'
f'display notification "{safe_message}" '
f'with title "Agent Dashboard" subtitle "{safe_title}"'
)
subprocess.Popen(
["osascript", "-e", script],
@@ -114,7 +114,7 @@ class PushNotifier:
def clear(self) -> None:
self._notifications.clear()
def add_listener(self, callback) -> None:
def add_listener(self, callback: "Callable[[Notification], None]") -> None:
"""Register a callback for real-time notification delivery."""
self._listeners.append(callback)
@@ -139,10 +139,7 @@ async def notify_briefing_ready(briefing) -> None:
logger.info("Briefing ready but no pending approvals — skipping native notification")
return
message = (
f"Your morning briefing is ready. "
f"{n_approvals} item(s) await your approval."
)
message = f"Your morning briefing is ready. " f"{n_approvals} item(s) await your approval."
notifier.notify(
title="Morning Briefing Ready",
message=message,

View File

@@ -156,33 +156,23 @@ class OpenFangClient:
async def browse(self, url: str, instruction: str = "") -> HandResult:
"""Web automation via OpenFang's Browser hand."""
return await self.execute_hand(
"browser", {"url": url, "instruction": instruction}
)
return await self.execute_hand("browser", {"url": url, "instruction": instruction})
async def collect(self, target: str, depth: str = "shallow") -> HandResult:
"""OSINT collection via OpenFang's Collector hand."""
return await self.execute_hand(
"collector", {"target": target, "depth": depth}
)
return await self.execute_hand("collector", {"target": target, "depth": depth})
async def predict(self, question: str, horizon: str = "1w") -> HandResult:
"""Superforecasting via OpenFang's Predictor hand."""
return await self.execute_hand(
"predictor", {"question": question, "horizon": horizon}
)
return await self.execute_hand("predictor", {"question": question, "horizon": horizon})
async def find_leads(self, icp: str, max_results: int = 10) -> HandResult:
"""Prospect discovery via OpenFang's Lead hand."""
return await self.execute_hand(
"lead", {"icp": icp, "max_results": max_results}
)
return await self.execute_hand("lead", {"icp": icp, "max_results": max_results})
async def research(self, topic: str, depth: str = "standard") -> HandResult:
"""Deep research via OpenFang's Researcher hand."""
return await self.execute_hand(
"researcher", {"topic": topic, "depth": depth}
)
return await self.execute_hand("researcher", {"topic": topic, "depth": depth})
# ── Inventory ────────────────────────────────────────────────────────────

View File

@@ -22,9 +22,11 @@ from infrastructure.openfang.client import OPENFANG_HANDS, openfang_client
try:
from mcp.schemas.base import create_tool_schema
except ImportError:
def create_tool_schema(**kwargs):
return kwargs
logger = logging.getLogger(__name__)
# ── Tool schemas ─────────────────────────────────────────────────────────────

View File

@@ -1,7 +1,7 @@
"""Cascade LLM Router — Automatic failover between providers."""
from .cascade import CascadeRouter, Provider, ProviderStatus, get_router
from .api import router
from .cascade import CascadeRouter, Provider, ProviderStatus, get_router
__all__ = [
"CascadeRouter",

View File

@@ -15,6 +15,7 @@ router = APIRouter(prefix="/api/v1/router", tags=["router"])
class CompletionRequest(BaseModel):
"""Request body for completions."""
messages: list[dict[str, str]]
model: str | None = None
temperature: float = 0.7
@@ -23,6 +24,7 @@ class CompletionRequest(BaseModel):
class CompletionResponse(BaseModel):
"""Response from completion endpoint."""
content: str
provider: str
model: str
@@ -31,6 +33,7 @@ class CompletionResponse(BaseModel):
class ProviderControl(BaseModel):
"""Control a provider's status."""
action: str # "enable", "disable", "reset_circuit"
@@ -45,7 +48,7 @@ async def complete(
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],
) -> dict[str, Any]:
"""Complete a conversation with automatic failover.
Routes through providers in priority order until one succeeds.
"""
try:
@@ -108,30 +111,32 @@ async def control_provider(
if p.name == provider_name:
provider = p
break
if not provider:
raise HTTPException(status_code=404, detail=f"Provider {provider_name} not found")
if control.action == "enable":
provider.enabled = True
provider.status = provider.status.__class__.HEALTHY
return {"message": f"Provider {provider_name} enabled"}
elif control.action == "disable":
provider.enabled = False
from .cascade import ProviderStatus
provider.status = ProviderStatus.DISABLED
return {"message": f"Provider {provider_name} disabled"}
elif control.action == "reset_circuit":
from .cascade import CircuitState, ProviderStatus
provider.circuit_state = CircuitState.CLOSED
provider.circuit_opened_at = None
provider.half_open_calls = 0
provider.metrics.consecutive_failures = 0
provider.status = ProviderStatus.HEALTHY
return {"message": f"Circuit breaker reset for {provider_name}"}
else:
raise HTTPException(status_code=400, detail=f"Unknown action: {control.action}")
@@ -142,28 +147,35 @@ async def run_health_check(
) -> dict[str, Any]:
"""Run health checks on all providers."""
results = []
for provider in cascade.providers:
# Quick ping to check availability
is_healthy = cascade._check_provider_available(provider)
from .cascade import ProviderStatus
if is_healthy:
if provider.status == ProviderStatus.UNHEALTHY:
# Reset circuit if it was open but now healthy
provider.circuit_state = provider.circuit_state.__class__.CLOSED
provider.circuit_opened_at = None
provider.status = ProviderStatus.HEALTHY if provider.metrics.error_rate < 0.1 else ProviderStatus.DEGRADED
provider.status = (
ProviderStatus.HEALTHY
if provider.metrics.error_rate < 0.1
else ProviderStatus.DEGRADED
)
else:
provider.status = ProviderStatus.UNHEALTHY
results.append({
"name": provider.name,
"type": provider.type,
"healthy": is_healthy,
"status": provider.status.value,
})
results.append(
{
"name": provider.name,
"type": provider.type,
"healthy": is_healthy,
"status": provider.status.value,
}
)
return {
"checked_at": asyncio.get_event_loop().time(),
"providers": results,
@@ -177,7 +189,7 @@ async def get_config(
) -> dict[str, Any]:
"""Get router configuration (without secrets)."""
cfg = cascade.config
return {
"timeout_seconds": cfg.timeout_seconds,
"max_retries_per_provider": cfg.max_retries_per_provider,

View File

@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
class ProviderStatus(Enum):
"""Health status of a provider."""
HEALTHY = "healthy"
DEGRADED = "degraded" # Working but slow or occasional errors
UNHEALTHY = "unhealthy" # Circuit breaker open
@@ -41,22 +42,25 @@ class ProviderStatus(Enum):
class CircuitState(Enum):
"""Circuit breaker state."""
CLOSED = "closed" # Normal operation
OPEN = "open" # Failing, rejecting requests
CLOSED = "closed" # Normal operation
OPEN = "open" # Failing, rejecting requests
HALF_OPEN = "half_open" # Testing if recovered
class ContentType(Enum):
"""Type of content in the request."""
TEXT = "text"
VISION = "vision" # Contains images
AUDIO = "audio" # Contains audio
VISION = "vision" # Contains images
AUDIO = "audio" # Contains audio
MULTIMODAL = "multimodal" # Multiple content types
@dataclass
class ProviderMetrics:
"""Metrics for a single provider."""
total_requests: int = 0
successful_requests: int = 0
failed_requests: int = 0
@@ -64,13 +68,13 @@ class ProviderMetrics:
last_request_time: Optional[str] = None
last_error_time: Optional[str] = None
consecutive_failures: int = 0
@property
def avg_latency_ms(self) -> float:
if self.total_requests == 0:
return 0.0
return self.total_latency_ms / self.total_requests
@property
def error_rate(self) -> float:
if self.total_requests == 0:
@@ -81,6 +85,7 @@ class ProviderMetrics:
@dataclass
class ModelCapability:
"""Capabilities a model supports."""
name: str
supports_vision: bool = False
supports_audio: bool = False
@@ -93,6 +98,7 @@ class ModelCapability:
@dataclass
class Provider:
"""LLM provider configuration and state."""
name: str
type: str # ollama, openai, anthropic, airllm
enabled: bool
@@ -101,14 +107,14 @@ class Provider:
api_key: Optional[str] = None
base_url: Optional[str] = None
models: list[dict] = field(default_factory=list)
# Runtime state
status: ProviderStatus = ProviderStatus.HEALTHY
metrics: ProviderMetrics = field(default_factory=ProviderMetrics)
circuit_state: CircuitState = CircuitState.CLOSED
circuit_opened_at: Optional[float] = None
half_open_calls: int = 0
def get_default_model(self) -> Optional[str]:
"""Get the default model for this provider."""
for model in self.models:
@@ -117,7 +123,7 @@ class Provider:
if self.models:
return self.models[0]["name"]
return None
def get_model_with_capability(self, capability: str) -> Optional[str]:
"""Get a model that supports the given capability."""
for model in self.models:
@@ -126,7 +132,7 @@ class Provider:
return model["name"]
# Fall back to default
return self.get_default_model()
def model_has_capability(self, model_name: str, capability: str) -> bool:
"""Check if a specific model has a capability."""
for model in self.models:
@@ -139,6 +145,7 @@ class Provider:
@dataclass
class RouterConfig:
"""Cascade router configuration."""
timeout_seconds: int = 30
max_retries_per_provider: int = 2
retry_delay_seconds: int = 1
@@ -154,22 +161,22 @@ class RouterConfig:
class CascadeRouter:
"""Routes LLM requests with automatic failover.
Now with multi-modal support:
- Automatically detects content type (text, vision, audio)
- Selects appropriate models based on capabilities
- Falls back through capability-specific model chains
- Supports image URLs and base64 encoding
Usage:
router = CascadeRouter()
# Text request
response = await router.complete(
messages=[{"role": "user", "content": "Hello"}],
model="llama3.2"
)
# Vision request (automatically detects and selects vision model)
response = await router.complete(
messages=[{
@@ -179,68 +186,75 @@ class CascadeRouter:
}],
model="llava:7b"
)
# Check metrics
metrics = router.get_metrics()
"""
def __init__(self, config_path: Optional[Path] = None) -> None:
self.config_path = config_path or Path("config/providers.yaml")
self.providers: list[Provider] = []
self.config: RouterConfig = RouterConfig()
self._load_config()
# Initialize multi-modal manager if available
self._mm_manager: Optional[Any] = None
try:
from infrastructure.models.multimodal import get_multimodal_manager
self._mm_manager = get_multimodal_manager()
except Exception as exc:
logger.debug("Multi-modal manager not available: %s", exc)
logger.info("CascadeRouter initialized with %d providers", len(self.providers))
def _load_config(self) -> None:
"""Load configuration from YAML."""
if not self.config_path.exists():
logger.warning("Config not found: %s, using defaults", self.config_path)
return
try:
if yaml is None:
raise RuntimeError("PyYAML not installed")
content = self.config_path.read_text()
# Expand environment variables
content = self._expand_env_vars(content)
data = yaml.safe_load(content)
# Load cascade settings
cascade = data.get("cascade", {})
# Load fallback chains
fallback_chains = data.get("fallback_chains", {})
# Load multi-modal settings
multimodal = data.get("multimodal", {})
self.config = RouterConfig(
timeout_seconds=cascade.get("timeout_seconds", 30),
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
retry_delay_seconds=cascade.get("retry_delay_seconds", 1),
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get("failure_threshold", 5),
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get("recovery_timeout", 60),
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get("half_open_max_calls", 2),
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get(
"failure_threshold", 5
),
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get(
"recovery_timeout", 60
),
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get(
"half_open_max_calls", 2
),
auto_pull_models=multimodal.get("auto_pull", True),
fallback_chains=fallback_chains,
)
# Load providers
for p_data in data.get("providers", []):
# Skip disabled providers
if not p_data.get("enabled", False):
continue
provider = Provider(
name=p_data["name"],
type=p_data["type"],
@@ -251,30 +265,34 @@ class CascadeRouter:
base_url=p_data.get("base_url"),
models=p_data.get("models", []),
)
# Check if provider is actually available
if self._check_provider_available(provider):
self.providers.append(provider)
else:
logger.warning("Provider %s not available, skipping", provider.name)
# Sort by priority
self.providers.sort(key=lambda p: p.priority)
except Exception as exc:
logger.error("Failed to load config: %s", exc)
def _expand_env_vars(self, content: str) -> str:
"""Expand ${VAR} syntax in YAML content."""
"""Expand ${VAR} syntax in YAML content.
Uses os.environ directly (not settings) because this is a generic
YAML config loader that must expand arbitrary variable references.
"""
import os
import re
def replace_var(match):
def replace_var(match: "re.Match[str]") -> str:
var_name = match.group(1)
return os.environ.get(var_name, match.group(0))
return re.sub(r"\$\{(\w+)\}", replace_var, content)
def _check_provider_available(self, provider: Provider) -> bool:
"""Check if a provider is actually available."""
if provider.type == "ollama":
@@ -288,48 +306,49 @@ class CascadeRouter:
return response.status_code == 200
except Exception:
return False
elif provider.type == "airllm":
# Check if airllm is installed
try:
import airllm
return True
except ImportError:
return False
elif provider.type in ("openai", "anthropic", "grok"):
# Check if API key is set
return provider.api_key is not None and provider.api_key != ""
return True
def _detect_content_type(self, messages: list[dict]) -> ContentType:
"""Detect the type of content in the messages.
Checks for images, audio, etc. in the message content.
"""
has_image = False
has_audio = False
for msg in messages:
content = msg.get("content", "")
# Check for image URLs/paths
if msg.get("images"):
has_image = True
# Check for image URLs in content
if isinstance(content, str):
image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp')
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
if any(ext in content.lower() for ext in image_extensions):
has_image = True
if content.startswith("data:image/"):
has_image = True
# Check for audio
if msg.get("audio"):
has_audio = True
# Check for multimodal content structure
if isinstance(content, list):
for item in content:
@@ -338,7 +357,7 @@ class CascadeRouter:
has_image = True
elif item.get("type") == "audio":
has_audio = True
if has_image and has_audio:
return ContentType.MULTIMODAL
elif has_image:
@@ -346,12 +365,9 @@ class CascadeRouter:
elif has_audio:
return ContentType.AUDIO
return ContentType.TEXT
def _get_fallback_model(
self,
provider: Provider,
original_model: str,
content_type: ContentType
self, provider: Provider, original_model: str, content_type: ContentType
) -> Optional[str]:
"""Get a fallback model for the given content type."""
# Map content type to capability
@@ -360,24 +376,24 @@ class CascadeRouter:
ContentType.AUDIO: "audio",
ContentType.MULTIMODAL: "vision", # Vision models often do both
}
capability = capability_map.get(content_type)
if not capability:
return None
# Check provider's models for capability
fallback_model = provider.get_model_with_capability(capability)
if fallback_model and fallback_model != original_model:
return fallback_model
# Use fallback chains from config
fallback_chain = self.config.fallback_chains.get(capability, [])
for model_name in fallback_chain:
if provider.model_has_capability(model_name, capability):
return model_name
return None
async def complete(
self,
messages: list[dict],
@@ -386,21 +402,21 @@ class CascadeRouter:
max_tokens: Optional[int] = None,
) -> dict:
"""Complete a chat conversation with automatic failover.
Multi-modal support:
- Automatically detects if messages contain images
- Falls back to vision-capable models when needed
- Supports image URLs, paths, and base64 encoding
Args:
messages: List of message dicts with role and content
model: Preferred model (tries this first, then provider defaults)
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
Returns:
Dict with content, provider_used, and metrics
Raises:
RuntimeError: If all providers fail
"""
@@ -408,15 +424,15 @@ class CascadeRouter:
content_type = self._detect_content_type(messages)
if content_type != ContentType.TEXT:
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
errors = []
for provider in self.providers:
# Skip disabled providers
if not provider.enabled:
logger.debug("Skipping %s (disabled)", provider.name)
continue
# Skip unhealthy providers (circuit breaker)
if provider.status == ProviderStatus.UNHEALTHY:
# Check if circuit breaker can close
@@ -427,16 +443,16 @@ class CascadeRouter:
else:
logger.debug("Skipping %s (circuit open)", provider.name)
continue
# Determine which model to use
selected_model = model or provider.get_default_model()
is_fallback_model = False
# For non-text content, check if model supports it
if content_type != ContentType.TEXT and selected_model:
if provider.type == "ollama" and self._mm_manager:
from infrastructure.models.multimodal import ModelCapability
# Check if selected model supports the required capability
if content_type == ContentType.VISION:
supports = self._mm_manager.model_supports(
@@ -450,16 +466,17 @@ class CascadeRouter:
if fallback:
logger.info(
"Model %s doesn't support vision, falling back to %s",
selected_model, fallback
selected_model,
fallback,
)
selected_model = fallback
is_fallback_model = True
else:
logger.warning(
"No vision-capable model found on %s, trying anyway",
provider.name
provider.name,
)
# Try this provider
for attempt in range(self.config.max_retries_per_provider):
try:
@@ -471,34 +488,35 @@ class CascadeRouter:
max_tokens=max_tokens,
content_type=content_type,
)
# Success! Update metrics and return
self._record_success(provider, result.get("latency_ms", 0))
return {
"content": result["content"],
"provider": provider.name,
"model": result.get("model", selected_model or provider.get_default_model()),
"model": result.get(
"model", selected_model or provider.get_default_model()
),
"latency_ms": result.get("latency_ms", 0),
"is_fallback_model": is_fallback_model,
}
except Exception as exc:
error_msg = str(exc)
logger.warning(
"Provider %s attempt %d failed: %s",
provider.name, attempt + 1, error_msg
"Provider %s attempt %d failed: %s", provider.name, attempt + 1, error_msg
)
errors.append(f"{provider.name}: {error_msg}")
if attempt < self.config.max_retries_per_provider - 1:
await asyncio.sleep(self.config.retry_delay_seconds)
# All retries failed for this provider
self._record_failure(provider)
# All providers failed
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
async def _try_provider(
self,
provider: Provider,
@@ -510,7 +528,7 @@ class CascadeRouter:
) -> dict:
"""Try a single provider request."""
start_time = time.time()
if provider.type == "ollama":
result = await self._call_ollama(
provider=provider,
@@ -545,12 +563,12 @@ class CascadeRouter:
)
else:
raise ValueError(f"Unknown provider type: {provider.type}")
latency_ms = (time.time() - start_time) * 1000
result["latency_ms"] = latency_ms
return result
async def _call_ollama(
self,
provider: Provider,
@@ -561,12 +579,12 @@ class CascadeRouter:
) -> dict:
"""Call Ollama API with multi-modal support."""
import aiohttp
url = f"{provider.url}/api/chat"
# Transform messages for Ollama format (including images)
transformed_messages = self._transform_messages_for_ollama(messages)
payload = {
"model": model,
"messages": transformed_messages,
@@ -575,31 +593,31 @@ class CascadeRouter:
"temperature": temperature,
},
}
timeout = aiohttp.ClientTimeout(total=self.config.timeout_seconds)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(url, json=payload) as response:
if response.status != 200:
text = await response.text()
raise RuntimeError(f"Ollama error {response.status}: {text}")
data = await response.json()
return {
"content": data["message"]["content"],
"model": model,
}
def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]:
"""Transform messages to Ollama format, handling images."""
transformed = []
for msg in messages:
new_msg = {
"role": msg.get("role", "user"),
"content": msg.get("content", ""),
}
# Handle images
images = msg.get("images", [])
if images:
@@ -620,11 +638,11 @@ class CascadeRouter:
new_msg["images"].append(img_data)
except Exception as exc:
logger.error("Failed to read image %s: %s", img, exc)
transformed.append(new_msg)
return transformed
async def _call_openai(
self,
provider: Provider,
@@ -635,13 +653,13 @@ class CascadeRouter:
) -> dict:
"""Call OpenAI API."""
import openai
client = openai.AsyncOpenAI(
api_key=provider.api_key,
base_url=provider.base_url,
timeout=self.config.timeout_seconds,
)
kwargs = {
"model": model,
"messages": messages,
@@ -649,14 +667,14 @@ class CascadeRouter:
}
if max_tokens:
kwargs["max_tokens"] = max_tokens
response = await client.chat.completions.create(**kwargs)
return {
"content": response.choices[0].message.content,
"model": response.model,
}
async def _call_anthropic(
self,
provider: Provider,
@@ -667,12 +685,12 @@ class CascadeRouter:
) -> dict:
"""Call Anthropic API."""
import anthropic
client = anthropic.AsyncAnthropic(
api_key=provider.api_key,
timeout=self.config.timeout_seconds,
)
# Convert messages to Anthropic format
system_msg = None
conversation = []
@@ -680,11 +698,13 @@ class CascadeRouter:
if msg["role"] == "system":
system_msg = msg["content"]
else:
conversation.append({
"role": msg["role"],
"content": msg["content"],
})
conversation.append(
{
"role": msg["role"],
"content": msg["content"],
}
)
kwargs = {
"model": model,
"messages": conversation,
@@ -693,9 +713,9 @@ class CascadeRouter:
}
if system_msg:
kwargs["system"] = system_msg
response = await client.messages.create(**kwargs)
return {
"content": response.content[0].text,
"model": response.model,
@@ -733,7 +753,7 @@ class CascadeRouter:
"content": response.choices[0].message.content,
"model": response.model,
}
def _record_success(self, provider: Provider, latency_ms: float) -> None:
"""Record a successful request."""
provider.metrics.total_requests += 1
@@ -741,50 +761,50 @@ class CascadeRouter:
provider.metrics.total_latency_ms += latency_ms
provider.metrics.last_request_time = datetime.now(timezone.utc).isoformat()
provider.metrics.consecutive_failures = 0
# Close circuit breaker if half-open
if provider.circuit_state == CircuitState.HALF_OPEN:
provider.half_open_calls += 1
if provider.half_open_calls >= self.config.circuit_breaker_half_open_max_calls:
self._close_circuit(provider)
# Update status based on error rate
if provider.metrics.error_rate < 0.1:
provider.status = ProviderStatus.HEALTHY
elif provider.metrics.error_rate < 0.3:
provider.status = ProviderStatus.DEGRADED
def _record_failure(self, provider: Provider) -> None:
"""Record a failed request."""
provider.metrics.total_requests += 1
provider.metrics.failed_requests += 1
provider.metrics.last_error_time = datetime.now(timezone.utc).isoformat()
provider.metrics.consecutive_failures += 1
# Check if we should open circuit breaker
if provider.metrics.consecutive_failures >= self.config.circuit_breaker_failure_threshold:
self._open_circuit(provider)
# Update status
if provider.metrics.error_rate > 0.3:
provider.status = ProviderStatus.DEGRADED
if provider.metrics.error_rate > 0.5:
provider.status = ProviderStatus.UNHEALTHY
def _open_circuit(self, provider: Provider) -> None:
"""Open the circuit breaker for a provider."""
provider.circuit_state = CircuitState.OPEN
provider.circuit_opened_at = time.time()
provider.status = ProviderStatus.UNHEALTHY
logger.warning("Circuit breaker OPEN for %s", provider.name)
def _can_close_circuit(self, provider: Provider) -> bool:
"""Check if circuit breaker can transition to half-open."""
if provider.circuit_opened_at is None:
return False
elapsed = time.time() - provider.circuit_opened_at
return elapsed >= self.config.circuit_breaker_recovery_timeout
def _close_circuit(self, provider: Provider) -> None:
"""Close the circuit breaker (provider healthy again)."""
provider.circuit_state = CircuitState.CLOSED
@@ -793,7 +813,7 @@ class CascadeRouter:
provider.metrics.consecutive_failures = 0
provider.status = ProviderStatus.HEALTHY
logger.info("Circuit breaker CLOSED for %s", provider.name)
def get_metrics(self) -> dict:
"""Get metrics for all providers."""
return {
@@ -814,16 +834,20 @@ class CascadeRouter:
for p in self.providers
]
}
def get_status(self) -> dict:
"""Get current router status."""
healthy = sum(1 for p in self.providers if p.status == ProviderStatus.HEALTHY)
return {
"total_providers": len(self.providers),
"healthy_providers": healthy,
"degraded_providers": sum(1 for p in self.providers if p.status == ProviderStatus.DEGRADED),
"unhealthy_providers": sum(1 for p in self.providers if p.status == ProviderStatus.UNHEALTHY),
"degraded_providers": sum(
1 for p in self.providers if p.status == ProviderStatus.DEGRADED
),
"unhealthy_providers": sum(
1 for p in self.providers if p.status == ProviderStatus.UNHEALTHY
),
"providers": [
{
"name": p.name,
@@ -835,7 +859,7 @@ class CascadeRouter:
for p in self.providers
],
}
async def generate_with_image(
self,
prompt: str,
@@ -844,21 +868,23 @@ class CascadeRouter:
temperature: float = 0.7,
) -> dict:
"""Convenience method for vision requests.
Args:
prompt: Text prompt about the image
image_path: Path to image file
model: Vision-capable model (auto-selected if not provided)
temperature: Sampling temperature
Returns:
Response dict with content and metadata
"""
messages = [{
"role": "user",
"content": prompt,
"images": [image_path],
}]
messages = [
{
"role": "user",
"content": prompt,
"images": [image_path],
}
]
return await self.complete(
messages=messages,
model=model,

View File

@@ -22,6 +22,7 @@ logger = logging.getLogger(__name__)
@dataclass
class WSEvent:
"""A WebSocket event to broadcast to connected clients."""
event: str
data: dict
timestamp: str
@@ -93,28 +94,42 @@ class WebSocketManager:
await self.broadcast("agent_left", {"agent_id": agent_id, "name": name})
async def broadcast_task_posted(self, task_id: str, description: str) -> None:
await self.broadcast("task_posted", {
"task_id": task_id, "description": description,
})
await self.broadcast(
"task_posted",
{
"task_id": task_id,
"description": description,
},
)
async def broadcast_bid_submitted(
self, task_id: str, agent_id: str, bid_sats: int
) -> None:
await self.broadcast("bid_submitted", {
"task_id": task_id, "agent_id": agent_id, "bid_sats": bid_sats,
})
async def broadcast_bid_submitted(self, task_id: str, agent_id: str, bid_sats: int) -> None:
await self.broadcast(
"bid_submitted",
{
"task_id": task_id,
"agent_id": agent_id,
"bid_sats": bid_sats,
},
)
async def broadcast_task_assigned(self, task_id: str, agent_id: str) -> None:
await self.broadcast("task_assigned", {
"task_id": task_id, "agent_id": agent_id,
})
await self.broadcast(
"task_assigned",
{
"task_id": task_id,
"agent_id": agent_id,
},
)
async def broadcast_task_completed(
self, task_id: str, agent_id: str, result: str
) -> None:
await self.broadcast("task_completed", {
"task_id": task_id, "agent_id": agent_id, "result": result[:200],
})
async def broadcast_task_completed(self, task_id: str, agent_id: str, result: str) -> None:
await self.broadcast(
"task_completed",
{
"task_id": task_id,
"agent_id": agent_id,
"result": result[:200],
},
)
@property
def connection_count(self) -> int:
@@ -122,28 +137,28 @@ class WebSocketManager:
async def broadcast_json(self, data: dict) -> int:
"""Broadcast raw JSON data to all connected clients.
Args:
data: Dictionary to send as JSON
Returns:
Number of clients notified
"""
message = json.dumps(data)
disconnected = []
count = 0
for ws in self._connections:
try:
await ws.send_text(message)
count += 1
except Exception:
disconnected.append(ws)
# Clean up dead connections
for ws in disconnected:
self.disconnect(ws)
return count
@property