forked from Rockachopa/Timmy-time-dashboard
feat: code quality audit + autoresearch integration + infra hardening (#150)
This commit is contained in:
committed by
GitHub
parent
fd0ede0d51
commit
ae3bb1cc21
@@ -119,9 +119,7 @@ def capture_error(
|
||||
return None
|
||||
|
||||
# Format the stack trace
|
||||
tb_str = "".join(
|
||||
traceback.format_exception(type(exc), exc, exc.__traceback__)
|
||||
)
|
||||
tb_str = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
|
||||
|
||||
# Extract file/line from traceback
|
||||
tb_obj = exc.__traceback__
|
||||
|
||||
@@ -19,38 +19,39 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class EventBroadcaster:
|
||||
"""Broadcasts events to WebSocket clients.
|
||||
|
||||
|
||||
Usage:
|
||||
from infrastructure.events.broadcaster import event_broadcaster
|
||||
event_broadcaster.broadcast(event)
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._ws_manager: Optional = None
|
||||
|
||||
|
||||
def _get_ws_manager(self):
|
||||
"""Lazy import to avoid circular deps."""
|
||||
if self._ws_manager is None:
|
||||
try:
|
||||
from infrastructure.ws_manager.handler import ws_manager
|
||||
|
||||
self._ws_manager = ws_manager
|
||||
except Exception as exc:
|
||||
logger.debug("WebSocket manager not available: %s", exc)
|
||||
return self._ws_manager
|
||||
|
||||
|
||||
async def broadcast(self, event: EventLogEntry) -> int:
|
||||
"""Broadcast an event to all connected WebSocket clients.
|
||||
|
||||
|
||||
Args:
|
||||
event: The event to broadcast
|
||||
|
||||
|
||||
Returns:
|
||||
Number of clients notified
|
||||
"""
|
||||
ws_manager = self._get_ws_manager()
|
||||
if not ws_manager:
|
||||
return 0
|
||||
|
||||
|
||||
# Build message payload
|
||||
payload = {
|
||||
"type": "event",
|
||||
@@ -62,9 +63,9 @@ class EventBroadcaster:
|
||||
"agent_id": event.agent_id,
|
||||
"timestamp": event.timestamp,
|
||||
"data": event.data,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
# Broadcast to all connected clients
|
||||
count = await ws_manager.broadcast_json(payload)
|
||||
@@ -73,10 +74,10 @@ class EventBroadcaster:
|
||||
except Exception as exc:
|
||||
logger.error("Failed to broadcast event: %s", exc)
|
||||
return 0
|
||||
|
||||
|
||||
def broadcast_sync(self, event: EventLogEntry) -> None:
|
||||
"""Synchronous wrapper for broadcast.
|
||||
|
||||
|
||||
Use this from synchronous code - it schedules the async broadcast
|
||||
in the event loop if one is running.
|
||||
"""
|
||||
@@ -151,11 +152,11 @@ def get_event_label(event_type: str) -> str:
|
||||
|
||||
def format_event_for_display(event: EventLogEntry) -> dict:
|
||||
"""Format event for display in activity feed.
|
||||
|
||||
|
||||
Returns dict with display-friendly fields.
|
||||
"""
|
||||
data = event.data or {}
|
||||
|
||||
|
||||
# Build description based on event type
|
||||
description = ""
|
||||
if event.event_type.value == "task.created":
|
||||
@@ -178,7 +179,7 @@ def format_event_for_display(event: EventLogEntry) -> dict:
|
||||
val = str(data[key])
|
||||
description = val[:60] + "..." if len(val) > 60 else val
|
||||
break
|
||||
|
||||
|
||||
return {
|
||||
"id": event.id,
|
||||
"icon": get_event_icon(event.event_type.value),
|
||||
|
||||
@@ -16,6 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class Event:
|
||||
"""A typed event in the system."""
|
||||
|
||||
type: str # e.g., "agent.task.assigned", "tool.execution.completed"
|
||||
source: str # Agent or component that emitted the event
|
||||
data: dict = field(default_factory=dict)
|
||||
@@ -29,15 +30,15 @@ EventHandler = Callable[[Event], Coroutine[Any, Any, None]]
|
||||
|
||||
class EventBus:
|
||||
"""Async event bus for publish/subscribe pattern.
|
||||
|
||||
|
||||
Usage:
|
||||
bus = EventBus()
|
||||
|
||||
|
||||
# Subscribe to events
|
||||
@bus.subscribe("agent.task.*")
|
||||
async def handle_task(event: Event):
|
||||
print(f"Task event: {event.data}")
|
||||
|
||||
|
||||
# Publish events
|
||||
await bus.publish(Event(
|
||||
type="agent.task.assigned",
|
||||
@@ -45,88 +46,89 @@ class EventBus:
|
||||
data={"task_id": "123", "agent": "forge"}
|
||||
))
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._subscribers: dict[str, list[EventHandler]] = {}
|
||||
self._history: list[Event] = []
|
||||
self._max_history = 1000
|
||||
logger.info("EventBus initialized")
|
||||
|
||||
|
||||
def subscribe(self, event_pattern: str) -> Callable[[EventHandler], EventHandler]:
|
||||
"""Decorator to subscribe to events matching a pattern.
|
||||
|
||||
|
||||
Patterns support wildcards:
|
||||
- "agent.task.assigned" — exact match
|
||||
- "agent.task.*" — any task event
|
||||
- "agent.*" — any agent event
|
||||
- "*" — all events
|
||||
"""
|
||||
|
||||
def decorator(handler: EventHandler) -> EventHandler:
|
||||
if event_pattern not in self._subscribers:
|
||||
self._subscribers[event_pattern] = []
|
||||
self._subscribers[event_pattern].append(handler)
|
||||
logger.debug("Subscribed handler to '%s'", event_pattern)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def unsubscribe(self, event_pattern: str, handler: EventHandler) -> bool:
|
||||
"""Remove a handler from a subscription."""
|
||||
if event_pattern not in self._subscribers:
|
||||
return False
|
||||
|
||||
|
||||
if handler in self._subscribers[event_pattern]:
|
||||
self._subscribers[event_pattern].remove(handler)
|
||||
logger.debug("Unsubscribed handler from '%s'", event_pattern)
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def publish(self, event: Event) -> int:
|
||||
"""Publish an event to all matching subscribers.
|
||||
|
||||
|
||||
Returns:
|
||||
Number of handlers invoked
|
||||
"""
|
||||
# Store in history
|
||||
self._history.append(event)
|
||||
if len(self._history) > self._max_history:
|
||||
self._history = self._history[-self._max_history:]
|
||||
|
||||
self._history = self._history[-self._max_history :]
|
||||
|
||||
# Find matching handlers
|
||||
handlers: list[EventHandler] = []
|
||||
|
||||
|
||||
for pattern, pattern_handlers in self._subscribers.items():
|
||||
if self._match_pattern(event.type, pattern):
|
||||
handlers.extend(pattern_handlers)
|
||||
|
||||
|
||||
# Invoke handlers concurrently
|
||||
if handlers:
|
||||
await asyncio.gather(
|
||||
*[self._invoke_handler(h, event) for h in handlers],
|
||||
return_exceptions=True
|
||||
*[self._invoke_handler(h, event) for h in handlers], return_exceptions=True
|
||||
)
|
||||
|
||||
|
||||
logger.debug("Published event '%s' to %d handlers", event.type, len(handlers))
|
||||
return len(handlers)
|
||||
|
||||
|
||||
async def _invoke_handler(self, handler: EventHandler, event: Event) -> None:
|
||||
"""Invoke a handler with error handling."""
|
||||
try:
|
||||
await handler(event)
|
||||
except Exception as exc:
|
||||
logger.error("Event handler failed for '%s': %s", event.type, exc)
|
||||
|
||||
|
||||
def _match_pattern(self, event_type: str, pattern: str) -> bool:
|
||||
"""Check if event type matches a wildcard pattern."""
|
||||
if pattern == "*":
|
||||
return True
|
||||
|
||||
|
||||
if pattern.endswith(".*"):
|
||||
prefix = pattern[:-2]
|
||||
return event_type.startswith(prefix + ".")
|
||||
|
||||
|
||||
return event_type == pattern
|
||||
|
||||
|
||||
def get_history(
|
||||
self,
|
||||
event_type: str | None = None,
|
||||
@@ -135,15 +137,15 @@ class EventBus:
|
||||
) -> list[Event]:
|
||||
"""Get recent event history with optional filtering."""
|
||||
events = self._history
|
||||
|
||||
|
||||
if event_type:
|
||||
events = [e for e in events if e.type == event_type]
|
||||
|
||||
|
||||
if source:
|
||||
events = [e for e in events if e.source == source]
|
||||
|
||||
|
||||
return events[-limit:]
|
||||
|
||||
|
||||
def clear_history(self) -> None:
|
||||
"""Clear event history."""
|
||||
self._history.clear()
|
||||
@@ -156,11 +158,13 @@ event_bus = EventBus()
|
||||
# Convenience functions
|
||||
async def emit(event_type: str, source: str, data: dict) -> int:
|
||||
"""Quick emit an event."""
|
||||
return await event_bus.publish(Event(
|
||||
type=event_type,
|
||||
source=source,
|
||||
data=data,
|
||||
))
|
||||
return await event_bus.publish(
|
||||
Event(
|
||||
type=event_type,
|
||||
source=source,
|
||||
data=data,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def on(event_pattern: str) -> Callable[[EventHandler], EventHandler]:
|
||||
|
||||
@@ -11,7 +11,7 @@ Usage:
|
||||
result = await git_hand.run("status")
|
||||
"""
|
||||
|
||||
from infrastructure.hands.shell import shell_hand
|
||||
from infrastructure.hands.git import git_hand
|
||||
from infrastructure.hands.shell import shell_hand
|
||||
|
||||
__all__ = ["shell_hand", "git_hand"]
|
||||
|
||||
@@ -25,16 +25,18 @@ from config import settings
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Operations that require explicit confirmation before execution
|
||||
DESTRUCTIVE_OPS = frozenset({
|
||||
"push --force",
|
||||
"push -f",
|
||||
"reset --hard",
|
||||
"clean -fd",
|
||||
"clean -f",
|
||||
"branch -D",
|
||||
"checkout -- .",
|
||||
"restore .",
|
||||
})
|
||||
DESTRUCTIVE_OPS = frozenset(
|
||||
{
|
||||
"push --force",
|
||||
"push -f",
|
||||
"reset --hard",
|
||||
"clean -fd",
|
||||
"clean -f",
|
||||
"branch -D",
|
||||
"checkout -- .",
|
||||
"restore .",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -190,7 +192,9 @@ class GitHand:
|
||||
flag = "-b" if create else ""
|
||||
return await self.run(f"checkout {flag} {branch}".strip())
|
||||
|
||||
async def push(self, remote: str = "origin", branch: str = "", force: bool = False) -> GitResult:
|
||||
async def push(
|
||||
self, remote: str = "origin", branch: str = "", force: bool = False
|
||||
) -> GitResult:
|
||||
"""Push to remote. Force-push requires explicit opt-in."""
|
||||
args = f"push -u {remote} {branch}".strip()
|
||||
if force:
|
||||
|
||||
@@ -26,15 +26,17 @@ from config import settings
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Commands that are always blocked regardless of allow-list
|
||||
_BLOCKED_COMMANDS = frozenset({
|
||||
"rm -rf /",
|
||||
"rm -rf /*",
|
||||
"mkfs",
|
||||
"dd if=/dev/zero",
|
||||
":(){ :|:& };:", # fork bomb
|
||||
"> /dev/sda",
|
||||
"chmod -R 777 /",
|
||||
})
|
||||
_BLOCKED_COMMANDS = frozenset(
|
||||
{
|
||||
"rm -rf /",
|
||||
"rm -rf /*",
|
||||
"mkfs",
|
||||
"dd if=/dev/zero",
|
||||
":(){ :|:& };:", # fork bomb
|
||||
"> /dev/sda",
|
||||
"chmod -R 777 /",
|
||||
}
|
||||
)
|
||||
|
||||
# Default allow-list: safe build/dev commands
|
||||
DEFAULT_ALLOWED_PREFIXES = (
|
||||
@@ -199,9 +201,7 @@ class ShellHand:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
latency = (time.time() - start) * 1000
|
||||
logger.warning(
|
||||
"Shell command timed out after %ds: %s", effective_timeout, command
|
||||
)
|
||||
logger.warning("Shell command timed out after %ds: %s", effective_timeout, command)
|
||||
return ShellResult(
|
||||
command=command,
|
||||
success=False,
|
||||
|
||||
@@ -11,15 +11,17 @@ the tool registry.
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from infrastructure.hands.shell import shell_hand
|
||||
from infrastructure.hands.git import git_hand
|
||||
from infrastructure.hands.shell import shell_hand
|
||||
|
||||
try:
|
||||
from mcp.schemas.base import create_tool_schema
|
||||
except ImportError:
|
||||
|
||||
def create_tool_schema(**kwargs):
|
||||
return kwargs
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Tool schemas ─────────────────────────────────────────────────────────────
|
||||
@@ -83,6 +85,7 @@ PERSONA_LOCAL_HAND_MAP: dict[str, list[str]] = {
|
||||
|
||||
# ── Handlers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _handle_shell(**kwargs: Any) -> str:
|
||||
"""Handler for the shell MCP tool."""
|
||||
command = kwargs.get("command", "")
|
||||
|
||||
@@ -1,12 +1,5 @@
|
||||
"""Infrastructure models package."""
|
||||
|
||||
from infrastructure.models.registry import (
|
||||
CustomModel,
|
||||
ModelFormat,
|
||||
ModelRegistry,
|
||||
ModelRole,
|
||||
model_registry,
|
||||
)
|
||||
from infrastructure.models.multimodal import (
|
||||
ModelCapability,
|
||||
ModelInfo,
|
||||
@@ -17,6 +10,13 @@ from infrastructure.models.multimodal import (
|
||||
model_supports_vision,
|
||||
pull_model_with_fallback,
|
||||
)
|
||||
from infrastructure.models.registry import (
|
||||
CustomModel,
|
||||
ModelFormat,
|
||||
ModelRegistry,
|
||||
ModelRole,
|
||||
model_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Registry
|
||||
|
||||
@@ -21,39 +21,130 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelCapability(Enum):
|
||||
"""Capabilities a model can have."""
|
||||
TEXT = auto() # Standard text completion
|
||||
VISION = auto() # Image understanding
|
||||
AUDIO = auto() # Audio/speech processing
|
||||
TOOLS = auto() # Function calling / tool use
|
||||
JSON = auto() # Structured output / JSON mode
|
||||
STREAMING = auto() # Streaming responses
|
||||
|
||||
TEXT = auto() # Standard text completion
|
||||
VISION = auto() # Image understanding
|
||||
AUDIO = auto() # Audio/speech processing
|
||||
TOOLS = auto() # Function calling / tool use
|
||||
JSON = auto() # Structured output / JSON mode
|
||||
STREAMING = auto() # Streaming responses
|
||||
|
||||
|
||||
# Known model capabilities (local Ollama models)
|
||||
# These are used when we can't query the model directly
|
||||
KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
|
||||
# Llama 3.x series
|
||||
"llama3.1": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.1:8b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.1:8b-instruct": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.1:70b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.1:405b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.2": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"llama3.1": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"llama3.1:8b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"llama3.1:8b-instruct": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"llama3.1:70b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"llama3.1:405b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"llama3.2": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
ModelCapability.VISION,
|
||||
},
|
||||
"llama3.2:1b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.2:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"llama3.2-vision": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"llama3.2-vision:11b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
|
||||
"llama3.2:3b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
ModelCapability.VISION,
|
||||
},
|
||||
"llama3.2-vision": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
ModelCapability.VISION,
|
||||
},
|
||||
"llama3.2-vision:11b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
ModelCapability.VISION,
|
||||
},
|
||||
# Qwen series
|
||||
"qwen2.5": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5:14b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5:32b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5:72b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5-vl": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"qwen2.5-vl:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"qwen2.5-vl:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
|
||||
"qwen2.5": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"qwen2.5:7b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"qwen2.5:14b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"qwen2.5:32b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"qwen2.5:72b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"qwen2.5-vl": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
ModelCapability.VISION,
|
||||
},
|
||||
"qwen2.5-vl:3b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
ModelCapability.VISION,
|
||||
},
|
||||
"qwen2.5-vl:7b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
ModelCapability.VISION,
|
||||
},
|
||||
# DeepSeek series
|
||||
"deepseek-r1": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-r1:1.5b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
@@ -61,21 +152,48 @@ KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
|
||||
"deepseek-r1:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-r1:32b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-r1:70b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-v3": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
"deepseek-v3": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
# Gemma series
|
||||
"gemma2": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"gemma2:2b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"gemma2:9b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"gemma2:27b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
# Mistral series
|
||||
"mistral": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"mistral:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"mistral-nemo": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"mistral-small": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"mistral-large": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
"mistral": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"mistral:7b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"mistral-nemo": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"mistral-small": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"mistral-large": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
# Vision-specific models
|
||||
"llava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"llava:7b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
@@ -86,21 +204,48 @@ KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
|
||||
"bakllava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"moondream": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"moondream:1.8b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
|
||||
# Phi series
|
||||
"phi3": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"phi3:3.8b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"phi3:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"phi4": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
"phi4": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
# Command R
|
||||
"command-r": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"command-r:35b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"command-r-plus": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
"command-r": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"command-r:35b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"command-r-plus": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
# Granite (IBM)
|
||||
"granite3-dense": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"granite3-moe": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"granite3-dense": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"granite3-moe": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -108,15 +253,15 @@ KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
|
||||
# These are tried in order when the primary model doesn't support a capability
|
||||
DEFAULT_FALLBACK_CHAINS: dict[ModelCapability, list[str]] = {
|
||||
ModelCapability.VISION: [
|
||||
"llama3.2:3b", # Fast vision model
|
||||
"llava:7b", # Classic vision model
|
||||
"qwen2.5-vl:3b", # Qwen vision
|
||||
"moondream:1.8b", # Tiny vision model (last resort)
|
||||
"llama3.2:3b", # Fast vision model
|
||||
"llava:7b", # Classic vision model
|
||||
"qwen2.5-vl:3b", # Qwen vision
|
||||
"moondream:1.8b", # Tiny vision model (last resort)
|
||||
],
|
||||
ModelCapability.TOOLS: [
|
||||
"llama3.1:8b-instruct", # Best tool use
|
||||
"llama3.2:3b", # Smaller but capable
|
||||
"qwen2.5:7b", # Reliable fallback
|
||||
"llama3.2:3b", # Smaller but capable
|
||||
"qwen2.5:7b", # Reliable fallback
|
||||
],
|
||||
ModelCapability.AUDIO: [
|
||||
# Audio models are less common in Ollama
|
||||
@@ -128,13 +273,14 @@ DEFAULT_FALLBACK_CHAINS: dict[ModelCapability, list[str]] = {
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Information about a model's capabilities and availability."""
|
||||
|
||||
name: str
|
||||
capabilities: set[ModelCapability] = field(default_factory=set)
|
||||
is_available: bool = False
|
||||
is_pulled: bool = False
|
||||
size_mb: Optional[int] = None
|
||||
description: str = ""
|
||||
|
||||
|
||||
def supports(self, capability: ModelCapability) -> bool:
|
||||
"""Check if model supports a specific capability."""
|
||||
return capability in self.capabilities
|
||||
@@ -142,26 +288,26 @@ class ModelInfo:
|
||||
|
||||
class MultiModalManager:
|
||||
"""Manages multi-modal model capabilities and fallback chains.
|
||||
|
||||
|
||||
This class:
|
||||
1. Detects what capabilities each model has
|
||||
2. Maintains fallback chains for different capabilities
|
||||
3. Pulls models on-demand with automatic fallback
|
||||
4. Routes requests to appropriate models based on content type
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, ollama_url: Optional[str] = None) -> None:
|
||||
self.ollama_url = ollama_url or settings.ollama_url
|
||||
self._available_models: dict[str, ModelInfo] = {}
|
||||
self._fallback_chains: dict[ModelCapability, list[str]] = dict(DEFAULT_FALLBACK_CHAINS)
|
||||
self._refresh_available_models()
|
||||
|
||||
|
||||
def _refresh_available_models(self) -> None:
|
||||
"""Query Ollama for available models."""
|
||||
try:
|
||||
import urllib.request
|
||||
import json
|
||||
|
||||
import urllib.request
|
||||
|
||||
url = self.ollama_url.replace("localhost", "127.0.0.1")
|
||||
req = urllib.request.Request(
|
||||
f"{url}/api/tags",
|
||||
@@ -170,7 +316,7 @@ class MultiModalManager:
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=5) as response:
|
||||
data = json.loads(response.read().decode())
|
||||
|
||||
|
||||
for model_data in data.get("models", []):
|
||||
name = model_data.get("name", "")
|
||||
self._available_models[name] = ModelInfo(
|
||||
@@ -181,58 +327,53 @@ class MultiModalManager:
|
||||
size_mb=model_data.get("size", 0) // (1024 * 1024),
|
||||
description=model_data.get("details", {}).get("family", ""),
|
||||
)
|
||||
|
||||
|
||||
logger.info("Found %d models in Ollama", len(self._available_models))
|
||||
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Could not refresh available models: %s", exc)
|
||||
|
||||
|
||||
def _detect_capabilities(self, model_name: str) -> set[ModelCapability]:
|
||||
"""Detect capabilities for a model based on known data."""
|
||||
# Normalize model name (strip tags for lookup)
|
||||
base_name = model_name.split(":")[0]
|
||||
|
||||
|
||||
# Try exact match first
|
||||
if model_name in KNOWN_MODEL_CAPABILITIES:
|
||||
return set(KNOWN_MODEL_CAPABILITIES[model_name])
|
||||
|
||||
|
||||
# Try base name match
|
||||
if base_name in KNOWN_MODEL_CAPABILITIES:
|
||||
return set(KNOWN_MODEL_CAPABILITIES[base_name])
|
||||
|
||||
|
||||
# Default to text-only for unknown models
|
||||
logger.debug("Unknown model %s, defaulting to TEXT only", model_name)
|
||||
return {ModelCapability.TEXT, ModelCapability.STREAMING}
|
||||
|
||||
|
||||
def get_model_capabilities(self, model_name: str) -> set[ModelCapability]:
|
||||
"""Get capabilities for a specific model."""
|
||||
if model_name in self._available_models:
|
||||
return self._available_models[model_name].capabilities
|
||||
return self._detect_capabilities(model_name)
|
||||
|
||||
|
||||
def model_supports(self, model_name: str, capability: ModelCapability) -> bool:
|
||||
"""Check if a model supports a specific capability."""
|
||||
capabilities = self.get_model_capabilities(model_name)
|
||||
return capability in capabilities
|
||||
|
||||
|
||||
def get_models_with_capability(self, capability: ModelCapability) -> list[ModelInfo]:
|
||||
"""Get all available models that support a capability."""
|
||||
return [
|
||||
info for info in self._available_models.values()
|
||||
if capability in info.capabilities
|
||||
]
|
||||
|
||||
return [info for info in self._available_models.values() if capability in info.capabilities]
|
||||
|
||||
def get_best_model_for(
|
||||
self,
|
||||
capability: ModelCapability,
|
||||
preferred_model: Optional[str] = None
|
||||
self, capability: ModelCapability, preferred_model: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""Get the best available model for a specific capability.
|
||||
|
||||
|
||||
Args:
|
||||
capability: The required capability
|
||||
preferred_model: Preferred model to use if available and capable
|
||||
|
||||
|
||||
Returns:
|
||||
Model name or None if no suitable model found
|
||||
"""
|
||||
@@ -243,25 +384,26 @@ class MultiModalManager:
|
||||
return preferred_model
|
||||
logger.debug(
|
||||
"Preferred model %s doesn't support %s, checking fallbacks",
|
||||
preferred_model, capability.name
|
||||
preferred_model,
|
||||
capability.name,
|
||||
)
|
||||
|
||||
|
||||
# Check fallback chain for this capability
|
||||
fallback_chain = self._fallback_chains.get(capability, [])
|
||||
for model_name in fallback_chain:
|
||||
if model_name in self._available_models:
|
||||
logger.debug("Using fallback model %s for %s", model_name, capability.name)
|
||||
return model_name
|
||||
|
||||
|
||||
# Find any available model with this capability
|
||||
capable_models = self.get_models_with_capability(capability)
|
||||
if capable_models:
|
||||
# Sort by size (prefer smaller/faster models as fallback)
|
||||
capable_models.sort(key=lambda m: m.size_mb or float('inf'))
|
||||
capable_models.sort(key=lambda m: m.size_mb or float("inf"))
|
||||
return capable_models[0].name
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def pull_model_with_fallback(
|
||||
self,
|
||||
primary_model: str,
|
||||
@@ -269,58 +411,58 @@ class MultiModalManager:
|
||||
auto_pull: bool = True,
|
||||
) -> tuple[str, bool]:
|
||||
"""Pull a model with automatic fallback if unavailable.
|
||||
|
||||
|
||||
Args:
|
||||
primary_model: The desired model to use
|
||||
capability: Required capability (for finding fallback)
|
||||
auto_pull: Whether to attempt pulling missing models
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (model_name, is_fallback)
|
||||
"""
|
||||
# Check if primary model is already available
|
||||
if primary_model in self._available_models:
|
||||
return primary_model, False
|
||||
|
||||
|
||||
# Try to pull the primary model
|
||||
if auto_pull:
|
||||
if self._pull_model(primary_model):
|
||||
return primary_model, False
|
||||
|
||||
|
||||
# Need to find a fallback
|
||||
if capability:
|
||||
fallback = self.get_best_model_for(capability, primary_model)
|
||||
if fallback:
|
||||
logger.info(
|
||||
"Primary model %s unavailable, using fallback %s",
|
||||
primary_model, fallback
|
||||
"Primary model %s unavailable, using fallback %s", primary_model, fallback
|
||||
)
|
||||
return fallback, True
|
||||
|
||||
|
||||
# Last resort: use the configured default model
|
||||
default_model = settings.ollama_model
|
||||
if default_model in self._available_models:
|
||||
logger.warning(
|
||||
"Falling back to default model %s (primary: %s unavailable)",
|
||||
default_model, primary_model
|
||||
default_model,
|
||||
primary_model,
|
||||
)
|
||||
return default_model, True
|
||||
|
||||
|
||||
# Absolute last resort
|
||||
return primary_model, False
|
||||
|
||||
|
||||
def _pull_model(self, model_name: str) -> bool:
|
||||
"""Attempt to pull a model from Ollama.
|
||||
|
||||
|
||||
Returns:
|
||||
True if successful or model already exists
|
||||
"""
|
||||
try:
|
||||
import urllib.request
|
||||
import json
|
||||
|
||||
import urllib.request
|
||||
|
||||
logger.info("Pulling model: %s", model_name)
|
||||
|
||||
|
||||
url = self.ollama_url.replace("localhost", "127.0.0.1")
|
||||
req = urllib.request.Request(
|
||||
f"{url}/api/pull",
|
||||
@@ -328,7 +470,7 @@ class MultiModalManager:
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps({"name": model_name, "stream": False}).encode(),
|
||||
)
|
||||
|
||||
|
||||
with urllib.request.urlopen(req, timeout=300) as response:
|
||||
if response.status == 200:
|
||||
logger.info("Successfully pulled model: %s", model_name)
|
||||
@@ -338,55 +480,51 @@ class MultiModalManager:
|
||||
else:
|
||||
logger.error("Failed to pull %s: HTTP %s", model_name, response.status)
|
||||
return False
|
||||
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error pulling model %s: %s", model_name, exc)
|
||||
return False
|
||||
|
||||
def configure_fallback_chain(
|
||||
self,
|
||||
capability: ModelCapability,
|
||||
models: list[str]
|
||||
) -> None:
|
||||
|
||||
def configure_fallback_chain(self, capability: ModelCapability, models: list[str]) -> None:
|
||||
"""Configure a custom fallback chain for a capability."""
|
||||
self._fallback_chains[capability] = models
|
||||
logger.info("Configured fallback chain for %s: %s", capability.name, models)
|
||||
|
||||
|
||||
def get_fallback_chain(self, capability: ModelCapability) -> list[str]:
|
||||
"""Get the fallback chain for a capability."""
|
||||
return list(self._fallback_chains.get(capability, []))
|
||||
|
||||
|
||||
def list_available_models(self) -> list[ModelInfo]:
|
||||
"""List all available models with their capabilities."""
|
||||
return list(self._available_models.values())
|
||||
|
||||
|
||||
def refresh(self) -> None:
|
||||
"""Refresh the list of available models."""
|
||||
self._refresh_available_models()
|
||||
|
||||
|
||||
def get_model_for_content(
|
||||
self,
|
||||
content_type: str, # "text", "image", "audio", "multimodal"
|
||||
preferred_model: Optional[str] = None,
|
||||
) -> tuple[str, bool]:
|
||||
"""Get appropriate model based on content type.
|
||||
|
||||
|
||||
Args:
|
||||
content_type: Type of content (text, image, audio, multimodal)
|
||||
preferred_model: User's preferred model
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (model_name, is_fallback)
|
||||
"""
|
||||
content_type = content_type.lower()
|
||||
|
||||
|
||||
if content_type in ("image", "vision", "multimodal"):
|
||||
# For vision content, we need a vision-capable model
|
||||
return self.pull_model_with_fallback(
|
||||
preferred_model or "llava:7b",
|
||||
capability=ModelCapability.VISION,
|
||||
)
|
||||
|
||||
|
||||
elif content_type == "audio":
|
||||
# Audio support is limited in Ollama
|
||||
# Would need specific audio models
|
||||
@@ -395,7 +533,7 @@ class MultiModalManager:
|
||||
preferred_model or settings.ollama_model,
|
||||
capability=ModelCapability.TEXT,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
# Standard text content
|
||||
return self.pull_model_with_fallback(
|
||||
@@ -417,8 +555,7 @@ def get_multimodal_manager() -> MultiModalManager:
|
||||
|
||||
|
||||
def get_model_for_capability(
|
||||
capability: ModelCapability,
|
||||
preferred_model: Optional[str] = None
|
||||
capability: ModelCapability, preferred_model: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""Convenience function to get best model for a capability."""
|
||||
return get_multimodal_manager().get_best_model_for(capability, preferred_model)
|
||||
@@ -430,9 +567,7 @@ def pull_model_with_fallback(
|
||||
auto_pull: bool = True,
|
||||
) -> tuple[str, bool]:
|
||||
"""Convenience function to pull model with fallback."""
|
||||
return get_multimodal_manager().pull_model_with_fallback(
|
||||
primary_model, capability, auto_pull
|
||||
)
|
||||
return get_multimodal_manager().pull_model_with_fallback(primary_model, capability, auto_pull)
|
||||
|
||||
|
||||
def model_supports_vision(model_name: str) -> bool:
|
||||
|
||||
@@ -26,26 +26,29 @@ DB_PATH = Path("data/swarm.db")
|
||||
|
||||
class ModelFormat(str, Enum):
|
||||
"""Supported model weight formats."""
|
||||
GGUF = "gguf" # Ollama-compatible quantised weights
|
||||
SAFETENSORS = "safetensors" # HuggingFace safetensors
|
||||
HF_CHECKPOINT = "hf" # Full HuggingFace checkpoint directory
|
||||
OLLAMA = "ollama" # Already loaded in Ollama by name
|
||||
|
||||
GGUF = "gguf" # Ollama-compatible quantised weights
|
||||
SAFETENSORS = "safetensors" # HuggingFace safetensors
|
||||
HF_CHECKPOINT = "hf" # Full HuggingFace checkpoint directory
|
||||
OLLAMA = "ollama" # Already loaded in Ollama by name
|
||||
|
||||
|
||||
class ModelRole(str, Enum):
|
||||
"""Role a model can play in the system (OpenClaw-RL style)."""
|
||||
GENERAL = "general" # Default agent inference
|
||||
REWARD = "reward" # Process Reward Model (PRM) scoring
|
||||
TEACHER = "teacher" # On-policy distillation teacher
|
||||
JUDGE = "judge" # Output quality evaluation
|
||||
|
||||
GENERAL = "general" # Default agent inference
|
||||
REWARD = "reward" # Process Reward Model (PRM) scoring
|
||||
TEACHER = "teacher" # On-policy distillation teacher
|
||||
JUDGE = "judge" # Output quality evaluation
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomModel:
|
||||
"""A registered custom model."""
|
||||
|
||||
name: str
|
||||
format: ModelFormat
|
||||
path: str # Absolute path or Ollama model name
|
||||
path: str # Absolute path or Ollama model name
|
||||
role: ModelRole = ModelRole.GENERAL
|
||||
context_window: int = 4096
|
||||
description: str = ""
|
||||
@@ -141,10 +144,16 @@ class ModelRegistry:
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
model.name, model.format.value, model.path,
|
||||
model.role.value, model.context_window, model.description,
|
||||
model.registered_at, int(model.active),
|
||||
model.default_temperature, model.max_tokens,
|
||||
model.name,
|
||||
model.format.value,
|
||||
model.path,
|
||||
model.role.value,
|
||||
model.context_window,
|
||||
model.description,
|
||||
model.registered_at,
|
||||
int(model.active),
|
||||
model.default_temperature,
|
||||
model.max_tokens,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
@@ -160,9 +169,7 @@ class ModelRegistry:
|
||||
return False
|
||||
conn = _get_conn()
|
||||
conn.execute("DELETE FROM custom_models WHERE name = ?", (name,))
|
||||
conn.execute(
|
||||
"DELETE FROM agent_model_assignments WHERE model_name = ?", (name,)
|
||||
)
|
||||
conn.execute("DELETE FROM agent_model_assignments WHERE model_name = ?", (name,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
del self._models[name]
|
||||
|
||||
@@ -9,8 +9,8 @@ No cloud push services — everything stays local.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
import platform
|
||||
import subprocess
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
@@ -25,9 +25,7 @@ class Notification:
|
||||
title: str
|
||||
message: str
|
||||
category: str # swarm | task | agent | system | payment
|
||||
timestamp: str = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
read: bool = False
|
||||
|
||||
|
||||
@@ -74,9 +72,11 @@ class PushNotifier:
|
||||
def _native_notify(self, title: str, message: str) -> None:
|
||||
"""Send a native macOS notification via osascript."""
|
||||
try:
|
||||
safe_message = message.replace("\\", "\\\\").replace('"', '\\"')
|
||||
safe_title = title.replace("\\", "\\\\").replace('"', '\\"')
|
||||
script = (
|
||||
f'display notification "{message}" '
|
||||
f'with title "Agent Dashboard" subtitle "{title}"'
|
||||
f'display notification "{safe_message}" '
|
||||
f'with title "Agent Dashboard" subtitle "{safe_title}"'
|
||||
)
|
||||
subprocess.Popen(
|
||||
["osascript", "-e", script],
|
||||
@@ -114,7 +114,7 @@ class PushNotifier:
|
||||
def clear(self) -> None:
|
||||
self._notifications.clear()
|
||||
|
||||
def add_listener(self, callback) -> None:
|
||||
def add_listener(self, callback: "Callable[[Notification], None]") -> None:
|
||||
"""Register a callback for real-time notification delivery."""
|
||||
self._listeners.append(callback)
|
||||
|
||||
@@ -139,10 +139,7 @@ async def notify_briefing_ready(briefing) -> None:
|
||||
logger.info("Briefing ready but no pending approvals — skipping native notification")
|
||||
return
|
||||
|
||||
message = (
|
||||
f"Your morning briefing is ready. "
|
||||
f"{n_approvals} item(s) await your approval."
|
||||
)
|
||||
message = f"Your morning briefing is ready. " f"{n_approvals} item(s) await your approval."
|
||||
notifier.notify(
|
||||
title="Morning Briefing Ready",
|
||||
message=message,
|
||||
|
||||
@@ -156,33 +156,23 @@ class OpenFangClient:
|
||||
|
||||
async def browse(self, url: str, instruction: str = "") -> HandResult:
|
||||
"""Web automation via OpenFang's Browser hand."""
|
||||
return await self.execute_hand(
|
||||
"browser", {"url": url, "instruction": instruction}
|
||||
)
|
||||
return await self.execute_hand("browser", {"url": url, "instruction": instruction})
|
||||
|
||||
async def collect(self, target: str, depth: str = "shallow") -> HandResult:
|
||||
"""OSINT collection via OpenFang's Collector hand."""
|
||||
return await self.execute_hand(
|
||||
"collector", {"target": target, "depth": depth}
|
||||
)
|
||||
return await self.execute_hand("collector", {"target": target, "depth": depth})
|
||||
|
||||
async def predict(self, question: str, horizon: str = "1w") -> HandResult:
|
||||
"""Superforecasting via OpenFang's Predictor hand."""
|
||||
return await self.execute_hand(
|
||||
"predictor", {"question": question, "horizon": horizon}
|
||||
)
|
||||
return await self.execute_hand("predictor", {"question": question, "horizon": horizon})
|
||||
|
||||
async def find_leads(self, icp: str, max_results: int = 10) -> HandResult:
|
||||
"""Prospect discovery via OpenFang's Lead hand."""
|
||||
return await self.execute_hand(
|
||||
"lead", {"icp": icp, "max_results": max_results}
|
||||
)
|
||||
return await self.execute_hand("lead", {"icp": icp, "max_results": max_results})
|
||||
|
||||
async def research(self, topic: str, depth: str = "standard") -> HandResult:
|
||||
"""Deep research via OpenFang's Researcher hand."""
|
||||
return await self.execute_hand(
|
||||
"researcher", {"topic": topic, "depth": depth}
|
||||
)
|
||||
return await self.execute_hand("researcher", {"topic": topic, "depth": depth})
|
||||
|
||||
# ── Inventory ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -22,9 +22,11 @@ from infrastructure.openfang.client import OPENFANG_HANDS, openfang_client
|
||||
try:
|
||||
from mcp.schemas.base import create_tool_schema
|
||||
except ImportError:
|
||||
|
||||
def create_tool_schema(**kwargs):
|
||||
return kwargs
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Tool schemas ─────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Cascade LLM Router — Automatic failover between providers."""
|
||||
|
||||
from .cascade import CascadeRouter, Provider, ProviderStatus, get_router
|
||||
from .api import router
|
||||
from .cascade import CascadeRouter, Provider, ProviderStatus, get_router
|
||||
|
||||
__all__ = [
|
||||
"CascadeRouter",
|
||||
|
||||
@@ -15,6 +15,7 @@ router = APIRouter(prefix="/api/v1/router", tags=["router"])
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
"""Request body for completions."""
|
||||
|
||||
messages: list[dict[str, str]]
|
||||
model: str | None = None
|
||||
temperature: float = 0.7
|
||||
@@ -23,6 +24,7 @@ class CompletionRequest(BaseModel):
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
"""Response from completion endpoint."""
|
||||
|
||||
content: str
|
||||
provider: str
|
||||
model: str
|
||||
@@ -31,6 +33,7 @@ class CompletionResponse(BaseModel):
|
||||
|
||||
class ProviderControl(BaseModel):
|
||||
"""Control a provider's status."""
|
||||
|
||||
action: str # "enable", "disable", "reset_circuit"
|
||||
|
||||
|
||||
@@ -45,7 +48,7 @@ async def complete(
|
||||
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],
|
||||
) -> dict[str, Any]:
|
||||
"""Complete a conversation with automatic failover.
|
||||
|
||||
|
||||
Routes through providers in priority order until one succeeds.
|
||||
"""
|
||||
try:
|
||||
@@ -108,30 +111,32 @@ async def control_provider(
|
||||
if p.name == provider_name:
|
||||
provider = p
|
||||
break
|
||||
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"Provider {provider_name} not found")
|
||||
|
||||
|
||||
if control.action == "enable":
|
||||
provider.enabled = True
|
||||
provider.status = provider.status.__class__.HEALTHY
|
||||
return {"message": f"Provider {provider_name} enabled"}
|
||||
|
||||
|
||||
elif control.action == "disable":
|
||||
provider.enabled = False
|
||||
from .cascade import ProviderStatus
|
||||
|
||||
provider.status = ProviderStatus.DISABLED
|
||||
return {"message": f"Provider {provider_name} disabled"}
|
||||
|
||||
|
||||
elif control.action == "reset_circuit":
|
||||
from .cascade import CircuitState, ProviderStatus
|
||||
|
||||
provider.circuit_state = CircuitState.CLOSED
|
||||
provider.circuit_opened_at = None
|
||||
provider.half_open_calls = 0
|
||||
provider.metrics.consecutive_failures = 0
|
||||
provider.status = ProviderStatus.HEALTHY
|
||||
return {"message": f"Circuit breaker reset for {provider_name}"}
|
||||
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown action: {control.action}")
|
||||
|
||||
@@ -142,28 +147,35 @@ async def run_health_check(
|
||||
) -> dict[str, Any]:
|
||||
"""Run health checks on all providers."""
|
||||
results = []
|
||||
|
||||
|
||||
for provider in cascade.providers:
|
||||
# Quick ping to check availability
|
||||
is_healthy = cascade._check_provider_available(provider)
|
||||
|
||||
|
||||
from .cascade import ProviderStatus
|
||||
|
||||
if is_healthy:
|
||||
if provider.status == ProviderStatus.UNHEALTHY:
|
||||
# Reset circuit if it was open but now healthy
|
||||
provider.circuit_state = provider.circuit_state.__class__.CLOSED
|
||||
provider.circuit_opened_at = None
|
||||
provider.status = ProviderStatus.HEALTHY if provider.metrics.error_rate < 0.1 else ProviderStatus.DEGRADED
|
||||
provider.status = (
|
||||
ProviderStatus.HEALTHY
|
||||
if provider.metrics.error_rate < 0.1
|
||||
else ProviderStatus.DEGRADED
|
||||
)
|
||||
else:
|
||||
provider.status = ProviderStatus.UNHEALTHY
|
||||
|
||||
results.append({
|
||||
"name": provider.name,
|
||||
"type": provider.type,
|
||||
"healthy": is_healthy,
|
||||
"status": provider.status.value,
|
||||
})
|
||||
|
||||
|
||||
results.append(
|
||||
{
|
||||
"name": provider.name,
|
||||
"type": provider.type,
|
||||
"healthy": is_healthy,
|
||||
"status": provider.status.value,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"checked_at": asyncio.get_event_loop().time(),
|
||||
"providers": results,
|
||||
@@ -177,7 +189,7 @@ async def get_config(
|
||||
) -> dict[str, Any]:
|
||||
"""Get router configuration (without secrets)."""
|
||||
cfg = cascade.config
|
||||
|
||||
|
||||
return {
|
||||
"timeout_seconds": cfg.timeout_seconds,
|
||||
"max_retries_per_provider": cfg.max_retries_per_provider,
|
||||
|
||||
@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ProviderStatus(Enum):
|
||||
"""Health status of a provider."""
|
||||
|
||||
HEALTHY = "healthy"
|
||||
DEGRADED = "degraded" # Working but slow or occasional errors
|
||||
UNHEALTHY = "unhealthy" # Circuit breaker open
|
||||
@@ -41,22 +42,25 @@ class ProviderStatus(Enum):
|
||||
|
||||
class CircuitState(Enum):
|
||||
"""Circuit breaker state."""
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Failing, rejecting requests
|
||||
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Failing, rejecting requests
|
||||
HALF_OPEN = "half_open" # Testing if recovered
|
||||
|
||||
|
||||
class ContentType(Enum):
|
||||
"""Type of content in the request."""
|
||||
|
||||
TEXT = "text"
|
||||
VISION = "vision" # Contains images
|
||||
AUDIO = "audio" # Contains audio
|
||||
VISION = "vision" # Contains images
|
||||
AUDIO = "audio" # Contains audio
|
||||
MULTIMODAL = "multimodal" # Multiple content types
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMetrics:
|
||||
"""Metrics for a single provider."""
|
||||
|
||||
total_requests: int = 0
|
||||
successful_requests: int = 0
|
||||
failed_requests: int = 0
|
||||
@@ -64,13 +68,13 @@ class ProviderMetrics:
|
||||
last_request_time: Optional[str] = None
|
||||
last_error_time: Optional[str] = None
|
||||
consecutive_failures: int = 0
|
||||
|
||||
|
||||
@property
|
||||
def avg_latency_ms(self) -> float:
|
||||
if self.total_requests == 0:
|
||||
return 0.0
|
||||
return self.total_latency_ms / self.total_requests
|
||||
|
||||
|
||||
@property
|
||||
def error_rate(self) -> float:
|
||||
if self.total_requests == 0:
|
||||
@@ -81,6 +85,7 @@ class ProviderMetrics:
|
||||
@dataclass
|
||||
class ModelCapability:
|
||||
"""Capabilities a model supports."""
|
||||
|
||||
name: str
|
||||
supports_vision: bool = False
|
||||
supports_audio: bool = False
|
||||
@@ -93,6 +98,7 @@ class ModelCapability:
|
||||
@dataclass
|
||||
class Provider:
|
||||
"""LLM provider configuration and state."""
|
||||
|
||||
name: str
|
||||
type: str # ollama, openai, anthropic, airllm
|
||||
enabled: bool
|
||||
@@ -101,14 +107,14 @@ class Provider:
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
models: list[dict] = field(default_factory=list)
|
||||
|
||||
|
||||
# Runtime state
|
||||
status: ProviderStatus = ProviderStatus.HEALTHY
|
||||
metrics: ProviderMetrics = field(default_factory=ProviderMetrics)
|
||||
circuit_state: CircuitState = CircuitState.CLOSED
|
||||
circuit_opened_at: Optional[float] = None
|
||||
half_open_calls: int = 0
|
||||
|
||||
|
||||
def get_default_model(self) -> Optional[str]:
|
||||
"""Get the default model for this provider."""
|
||||
for model in self.models:
|
||||
@@ -117,7 +123,7 @@ class Provider:
|
||||
if self.models:
|
||||
return self.models[0]["name"]
|
||||
return None
|
||||
|
||||
|
||||
def get_model_with_capability(self, capability: str) -> Optional[str]:
|
||||
"""Get a model that supports the given capability."""
|
||||
for model in self.models:
|
||||
@@ -126,7 +132,7 @@ class Provider:
|
||||
return model["name"]
|
||||
# Fall back to default
|
||||
return self.get_default_model()
|
||||
|
||||
|
||||
def model_has_capability(self, model_name: str, capability: str) -> bool:
|
||||
"""Check if a specific model has a capability."""
|
||||
for model in self.models:
|
||||
@@ -139,6 +145,7 @@ class Provider:
|
||||
@dataclass
|
||||
class RouterConfig:
|
||||
"""Cascade router configuration."""
|
||||
|
||||
timeout_seconds: int = 30
|
||||
max_retries_per_provider: int = 2
|
||||
retry_delay_seconds: int = 1
|
||||
@@ -154,22 +161,22 @@ class RouterConfig:
|
||||
|
||||
class CascadeRouter:
|
||||
"""Routes LLM requests with automatic failover.
|
||||
|
||||
|
||||
Now with multi-modal support:
|
||||
- Automatically detects content type (text, vision, audio)
|
||||
- Selects appropriate models based on capabilities
|
||||
- Falls back through capability-specific model chains
|
||||
- Supports image URLs and base64 encoding
|
||||
|
||||
|
||||
Usage:
|
||||
router = CascadeRouter()
|
||||
|
||||
|
||||
# Text request
|
||||
response = await router.complete(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
model="llama3.2"
|
||||
)
|
||||
|
||||
|
||||
# Vision request (automatically detects and selects vision model)
|
||||
response = await router.complete(
|
||||
messages=[{
|
||||
@@ -179,68 +186,75 @@ class CascadeRouter:
|
||||
}],
|
||||
model="llava:7b"
|
||||
)
|
||||
|
||||
|
||||
# Check metrics
|
||||
metrics = router.get_metrics()
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, config_path: Optional[Path] = None) -> None:
|
||||
self.config_path = config_path or Path("config/providers.yaml")
|
||||
self.providers: list[Provider] = []
|
||||
self.config: RouterConfig = RouterConfig()
|
||||
self._load_config()
|
||||
|
||||
|
||||
# Initialize multi-modal manager if available
|
||||
self._mm_manager: Optional[Any] = None
|
||||
try:
|
||||
from infrastructure.models.multimodal import get_multimodal_manager
|
||||
|
||||
self._mm_manager = get_multimodal_manager()
|
||||
except Exception as exc:
|
||||
logger.debug("Multi-modal manager not available: %s", exc)
|
||||
|
||||
|
||||
logger.info("CascadeRouter initialized with %d providers", len(self.providers))
|
||||
|
||||
|
||||
def _load_config(self) -> None:
|
||||
"""Load configuration from YAML."""
|
||||
if not self.config_path.exists():
|
||||
logger.warning("Config not found: %s, using defaults", self.config_path)
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
if yaml is None:
|
||||
raise RuntimeError("PyYAML not installed")
|
||||
|
||||
|
||||
content = self.config_path.read_text()
|
||||
# Expand environment variables
|
||||
content = self._expand_env_vars(content)
|
||||
data = yaml.safe_load(content)
|
||||
|
||||
|
||||
# Load cascade settings
|
||||
cascade = data.get("cascade", {})
|
||||
|
||||
|
||||
# Load fallback chains
|
||||
fallback_chains = data.get("fallback_chains", {})
|
||||
|
||||
|
||||
# Load multi-modal settings
|
||||
multimodal = data.get("multimodal", {})
|
||||
|
||||
|
||||
self.config = RouterConfig(
|
||||
timeout_seconds=cascade.get("timeout_seconds", 30),
|
||||
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
|
||||
retry_delay_seconds=cascade.get("retry_delay_seconds", 1),
|
||||
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get("failure_threshold", 5),
|
||||
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get("recovery_timeout", 60),
|
||||
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get("half_open_max_calls", 2),
|
||||
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get(
|
||||
"failure_threshold", 5
|
||||
),
|
||||
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get(
|
||||
"recovery_timeout", 60
|
||||
),
|
||||
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get(
|
||||
"half_open_max_calls", 2
|
||||
),
|
||||
auto_pull_models=multimodal.get("auto_pull", True),
|
||||
fallback_chains=fallback_chains,
|
||||
)
|
||||
|
||||
|
||||
# Load providers
|
||||
for p_data in data.get("providers", []):
|
||||
# Skip disabled providers
|
||||
if not p_data.get("enabled", False):
|
||||
continue
|
||||
|
||||
|
||||
provider = Provider(
|
||||
name=p_data["name"],
|
||||
type=p_data["type"],
|
||||
@@ -251,30 +265,34 @@ class CascadeRouter:
|
||||
base_url=p_data.get("base_url"),
|
||||
models=p_data.get("models", []),
|
||||
)
|
||||
|
||||
|
||||
# Check if provider is actually available
|
||||
if self._check_provider_available(provider):
|
||||
self.providers.append(provider)
|
||||
else:
|
||||
logger.warning("Provider %s not available, skipping", provider.name)
|
||||
|
||||
|
||||
# Sort by priority
|
||||
self.providers.sort(key=lambda p: p.priority)
|
||||
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Failed to load config: %s", exc)
|
||||
|
||||
|
||||
def _expand_env_vars(self, content: str) -> str:
|
||||
"""Expand ${VAR} syntax in YAML content."""
|
||||
"""Expand ${VAR} syntax in YAML content.
|
||||
|
||||
Uses os.environ directly (not settings) because this is a generic
|
||||
YAML config loader that must expand arbitrary variable references.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
|
||||
def replace_var(match):
|
||||
|
||||
def replace_var(match: "re.Match[str]") -> str:
|
||||
var_name = match.group(1)
|
||||
return os.environ.get(var_name, match.group(0))
|
||||
|
||||
|
||||
return re.sub(r"\$\{(\w+)\}", replace_var, content)
|
||||
|
||||
|
||||
def _check_provider_available(self, provider: Provider) -> bool:
|
||||
"""Check if a provider is actually available."""
|
||||
if provider.type == "ollama":
|
||||
@@ -288,48 +306,49 @@ class CascadeRouter:
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
elif provider.type == "airllm":
|
||||
# Check if airllm is installed
|
||||
try:
|
||||
import airllm
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
elif provider.type in ("openai", "anthropic", "grok"):
|
||||
# Check if API key is set
|
||||
return provider.api_key is not None and provider.api_key != ""
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _detect_content_type(self, messages: list[dict]) -> ContentType:
|
||||
"""Detect the type of content in the messages.
|
||||
|
||||
|
||||
Checks for images, audio, etc. in the message content.
|
||||
"""
|
||||
has_image = False
|
||||
has_audio = False
|
||||
|
||||
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
|
||||
|
||||
# Check for image URLs/paths
|
||||
if msg.get("images"):
|
||||
has_image = True
|
||||
|
||||
|
||||
# Check for image URLs in content
|
||||
if isinstance(content, str):
|
||||
image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp')
|
||||
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
|
||||
if any(ext in content.lower() for ext in image_extensions):
|
||||
has_image = True
|
||||
if content.startswith("data:image/"):
|
||||
has_image = True
|
||||
|
||||
|
||||
# Check for audio
|
||||
if msg.get("audio"):
|
||||
has_audio = True
|
||||
|
||||
|
||||
# Check for multimodal content structure
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
@@ -338,7 +357,7 @@ class CascadeRouter:
|
||||
has_image = True
|
||||
elif item.get("type") == "audio":
|
||||
has_audio = True
|
||||
|
||||
|
||||
if has_image and has_audio:
|
||||
return ContentType.MULTIMODAL
|
||||
elif has_image:
|
||||
@@ -346,12 +365,9 @@ class CascadeRouter:
|
||||
elif has_audio:
|
||||
return ContentType.AUDIO
|
||||
return ContentType.TEXT
|
||||
|
||||
|
||||
def _get_fallback_model(
|
||||
self,
|
||||
provider: Provider,
|
||||
original_model: str,
|
||||
content_type: ContentType
|
||||
self, provider: Provider, original_model: str, content_type: ContentType
|
||||
) -> Optional[str]:
|
||||
"""Get a fallback model for the given content type."""
|
||||
# Map content type to capability
|
||||
@@ -360,24 +376,24 @@ class CascadeRouter:
|
||||
ContentType.AUDIO: "audio",
|
||||
ContentType.MULTIMODAL: "vision", # Vision models often do both
|
||||
}
|
||||
|
||||
|
||||
capability = capability_map.get(content_type)
|
||||
if not capability:
|
||||
return None
|
||||
|
||||
|
||||
# Check provider's models for capability
|
||||
fallback_model = provider.get_model_with_capability(capability)
|
||||
if fallback_model and fallback_model != original_model:
|
||||
return fallback_model
|
||||
|
||||
|
||||
# Use fallback chains from config
|
||||
fallback_chain = self.config.fallback_chains.get(capability, [])
|
||||
for model_name in fallback_chain:
|
||||
if provider.model_has_capability(model_name, capability):
|
||||
return model_name
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict],
|
||||
@@ -386,21 +402,21 @@ class CascadeRouter:
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""Complete a chat conversation with automatic failover.
|
||||
|
||||
|
||||
Multi-modal support:
|
||||
- Automatically detects if messages contain images
|
||||
- Falls back to vision-capable models when needed
|
||||
- Supports image URLs, paths, and base64 encoding
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with role and content
|
||||
model: Preferred model (tries this first, then provider defaults)
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with content, provider_used, and metrics
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: If all providers fail
|
||||
"""
|
||||
@@ -408,15 +424,15 @@ class CascadeRouter:
|
||||
content_type = self._detect_content_type(messages)
|
||||
if content_type != ContentType.TEXT:
|
||||
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
|
||||
|
||||
|
||||
errors = []
|
||||
|
||||
|
||||
for provider in self.providers:
|
||||
# Skip disabled providers
|
||||
if not provider.enabled:
|
||||
logger.debug("Skipping %s (disabled)", provider.name)
|
||||
continue
|
||||
|
||||
|
||||
# Skip unhealthy providers (circuit breaker)
|
||||
if provider.status == ProviderStatus.UNHEALTHY:
|
||||
# Check if circuit breaker can close
|
||||
@@ -427,16 +443,16 @@ class CascadeRouter:
|
||||
else:
|
||||
logger.debug("Skipping %s (circuit open)", provider.name)
|
||||
continue
|
||||
|
||||
|
||||
# Determine which model to use
|
||||
selected_model = model or provider.get_default_model()
|
||||
is_fallback_model = False
|
||||
|
||||
|
||||
# For non-text content, check if model supports it
|
||||
if content_type != ContentType.TEXT and selected_model:
|
||||
if provider.type == "ollama" and self._mm_manager:
|
||||
from infrastructure.models.multimodal import ModelCapability
|
||||
|
||||
|
||||
# Check if selected model supports the required capability
|
||||
if content_type == ContentType.VISION:
|
||||
supports = self._mm_manager.model_supports(
|
||||
@@ -450,16 +466,17 @@ class CascadeRouter:
|
||||
if fallback:
|
||||
logger.info(
|
||||
"Model %s doesn't support vision, falling back to %s",
|
||||
selected_model, fallback
|
||||
selected_model,
|
||||
fallback,
|
||||
)
|
||||
selected_model = fallback
|
||||
is_fallback_model = True
|
||||
else:
|
||||
logger.warning(
|
||||
"No vision-capable model found on %s, trying anyway",
|
||||
provider.name
|
||||
provider.name,
|
||||
)
|
||||
|
||||
|
||||
# Try this provider
|
||||
for attempt in range(self.config.max_retries_per_provider):
|
||||
try:
|
||||
@@ -471,34 +488,35 @@ class CascadeRouter:
|
||||
max_tokens=max_tokens,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
|
||||
# Success! Update metrics and return
|
||||
self._record_success(provider, result.get("latency_ms", 0))
|
||||
return {
|
||||
"content": result["content"],
|
||||
"provider": provider.name,
|
||||
"model": result.get("model", selected_model or provider.get_default_model()),
|
||||
"model": result.get(
|
||||
"model", selected_model or provider.get_default_model()
|
||||
),
|
||||
"latency_ms": result.get("latency_ms", 0),
|
||||
"is_fallback_model": is_fallback_model,
|
||||
}
|
||||
|
||||
|
||||
except Exception as exc:
|
||||
error_msg = str(exc)
|
||||
logger.warning(
|
||||
"Provider %s attempt %d failed: %s",
|
||||
provider.name, attempt + 1, error_msg
|
||||
"Provider %s attempt %d failed: %s", provider.name, attempt + 1, error_msg
|
||||
)
|
||||
errors.append(f"{provider.name}: {error_msg}")
|
||||
|
||||
|
||||
if attempt < self.config.max_retries_per_provider - 1:
|
||||
await asyncio.sleep(self.config.retry_delay_seconds)
|
||||
|
||||
|
||||
# All retries failed for this provider
|
||||
self._record_failure(provider)
|
||||
|
||||
|
||||
# All providers failed
|
||||
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
|
||||
|
||||
|
||||
async def _try_provider(
|
||||
self,
|
||||
provider: Provider,
|
||||
@@ -510,7 +528,7 @@ class CascadeRouter:
|
||||
) -> dict:
|
||||
"""Try a single provider request."""
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
if provider.type == "ollama":
|
||||
result = await self._call_ollama(
|
||||
provider=provider,
|
||||
@@ -545,12 +563,12 @@ class CascadeRouter:
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider.type}")
|
||||
|
||||
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
result["latency_ms"] = latency_ms
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _call_ollama(
|
||||
self,
|
||||
provider: Provider,
|
||||
@@ -561,12 +579,12 @@ class CascadeRouter:
|
||||
) -> dict:
|
||||
"""Call Ollama API with multi-modal support."""
|
||||
import aiohttp
|
||||
|
||||
|
||||
url = f"{provider.url}/api/chat"
|
||||
|
||||
|
||||
# Transform messages for Ollama format (including images)
|
||||
transformed_messages = self._transform_messages_for_ollama(messages)
|
||||
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": transformed_messages,
|
||||
@@ -575,31 +593,31 @@ class CascadeRouter:
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=self.config.timeout_seconds)
|
||||
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(url, json=payload) as response:
|
||||
if response.status != 200:
|
||||
text = await response.text()
|
||||
raise RuntimeError(f"Ollama error {response.status}: {text}")
|
||||
|
||||
|
||||
data = await response.json()
|
||||
return {
|
||||
"content": data["message"]["content"],
|
||||
"model": model,
|
||||
}
|
||||
|
||||
|
||||
def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]:
|
||||
"""Transform messages to Ollama format, handling images."""
|
||||
transformed = []
|
||||
|
||||
|
||||
for msg in messages:
|
||||
new_msg = {
|
||||
"role": msg.get("role", "user"),
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
|
||||
|
||||
# Handle images
|
||||
images = msg.get("images", [])
|
||||
if images:
|
||||
@@ -620,11 +638,11 @@ class CascadeRouter:
|
||||
new_msg["images"].append(img_data)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to read image %s: %s", img, exc)
|
||||
|
||||
|
||||
transformed.append(new_msg)
|
||||
|
||||
|
||||
return transformed
|
||||
|
||||
|
||||
async def _call_openai(
|
||||
self,
|
||||
provider: Provider,
|
||||
@@ -635,13 +653,13 @@ class CascadeRouter:
|
||||
) -> dict:
|
||||
"""Call OpenAI API."""
|
||||
import openai
|
||||
|
||||
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=provider.api_key,
|
||||
base_url=provider.base_url,
|
||||
timeout=self.config.timeout_seconds,
|
||||
)
|
||||
|
||||
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
@@ -649,14 +667,14 @@ class CascadeRouter:
|
||||
}
|
||||
if max_tokens:
|
||||
kwargs["max_tokens"] = max_tokens
|
||||
|
||||
|
||||
response = await client.chat.completions.create(**kwargs)
|
||||
|
||||
|
||||
return {
|
||||
"content": response.choices[0].message.content,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
|
||||
async def _call_anthropic(
|
||||
self,
|
||||
provider: Provider,
|
||||
@@ -667,12 +685,12 @@ class CascadeRouter:
|
||||
) -> dict:
|
||||
"""Call Anthropic API."""
|
||||
import anthropic
|
||||
|
||||
|
||||
client = anthropic.AsyncAnthropic(
|
||||
api_key=provider.api_key,
|
||||
timeout=self.config.timeout_seconds,
|
||||
)
|
||||
|
||||
|
||||
# Convert messages to Anthropic format
|
||||
system_msg = None
|
||||
conversation = []
|
||||
@@ -680,11 +698,13 @@ class CascadeRouter:
|
||||
if msg["role"] == "system":
|
||||
system_msg = msg["content"]
|
||||
else:
|
||||
conversation.append({
|
||||
"role": msg["role"],
|
||||
"content": msg["content"],
|
||||
})
|
||||
|
||||
conversation.append(
|
||||
{
|
||||
"role": msg["role"],
|
||||
"content": msg["content"],
|
||||
}
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"messages": conversation,
|
||||
@@ -693,9 +713,9 @@ class CascadeRouter:
|
||||
}
|
||||
if system_msg:
|
||||
kwargs["system"] = system_msg
|
||||
|
||||
|
||||
response = await client.messages.create(**kwargs)
|
||||
|
||||
|
||||
return {
|
||||
"content": response.content[0].text,
|
||||
"model": response.model,
|
||||
@@ -733,7 +753,7 @@ class CascadeRouter:
|
||||
"content": response.choices[0].message.content,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
|
||||
def _record_success(self, provider: Provider, latency_ms: float) -> None:
|
||||
"""Record a successful request."""
|
||||
provider.metrics.total_requests += 1
|
||||
@@ -741,50 +761,50 @@ class CascadeRouter:
|
||||
provider.metrics.total_latency_ms += latency_ms
|
||||
provider.metrics.last_request_time = datetime.now(timezone.utc).isoformat()
|
||||
provider.metrics.consecutive_failures = 0
|
||||
|
||||
|
||||
# Close circuit breaker if half-open
|
||||
if provider.circuit_state == CircuitState.HALF_OPEN:
|
||||
provider.half_open_calls += 1
|
||||
if provider.half_open_calls >= self.config.circuit_breaker_half_open_max_calls:
|
||||
self._close_circuit(provider)
|
||||
|
||||
|
||||
# Update status based on error rate
|
||||
if provider.metrics.error_rate < 0.1:
|
||||
provider.status = ProviderStatus.HEALTHY
|
||||
elif provider.metrics.error_rate < 0.3:
|
||||
provider.status = ProviderStatus.DEGRADED
|
||||
|
||||
|
||||
def _record_failure(self, provider: Provider) -> None:
|
||||
"""Record a failed request."""
|
||||
provider.metrics.total_requests += 1
|
||||
provider.metrics.failed_requests += 1
|
||||
provider.metrics.last_error_time = datetime.now(timezone.utc).isoformat()
|
||||
provider.metrics.consecutive_failures += 1
|
||||
|
||||
|
||||
# Check if we should open circuit breaker
|
||||
if provider.metrics.consecutive_failures >= self.config.circuit_breaker_failure_threshold:
|
||||
self._open_circuit(provider)
|
||||
|
||||
|
||||
# Update status
|
||||
if provider.metrics.error_rate > 0.3:
|
||||
provider.status = ProviderStatus.DEGRADED
|
||||
if provider.metrics.error_rate > 0.5:
|
||||
provider.status = ProviderStatus.UNHEALTHY
|
||||
|
||||
|
||||
def _open_circuit(self, provider: Provider) -> None:
|
||||
"""Open the circuit breaker for a provider."""
|
||||
provider.circuit_state = CircuitState.OPEN
|
||||
provider.circuit_opened_at = time.time()
|
||||
provider.status = ProviderStatus.UNHEALTHY
|
||||
logger.warning("Circuit breaker OPEN for %s", provider.name)
|
||||
|
||||
|
||||
def _can_close_circuit(self, provider: Provider) -> bool:
|
||||
"""Check if circuit breaker can transition to half-open."""
|
||||
if provider.circuit_opened_at is None:
|
||||
return False
|
||||
elapsed = time.time() - provider.circuit_opened_at
|
||||
return elapsed >= self.config.circuit_breaker_recovery_timeout
|
||||
|
||||
|
||||
def _close_circuit(self, provider: Provider) -> None:
|
||||
"""Close the circuit breaker (provider healthy again)."""
|
||||
provider.circuit_state = CircuitState.CLOSED
|
||||
@@ -793,7 +813,7 @@ class CascadeRouter:
|
||||
provider.metrics.consecutive_failures = 0
|
||||
provider.status = ProviderStatus.HEALTHY
|
||||
logger.info("Circuit breaker CLOSED for %s", provider.name)
|
||||
|
||||
|
||||
def get_metrics(self) -> dict:
|
||||
"""Get metrics for all providers."""
|
||||
return {
|
||||
@@ -814,16 +834,20 @@ class CascadeRouter:
|
||||
for p in self.providers
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Get current router status."""
|
||||
healthy = sum(1 for p in self.providers if p.status == ProviderStatus.HEALTHY)
|
||||
|
||||
|
||||
return {
|
||||
"total_providers": len(self.providers),
|
||||
"healthy_providers": healthy,
|
||||
"degraded_providers": sum(1 for p in self.providers if p.status == ProviderStatus.DEGRADED),
|
||||
"unhealthy_providers": sum(1 for p in self.providers if p.status == ProviderStatus.UNHEALTHY),
|
||||
"degraded_providers": sum(
|
||||
1 for p in self.providers if p.status == ProviderStatus.DEGRADED
|
||||
),
|
||||
"unhealthy_providers": sum(
|
||||
1 for p in self.providers if p.status == ProviderStatus.UNHEALTHY
|
||||
),
|
||||
"providers": [
|
||||
{
|
||||
"name": p.name,
|
||||
@@ -835,7 +859,7 @@ class CascadeRouter:
|
||||
for p in self.providers
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def generate_with_image(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -844,21 +868,23 @@ class CascadeRouter:
|
||||
temperature: float = 0.7,
|
||||
) -> dict:
|
||||
"""Convenience method for vision requests.
|
||||
|
||||
|
||||
Args:
|
||||
prompt: Text prompt about the image
|
||||
image_path: Path to image file
|
||||
model: Vision-capable model (auto-selected if not provided)
|
||||
temperature: Sampling temperature
|
||||
|
||||
|
||||
Returns:
|
||||
Response dict with content and metadata
|
||||
"""
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"images": [image_path],
|
||||
}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"images": [image_path],
|
||||
}
|
||||
]
|
||||
return await self.complete(
|
||||
messages=messages,
|
||||
model=model,
|
||||
|
||||
@@ -22,6 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class WSEvent:
|
||||
"""A WebSocket event to broadcast to connected clients."""
|
||||
|
||||
event: str
|
||||
data: dict
|
||||
timestamp: str
|
||||
@@ -93,28 +94,42 @@ class WebSocketManager:
|
||||
await self.broadcast("agent_left", {"agent_id": agent_id, "name": name})
|
||||
|
||||
async def broadcast_task_posted(self, task_id: str, description: str) -> None:
|
||||
await self.broadcast("task_posted", {
|
||||
"task_id": task_id, "description": description,
|
||||
})
|
||||
await self.broadcast(
|
||||
"task_posted",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"description": description,
|
||||
},
|
||||
)
|
||||
|
||||
async def broadcast_bid_submitted(
|
||||
self, task_id: str, agent_id: str, bid_sats: int
|
||||
) -> None:
|
||||
await self.broadcast("bid_submitted", {
|
||||
"task_id": task_id, "agent_id": agent_id, "bid_sats": bid_sats,
|
||||
})
|
||||
async def broadcast_bid_submitted(self, task_id: str, agent_id: str, bid_sats: int) -> None:
|
||||
await self.broadcast(
|
||||
"bid_submitted",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"agent_id": agent_id,
|
||||
"bid_sats": bid_sats,
|
||||
},
|
||||
)
|
||||
|
||||
async def broadcast_task_assigned(self, task_id: str, agent_id: str) -> None:
|
||||
await self.broadcast("task_assigned", {
|
||||
"task_id": task_id, "agent_id": agent_id,
|
||||
})
|
||||
await self.broadcast(
|
||||
"task_assigned",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"agent_id": agent_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def broadcast_task_completed(
|
||||
self, task_id: str, agent_id: str, result: str
|
||||
) -> None:
|
||||
await self.broadcast("task_completed", {
|
||||
"task_id": task_id, "agent_id": agent_id, "result": result[:200],
|
||||
})
|
||||
async def broadcast_task_completed(self, task_id: str, agent_id: str, result: str) -> None:
|
||||
await self.broadcast(
|
||||
"task_completed",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"agent_id": agent_id,
|
||||
"result": result[:200],
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def connection_count(self) -> int:
|
||||
@@ -122,28 +137,28 @@ class WebSocketManager:
|
||||
|
||||
async def broadcast_json(self, data: dict) -> int:
|
||||
"""Broadcast raw JSON data to all connected clients.
|
||||
|
||||
|
||||
Args:
|
||||
data: Dictionary to send as JSON
|
||||
|
||||
|
||||
Returns:
|
||||
Number of clients notified
|
||||
"""
|
||||
message = json.dumps(data)
|
||||
disconnected = []
|
||||
count = 0
|
||||
|
||||
|
||||
for ws in self._connections:
|
||||
try:
|
||||
await ws.send_text(message)
|
||||
count += 1
|
||||
except Exception:
|
||||
disconnected.append(ws)
|
||||
|
||||
|
||||
# Clean up dead connections
|
||||
for ws in disconnected:
|
||||
self.disconnect(ws)
|
||||
|
||||
|
||||
return count
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user