feat: Multi-modal support with automatic model fallback

- Add MultiModalManager with capability detection for vision/audio/tools
- Define fallback chains: vision (llama3.2:3b -> llava:7b -> moondream)
                       tools (llama3.1:8b-instruct -> qwen2.5:7b)
- Update CascadeRouter to detect content type and select appropriate models
- Add model pulling with automatic fallback in agent creation
- Update providers.yaml with multi-modal model configurations
- Update OllamaAdapter to use model resolution with vision support

Tests: All 96 infrastructure tests pass
This commit is contained in:
Alexander Payne
2026-02-26 22:29:44 -05:00
parent a85661274c
commit 72a58f1f49
8 changed files with 990 additions and 14 deletions

View File

@@ -24,11 +24,31 @@ providers:
priority: 1
url: "http://localhost:11434"
models:
- name: llama3.2
# Text + Tools models
- name: llama3.1:8b-instruct
default: true
context_window: 128000
capabilities: [text, tools, json, streaming]
- name: llama3.2:3b
context_window: 128000
capabilities: [text, tools, json, streaming, vision]
- name: qwen2.5:14b
context_window: 32000
capabilities: [text, tools, json, streaming]
- name: deepseek-r1:1.5b
context_window: 32000
capabilities: [text, json, streaming]
# Vision models
- name: llava:7b
context_window: 4096
capabilities: [text, vision, streaming]
- name: qwen2.5-vl:3b
context_window: 32000
capabilities: [text, vision, tools, json, streaming]
- name: moondream:1.8b
context_window: 2048
capabilities: [text, vision, streaming]
# Secondary: Local AirLLM (if installed)
- name: airllm-local
@@ -38,8 +58,11 @@ providers:
models:
- name: 70b
default: true
capabilities: [text, tools, json, streaming]
- name: 8b
capabilities: [text, tools, json, streaming]
- name: 405b
capabilities: [text, tools, json, streaming]
# Tertiary: OpenAI (if API key available)
- name: openai-backup
@@ -52,8 +75,10 @@ providers:
- name: gpt-4o-mini
default: true
context_window: 128000
capabilities: [text, vision, tools, json, streaming]
- name: gpt-4o
context_window: 128000
capabilities: [text, vision, tools, json, streaming]
# Quaternary: Anthropic (if API key available)
- name: anthropic-backup
@@ -65,10 +90,37 @@ providers:
- name: claude-3-haiku-20240307
default: true
context_window: 200000
capabilities: [text, vision, streaming]
- name: claude-3-sonnet-20240229
context_window: 200000
capabilities: [text, vision, tools, streaming]
# ── Custom Models ──────────────────────────────────────────────────────
# ── Capability-Based Fallback Chains ────────────────────────────────────────
# When a model doesn't support a required capability (e.g., vision),
# the system falls back through these chains in order.
fallback_chains:
# Vision-capable models (for image understanding)
vision:
- llama3.2:3b # Fast, good vision
- qwen2.5-vl:3b # Excellent vision, small
- llava:7b # Classic vision model
- moondream:1.8b # Tiny, fast vision
# Tool-calling models (for function calling)
tools:
- llama3.1:8b-instruct # Best tool use
- qwen2.5:7b # Reliable tools
- llama3.2:3b # Small but capable
# General text generation (any model)
text:
- llama3.1:8b-instruct
- qwen2.5:14b
- deepseek-r1:1.5b
- llama3.2:3b
# ── Custom Models ───────────────────────────────────────────────────────────
# Register custom model weights for per-agent assignment.
# Supports GGUF (Ollama), safetensors, and HuggingFace checkpoint dirs.
# Models can also be registered at runtime via the /api/v1/models API.
@@ -91,7 +143,7 @@ custom_models: []
# context_window: 32000
# description: "Process reward model for scoring outputs"
# ── Agent Model Assignments ─────────────────────────────────────────────
# ── Agent Model Assignments ─────────────────────────────────────────────────
# Map persona agent IDs to specific models.
# Agents without an assignment use the global default (ollama_model).
agent_model_assignments: {}
@@ -99,6 +151,20 @@ agent_model_assignments: {}
# persona-forge: my-finetuned-llama
# persona-echo: deepseek-r1:1.5b
# ── Multi-Modal Settings ────────────────────────────────────────────────────
multimodal:
# Automatically pull models when needed
auto_pull: true
# Timeout for model pulling (seconds)
pull_timeout: 300
# Maximum fallback depth (how many models to try before giving up)
max_fallback_depth: 3
# Prefer smaller models for vision when available (faster)
prefer_small_vision: true
# Cost tracking (optional, for budget monitoring)
cost_tracking:
enabled: true

