1
0

feat: Multi-modal support with automatic model fallback

- Add MultiModalManager with capability detection for vision/audio/tools
- Define fallback chains: vision (llama3.2:3b -> llava:7b -> moondream)
                       tools (llama3.1:8b-instruct -> qwen2.5:7b)
- Update CascadeRouter to detect content type and select appropriate models
- Add model pulling with automatic fallback in agent creation
- Update providers.yaml with multi-modal model configurations
- Update OllamaAdapter to use model resolution with vision support

Tests: All 96 infrastructure tests pass
This commit is contained in:
Alexander Payne
2026-02-26 22:29:44 -05:00
parent a85661274c
commit 72a58f1f49
8 changed files with 990 additions and 14 deletions

View File

@@ -3,14 +3,19 @@
Routes requests through an ordered list of LLM providers,
automatically failing over on rate limits or errors.
Tracks metrics for latency, errors, and cost.
Now with multi-modal support — automatically selects vision-capable
models for image inputs and falls back through capability chains.
"""
import asyncio
import base64
import logging
import time
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from enum import Enum
from pathlib import Path
from typing import Any, Optional
from pathlib import Path
@@ -43,6 +48,14 @@ class CircuitState(Enum):
HALF_OPEN = "half_open" # Testing if recovered
class ContentType(Enum):
"""Type of content in the request."""
TEXT = "text"
VISION = "vision" # Contains images
AUDIO = "audio" # Contains audio
MULTIMODAL = "multimodal" # Multiple content types
@dataclass
class ProviderMetrics:
"""Metrics for a single provider."""
@@ -67,6 +80,18 @@ class ProviderMetrics:
return self.failed_requests / self.total_requests
@dataclass
class ModelCapability:
"""Capabilities a model supports."""
name: str
supports_vision: bool = False
supports_audio: bool = False
supports_tools: bool = False
supports_json: bool = False
supports_streaming: bool = True
context_window: int = 4096
@dataclass
class Provider:
"""LLM provider configuration and state."""
@@ -94,6 +119,23 @@ class Provider:
if self.models:
return self.models[0]["name"]
return None
def get_model_with_capability(self, capability: str) -> Optional[str]:
"""Get a model that supports the given capability."""
for model in self.models:
capabilities = model.get("capabilities", [])
if capability in capabilities:
return model["name"]
# Fall back to default
return self.get_default_model()
def model_has_capability(self, model_name: str, capability: str) -> bool:
"""Check if a specific model has a capability."""
for model in self.models:
if model["name"] == model_name:
capabilities = model.get("capabilities", [])
return capability in capabilities
return False
@dataclass
@@ -107,19 +149,39 @@ class RouterConfig:
circuit_breaker_half_open_max_calls: int = 2
cost_tracking_enabled: bool = True
budget_daily_usd: float = 10.0
# Multi-modal settings
auto_pull_models: bool = True
fallback_chains: dict = field(default_factory=dict)
class CascadeRouter:
"""Routes LLM requests with automatic failover.
Now with multi-modal support:
- Automatically detects content type (text, vision, audio)
- Selects appropriate models based on capabilities
- Falls back through capability-specific model chains
- Supports image URLs and base64 encoding
Usage:
router = CascadeRouter()
# Text request
response = await router.complete(
messages=[{"role": "user", "content": "Hello"}],
model="llama3.2"
)
# Vision request (automatically detects and selects vision model)
response = await router.complete(
messages=[{
"role": "user",
"content": "What's in this image?",
"images": ["path/to/image.jpg"]
}],
model="llava:7b"
)
# Check metrics
metrics = router.get_metrics()
"""
@@ -130,6 +192,14 @@ class CascadeRouter:
self.config: RouterConfig = RouterConfig()
self._load_config()
# Initialize multi-modal manager if available
self._mm_manager: Optional[Any] = None
try:
from infrastructure.models.multimodal import get_multimodal_manager
self._mm_manager = get_multimodal_manager()
except Exception as exc:
logger.debug("Multi-modal manager not available: %s", exc)
logger.info("CascadeRouter initialized with %d providers", len(self.providers))
def _load_config(self) -> None:
@@ -149,6 +219,13 @@ class CascadeRouter:
# Load cascade settings
cascade = data.get("cascade", {})
# Load fallback chains
fallback_chains = data.get("fallback_chains", {})
# Load multi-modal settings
multimodal = data.get("multimodal", {})
self.config = RouterConfig(
timeout_seconds=cascade.get("timeout_seconds", 30),
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
@@ -156,6 +233,8 @@ class CascadeRouter:
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get("failure_threshold", 5),
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get("recovery_timeout", 60),
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get("half_open_max_calls", 2),
auto_pull_models=multimodal.get("auto_pull", True),
fallback_chains=fallback_chains,
)
# Load providers
@@ -226,6 +305,81 @@ class CascadeRouter:
return True
def _detect_content_type(self, messages: list[dict]) -> ContentType:
"""Detect the type of content in the messages.
Checks for images, audio, etc. in the message content.
"""
has_image = False
has_audio = False
for msg in messages:
content = msg.get("content", "")
# Check for image URLs/paths
if msg.get("images"):
has_image = True
# Check for image URLs in content
if isinstance(content, str):
image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp')
if any(ext in content.lower() for ext in image_extensions):
has_image = True
if content.startswith("data:image/"):
has_image = True
# Check for audio
if msg.get("audio"):
has_audio = True
# Check for multimodal content structure
if isinstance(content, list):
for item in content:
if isinstance(item, dict):
if item.get("type") == "image_url":
has_image = True
elif item.get("type") == "audio":
has_audio = True
if has_image and has_audio:
return ContentType.MULTIMODAL
elif has_image:
return ContentType.VISION
elif has_audio:
return ContentType.AUDIO
return ContentType.TEXT
def _get_fallback_model(
self,
provider: Provider,
original_model: str,
content_type: ContentType
) -> Optional[str]:
"""Get a fallback model for the given content type."""
# Map content type to capability
capability_map = {
ContentType.VISION: "vision",
ContentType.AUDIO: "audio",
ContentType.MULTIMODAL: "vision", # Vision models often do both
}
capability = capability_map.get(content_type)
if not capability:
return None
# Check provider's models for capability
fallback_model = provider.get_model_with_capability(capability)
if fallback_model and fallback_model != original_model:
return fallback_model
# Use fallback chains from config
fallback_chain = self.config.fallback_chains.get(capability, [])
for model_name in fallback_chain:
if provider.model_has_capability(model_name, capability):
return model_name
return None
async def complete(
self,
messages: list[dict],
@@ -235,6 +389,11 @@ class CascadeRouter:
) -> dict:
"""Complete a chat conversation with automatic failover.
Multi-modal support:
- Automatically detects if messages contain images
- Falls back to vision-capable models when needed
- Supports image URLs, paths, and base64 encoding
Args:
messages: List of message dicts with role and content
model: Preferred model (tries this first, then provider defaults)
@@ -247,6 +406,11 @@ class CascadeRouter:
Raises:
RuntimeError: If all providers fail
"""
# Detect content type for multi-modal routing
content_type = self._detect_content_type(messages)
if content_type != ContentType.TEXT:
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
errors = []
for provider in self.providers:
@@ -266,15 +430,48 @@ class CascadeRouter:
logger.debug("Skipping %s (circuit open)", provider.name)
continue
# Determine which model to use
selected_model = model or provider.get_default_model()
is_fallback_model = False
# For non-text content, check if model supports it
if content_type != ContentType.TEXT and selected_model:
if provider.type == "ollama" and self._mm_manager:
from infrastructure.models.multimodal import ModelCapability
# Check if selected model supports the required capability
if content_type == ContentType.VISION:
supports = self._mm_manager.model_supports(
selected_model, ModelCapability.VISION
)
if not supports:
# Find fallback model
fallback = self._get_fallback_model(
provider, selected_model, content_type
)
if fallback:
logger.info(
"Model %s doesn't support vision, falling back to %s",
selected_model, fallback
)
selected_model = fallback
is_fallback_model = True
else:
logger.warning(
"No vision-capable model found on %s, trying anyway",
provider.name
)
# Try this provider
for attempt in range(self.config.max_retries_per_provider):
try:
result = await self._try_provider(
provider=provider,
messages=messages,
model=model,
model=selected_model,
temperature=temperature,
max_tokens=max_tokens,
content_type=content_type,
)
# Success! Update metrics and return
@@ -282,8 +479,9 @@ class CascadeRouter:
return {
"content": result["content"],
"provider": provider.name,
"model": result.get("model", model or provider.get_default_model()),
"model": result.get("model", selected_model or provider.get_default_model()),
"latency_ms": result.get("latency_ms", 0),
"is_fallback_model": is_fallback_model,
}
except Exception as exc:
@@ -307,9 +505,10 @@ class CascadeRouter:
self,
provider: Provider,
messages: list[dict],
model: Optional[str],
model: str,
temperature: float,
max_tokens: Optional[int],
content_type: ContentType = ContentType.TEXT,
) -> dict:
"""Try a single provider request."""
start_time = time.time()
@@ -320,6 +519,7 @@ class CascadeRouter:
messages=messages,
model=model or provider.get_default_model(),
temperature=temperature,
content_type=content_type,
)
elif provider.type == "openai":
result = await self._call_openai(
@@ -359,15 +559,19 @@ class CascadeRouter:
messages: list[dict],
model: str,
temperature: float,
content_type: ContentType = ContentType.TEXT,
) -> dict:
"""Call Ollama API."""
"""Call Ollama API with multi-modal support."""
import aiohttp
url = f"{provider.url}/api/chat"
# Transform messages for Ollama format (including images)
transformed_messages = self._transform_messages_for_ollama(messages)
payload = {
"model": model,
"messages": messages,
"messages": transformed_messages,
"stream": False,
"options": {
"temperature": temperature,
@@ -388,6 +592,41 @@ class CascadeRouter:
"model": model,
}
def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]:
"""Transform messages to Ollama format, handling images."""
transformed = []
for msg in messages:
new_msg = {
"role": msg.get("role", "user"),
"content": msg.get("content", ""),
}
# Handle images
images = msg.get("images", [])
if images:
new_msg["images"] = []
for img in images:
if isinstance(img, str):
if img.startswith("data:image/"):
# Base64 encoded image
new_msg["images"].append(img.split(",")[1])
elif img.startswith("http://") or img.startswith("https://"):
# URL - would need to download, skip for now
logger.warning("Image URLs not yet supported, skipping: %s", img)
elif Path(img).exists():
# Local file path - read and encode
try:
with open(img, "rb") as f:
img_data = base64.b64encode(f.read()).decode()
new_msg["images"].append(img_data)
except Exception as exc:
logger.error("Failed to read image %s: %s", img, exc)
transformed.append(new_msg)
return transformed
async def _call_openai(
self,
provider: Provider,
@@ -496,7 +735,7 @@ class CascadeRouter:
"content": response.choices[0].message.content,
"model": response.model,
}
def _record_success(self, provider: Provider, latency_ms: float) -> None:
"""Record a successful request."""
provider.metrics.total_requests += 1
@@ -598,6 +837,35 @@ class CascadeRouter:
for p in self.providers
],
}
async def generate_with_image(
self,
prompt: str,
image_path: str,
model: Optional[str] = None,
temperature: float = 0.7,
) -> dict:
"""Convenience method for vision requests.
Args:
prompt: Text prompt about the image
image_path: Path to image file
model: Vision-capable model (auto-selected if not provided)
temperature: Sampling temperature
Returns:
Response dict with content and metadata
"""
messages = [{
"role": "user",
"content": prompt,
"images": [image_path],
}]
return await self.complete(
messages=messages,
model=model,
temperature=temperature,
)
# Module-level singleton