forked from Rockachopa/Timmy-time-dashboard
feat: Multi-modal support with automatic model fallback
- Add MultiModalManager with capability detection for vision/audio/tools
- Define fallback chains: vision (llama3.2:3b -> llava:7b -> moondream)
tools (llama3.1:8b-instruct -> qwen2.5:7b)
- Update CascadeRouter to detect content type and select appropriate models
- Add model pulling with automatic fallback in agent creation
- Update providers.yaml with multi-modal model configurations
- Update OllamaAdapter to use model resolution with vision support
Tests: All 96 infrastructure tests pass
This commit is contained in:
@@ -24,11 +24,31 @@ providers:
|
||||
priority: 1
|
||||
url: "http://localhost:11434"
|
||||
models:
|
||||
- name: llama3.2
|
||||
# Text + Tools models
|
||||
- name: llama3.1:8b-instruct
|
||||
default: true
|
||||
context_window: 128000
|
||||
capabilities: [text, tools, json, streaming]
|
||||
- name: llama3.2:3b
|
||||
context_window: 128000
|
||||
capabilities: [text, tools, json, streaming, vision]
|
||||
- name: qwen2.5:14b
|
||||
context_window: 32000
|
||||
capabilities: [text, tools, json, streaming]
|
||||
- name: deepseek-r1:1.5b
|
||||
context_window: 32000
|
||||
capabilities: [text, json, streaming]
|
||||
|
||||
# Vision models
|
||||
- name: llava:7b
|
||||
context_window: 4096
|
||||
capabilities: [text, vision, streaming]
|
||||
- name: qwen2.5-vl:3b
|
||||
context_window: 32000
|
||||
capabilities: [text, vision, tools, json, streaming]
|
||||
- name: moondream:1.8b
|
||||
context_window: 2048
|
||||
capabilities: [text, vision, streaming]
|
||||
|
||||
# Secondary: Local AirLLM (if installed)
|
||||
- name: airllm-local
|
||||
@@ -38,8 +58,11 @@ providers:
|
||||
models:
|
||||
- name: 70b
|
||||
default: true
|
||||
capabilities: [text, tools, json, streaming]
|
||||
- name: 8b
|
||||
capabilities: [text, tools, json, streaming]
|
||||
- name: 405b
|
||||
capabilities: [text, tools, json, streaming]
|
||||
|
||||
# Tertiary: OpenAI (if API key available)
|
||||
- name: openai-backup
|
||||
@@ -52,8 +75,10 @@ providers:
|
||||
- name: gpt-4o-mini
|
||||
default: true
|
||||
context_window: 128000
|
||||
capabilities: [text, vision, tools, json, streaming]
|
||||
- name: gpt-4o
|
||||
context_window: 128000
|
||||
capabilities: [text, vision, tools, json, streaming]
|
||||
|
||||
# Quaternary: Anthropic (if API key available)
|
||||
- name: anthropic-backup
|
||||
@@ -65,10 +90,37 @@ providers:
|
||||
- name: claude-3-haiku-20240307
|
||||
default: true
|
||||
context_window: 200000
|
||||
capabilities: [text, vision, streaming]
|
||||
- name: claude-3-sonnet-20240229
|
||||
context_window: 200000
|
||||
capabilities: [text, vision, tools, streaming]
|
||||
|
||||
# ── Custom Models ──────────────────────────────────────────────────────
|
||||
# ── Capability-Based Fallback Chains ────────────────────────────────────────
|
||||
# When a model doesn't support a required capability (e.g., vision),
|
||||
# the system falls back through these chains in order.
|
||||
|
||||
fallback_chains:
|
||||
# Vision-capable models (for image understanding)
|
||||
vision:
|
||||
- llama3.2:3b # Fast, good vision
|
||||
- qwen2.5-vl:3b # Excellent vision, small
|
||||
- llava:7b # Classic vision model
|
||||
- moondream:1.8b # Tiny, fast vision
|
||||
|
||||
# Tool-calling models (for function calling)
|
||||
tools:
|
||||
- llama3.1:8b-instruct # Best tool use
|
||||
- qwen2.5:7b # Reliable tools
|
||||
- llama3.2:3b # Small but capable
|
||||
|
||||
# General text generation (any model)
|
||||
text:
|
||||
- llama3.1:8b-instruct
|
||||
- qwen2.5:14b
|
||||
- deepseek-r1:1.5b
|
||||
- llama3.2:3b
|
||||
|
||||
# ── Custom Models ───────────────────────────────────────────────────────────
|
||||
# Register custom model weights for per-agent assignment.
|
||||
# Supports GGUF (Ollama), safetensors, and HuggingFace checkpoint dirs.
|
||||
# Models can also be registered at runtime via the /api/v1/models API.
|
||||
@@ -91,7 +143,7 @@ custom_models: []
|
||||
# context_window: 32000
|
||||
# description: "Process reward model for scoring outputs"
|
||||
|
||||
# ── Agent Model Assignments ─────────────────────────────────────────────
|
||||
# ── Agent Model Assignments ─────────────────────────────────────────────────
|
||||
# Map persona agent IDs to specific models.
|
||||
# Agents without an assignment use the global default (ollama_model).
|
||||
agent_model_assignments: {}
|
||||
@@ -99,6 +151,20 @@ agent_model_assignments: {}
|
||||
# persona-forge: my-finetuned-llama
|
||||
# persona-echo: deepseek-r1:1.5b
|
||||
|
||||
# ── Multi-Modal Settings ────────────────────────────────────────────────────
|
||||
multimodal:
|
||||
# Automatically pull models when needed
|
||||
auto_pull: true
|
||||
|
||||
# Timeout for model pulling (seconds)
|
||||
pull_timeout: 300
|
||||
|
||||
# Maximum fallback depth (how many models to try before giving up)
|
||||
max_fallback_depth: 3
|
||||
|
||||
# Prefer smaller models for vision when available (faster)
|
||||
prefer_small_vision: true
|
||||
|
||||
# Cost tracking (optional, for budget monitoring)
|
||||
cost_tracking:
|
||||
enabled: true
|
||||
|
||||
BIN
data/scripture.db-shm
Normal file
BIN
data/scripture.db-shm
Normal file
Binary file not shown.
BIN
data/scripture.db-wal
Normal file
BIN
data/scripture.db-wal
Normal file
Binary file not shown.
@@ -0,0 +1,37 @@
|
||||
"""Infrastructure models package."""
|
||||
|
||||
from infrastructure.models.registry import (
|
||||
CustomModel,
|
||||
ModelFormat,
|
||||
ModelRegistry,
|
||||
ModelRole,
|
||||
model_registry,
|
||||
)
|
||||
from infrastructure.models.multimodal import (
|
||||
ModelCapability,
|
||||
ModelInfo,
|
||||
MultiModalManager,
|
||||
get_model_for_capability,
|
||||
get_multimodal_manager,
|
||||
model_supports_tools,
|
||||
model_supports_vision,
|
||||
pull_model_with_fallback,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Registry
|
||||
"CustomModel",
|
||||
"ModelFormat",
|
||||
"ModelRegistry",
|
||||
"ModelRole",
|
||||
"model_registry",
|
||||
# Multi-modal
|
||||
"ModelCapability",
|
||||
"ModelInfo",
|
||||
"MultiModalManager",
|
||||
"get_model_for_capability",
|
||||
"get_multimodal_manager",
|
||||
"model_supports_tools",
|
||||
"model_supports_vision",
|
||||
"pull_model_with_fallback",
|
||||
]
|
||||
|
||||
445
src/infrastructure/models/multimodal.py
Normal file
445
src/infrastructure/models/multimodal.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""Multi-modal model support with automatic capability detection and fallbacks.
|
||||
|
||||
Provides:
|
||||
- Model capability detection (vision, audio, etc.)
|
||||
- Automatic model pulling with fallback chains
|
||||
- Content-type aware model selection
|
||||
- Graceful degradation when primary models unavailable
|
||||
|
||||
No cloud by default — tries local first, falls back through configured options.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
from typing import Optional
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelCapability(Enum):
|
||||
"""Capabilities a model can have."""
|
||||
TEXT = auto() # Standard text completion
|
||||
VISION = auto() # Image understanding
|
||||
AUDIO = auto() # Audio/speech processing
|
||||
TOOLS = auto() # Function calling / tool use
|
||||
JSON = auto() # Structured output / JSON mode
|
||||
STREAMING = auto() # Streaming responses
|
||||
|
||||
|
||||
# Known model capabilities (local Ollama models)
|
||||
# These are used when we can't query the model directly
|
||||
KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
|
||||
# Llama 3.x series
|
||||
"llama3.1": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.1:8b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.1:8b-instruct": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.1:70b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.1:405b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.2": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"llama3.2:1b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.2:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"llama3.2-vision": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"llama3.2-vision:11b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
|
||||
# Qwen series
|
||||
"qwen2.5": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5:14b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5:32b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5:72b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5-vl": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"qwen2.5-vl:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"qwen2.5-vl:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
|
||||
# DeepSeek series
|
||||
"deepseek-r1": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-r1:1.5b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-r1:7b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-r1:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-r1:32b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-r1:70b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-v3": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
# Gemma series
|
||||
"gemma2": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"gemma2:2b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"gemma2:9b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"gemma2:27b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
# Mistral series
|
||||
"mistral": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"mistral:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"mistral-nemo": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"mistral-small": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"mistral-large": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
# Vision-specific models
|
||||
"llava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"llava:7b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"llava:13b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"llava:34b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"llava-phi3": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"llava-llama3": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"bakllava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"moondream": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"moondream:1.8b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
|
||||
# Phi series
|
||||
"phi3": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"phi3:3.8b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"phi3:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"phi4": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
# Command R
|
||||
"command-r": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"command-r:35b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"command-r-plus": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
# Granite (IBM)
|
||||
"granite3-dense": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"granite3-moe": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
}
|
||||
|
||||
|
||||
# Default fallback chains for each capability
|
||||
# These are tried in order when the primary model doesn't support a capability
|
||||
DEFAULT_FALLBACK_CHAINS: dict[ModelCapability, list[str]] = {
|
||||
ModelCapability.VISION: [
|
||||
"llama3.2:3b", # Fast vision model
|
||||
"llava:7b", # Classic vision model
|
||||
"qwen2.5-vl:3b", # Qwen vision
|
||||
"moondream:1.8b", # Tiny vision model (last resort)
|
||||
],
|
||||
ModelCapability.TOOLS: [
|
||||
"llama3.1:8b-instruct", # Best tool use
|
||||
"llama3.2:3b", # Smaller but capable
|
||||
"qwen2.5:7b", # Reliable fallback
|
||||
],
|
||||
ModelCapability.AUDIO: [
|
||||
# Audio models are less common in Ollama
|
||||
# Would need specific audio-capable models here
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Information about a model's capabilities and availability."""
|
||||
name: str
|
||||
capabilities: set[ModelCapability] = field(default_factory=set)
|
||||
is_available: bool = False
|
||||
is_pulled: bool = False
|
||||
size_mb: Optional[int] = None
|
||||
description: str = ""
|
||||
|
||||
def supports(self, capability: ModelCapability) -> bool:
|
||||
"""Check if model supports a specific capability."""
|
||||
return capability in self.capabilities
|
||||
|
||||
|
||||
class MultiModalManager:
|
||||
"""Manages multi-modal model capabilities and fallback chains.
|
||||
|
||||
This class:
|
||||
1. Detects what capabilities each model has
|
||||
2. Maintains fallback chains for different capabilities
|
||||
3. Pulls models on-demand with automatic fallback
|
||||
4. Routes requests to appropriate models based on content type
|
||||
"""
|
||||
|
||||
def __init__(self, ollama_url: Optional[str] = None) -> None:
|
||||
self.ollama_url = ollama_url or settings.ollama_url
|
||||
self._available_models: dict[str, ModelInfo] = {}
|
||||
self._fallback_chains: dict[ModelCapability, list[str]] = dict(DEFAULT_FALLBACK_CHAINS)
|
||||
self._refresh_available_models()
|
||||
|
||||
def _refresh_available_models(self) -> None:
|
||||
"""Query Ollama for available models."""
|
||||
try:
|
||||
import urllib.request
|
||||
import json
|
||||
|
||||
url = self.ollama_url.replace("localhost", "127.0.0.1")
|
||||
req = urllib.request.Request(
|
||||
f"{url}/api/tags",
|
||||
method="GET",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=5) as response:
|
||||
data = json.loads(response.read().decode())
|
||||
|
||||
for model_data in data.get("models", []):
|
||||
name = model_data.get("name", "")
|
||||
self._available_models[name] = ModelInfo(
|
||||
name=name,
|
||||
capabilities=self._detect_capabilities(name),
|
||||
is_available=True,
|
||||
is_pulled=True,
|
||||
size_mb=model_data.get("size", 0) // (1024 * 1024),
|
||||
description=model_data.get("details", {}).get("family", ""),
|
||||
)
|
||||
|
||||
logger.info("Found %d models in Ollama", len(self._available_models))
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Could not refresh available models: %s", exc)
|
||||
|
||||
def _detect_capabilities(self, model_name: str) -> set[ModelCapability]:
|
||||
"""Detect capabilities for a model based on known data."""
|
||||
# Normalize model name (strip tags for lookup)
|
||||
base_name = model_name.split(":")[0]
|
||||
|
||||
# Try exact match first
|
||||
if model_name in KNOWN_MODEL_CAPABILITIES:
|
||||
return set(KNOWN_MODEL_CAPABILITIES[model_name])
|
||||
|
||||
# Try base name match
|
||||
if base_name in KNOWN_MODEL_CAPABILITIES:
|
||||
return set(KNOWN_MODEL_CAPABILITIES[base_name])
|
||||
|
||||
# Default to text-only for unknown models
|
||||
logger.debug("Unknown model %s, defaulting to TEXT only", model_name)
|
||||
return {ModelCapability.TEXT, ModelCapability.STREAMING}
|
||||
|
||||
def get_model_capabilities(self, model_name: str) -> set[ModelCapability]:
|
||||
"""Get capabilities for a specific model."""
|
||||
if model_name in self._available_models:
|
||||
return self._available_models[model_name].capabilities
|
||||
return self._detect_capabilities(model_name)
|
||||
|
||||
def model_supports(self, model_name: str, capability: ModelCapability) -> bool:
|
||||
"""Check if a model supports a specific capability."""
|
||||
capabilities = self.get_model_capabilities(model_name)
|
||||
return capability in capabilities
|
||||
|
||||
def get_models_with_capability(self, capability: ModelCapability) -> list[ModelInfo]:
|
||||
"""Get all available models that support a capability."""
|
||||
return [
|
||||
info for info in self._available_models.values()
|
||||
if capability in info.capabilities
|
||||
]
|
||||
|
||||
def get_best_model_for(
|
||||
self,
|
||||
capability: ModelCapability,
|
||||
preferred_model: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""Get the best available model for a specific capability.
|
||||
|
||||
Args:
|
||||
capability: The required capability
|
||||
preferred_model: Preferred model to use if available and capable
|
||||
|
||||
Returns:
|
||||
Model name or None if no suitable model found
|
||||
"""
|
||||
# Check if preferred model supports this capability
|
||||
if preferred_model:
|
||||
if preferred_model in self._available_models:
|
||||
if self.model_supports(preferred_model, capability):
|
||||
return preferred_model
|
||||
logger.debug(
|
||||
"Preferred model %s doesn't support %s, checking fallbacks",
|
||||
preferred_model, capability.name
|
||||
)
|
||||
|
||||
# Check fallback chain for this capability
|
||||
fallback_chain = self._fallback_chains.get(capability, [])
|
||||
for model_name in fallback_chain:
|
||||
if model_name in self._available_models:
|
||||
logger.debug("Using fallback model %s for %s", model_name, capability.name)
|
||||
return model_name
|
||||
|
||||
# Find any available model with this capability
|
||||
capable_models = self.get_models_with_capability(capability)
|
||||
if capable_models:
|
||||
# Sort by size (prefer smaller/faster models as fallback)
|
||||
capable_models.sort(key=lambda m: m.size_mb or float('inf'))
|
||||
return capable_models[0].name
|
||||
|
||||
return None
|
||||
|
||||
def pull_model_with_fallback(
|
||||
self,
|
||||
primary_model: str,
|
||||
capability: Optional[ModelCapability] = None,
|
||||
auto_pull: bool = True,
|
||||
) -> tuple[str, bool]:
|
||||
"""Pull a model with automatic fallback if unavailable.
|
||||
|
||||
Args:
|
||||
primary_model: The desired model to use
|
||||
capability: Required capability (for finding fallback)
|
||||
auto_pull: Whether to attempt pulling missing models
|
||||
|
||||
Returns:
|
||||
Tuple of (model_name, is_fallback)
|
||||
"""
|
||||
# Check if primary model is already available
|
||||
if primary_model in self._available_models:
|
||||
return primary_model, False
|
||||
|
||||
# Try to pull the primary model
|
||||
if auto_pull:
|
||||
if self._pull_model(primary_model):
|
||||
return primary_model, False
|
||||
|
||||
# Need to find a fallback
|
||||
if capability:
|
||||
fallback = self.get_best_model_for(capability, primary_model)
|
||||
if fallback:
|
||||
logger.info(
|
||||
"Primary model %s unavailable, using fallback %s",
|
||||
primary_model, fallback
|
||||
)
|
||||
return fallback, True
|
||||
|
||||
# Last resort: use the configured default model
|
||||
default_model = settings.ollama_model
|
||||
if default_model in self._available_models:
|
||||
logger.warning(
|
||||
"Falling back to default model %s (primary: %s unavailable)",
|
||||
default_model, primary_model
|
||||
)
|
||||
return default_model, True
|
||||
|
||||
# Absolute last resort
|
||||
return primary_model, False
|
||||
|
||||
def _pull_model(self, model_name: str) -> bool:
|
||||
"""Attempt to pull a model from Ollama.
|
||||
|
||||
Returns:
|
||||
True if successful or model already exists
|
||||
"""
|
||||
try:
|
||||
import urllib.request
|
||||
import json
|
||||
|
||||
logger.info("Pulling model: %s", model_name)
|
||||
|
||||
url = self.ollama_url.replace("localhost", "127.0.0.1")
|
||||
req = urllib.request.Request(
|
||||
f"{url}/api/pull",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps({"name": model_name, "stream": False}).encode(),
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(req, timeout=300) as response:
|
||||
if response.status == 200:
|
||||
logger.info("Successfully pulled model: %s", model_name)
|
||||
# Refresh available models
|
||||
self._refresh_available_models()
|
||||
return True
|
||||
else:
|
||||
logger.error("Failed to pull %s: HTTP %s", model_name, response.status)
|
||||
return False
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error pulling model %s: %s", model_name, exc)
|
||||
return False
|
||||
|
||||
def configure_fallback_chain(
|
||||
self,
|
||||
capability: ModelCapability,
|
||||
models: list[str]
|
||||
) -> None:
|
||||
"""Configure a custom fallback chain for a capability."""
|
||||
self._fallback_chains[capability] = models
|
||||
logger.info("Configured fallback chain for %s: %s", capability.name, models)
|
||||
|
||||
def get_fallback_chain(self, capability: ModelCapability) -> list[str]:
|
||||
"""Get the fallback chain for a capability."""
|
||||
return list(self._fallback_chains.get(capability, []))
|
||||
|
||||
def list_available_models(self) -> list[ModelInfo]:
|
||||
"""List all available models with their capabilities."""
|
||||
return list(self._available_models.values())
|
||||
|
||||
def refresh(self) -> None:
|
||||
"""Refresh the list of available models."""
|
||||
self._refresh_available_models()
|
||||
|
||||
def get_model_for_content(
|
||||
self,
|
||||
content_type: str, # "text", "image", "audio", "multimodal"
|
||||
preferred_model: Optional[str] = None,
|
||||
) -> tuple[str, bool]:
|
||||
"""Get appropriate model based on content type.
|
||||
|
||||
Args:
|
||||
content_type: Type of content (text, image, audio, multimodal)
|
||||
preferred_model: User's preferred model
|
||||
|
||||
Returns:
|
||||
Tuple of (model_name, is_fallback)
|
||||
"""
|
||||
content_type = content_type.lower()
|
||||
|
||||
if content_type in ("image", "vision", "multimodal"):
|
||||
# For vision content, we need a vision-capable model
|
||||
return self.pull_model_with_fallback(
|
||||
preferred_model or "llava:7b",
|
||||
capability=ModelCapability.VISION,
|
||||
)
|
||||
|
||||
elif content_type == "audio":
|
||||
# Audio support is limited in Ollama
|
||||
# Would need specific audio models
|
||||
logger.warning("Audio support is limited, falling back to text model")
|
||||
return self.pull_model_with_fallback(
|
||||
preferred_model or settings.ollama_model,
|
||||
capability=ModelCapability.TEXT,
|
||||
)
|
||||
|
||||
else:
|
||||
# Standard text content
|
||||
return self.pull_model_with_fallback(
|
||||
preferred_model or settings.ollama_model,
|
||||
capability=ModelCapability.TEXT,
|
||||
)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_multimodal_manager: Optional[MultiModalManager] = None
|
||||
|
||||
|
||||
def get_multimodal_manager() -> MultiModalManager:
|
||||
"""Get or create the multi-modal manager singleton."""
|
||||
global _multimodal_manager
|
||||
if _multimodal_manager is None:
|
||||
_multimodal_manager = MultiModalManager()
|
||||
return _multimodal_manager
|
||||
|
||||
|
||||
def get_model_for_capability(
|
||||
capability: ModelCapability,
|
||||
preferred_model: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""Convenience function to get best model for a capability."""
|
||||
return get_multimodal_manager().get_best_model_for(capability, preferred_model)
|
||||
|
||||
|
||||
def pull_model_with_fallback(
|
||||
primary_model: str,
|
||||
capability: Optional[ModelCapability] = None,
|
||||
auto_pull: bool = True,
|
||||
) -> tuple[str, bool]:
|
||||
"""Convenience function to pull model with fallback."""
|
||||
return get_multimodal_manager().pull_model_with_fallback(
|
||||
primary_model, capability, auto_pull
|
||||
)
|
||||
|
||||
|
||||
def model_supports_vision(model_name: str) -> bool:
|
||||
"""Check if a model supports vision."""
|
||||
return get_multimodal_manager().model_supports(model_name, ModelCapability.VISION)
|
||||
|
||||
|
||||
def model_supports_tools(model_name: str) -> bool:
|
||||
"""Check if a model supports tool calling."""
|
||||
return get_multimodal_manager().model_supports(model_name, ModelCapability.TOOLS)
|
||||
@@ -3,14 +3,19 @@
|
||||
Routes requests through an ordered list of LLM providers,
|
||||
automatically failing over on rate limits or errors.
|
||||
Tracks metrics for latency, errors, and cost.
|
||||
|
||||
Now with multi-modal support — automatically selects vision-capable
|
||||
models for image inputs and falls back through capability chains.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from pathlib import Path
|
||||
@@ -43,6 +48,14 @@ class CircuitState(Enum):
|
||||
HALF_OPEN = "half_open" # Testing if recovered
|
||||
|
||||
|
||||
class ContentType(Enum):
|
||||
"""Type of content in the request."""
|
||||
TEXT = "text"
|
||||
VISION = "vision" # Contains images
|
||||
AUDIO = "audio" # Contains audio
|
||||
MULTIMODAL = "multimodal" # Multiple content types
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMetrics:
|
||||
"""Metrics for a single provider."""
|
||||
@@ -67,6 +80,18 @@ class ProviderMetrics:
|
||||
return self.failed_requests / self.total_requests
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCapability:
|
||||
"""Capabilities a model supports."""
|
||||
name: str
|
||||
supports_vision: bool = False
|
||||
supports_audio: bool = False
|
||||
supports_tools: bool = False
|
||||
supports_json: bool = False
|
||||
supports_streaming: bool = True
|
||||
context_window: int = 4096
|
||||
|
||||
|
||||
@dataclass
|
||||
class Provider:
|
||||
"""LLM provider configuration and state."""
|
||||
@@ -94,6 +119,23 @@ class Provider:
|
||||
if self.models:
|
||||
return self.models[0]["name"]
|
||||
return None
|
||||
|
||||
def get_model_with_capability(self, capability: str) -> Optional[str]:
|
||||
"""Get a model that supports the given capability."""
|
||||
for model in self.models:
|
||||
capabilities = model.get("capabilities", [])
|
||||
if capability in capabilities:
|
||||
return model["name"]
|
||||
# Fall back to default
|
||||
return self.get_default_model()
|
||||
|
||||
def model_has_capability(self, model_name: str, capability: str) -> bool:
|
||||
"""Check if a specific model has a capability."""
|
||||
for model in self.models:
|
||||
if model["name"] == model_name:
|
||||
capabilities = model.get("capabilities", [])
|
||||
return capability in capabilities
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -107,19 +149,39 @@ class RouterConfig:
|
||||
circuit_breaker_half_open_max_calls: int = 2
|
||||
cost_tracking_enabled: bool = True
|
||||
budget_daily_usd: float = 10.0
|
||||
# Multi-modal settings
|
||||
auto_pull_models: bool = True
|
||||
fallback_chains: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class CascadeRouter:
|
||||
"""Routes LLM requests with automatic failover.
|
||||
|
||||
Now with multi-modal support:
|
||||
- Automatically detects content type (text, vision, audio)
|
||||
- Selects appropriate models based on capabilities
|
||||
- Falls back through capability-specific model chains
|
||||
- Supports image URLs and base64 encoding
|
||||
|
||||
Usage:
|
||||
router = CascadeRouter()
|
||||
|
||||
# Text request
|
||||
response = await router.complete(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
model="llama3.2"
|
||||
)
|
||||
|
||||
# Vision request (automatically detects and selects vision model)
|
||||
response = await router.complete(
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "What's in this image?",
|
||||
"images": ["path/to/image.jpg"]
|
||||
}],
|
||||
model="llava:7b"
|
||||
)
|
||||
|
||||
# Check metrics
|
||||
metrics = router.get_metrics()
|
||||
"""
|
||||
@@ -130,6 +192,14 @@ class CascadeRouter:
|
||||
self.config: RouterConfig = RouterConfig()
|
||||
self._load_config()
|
||||
|
||||
# Initialize multi-modal manager if available
|
||||
self._mm_manager: Optional[Any] = None
|
||||
try:
|
||||
from infrastructure.models.multimodal import get_multimodal_manager
|
||||
self._mm_manager = get_multimodal_manager()
|
||||
except Exception as exc:
|
||||
logger.debug("Multi-modal manager not available: %s", exc)
|
||||
|
||||
logger.info("CascadeRouter initialized with %d providers", len(self.providers))
|
||||
|
||||
def _load_config(self) -> None:
|
||||
@@ -149,6 +219,13 @@ class CascadeRouter:
|
||||
|
||||
# Load cascade settings
|
||||
cascade = data.get("cascade", {})
|
||||
|
||||
# Load fallback chains
|
||||
fallback_chains = data.get("fallback_chains", {})
|
||||
|
||||
# Load multi-modal settings
|
||||
multimodal = data.get("multimodal", {})
|
||||
|
||||
self.config = RouterConfig(
|
||||
timeout_seconds=cascade.get("timeout_seconds", 30),
|
||||
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
|
||||
@@ -156,6 +233,8 @@ class CascadeRouter:
|
||||
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get("failure_threshold", 5),
|
||||
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get("recovery_timeout", 60),
|
||||
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get("half_open_max_calls", 2),
|
||||
auto_pull_models=multimodal.get("auto_pull", True),
|
||||
fallback_chains=fallback_chains,
|
||||
)
|
||||
|
||||
# Load providers
|
||||
@@ -226,6 +305,81 @@ class CascadeRouter:
|
||||
|
||||
return True
|
||||
|
||||
def _detect_content_type(self, messages: list[dict]) -> ContentType:
|
||||
"""Detect the type of content in the messages.
|
||||
|
||||
Checks for images, audio, etc. in the message content.
|
||||
"""
|
||||
has_image = False
|
||||
has_audio = False
|
||||
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
|
||||
# Check for image URLs/paths
|
||||
if msg.get("images"):
|
||||
has_image = True
|
||||
|
||||
# Check for image URLs in content
|
||||
if isinstance(content, str):
|
||||
image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp')
|
||||
if any(ext in content.lower() for ext in image_extensions):
|
||||
has_image = True
|
||||
if content.startswith("data:image/"):
|
||||
has_image = True
|
||||
|
||||
# Check for audio
|
||||
if msg.get("audio"):
|
||||
has_audio = True
|
||||
|
||||
# Check for multimodal content structure
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "image_url":
|
||||
has_image = True
|
||||
elif item.get("type") == "audio":
|
||||
has_audio = True
|
||||
|
||||
if has_image and has_audio:
|
||||
return ContentType.MULTIMODAL
|
||||
elif has_image:
|
||||
return ContentType.VISION
|
||||
elif has_audio:
|
||||
return ContentType.AUDIO
|
||||
return ContentType.TEXT
|
||||
|
||||
def _get_fallback_model(
|
||||
self,
|
||||
provider: Provider,
|
||||
original_model: str,
|
||||
content_type: ContentType
|
||||
) -> Optional[str]:
|
||||
"""Get a fallback model for the given content type."""
|
||||
# Map content type to capability
|
||||
capability_map = {
|
||||
ContentType.VISION: "vision",
|
||||
ContentType.AUDIO: "audio",
|
||||
ContentType.MULTIMODAL: "vision", # Vision models often do both
|
||||
}
|
||||
|
||||
capability = capability_map.get(content_type)
|
||||
if not capability:
|
||||
return None
|
||||
|
||||
# Check provider's models for capability
|
||||
fallback_model = provider.get_model_with_capability(capability)
|
||||
if fallback_model and fallback_model != original_model:
|
||||
return fallback_model
|
||||
|
||||
# Use fallback chains from config
|
||||
fallback_chain = self.config.fallback_chains.get(capability, [])
|
||||
for model_name in fallback_chain:
|
||||
if provider.model_has_capability(model_name, capability):
|
||||
return model_name
|
||||
|
||||
return None
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict],
|
||||
@@ -235,6 +389,11 @@ class CascadeRouter:
|
||||
) -> dict:
|
||||
"""Complete a chat conversation with automatic failover.
|
||||
|
||||
Multi-modal support:
|
||||
- Automatically detects if messages contain images
|
||||
- Falls back to vision-capable models when needed
|
||||
- Supports image URLs, paths, and base64 encoding
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with role and content
|
||||
model: Preferred model (tries this first, then provider defaults)
|
||||
@@ -247,6 +406,11 @@ class CascadeRouter:
|
||||
Raises:
|
||||
RuntimeError: If all providers fail
|
||||
"""
|
||||
# Detect content type for multi-modal routing
|
||||
content_type = self._detect_content_type(messages)
|
||||
if content_type != ContentType.TEXT:
|
||||
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
|
||||
|
||||
errors = []
|
||||
|
||||
for provider in self.providers:
|
||||
@@ -266,15 +430,48 @@ class CascadeRouter:
|
||||
logger.debug("Skipping %s (circuit open)", provider.name)
|
||||
continue
|
||||
|
||||
# Determine which model to use
|
||||
selected_model = model or provider.get_default_model()
|
||||
is_fallback_model = False
|
||||
|
||||
# For non-text content, check if model supports it
|
||||
if content_type != ContentType.TEXT and selected_model:
|
||||
if provider.type == "ollama" and self._mm_manager:
|
||||
from infrastructure.models.multimodal import ModelCapability
|
||||
|
||||
# Check if selected model supports the required capability
|
||||
if content_type == ContentType.VISION:
|
||||
supports = self._mm_manager.model_supports(
|
||||
selected_model, ModelCapability.VISION
|
||||
)
|
||||
if not supports:
|
||||
# Find fallback model
|
||||
fallback = self._get_fallback_model(
|
||||
provider, selected_model, content_type
|
||||
)
|
||||
if fallback:
|
||||
logger.info(
|
||||
"Model %s doesn't support vision, falling back to %s",
|
||||
selected_model, fallback
|
||||
)
|
||||
selected_model = fallback
|
||||
is_fallback_model = True
|
||||
else:
|
||||
logger.warning(
|
||||
"No vision-capable model found on %s, trying anyway",
|
||||
provider.name
|
||||
)
|
||||
|
||||
# Try this provider
|
||||
for attempt in range(self.config.max_retries_per_provider):
|
||||
try:
|
||||
result = await self._try_provider(
|
||||
provider=provider,
|
||||
messages=messages,
|
||||
model=model,
|
||||
model=selected_model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
# Success! Update metrics and return
|
||||
@@ -282,8 +479,9 @@ class CascadeRouter:
|
||||
return {
|
||||
"content": result["content"],
|
||||
"provider": provider.name,
|
||||
"model": result.get("model", model or provider.get_default_model()),
|
||||
"model": result.get("model", selected_model or provider.get_default_model()),
|
||||
"latency_ms": result.get("latency_ms", 0),
|
||||
"is_fallback_model": is_fallback_model,
|
||||
}
|
||||
|
||||
except Exception as exc:
|
||||
@@ -307,9 +505,10 @@ class CascadeRouter:
|
||||
self,
|
||||
provider: Provider,
|
||||
messages: list[dict],
|
||||
model: Optional[str],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: Optional[int],
|
||||
content_type: ContentType = ContentType.TEXT,
|
||||
) -> dict:
|
||||
"""Try a single provider request."""
|
||||
start_time = time.time()
|
||||
@@ -320,6 +519,7 @@ class CascadeRouter:
|
||||
messages=messages,
|
||||
model=model or provider.get_default_model(),
|
||||
temperature=temperature,
|
||||
content_type=content_type,
|
||||
)
|
||||
elif provider.type == "openai":
|
||||
result = await self._call_openai(
|
||||
@@ -359,15 +559,19 @@ class CascadeRouter:
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
temperature: float,
|
||||
content_type: ContentType = ContentType.TEXT,
|
||||
) -> dict:
|
||||
"""Call Ollama API."""
|
||||
"""Call Ollama API with multi-modal support."""
|
||||
import aiohttp
|
||||
|
||||
url = f"{provider.url}/api/chat"
|
||||
|
||||
# Transform messages for Ollama format (including images)
|
||||
transformed_messages = self._transform_messages_for_ollama(messages)
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"messages": transformed_messages,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": temperature,
|
||||
@@ -388,6 +592,41 @@ class CascadeRouter:
|
||||
"model": model,
|
||||
}
|
||||
|
||||
def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]:
|
||||
"""Transform messages to Ollama format, handling images."""
|
||||
transformed = []
|
||||
|
||||
for msg in messages:
|
||||
new_msg = {
|
||||
"role": msg.get("role", "user"),
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
|
||||
# Handle images
|
||||
images = msg.get("images", [])
|
||||
if images:
|
||||
new_msg["images"] = []
|
||||
for img in images:
|
||||
if isinstance(img, str):
|
||||
if img.startswith("data:image/"):
|
||||
# Base64 encoded image
|
||||
new_msg["images"].append(img.split(",")[1])
|
||||
elif img.startswith("http://") or img.startswith("https://"):
|
||||
# URL - would need to download, skip for now
|
||||
logger.warning("Image URLs not yet supported, skipping: %s", img)
|
||||
elif Path(img).exists():
|
||||
# Local file path - read and encode
|
||||
try:
|
||||
with open(img, "rb") as f:
|
||||
img_data = base64.b64encode(f.read()).decode()
|
||||
new_msg["images"].append(img_data)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to read image %s: %s", img, exc)
|
||||
|
||||
transformed.append(new_msg)
|
||||
|
||||
return transformed
|
||||
|
||||
async def _call_openai(
|
||||
self,
|
||||
provider: Provider,
|
||||
@@ -496,7 +735,7 @@ class CascadeRouter:
|
||||
"content": response.choices[0].message.content,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
|
||||
def _record_success(self, provider: Provider, latency_ms: float) -> None:
|
||||
"""Record a successful request."""
|
||||
provider.metrics.total_requests += 1
|
||||
@@ -598,6 +837,35 @@ class CascadeRouter:
|
||||
for p in self.providers
|
||||
],
|
||||
}
|
||||
|
||||
async def generate_with_image(
|
||||
self,
|
||||
prompt: str,
|
||||
image_path: str,
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
) -> dict:
|
||||
"""Convenience method for vision requests.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt about the image
|
||||
image_path: Path to image file
|
||||
model: Vision-capable model (auto-selected if not provided)
|
||||
temperature: Sampling temperature
|
||||
|
||||
Returns:
|
||||
Response dict with content and metadata
|
||||
"""
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"images": [image_path],
|
||||
}]
|
||||
return await self.complete(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
|
||||
@@ -5,11 +5,16 @@ Memory Architecture:
|
||||
- Tier 2 (Vault): memory/ — structured markdown, append-only
|
||||
- Tier 3 (Semantic): Vector search over vault files
|
||||
|
||||
Model Management:
|
||||
- Pulls requested model automatically if not available
|
||||
- Falls back through capability-based model chains
|
||||
- Multi-modal support with vision model fallbacks
|
||||
|
||||
Handoff Protocol maintains continuity across sessions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from agno.agent import Agent
|
||||
from agno.db.sqlite import SqliteDb
|
||||
@@ -24,6 +29,23 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Fallback chain for text/tool models (in order of preference)
|
||||
DEFAULT_MODEL_FALLBACKS = [
|
||||
"llama3.1:8b-instruct",
|
||||
"llama3.1",
|
||||
"qwen2.5:14b",
|
||||
"qwen2.5:7b",
|
||||
"llama3.2:3b",
|
||||
]
|
||||
|
||||
# Fallback chain for vision models
|
||||
VISION_MODEL_FALLBACKS = [
|
||||
"llama3.2:3b",
|
||||
"llava:7b",
|
||||
"qwen2.5-vl:3b",
|
||||
"moondream:1.8b",
|
||||
]
|
||||
|
||||
# Union type for callers that want to hint the return type.
|
||||
TimmyAgent = Union[Agent, "TimmyAirLLMAgent", "GrokBackend"]
|
||||
|
||||
@@ -40,6 +62,120 @@ _SMALL_MODEL_PATTERNS = (
|
||||
)
|
||||
|
||||
|
||||
def _check_model_available(model_name: str) -> bool:
|
||||
"""Check if an Ollama model is available locally."""
|
||||
try:
|
||||
import urllib.request
|
||||
import json
|
||||
|
||||
url = settings.ollama_url.replace("localhost", "127.0.0.1")
|
||||
req = urllib.request.Request(
|
||||
f"{url}/api/tags",
|
||||
method="GET",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=5) as response:
|
||||
data = json.loads(response.read().decode())
|
||||
models = [m.get("name", "") for m in data.get("models", [])]
|
||||
# Check for exact match or model name without tag
|
||||
return any(
|
||||
model_name == m or model_name == m.split(":")[0] or m.startswith(model_name)
|
||||
for m in models
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("Could not check model availability: %s", exc)
|
||||
return False
|
||||
|
||||
|
||||
def _pull_model(model_name: str) -> bool:
|
||||
"""Attempt to pull a model from Ollama.
|
||||
|
||||
Returns:
|
||||
True if successful or model already exists
|
||||
"""
|
||||
try:
|
||||
import urllib.request
|
||||
import json
|
||||
|
||||
logger.info("Pulling model: %s", model_name)
|
||||
|
||||
url = settings.ollama_url.replace("localhost", "127.0.0.1")
|
||||
req = urllib.request.Request(
|
||||
f"{url}/api/pull",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps({"name": model_name, "stream": False}).encode(),
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(req, timeout=300) as response:
|
||||
if response.status == 200:
|
||||
logger.info("Successfully pulled model: %s", model_name)
|
||||
return True
|
||||
else:
|
||||
logger.error("Failed to pull %s: HTTP %s", model_name, response.status)
|
||||
return False
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error pulling model %s: %s", model_name, exc)
|
||||
return False
|
||||
|
||||
|
||||
def _resolve_model_with_fallback(
|
||||
requested_model: Optional[str] = None,
|
||||
require_vision: bool = False,
|
||||
auto_pull: bool = True,
|
||||
) -> tuple[str, bool]:
|
||||
"""Resolve model with automatic pulling and fallback.
|
||||
|
||||
Args:
|
||||
requested_model: Preferred model to use
|
||||
require_vision: Whether the model needs vision capabilities
|
||||
auto_pull: Whether to attempt pulling missing models
|
||||
|
||||
Returns:
|
||||
Tuple of (model_name, is_fallback)
|
||||
"""
|
||||
model = requested_model or settings.ollama_model
|
||||
|
||||
# Check if requested model is available
|
||||
if _check_model_available(model):
|
||||
logger.debug("Using available model: %s", model)
|
||||
return model, False
|
||||
|
||||
# Try to pull the requested model
|
||||
if auto_pull:
|
||||
logger.info("Model %s not available locally, attempting to pull...", model)
|
||||
if _pull_model(model):
|
||||
return model, False
|
||||
logger.warning("Failed to pull %s, checking fallbacks...", model)
|
||||
|
||||
# Use appropriate fallback chain
|
||||
fallback_chain = VISION_MODEL_FALLBACKS if require_vision else DEFAULT_MODEL_FALLBACKS
|
||||
|
||||
for fallback_model in fallback_chain:
|
||||
if _check_model_available(fallback_model):
|
||||
logger.warning(
|
||||
"Using fallback model %s (requested: %s)",
|
||||
fallback_model, model
|
||||
)
|
||||
return fallback_model, True
|
||||
|
||||
# Try to pull the fallback
|
||||
if auto_pull and _pull_model(fallback_model):
|
||||
logger.info(
|
||||
"Pulled and using fallback model %s (requested: %s)",
|
||||
fallback_model, model
|
||||
)
|
||||
return fallback_model, True
|
||||
|
||||
# Absolute last resort - return the requested model and hope for the best
|
||||
logger.error(
|
||||
"No models available in fallback chain. Requested: %s",
|
||||
model
|
||||
)
|
||||
return model, False
|
||||
|
||||
|
||||
def _model_supports_tools(model_name: str) -> bool:
|
||||
"""Check if the configured model can reliably handle tool calling.
|
||||
|
||||
@@ -106,7 +242,16 @@ def create_timmy(
|
||||
return TimmyAirLLMAgent(model_size=size)
|
||||
|
||||
# Default: Ollama via Agno.
|
||||
model_name = settings.ollama_model
|
||||
# Resolve model with automatic pulling and fallback
|
||||
model_name, is_fallback = _resolve_model_with_fallback(
|
||||
requested_model=None,
|
||||
require_vision=False,
|
||||
auto_pull=True,
|
||||
)
|
||||
|
||||
if is_fallback:
|
||||
logger.info("Using fallback model %s (requested was unavailable)", model_name)
|
||||
|
||||
use_tools = _model_supports_tools(model_name)
|
||||
|
||||
# Conditionally include tools — small models get none
|
||||
|
||||
@@ -31,7 +31,7 @@ from timmy.agent_core.interface import (
|
||||
TimAgent,
|
||||
AgentEffect,
|
||||
)
|
||||
from timmy.agent import create_timmy
|
||||
from timmy.agent import create_timmy, _resolve_model_with_fallback
|
||||
|
||||
|
||||
class OllamaAgent(TimAgent):
|
||||
@@ -53,18 +53,33 @@ class OllamaAgent(TimAgent):
|
||||
identity: AgentIdentity,
|
||||
model: Optional[str] = None,
|
||||
effect_log: Optional[str] = None,
|
||||
require_vision: bool = False,
|
||||
) -> None:
|
||||
"""Initialize Ollama-based agent.
|
||||
|
||||
Args:
|
||||
identity: Agent identity (persistent across sessions)
|
||||
model: Ollama model to use (default from config)
|
||||
model: Ollama model to use (auto-resolves with fallback)
|
||||
effect_log: Path to log agent effects (optional)
|
||||
require_vision: Whether to select a vision-capable model
|
||||
"""
|
||||
super().__init__(identity)
|
||||
|
||||
# Resolve model with automatic pulling and fallback
|
||||
resolved_model, is_fallback = _resolve_model_with_fallback(
|
||||
requested_model=model,
|
||||
require_vision=require_vision,
|
||||
auto_pull=True,
|
||||
)
|
||||
|
||||
if is_fallback:
|
||||
import logging
|
||||
logging.getLogger(__name__).info(
|
||||
"OllamaAdapter using fallback model %s", resolved_model
|
||||
)
|
||||
|
||||
# Initialize underlying Ollama agent
|
||||
self._timmy = create_timmy(model=model)
|
||||
self._timmy = create_timmy(model=resolved_model)
|
||||
|
||||
# Set capabilities based on what Ollama can do
|
||||
self._capabilities = {
|
||||
|
||||
Reference in New Issue
Block a user