BIN
data/scripture.db-shm Normal file

Binary file not shown.

BIN
data/scripture.db-wal Normal file

Binary file not shown.

View File

@@ -0,0 +1,37 @@
"""Infrastructure models package."""
from infrastructure.models.registry import (
CustomModel,
ModelFormat,
ModelRegistry,
ModelRole,
model_registry,
)
from infrastructure.models.multimodal import (
ModelCapability,
ModelInfo,
MultiModalManager,
get_model_for_capability,
get_multimodal_manager,
model_supports_tools,
model_supports_vision,
pull_model_with_fallback,
)
__all__ = [
# Registry
"CustomModel",
"ModelFormat",
"ModelRegistry",
"ModelRole",
"model_registry",
# Multi-modal
"ModelCapability",
"ModelInfo",
"MultiModalManager",
"get_model_for_capability",
"get_multimodal_manager",
"model_supports_tools",
"model_supports_vision",
"pull_model_with_fallback",
]

View File

@@ -0,0 +1,445 @@
"""Multi-modal model support with automatic capability detection and fallbacks.
Provides:
- Model capability detection (vision, audio, etc.)
- Automatic model pulling with fallback chains
- Content-type aware model selection
- Graceful degradation when primary models unavailable
No cloud by default — tries local first, falls back through configured options.
"""
import logging
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Optional
from config import settings
logger = logging.getLogger(__name__)
class ModelCapability(Enum):
"""Capabilities a model can have."""
TEXT = auto() # Standard text completion
VISION = auto() # Image understanding
AUDIO = auto() # Audio/speech processing
TOOLS = auto() # Function calling / tool use
JSON = auto() # Structured output / JSON mode
STREAMING = auto() # Streaming responses
# Known model capabilities (local Ollama models)
# These are used when we can't query the model directly
KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
# Llama 3.x series
"llama3.1": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"llama3.1:8b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"llama3.1:8b-instruct": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"llama3.1:70b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"llama3.1:405b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"llama3.2": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
"llama3.2:1b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"llama3.2:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
"llama3.2-vision": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
"llama3.2-vision:11b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
# Qwen series
"qwen2.5": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"qwen2.5:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"qwen2.5:14b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"qwen2.5:32b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"qwen2.5:72b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"qwen2.5-vl": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
"qwen2.5-vl:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
"qwen2.5-vl:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
# DeepSeek series
"deepseek-r1": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"deepseek-r1:1.5b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"deepseek-r1:7b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"deepseek-r1:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"deepseek-r1:32b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"deepseek-r1:70b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"deepseek-v3": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
# Gemma series
"gemma2": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"gemma2:2b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"gemma2:9b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"gemma2:27b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
# Mistral series
"mistral": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"mistral:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"mistral-nemo": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"mistral-small": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"mistral-large": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
# Vision-specific models
"llava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
"llava:7b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
"llava:13b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
"llava:34b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
"llava-phi3": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
"llava-llama3": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
"bakllava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
"moondream": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
"moondream:1.8b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
# Phi series
"phi3": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"phi3:3.8b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"phi3:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"phi4": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
# Command R
"command-r": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"command-r:35b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"command-r-plus": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
# Granite (IBM)
"granite3-dense": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
"granite3-moe": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
}
# Default fallback chains for each capability
# These are tried in order when the primary model doesn't support a capability
DEFAULT_FALLBACK_CHAINS: dict[ModelCapability, list[str]] = {
ModelCapability.VISION: [
"llama3.2:3b", # Fast vision model
"llava:7b", # Classic vision model
"qwen2.5-vl:3b", # Qwen vision
"moondream:1.8b", # Tiny vision model (last resort)
],
ModelCapability.TOOLS: [
"llama3.1:8b-instruct", # Best tool use
"llama3.2:3b", # Smaller but capable
"qwen2.5:7b", # Reliable fallback
],
ModelCapability.AUDIO: [
# Audio models are less common in Ollama
# Would need specific audio-capable models here
],
}
@dataclass
class ModelInfo:
"""Information about a model's capabilities and availability."""
name: str
capabilities: set[ModelCapability] = field(default_factory=set)
is_available: bool = False
is_pulled: bool = False
size_mb: Optional[int] = None
description: str = ""
def supports(self, capability: ModelCapability) -> bool:
"""Check if model supports a specific capability."""
return capability in self.capabilities
class MultiModalManager:
"""Manages multi-modal model capabilities and fallback chains.
This class:
1. Detects what capabilities each model has
2. Maintains fallback chains for different capabilities
3. Pulls models on-demand with automatic fallback
4. Routes requests to appropriate models based on content type
"""
def __init__(self, ollama_url: Optional[str] = None) -> None:
self.ollama_url = ollama_url or settings.ollama_url
self._available_models: dict[str, ModelInfo] = {}
self._fallback_chains: dict[ModelCapability, list[str]] = dict(DEFAULT_FALLBACK_CHAINS)
self._refresh_available_models()
def _refresh_available_models(self) -> None:
"""Query Ollama for available models."""
try:
import urllib.request
import json
url = self.ollama_url.replace("localhost", "127.0.0.1")
req = urllib.request.Request(
f"{url}/api/tags",
method="GET",
headers={"Accept": "application/json"},
)
with urllib.request.urlopen(req, timeout=5) as response:
data = json.loads(response.read().decode())
for model_data in data.get("models", []):
name = model_data.get("name", "")
self._available_models[name] = ModelInfo(
name=name,
capabilities=self._detect_capabilities(name),
is_available=True,
is_pulled=True,
size_mb=model_data.get("size", 0) // (1024 * 1024),
description=model_data.get("details", {}).get("family", ""),
)
logger.info("Found %d models in Ollama", len(self._available_models))
except Exception as exc:
logger.warning("Could not refresh available models: %s", exc)
def _detect_capabilities(self, model_name: str) -> set[ModelCapability]:
"""Detect capabilities for a model based on known data."""
# Normalize model name (strip tags for lookup)
base_name = model_name.split(":")[0]
# Try exact match first
if model_name in KNOWN_MODEL_CAPABILITIES:
return set(KNOWN_MODEL_CAPABILITIES[model_name])
# Try base name match
if base_name in KNOWN_MODEL_CAPABILITIES:
return set(KNOWN_MODEL_CAPABILITIES[base_name])
# Default to text-only for unknown models
logger.debug("Unknown model %s, defaulting to TEXT only", model_name)
return {ModelCapability.TEXT, ModelCapability.STREAMING}
def get_model_capabilities(self, model_name: str) -> set[ModelCapability]:
"""Get capabilities for a specific model."""
if model_name in self._available_models:
return self._available_models[model_name].capabilities
return self._detect_capabilities(model_name)
def model_supports(self, model_name: str, capability: ModelCapability) -> bool:
"""Check if a model supports a specific capability."""
capabilities = self.get_model_capabilities(model_name)
return capability in capabilities
def get_models_with_capability(self, capability: ModelCapability) -> list[ModelInfo]:
"""Get all available models that support a capability."""
return [
info for info in self._available_models.values()
if capability in info.capabilities
]
def get_best_model_for(
self,
capability: ModelCapability,
preferred_model: Optional[str] = None
) -> Optional[str]:
"""Get the best available model for a specific capability.
Args:
capability: The required capability
preferred_model: Preferred model to use if available and capable
Returns:
Model name or None if no suitable model found
"""
# Check if preferred model supports this capability
if preferred_model:
if preferred_model in self._available_models:
if self.model_supports(preferred_model, capability):
return preferred_model
logger.debug(
"Preferred model %s doesn't support %s, checking fallbacks",
preferred_model, capability.name
)
# Check fallback chain for this capability
fallback_chain = self._fallback_chains.get(capability, [])
for model_name in fallback_chain:
if model_name in self._available_models:
logger.debug("Using fallback model %s for %s", model_name, capability.name)
return model_name
# Find any available model with this capability
capable_models = self.get_models_with_capability(capability)
if capable_models:
# Sort by size (prefer smaller/faster models as fallback)
capable_models.sort(key=lambda m: m.size_mb or float('inf'))
return capable_models[0].name
return None
def pull_model_with_fallback(
self,
primary_model: str,
capability: Optional[ModelCapability] = None,
auto_pull: bool = True,
) -> tuple[str, bool]:
"""Pull a model with automatic fallback if unavailable.
Args:
primary_model: The desired model to use
capability: Required capability (for finding fallback)
auto_pull: Whether to attempt pulling missing models
Returns:
Tuple of (model_name, is_fallback)
"""
# Check if primary model is already available
if primary_model in self._available_models:
return primary_model, False
# Try to pull the primary model
if auto_pull:
if self._pull_model(primary_model):
return primary_model, False
# Need to find a fallback
if capability:
fallback = self.get_best_model_for(capability, primary_model)
if fallback:
logger.info(
"Primary model %s unavailable, using fallback %s",
primary_model, fallback
)
return fallback, True
# Last resort: use the configured default model
default_model = settings.ollama_model
if default_model in self._available_models:
logger.warning(
"Falling back to default model %s (primary: %s unavailable)",
default_model, primary_model
)
return default_model, True
# Absolute last resort
return primary_model, False
def _pull_model(self, model_name: str) -> bool:
"""Attempt to pull a model from Ollama.
Returns:
True if successful or model already exists
"""
try:
import urllib.request
import json
logger.info("Pulling model: %s", model_name)
url = self.ollama_url.replace("localhost", "127.0.0.1")
req = urllib.request.Request(
f"{url}/api/pull",
method="POST",
headers={"Content-Type": "application/json"},
data=json.dumps({"name": model_name, "stream": False}).encode(),
)
with urllib.request.urlopen(req, timeout=300) as response:
if response.status == 200:
logger.info("Successfully pulled model: %s", model_name)
# Refresh available models
self._refresh_available_models()
return True
else:
logger.error("Failed to pull %s: HTTP %s", model_name, response.status)
return False
except Exception as exc:
logger.error("Error pulling model %s: %s", model_name, exc)
return False
def configure_fallback_chain(
self,
capability: ModelCapability,
models: list[str]
) -> None:
"""Configure a custom fallback chain for a capability."""
self._fallback_chains[capability] = models
logger.info("Configured fallback chain for %s: %s", capability.name, models)
def get_fallback_chain(self, capability: ModelCapability) -> list[str]:
"""Get the fallback chain for a capability."""
return list(self._fallback_chains.get(capability, []))
def list_available_models(self) -> list[ModelInfo]:
"""List all available models with their capabilities."""
return list(self._available_models.values())
def refresh(self) -> None:
"""Refresh the list of available models."""
self._refresh_available_models()
def get_model_for_content(
self,
content_type: str, # "text", "image", "audio", "multimodal"
preferred_model: Optional[str] = None,
) -> tuple[str, bool]:
"""Get appropriate model based on content type.
Args:
content_type: Type of content (text, image, audio, multimodal)
preferred_model: User's preferred model
Returns:
Tuple of (model_name, is_fallback)
"""
content_type = content_type.lower()
if content_type in ("image", "vision", "multimodal"):
# For vision content, we need a vision-capable model
return self.pull_model_with_fallback(
preferred_model or "llava:7b",
capability=ModelCapability.VISION,
)
elif content_type == "audio":
# Audio support is limited in Ollama
# Would need specific audio models
logger.warning("Audio support is limited, falling back to text model")
return self.pull_model_with_fallback(
preferred_model or settings.ollama_model,
capability=ModelCapability.TEXT,
)
else:
# Standard text content
return self.pull_model_with_fallback(
preferred_model or settings.ollama_model,
capability=ModelCapability.TEXT,
)
# Module-level singleton
_multimodal_manager: Optional[MultiModalManager] = None
def get_multimodal_manager() -> MultiModalManager:
"""Get or create the multi-modal manager singleton."""
global _multimodal_manager
if _multimodal_manager is None:
_multimodal_manager = MultiModalManager()
return _multimodal_manager
def get_model_for_capability(
capability: ModelCapability,
preferred_model: Optional[str] = None
) -> Optional[str]:
"""Convenience function to get best model for a capability."""
return get_multimodal_manager().get_best_model_for(capability, preferred_model)
def pull_model_with_fallback(
primary_model: str,
capability: Optional[ModelCapability] = None,
auto_pull: bool = True,
) -> tuple[str, bool]:
"""Convenience function to pull model with fallback."""
return get_multimodal_manager().pull_model_with_fallback(
primary_model, capability, auto_pull
)
def model_supports_vision(model_name: str) -> bool:
"""Check if a model supports vision."""
return get_multimodal_manager().model_supports(model_name, ModelCapability.VISION)
def model_supports_tools(model_name: str) -> bool:
"""Check if a model supports tool calling."""
return get_multimodal_manager().model_supports(model_name, ModelCapability.TOOLS)

