forked from Rockachopa/Timmy-time-dashboard
feat: Multi-modal support with automatic model fallback
- Add MultiModalManager with capability detection for vision/audio/tools
- Define fallback chains: vision (llama3.2:3b -> llava:7b -> moondream)
tools (llama3.1:8b-instruct -> qwen2.5:7b)
- Update CascadeRouter to detect content type and select appropriate models
- Add model pulling with automatic fallback in agent creation
- Update providers.yaml with multi-modal model configurations
- Update OllamaAdapter to use model resolution with vision support
Tests: All 96 infrastructure tests pass
This commit is contained in:
@@ -3,14 +3,19 @@
|
||||
Routes requests through an ordered list of LLM providers,
|
||||
automatically failing over on rate limits or errors.
|
||||
Tracks metrics for latency, errors, and cost.
|
||||
|
||||
Now with multi-modal support — automatically selects vision-capable
|
||||
models for image inputs and falls back through capability chains.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from pathlib import Path
|
||||
@@ -43,6 +48,14 @@ class CircuitState(Enum):
|
||||
HALF_OPEN = "half_open" # Testing if recovered
|
||||
|
||||
|
||||
class ContentType(Enum):
|
||||
"""Type of content in the request."""
|
||||
TEXT = "text"
|
||||
VISION = "vision" # Contains images
|
||||
AUDIO = "audio" # Contains audio
|
||||
MULTIMODAL = "multimodal" # Multiple content types
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMetrics:
|
||||
"""Metrics for a single provider."""
|
||||
@@ -67,6 +80,18 @@ class ProviderMetrics:
|
||||
return self.failed_requests / self.total_requests
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCapability:
|
||||
"""Capabilities a model supports."""
|
||||
name: str
|
||||
supports_vision: bool = False
|
||||
supports_audio: bool = False
|
||||
supports_tools: bool = False
|
||||
supports_json: bool = False
|
||||
supports_streaming: bool = True
|
||||
context_window: int = 4096
|
||||
|
||||
|
||||
@dataclass
|
||||
class Provider:
|
||||
"""LLM provider configuration and state."""
|
||||
@@ -94,6 +119,23 @@ class Provider:
|
||||
if self.models:
|
||||
return self.models[0]["name"]
|
||||
return None
|
||||
|
||||
def get_model_with_capability(self, capability: str) -> Optional[str]:
|
||||
"""Get a model that supports the given capability."""
|
||||
for model in self.models:
|
||||
capabilities = model.get("capabilities", [])
|
||||
if capability in capabilities:
|
||||
return model["name"]
|
||||
# Fall back to default
|
||||
return self.get_default_model()
|
||||
|
||||
def model_has_capability(self, model_name: str, capability: str) -> bool:
|
||||
"""Check if a specific model has a capability."""
|
||||
for model in self.models:
|
||||
if model["name"] == model_name:
|
||||
capabilities = model.get("capabilities", [])
|
||||
return capability in capabilities
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -107,19 +149,39 @@ class RouterConfig:
|
||||
circuit_breaker_half_open_max_calls: int = 2
|
||||
cost_tracking_enabled: bool = True
|
||||
budget_daily_usd: float = 10.0
|
||||
# Multi-modal settings
|
||||
auto_pull_models: bool = True
|
||||
fallback_chains: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class CascadeRouter:
|
||||
"""Routes LLM requests with automatic failover.
|
||||
|
||||
Now with multi-modal support:
|
||||
- Automatically detects content type (text, vision, audio)
|
||||
- Selects appropriate models based on capabilities
|
||||
- Falls back through capability-specific model chains
|
||||
- Supports image URLs and base64 encoding
|
||||
|
||||
Usage:
|
||||
router = CascadeRouter()
|
||||
|
||||
# Text request
|
||||
response = await router.complete(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
model="llama3.2"
|
||||
)
|
||||
|
||||
# Vision request (automatically detects and selects vision model)
|
||||
response = await router.complete(
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "What's in this image?",
|
||||
"images": ["path/to/image.jpg"]
|
||||
}],
|
||||
model="llava:7b"
|
||||
)
|
||||
|
||||
# Check metrics
|
||||
metrics = router.get_metrics()
|
||||
"""
|
||||
@@ -130,6 +192,14 @@ class CascadeRouter:
|
||||
self.config: RouterConfig = RouterConfig()
|
||||
self._load_config()
|
||||
|
||||
# Initialize multi-modal manager if available
|
||||
self._mm_manager: Optional[Any] = None
|
||||
try:
|
||||
from infrastructure.models.multimodal import get_multimodal_manager
|
||||
self._mm_manager = get_multimodal_manager()
|
||||
except Exception as exc:
|
||||
logger.debug("Multi-modal manager not available: %s", exc)
|
||||
|
||||
logger.info("CascadeRouter initialized with %d providers", len(self.providers))
|
||||
|
||||
def _load_config(self) -> None:
|
||||
@@ -149,6 +219,13 @@ class CascadeRouter:
|
||||
|
||||
# Load cascade settings
|
||||
cascade = data.get("cascade", {})
|
||||
|
||||
# Load fallback chains
|
||||
fallback_chains = data.get("fallback_chains", {})
|
||||
|
||||
# Load multi-modal settings
|
||||
multimodal = data.get("multimodal", {})
|
||||
|
||||
self.config = RouterConfig(
|
||||
timeout_seconds=cascade.get("timeout_seconds", 30),
|
||||
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
|
||||
@@ -156,6 +233,8 @@ class CascadeRouter:
|
||||
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get("failure_threshold", 5),
|
||||
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get("recovery_timeout", 60),
|
||||
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get("half_open_max_calls", 2),
|
||||
auto_pull_models=multimodal.get("auto_pull", True),
|
||||
fallback_chains=fallback_chains,
|
||||
)
|
||||
|
||||
# Load providers
|
||||
@@ -226,6 +305,81 @@ class CascadeRouter:
|
||||
|
||||
return True
|
||||
|
||||
def _detect_content_type(self, messages: list[dict]) -> ContentType:
|
||||
"""Detect the type of content in the messages.
|
||||
|
||||
Checks for images, audio, etc. in the message content.
|
||||
"""
|
||||
has_image = False
|
||||
has_audio = False
|
||||
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
|
||||
# Check for image URLs/paths
|
||||
if msg.get("images"):
|
||||
has_image = True
|
||||
|
||||
# Check for image URLs in content
|
||||
if isinstance(content, str):
|
||||
image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp')
|
||||
if any(ext in content.lower() for ext in image_extensions):
|
||||
has_image = True
|
||||
if content.startswith("data:image/"):
|
||||
has_image = True
|
||||
|
||||
# Check for audio
|
||||
if msg.get("audio"):
|
||||
has_audio = True
|
||||
|
||||
# Check for multimodal content structure
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "image_url":
|
||||
has_image = True
|
||||
elif item.get("type") == "audio":
|
||||
has_audio = True
|
||||
|
||||
if has_image and has_audio:
|
||||
return ContentType.MULTIMODAL
|
||||
elif has_image:
|
||||
return ContentType.VISION
|
||||
elif has_audio:
|
||||
return ContentType.AUDIO
|
||||
return ContentType.TEXT
|
||||
|
||||
def _get_fallback_model(
|
||||
self,
|
||||
provider: Provider,
|
||||
original_model: str,
|
||||
content_type: ContentType
|
||||
) -> Optional[str]:
|
||||
"""Get a fallback model for the given content type."""
|
||||
# Map content type to capability
|
||||
capability_map = {
|
||||
ContentType.VISION: "vision",
|
||||
ContentType.AUDIO: "audio",
|
||||
ContentType.MULTIMODAL: "vision", # Vision models often do both
|
||||
}
|
||||
|
||||
capability = capability_map.get(content_type)
|
||||
if not capability:
|
||||
return None
|
||||
|
||||
# Check provider's models for capability
|
||||
fallback_model = provider.get_model_with_capability(capability)
|
||||
if fallback_model and fallback_model != original_model:
|
||||
return fallback_model
|
||||
|
||||
# Use fallback chains from config
|
||||
fallback_chain = self.config.fallback_chains.get(capability, [])
|
||||
for model_name in fallback_chain:
|
||||
if provider.model_has_capability(model_name, capability):
|
||||
return model_name
|
||||
|
||||
return None
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict],
|
||||
@@ -235,6 +389,11 @@ class CascadeRouter:
|
||||
) -> dict:
|
||||
"""Complete a chat conversation with automatic failover.
|
||||
|
||||
Multi-modal support:
|
||||
- Automatically detects if messages contain images
|
||||
- Falls back to vision-capable models when needed
|
||||
- Supports image URLs, paths, and base64 encoding
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with role and content
|
||||
model: Preferred model (tries this first, then provider defaults)
|
||||
@@ -247,6 +406,11 @@ class CascadeRouter:
|
||||
Raises:
|
||||
RuntimeError: If all providers fail
|
||||
"""
|
||||
# Detect content type for multi-modal routing
|
||||
content_type = self._detect_content_type(messages)
|
||||
if content_type != ContentType.TEXT:
|
||||
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
|
||||
|
||||
errors = []
|
||||
|
||||
for provider in self.providers:
|
||||
@@ -266,15 +430,48 @@ class CascadeRouter:
|
||||
logger.debug("Skipping %s (circuit open)", provider.name)
|
||||
continue
|
||||
|
||||
# Determine which model to use
|
||||
selected_model = model or provider.get_default_model()
|
||||
is_fallback_model = False
|
||||
|
||||
# For non-text content, check if model supports it
|
||||
if content_type != ContentType.TEXT and selected_model:
|
||||
if provider.type == "ollama" and self._mm_manager:
|
||||
from infrastructure.models.multimodal import ModelCapability
|
||||
|
||||
# Check if selected model supports the required capability
|
||||
if content_type == ContentType.VISION:
|
||||
supports = self._mm_manager.model_supports(
|
||||
selected_model, ModelCapability.VISION
|
||||
)
|
||||
if not supports:
|
||||
# Find fallback model
|
||||
fallback = self._get_fallback_model(
|
||||
provider, selected_model, content_type
|
||||
)
|
||||
if fallback:
|
||||
logger.info(
|
||||
"Model %s doesn't support vision, falling back to %s",
|
||||
selected_model, fallback
|
||||
)
|
||||
selected_model = fallback
|
||||
is_fallback_model = True
|
||||
else:
|
||||
logger.warning(
|
||||
"No vision-capable model found on %s, trying anyway",
|
||||
provider.name
|
||||
)
|
||||
|
||||
# Try this provider
|
||||
for attempt in range(self.config.max_retries_per_provider):
|
||||
try:
|
||||
result = await self._try_provider(
|
||||
provider=provider,
|
||||
messages=messages,
|
||||
model=model,
|
||||
model=selected_model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
# Success! Update metrics and return
|
||||
@@ -282,8 +479,9 @@ class CascadeRouter:
|
||||
return {
|
||||
"content": result["content"],
|
||||
"provider": provider.name,
|
||||
"model": result.get("model", model or provider.get_default_model()),
|
||||
"model": result.get("model", selected_model or provider.get_default_model()),
|
||||
"latency_ms": result.get("latency_ms", 0),
|
||||
"is_fallback_model": is_fallback_model,
|
||||
}
|
||||
|
||||
except Exception as exc:
|
||||
@@ -307,9 +505,10 @@ class CascadeRouter:
|
||||
self,
|
||||
provider: Provider,
|
||||
messages: list[dict],
|
||||
model: Optional[str],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: Optional[int],
|
||||
content_type: ContentType = ContentType.TEXT,
|
||||
) -> dict:
|
||||
"""Try a single provider request."""
|
||||
start_time = time.time()
|
||||
@@ -320,6 +519,7 @@ class CascadeRouter:
|
||||
messages=messages,
|
||||
model=model or provider.get_default_model(),
|
||||
temperature=temperature,
|
||||
content_type=content_type,
|
||||
)
|
||||
elif provider.type == "openai":
|
||||
result = await self._call_openai(
|
||||
@@ -359,15 +559,19 @@ class CascadeRouter:
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
temperature: float,
|
||||
content_type: ContentType = ContentType.TEXT,
|
||||
) -> dict:
|
||||
"""Call Ollama API."""
|
||||
"""Call Ollama API with multi-modal support."""
|
||||
import aiohttp
|
||||
|
||||
url = f"{provider.url}/api/chat"
|
||||
|
||||
# Transform messages for Ollama format (including images)
|
||||
transformed_messages = self._transform_messages_for_ollama(messages)
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"messages": transformed_messages,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": temperature,
|
||||
@@ -388,6 +592,41 @@ class CascadeRouter:
|
||||
"model": model,
|
||||
}
|
||||
|
||||
def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]:
|
||||
"""Transform messages to Ollama format, handling images."""
|
||||
transformed = []
|
||||
|
||||
for msg in messages:
|
||||
new_msg = {
|
||||
"role": msg.get("role", "user"),
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
|
||||
# Handle images
|
||||
images = msg.get("images", [])
|
||||
if images:
|
||||
new_msg["images"] = []
|
||||
for img in images:
|
||||
if isinstance(img, str):
|
||||
if img.startswith("data:image/"):
|
||||
# Base64 encoded image
|
||||
new_msg["images"].append(img.split(",")[1])
|
||||
elif img.startswith("http://") or img.startswith("https://"):
|
||||
# URL - would need to download, skip for now
|
||||
logger.warning("Image URLs not yet supported, skipping: %s", img)
|
||||
elif Path(img).exists():
|
||||
# Local file path - read and encode
|
||||
try:
|
||||
with open(img, "rb") as f:
|
||||
img_data = base64.b64encode(f.read()).decode()
|
||||
new_msg["images"].append(img_data)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to read image %s: %s", img, exc)
|
||||
|
||||
transformed.append(new_msg)
|
||||
|
||||
return transformed
|
||||
|
||||
async def _call_openai(
|
||||
self,
|
||||
provider: Provider,
|
||||
@@ -496,7 +735,7 @@ class CascadeRouter:
|
||||
"content": response.choices[0].message.content,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
|
||||
def _record_success(self, provider: Provider, latency_ms: float) -> None:
|
||||
"""Record a successful request."""
|
||||
provider.metrics.total_requests += 1
|
||||
@@ -598,6 +837,35 @@ class CascadeRouter:
|
||||
for p in self.providers
|
||||
],
|
||||
}
|
||||
|
||||
async def generate_with_image(
|
||||
self,
|
||||
prompt: str,
|
||||
image_path: str,
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
) -> dict:
|
||||
"""Convenience method for vision requests.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt about the image
|
||||
image_path: Path to image file
|
||||
model: Vision-capable model (auto-selected if not provided)
|
||||
temperature: Sampling temperature
|
||||
|
||||
Returns:
|
||||
Response dict with content and metadata
|
||||
"""
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"images": [image_path],
|
||||
}]
|
||||
return await self.complete(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
|
||||
Reference in New Issue
Block a user