View File

@@ -3,14 +3,19 @@
Routes requests through an ordered list of LLM providers,
automatically failing over on rate limits or errors.
Tracks metrics for latency, errors, and cost.
Now with multi-modal support — automatically selects vision-capable
models for image inputs and falls back through capability chains.
"""
import asyncio
import base64
import logging
import time
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from enum import Enum
from pathlib import Path
from typing import Any, Optional
from pathlib import Path
@@ -43,6 +48,14 @@ class CircuitState(Enum):
HALF_OPEN = "half_open" # Testing if recovered
class ContentType(Enum):
"""Type of content in the request."""
TEXT = "text"
VISION = "vision" # Contains images
AUDIO = "audio" # Contains audio
MULTIMODAL = "multimodal" # Multiple content types
@dataclass
class ProviderMetrics:
"""Metrics for a single provider."""
@@ -67,6 +80,18 @@ class ProviderMetrics:
return self.failed_requests / self.total_requests
@dataclass
class ModelCapability:
"""Capabilities a model supports."""
name: str
supports_vision: bool = False
supports_audio: bool = False
supports_tools: bool = False
supports_json: bool = False
supports_streaming: bool = True
context_window: int = 4096
@dataclass
class Provider:
"""LLM provider configuration and state."""
@@ -94,6 +119,23 @@ class Provider:
if self.models:
return self.models[0]["name"]
return None
def get_model_with_capability(self, capability: str) -> Optional[str]:
"""Get a model that supports the given capability."""
for model in self.models:
capabilities = model.get("capabilities", [])
if capability in capabilities:
return model["name"]
# Fall back to default
return self.get_default_model()
def model_has_capability(self, model_name: str, capability: str) -> bool:
"""Check if a specific model has a capability."""
for model in self.models:
if model["name"] == model_name:
capabilities = model.get("capabilities", [])
return capability in capabilities
return False
@dataclass
@@ -107,19 +149,39 @@ class RouterConfig:
circuit_breaker_half_open_max_calls: int = 2
cost_tracking_enabled: bool = True
budget_daily_usd: float = 10.0
# Multi-modal settings
auto_pull_models: bool = True
fallback_chains: dict = field(default_factory=dict)
class CascadeRouter:
"""Routes LLM requests with automatic failover.
Now with multi-modal support:
- Automatically detects content type (text, vision, audio)
- Selects appropriate models based on capabilities
- Falls back through capability-specific model chains
- Supports image URLs and base64 encoding
Usage:
router = CascadeRouter()
# Text request
response = await router.complete(
messages=[{"role": "user", "content": "Hello"}],
model="llama3.2"
)
# Vision request (automatically detects and selects vision model)
response = await router.complete(
messages=[{
"role": "user",
"content": "What's in this image?",
"images": ["path/to/image.jpg"]
}],
model="llava:7b"
)
# Check metrics
metrics = router.get_metrics()
"""
@@ -130,6 +192,14 @@ class CascadeRouter:
self.config: RouterConfig = RouterConfig()
self._load_config()
# Initialize multi-modal manager if available
self._mm_manager: Optional[Any] = None
try:
from infrastructure.models.multimodal import get_multimodal_manager
self._mm_manager = get_multimodal_manager()
except Exception as exc:
logger.debug("Multi-modal manager not available: %s", exc)
logger.info("CascadeRouter initialized with %d providers", len(self.providers))
def _load_config(self) -> None:
@@ -149,6 +219,13 @@ class CascadeRouter:
# Load cascade settings
cascade = data.get("cascade", {})
# Load fallback chains
fallback_chains = data.get("fallback_chains", {})
# Load multi-modal settings
multimodal = data.get("multimodal", {})
self.config = RouterConfig(
timeout_seconds=cascade.get("timeout_seconds", 30),
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
@@ -156,6 +233,8 @@ class CascadeRouter:
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get("failure_threshold", 5),
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get("recovery_timeout", 60),
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get("half_open_max_calls", 2),
auto_pull_models=multimodal.get("auto_pull", True),
fallback_chains=fallback_chains,
)
# Load providers
@@ -226,6 +305,81 @@ class CascadeRouter:
return True
def _detect_content_type(self, messages: list[dict]) -> ContentType:
"""Detect the type of content in the messages.
Checks for images, audio, etc. in the message content.
"""
has_image = False
has_audio = False
for msg in messages:
content = msg.get("content", "")
# Check for image URLs/paths
if msg.get("images"):
has_image = True
# Check for image URLs in content
if isinstance(content, str):
image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp')
if any(ext in content.lower() for ext in image_extensions):
has_image = True
if content.startswith("data:image/"):
has_image = True
# Check for audio
if msg.get("audio"):
has_audio = True
# Check for multimodal content structure
if isinstance(content, list):
for item in content:
if isinstance(item, dict):
if item.get("type") == "image_url":
has_image = True
elif item.get("type") == "audio":
has_audio = True
if has_image and has_audio:
return ContentType.MULTIMODAL
elif has_image:
return ContentType.VISION
elif has_audio:
return ContentType.AUDIO
return ContentType.TEXT
def _get_fallback_model(
self,
provider: Provider,
original_model: str,
content_type: ContentType
) -> Optional[str]:
"""Get a fallback model for the given content type."""
# Map content type to capability
capability_map = {
ContentType.VISION: "vision",
ContentType.AUDIO: "audio",
ContentType.MULTIMODAL: "vision", # Vision models often do both
}
capability = capability_map.get(content_type)
if not capability:
return None
# Check provider's models for capability
fallback_model = provider.get_model_with_capability(capability)
if fallback_model and fallback_model != original_model:
return fallback_model
# Use fallback chains from config
fallback_chain = self.config.fallback_chains.get(capability, [])
for model_name in fallback_chain:
if provider.model_has_capability(model_name, capability):
return model_name
return None
async def complete(
self,
messages: list[dict],
@@ -235,6 +389,11 @@ class CascadeRouter:
) -> dict:
"""Complete a chat conversation with automatic failover.
Multi-modal support:
- Automatically detects if messages contain images
- Falls back to vision-capable models when needed
- Supports image URLs, paths, and base64 encoding
Args:
messages: List of message dicts with role and content
model: Preferred model (tries this first, then provider defaults)
@@ -247,6 +406,11 @@ class CascadeRouter:
Raises:
RuntimeError: If all providers fail
"""
# Detect content type for multi-modal routing
content_type = self._detect_content_type(messages)
if content_type != ContentType.TEXT:
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
errors = []
for provider in self.providers:
@@ -266,15 +430,48 @@ class CascadeRouter:
logger.debug("Skipping %s (circuit open)", provider.name)
continue
# Determine which model to use
selected_model = model or provider.get_default_model()
is_fallback_model = False
# For non-text content, check if model supports it
if content_type != ContentType.TEXT and selected_model:
if provider.type == "ollama" and self._mm_manager:
from infrastructure.models.multimodal import ModelCapability
# Check if selected model supports the required capability
if content_type == ContentType.VISION:
supports = self._mm_manager.model_supports(
selected_model, ModelCapability.VISION
)
if not supports:
# Find fallback model
fallback = self._get_fallback_model(
provider, selected_model, content_type
)
if fallback:
logger.info(
"Model %s doesn't support vision, falling back to %s",
selected_model, fallback
)
selected_model = fallback
is_fallback_model = True
else:
logger.warning(
"No vision-capable model found on %s, trying anyway",
provider.name
)
# Try this provider
for attempt in range(self.config.max_retries_per_provider):
try:
result = await self._try_provider(
provider=provider,
messages=messages,
model=model,
model=selected_model,
temperature=temperature,
max_tokens=max_tokens,
content_type=content_type,
)
# Success! Update metrics and return
@@ -282,8 +479,9 @@ class CascadeRouter:
return {
"content": result["content"],
"provider": provider.name,
"model": result.get("model", model or provider.get_default_model()),
"model": result.get("model", selected_model or provider.get_default_model()),
"latency_ms": result.get("latency_ms", 0),
"is_fallback_model": is_fallback_model,
}
except Exception as exc:
@@ -307,9 +505,10 @@ class CascadeRouter:
self,
provider: Provider,
messages: list[dict],
model: Optional[str],
model: str,
temperature: float,
max_tokens: Optional[int],
content_type: ContentType = ContentType.TEXT,
) -> dict:
"""Try a single provider request."""
start_time = time.time()
@@ -320,6 +519,7 @@ class CascadeRouter:
messages=messages,
model=model or provider.get_default_model(),
temperature=temperature,
content_type=content_type,
)
elif provider.type == "openai":
result = await self._call_openai(
@@ -359,15 +559,19 @@ class CascadeRouter:
messages: list[dict],
model: str,
temperature: float,
content_type: ContentType = ContentType.TEXT,
) -> dict:
"""Call Ollama API."""
"""Call Ollama API with multi-modal support."""
import aiohttp
url = f"{provider.url}/api/chat"
# Transform messages for Ollama format (including images)
transformed_messages = self._transform_messages_for_ollama(messages)
payload = {
"model": model,
"messages": messages,
"messages": transformed_messages,
"stream": False,
"options": {
"temperature": temperature,
@@ -388,6 +592,41 @@ class CascadeRouter:
"model": model,
}
def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]:
"""Transform messages to Ollama format, handling images."""
transformed = []
for msg in messages:
new_msg = {
"role": msg.get("role", "user"),
"content": msg.get("content", ""),
}
# Handle images
images = msg.get("images", [])
if images:
new_msg["images"] = []
for img in images:
if isinstance(img, str):
if img.startswith("data:image/"):
# Base64 encoded image
new_msg["images"].append(img.split(",")[1])
elif img.startswith("http://") or img.startswith("https://"):
# URL - would need to download, skip for now
logger.warning("Image URLs not yet supported, skipping: %s", img)
elif Path(img).exists():
# Local file path - read and encode
try:
with open(img, "rb") as f:
img_data = base64.b64encode(f.read()).decode()
new_msg["images"].append(img_data)
except Exception as exc:
logger.error("Failed to read image %s: %s", img, exc)
transformed.append(new_msg)
return transformed
async def _call_openai(
self,
provider: Provider,
@@ -496,7 +735,7 @@ class CascadeRouter:
"content": response.choices[0].message.content,
"model": response.model,
}
def _record_success(self, provider: Provider, latency_ms: float) -> None:
"""Record a successful request."""
provider.metrics.total_requests += 1
@@ -598,6 +837,35 @@ class CascadeRouter:
for p in self.providers
],
}
async def generate_with_image(
self,
prompt: str,
image_path: str,
model: Optional[str] = None,
temperature: float = 0.7,
) -> dict:
"""Convenience method for vision requests.
Args:
prompt: Text prompt about the image
image_path: Path to image file
model: Vision-capable model (auto-selected if not provided)
temperature: Sampling temperature
Returns:
Response dict with content and metadata
"""
messages = [{
"role": "user",
"content": prompt,
"images": [image_path],
}]
return await self.complete(
messages=messages,
model=model,
temperature=temperature,
)
# Module-level singleton

View File

@@ -5,11 +5,16 @@ Memory Architecture:
- Tier 2 (Vault): memory/ — structured markdown, append-only
- Tier 3 (Semantic): Vector search over vault files
Model Management:
- Pulls requested model automatically if not available
- Falls back through capability-based model chains
- Multi-modal support with vision model fallbacks
Handoff Protocol maintains continuity across sessions.
"""
import logging
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Optional, Union
from agno.agent import Agent
from agno.db.sqlite import SqliteDb
@@ -24,6 +29,23 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Fallback chain for text/tool models (in order of preference)
DEFAULT_MODEL_FALLBACKS = [
"llama3.1:8b-instruct",
"llama3.1",
"qwen2.5:14b",
"qwen2.5:7b",
"llama3.2:3b",
]
# Fallback chain for vision models
VISION_MODEL_FALLBACKS = [
"llama3.2:3b",
"llava:7b",
"qwen2.5-vl:3b",
"moondream:1.8b",
]
# Union type for callers that want to hint the return type.
TimmyAgent = Union[Agent, "TimmyAirLLMAgent", "GrokBackend"]
@@ -40,6 +62,120 @@ _SMALL_MODEL_PATTERNS = (
)
def _check_model_available(model_name: str) -> bool:
"""Check if an Ollama model is available locally."""
try:
import urllib.request
import json
url = settings.ollama_url.replace("localhost", "127.0.0.1")
req = urllib.request.Request(
f"{url}/api/tags",
method="GET",
headers={"Accept": "application/json"},
)
with urllib.request.urlopen(req, timeout=5) as response:
data = json.loads(response.read().decode())
models = [m.get("name", "") for m in data.get("models", [])]
# Check for exact match or model name without tag
return any(
model_name == m or model_name == m.split(":")[0] or m.startswith(model_name)
for m in models
)
except Exception as exc:
logger.debug("Could not check model availability: %s", exc)
return False
def _pull_model(model_name: str) -> bool:
"""Attempt to pull a model from Ollama.
Returns:
True if successful or model already exists
"""
try:
import urllib.request
import json
logger.info("Pulling model: %s", model_name)
url = settings.ollama_url.replace("localhost", "127.0.0.1")
req = urllib.request.Request(
f"{url}/api/pull",
method="POST",
headers={"Content-Type": "application/json"},
data=json.dumps({"name": model_name, "stream": False}).encode(),
)
with urllib.request.urlopen(req, timeout=300) as response:
if response.status == 200:
logger.info("Successfully pulled model: %s", model_name)
return True
else:
logger.error("Failed to pull %s: HTTP %s", model_name, response.status)
return False
except Exception as exc:
logger.error("Error pulling model %s: %s", model_name, exc)
return False
def _resolve_model_with_fallback(
requested_model: Optional[str] = None,
require_vision: bool = False,
auto_pull: bool = True,
) -> tuple[str, bool]:
"""Resolve model with automatic pulling and fallback.
Args:
requested_model: Preferred model to use
require_vision: Whether the model needs vision capabilities
auto_pull: Whether to attempt pulling missing models
Returns:
Tuple of (model_name, is_fallback)
"""
model = requested_model or settings.ollama_model
# Check if requested model is available
if _check_model_available(model):
logger.debug("Using available model: %s", model)
return model, False
# Try to pull the requested model
if auto_pull:
logger.info("Model %s not available locally, attempting to pull...", model)
if _pull_model(model):
return model, False
logger.warning("Failed to pull %s, checking fallbacks...", model)
# Use appropriate fallback chain
fallback_chain = VISION_MODEL_FALLBACKS if require_vision else DEFAULT_MODEL_FALLBACKS
for fallback_model in fallback_chain:
if _check_model_available(fallback_model):
logger.warning(
"Using fallback model %s (requested: %s)",
fallback_model, model
)
return fallback_model, True
# Try to pull the fallback
if auto_pull and _pull_model(fallback_model):
logger.info(
"Pulled and using fallback model %s (requested: %s)",
fallback_model, model
)
return fallback_model, True
# Absolute last resort - return the requested model and hope for the best
logger.error(
"No models available in fallback chain. Requested: %s",
model
)
return model, False
def _model_supports_tools(model_name: str) -> bool:
"""Check if the configured model can reliably handle tool calling.
@@ -106,7 +242,16 @@ def create_timmy(
return TimmyAirLLMAgent(model_size=size)
# Default: Ollama via Agno.
model_name = settings.ollama_model
# Resolve model with automatic pulling and fallback
model_name, is_fallback = _resolve_model_with_fallback(
requested_model=None,
require_vision=False,
auto_pull=True,
)
if is_fallback:
logger.info("Using fallback model %s (requested was unavailable)", model_name)
use_tools = _model_supports_tools(model_name)
# Conditionally include tools — small models get none

View File

@@ -31,7 +31,7 @@ from timmy.agent_core.interface import (
TimAgent,
AgentEffect,
)
from timmy.agent import create_timmy
from timmy.agent import create_timmy, _resolve_model_with_fallback
class OllamaAgent(TimAgent):
@@ -53,18 +53,33 @@ class OllamaAgent(TimAgent):
identity: AgentIdentity,
model: Optional[str] = None,
effect_log: Optional[str] = None,
require_vision: bool = False,
) -> None:
"""Initialize Ollama-based agent.
Args:
identity: Agent identity (persistent across sessions)
model: Ollama model to use (default from config)
model: Ollama model to use (auto-resolves with fallback)
effect_log: Path to log agent effects (optional)
require_vision: Whether to select a vision-capable model
"""
super().__init__(identity)
# Resolve model with automatic pulling and fallback
resolved_model, is_fallback = _resolve_model_with_fallback(
requested_model=model,
require_vision=require_vision,
auto_pull=True,
)
if is_fallback:
import logging
logging.getLogger(__name__).info(
"OllamaAdapter using fallback model %s", resolved_model
)
# Initialize underlying Ollama agent
self._timmy = create_timmy(model=model)
self._timmy = create_timmy(model=resolved_model)
# Set capabilities based on what Ollama can do
self._capabilities = {