forked from Rockachopa/Timmy-time-dashboard
718 lines
26 KiB
Python
718 lines
26 KiB
Python
"""Cascade LLM Router — Automatic failover between providers.
|
||
|
||
Routes requests through an ordered list of LLM providers,
|
||
automatically failing over on rate limits or errors.
|
||
Tracks metrics for latency, errors, and cost.
|
||
|
||
Now with multi-modal support — automatically selects vision-capable
|
||
models for image inputs and falls back through capability chains.
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
from pathlib import Path
|
||
from typing import TYPE_CHECKING, Any
|
||
|
||
if TYPE_CHECKING:
|
||
from infrastructure.router.classifier import TaskComplexity
|
||
|
||
from config import settings
|
||
|
||
try:
|
||
import yaml
|
||
except ImportError:
|
||
yaml = None # type: ignore
|
||
|
||
try:
|
||
import requests
|
||
except ImportError:
|
||
requests = None # type: ignore
|
||
|
||
# Re-export data models so existing ``from …cascade import X`` keeps working.
|
||
# Mixins
|
||
from .health import HealthMixin
|
||
from .models import ( # noqa: F401 – re-exports
|
||
CircuitState,
|
||
ContentType,
|
||
ModelCapability,
|
||
Provider,
|
||
ProviderMetrics,
|
||
ProviderStatus,
|
||
RouterConfig,
|
||
)
|
||
from .providers import ProviderCallsMixin
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class CascadeRouter(HealthMixin, ProviderCallsMixin):
|
||
"""Routes LLM requests with automatic failover.
|
||
|
||
Now with multi-modal support:
|
||
- Automatically detects content type (text, vision, audio)
|
||
- Selects appropriate models based on capabilities
|
||
- Falls back through capability-specific model chains
|
||
- Supports image URLs and base64 encoding
|
||
|
||
Usage:
|
||
router = CascadeRouter()
|
||
|
||
# Text request
|
||
response = await router.complete(
|
||
messages=[{"role": "user", "content": "Hello"}],
|
||
model="llama3.2"
|
||
)
|
||
|
||
# Vision request (automatically detects and selects vision model)
|
||
response = await router.complete(
|
||
messages=[{
|
||
"role": "user",
|
||
"content": "What's in this image?",
|
||
"images": ["path/to/image.jpg"]
|
||
}],
|
||
model="llava:7b"
|
||
)
|
||
|
||
# Check metrics
|
||
metrics = router.get_metrics()
|
||
"""
|
||
|
||
def __init__(self, config_path: Path | None = None) -> None:
|
||
self.config_path = config_path or Path("config/providers.yaml")
|
||
self.providers: list[Provider] = []
|
||
self.config: RouterConfig = RouterConfig()
|
||
self._load_config()
|
||
|
||
# Initialize multi-modal manager if available
|
||
self._mm_manager: Any | None = None
|
||
try:
|
||
from infrastructure.models.multimodal import get_multimodal_manager
|
||
|
||
self._mm_manager = get_multimodal_manager()
|
||
except Exception as exc:
|
||
logger.debug("Multi-modal manager not available: %s", exc)
|
||
|
||
logger.info("CascadeRouter initialized with %d providers", len(self.providers))
|
||
|
||
def _load_config(self) -> None:
|
||
"""Load configuration from YAML."""
|
||
if not self.config_path.exists():
|
||
logger.warning("Config not found: %s, using defaults", self.config_path)
|
||
return
|
||
|
||
try:
|
||
if yaml is None:
|
||
raise RuntimeError("PyYAML not installed")
|
||
|
||
content = self.config_path.read_text()
|
||
content = self._expand_env_vars(content)
|
||
data = yaml.safe_load(content)
|
||
|
||
self.config = self._parse_router_config(data)
|
||
self._load_providers(data)
|
||
|
||
except Exception as exc:
|
||
logger.error("Failed to load config: %s", exc)
|
||
|
||
def _parse_router_config(self, data: dict) -> RouterConfig:
|
||
"""Build a RouterConfig from parsed YAML data."""
|
||
cascade = data.get("cascade", {})
|
||
cb = cascade.get("circuit_breaker", {})
|
||
multimodal = data.get("multimodal", {})
|
||
|
||
return RouterConfig(
|
||
timeout_seconds=cascade.get("timeout_seconds", 30),
|
||
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
|
||
retry_delay_seconds=cascade.get("retry_delay_seconds", 1),
|
||
circuit_breaker_failure_threshold=cb.get("failure_threshold", 5),
|
||
circuit_breaker_recovery_timeout=cb.get("recovery_timeout", 60),
|
||
circuit_breaker_half_open_max_calls=cb.get("half_open_max_calls", 2),
|
||
auto_pull_models=multimodal.get("auto_pull", True),
|
||
fallback_chains=data.get("fallback_chains", {}),
|
||
)
|
||
|
||
def _load_providers(self, data: dict) -> None:
|
||
"""Load, filter, and sort providers from parsed YAML data."""
|
||
for p_data in data.get("providers", []):
|
||
if not p_data.get("enabled", False):
|
||
continue
|
||
|
||
provider = Provider(
|
||
name=p_data["name"],
|
||
type=p_data["type"],
|
||
enabled=p_data.get("enabled", True),
|
||
priority=p_data.get("priority", 99),
|
||
tier=p_data.get("tier"),
|
||
url=p_data.get("url"),
|
||
api_key=p_data.get("api_key"),
|
||
base_url=p_data.get("base_url"),
|
||
models=p_data.get("models", []),
|
||
)
|
||
|
||
if self._check_provider_available(provider):
|
||
self.providers.append(provider)
|
||
else:
|
||
logger.warning("Provider %s not available, skipping", provider.name)
|
||
|
||
self.providers.sort(key=lambda p: p.priority)
|
||
|
||
def _expand_env_vars(self, content: str) -> str:
|
||
"""Expand ${VAR} syntax in YAML content.
|
||
|
||
Uses os.environ directly (not settings) because this is a generic
|
||
YAML config loader that must expand arbitrary variable references.
|
||
"""
|
||
import os
|
||
import re
|
||
|
||
def replace_var(match: "re.Match[str]") -> str:
|
||
var_name = match.group(1)
|
||
return os.environ.get(var_name, match.group(0))
|
||
|
||
return re.sub(r"\$\{(\w+)\}", replace_var, content)
|
||
|
||
def _check_provider_available(self, provider: Provider) -> bool:
|
||
"""Check if a provider is actually available."""
|
||
if provider.type == "ollama":
|
||
# Check if Ollama is running
|
||
if requests is None:
|
||
# Can't check without requests, assume available
|
||
return True
|
||
try:
|
||
url = provider.url or settings.ollama_url
|
||
response = requests.get(f"{url}/api/tags", timeout=5)
|
||
return response.status_code == 200
|
||
except Exception as exc:
|
||
logger.debug("Ollama provider check error: %s", exc)
|
||
return False
|
||
|
||
elif provider.type == "vllm_mlx":
|
||
# Check if local vllm-mlx server is running (OpenAI-compatible)
|
||
if requests is None:
|
||
return True
|
||
try:
|
||
base_url = provider.base_url or provider.url or "http://localhost:8000"
|
||
# Strip /v1 suffix — health endpoint is at the root
|
||
server_root = base_url.rstrip("/")
|
||
if server_root.endswith("/v1"):
|
||
server_root = server_root[:-3]
|
||
response = requests.get(f"{server_root}/health", timeout=5)
|
||
return response.status_code == 200
|
||
except Exception as exc:
|
||
logger.debug("vllm-mlx provider check error: %s", exc)
|
||
return False
|
||
|
||
elif provider.type in ("openai", "anthropic", "grok"):
|
||
# Check if API key is set
|
||
return provider.api_key is not None and provider.api_key != ""
|
||
|
||
return True
|
||
|
||
def _detect_content_type(self, messages: list[dict]) -> ContentType:
|
||
"""Detect the type of content in the messages.
|
||
|
||
Checks for images, audio, etc. in the message content.
|
||
"""
|
||
has_image = False
|
||
has_audio = False
|
||
|
||
for msg in messages:
|
||
content = msg.get("content", "")
|
||
|
||
# Check for image URLs/paths
|
||
if msg.get("images"):
|
||
has_image = True
|
||
|
||
# Check for image URLs in content
|
||
if isinstance(content, str):
|
||
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
|
||
if any(ext in content.lower() for ext in image_extensions):
|
||
has_image = True
|
||
if content.startswith("data:image/"):
|
||
has_image = True
|
||
|
||
# Check for audio
|
||
if msg.get("audio"):
|
||
has_audio = True
|
||
|
||
# Check for multimodal content structure
|
||
if isinstance(content, list):
|
||
for item in content:
|
||
if isinstance(item, dict):
|
||
if item.get("type") == "image_url":
|
||
has_image = True
|
||
elif item.get("type") == "audio":
|
||
has_audio = True
|
||
|
||
if has_image and has_audio:
|
||
return ContentType.MULTIMODAL
|
||
elif has_image:
|
||
return ContentType.VISION
|
||
elif has_audio:
|
||
return ContentType.AUDIO
|
||
return ContentType.TEXT
|
||
|
||
def _get_fallback_model(
|
||
self, provider: Provider, original_model: str, content_type: ContentType
|
||
) -> str | None:
|
||
"""Get a fallback model for the given content type."""
|
||
# Map content type to capability
|
||
capability_map = {
|
||
ContentType.VISION: "vision",
|
||
ContentType.AUDIO: "audio",
|
||
ContentType.MULTIMODAL: "vision", # Vision models often do both
|
||
}
|
||
|
||
capability = capability_map.get(content_type)
|
||
if not capability:
|
||
return None
|
||
|
||
# Check provider's models for capability
|
||
fallback_model = provider.get_model_with_capability(capability)
|
||
if fallback_model and fallback_model != original_model:
|
||
return fallback_model
|
||
|
||
# Use fallback chains from config
|
||
fallback_chain = self.config.fallback_chains.get(capability, [])
|
||
for model_name in fallback_chain:
|
||
if provider.model_has_capability(model_name, capability):
|
||
return model_name
|
||
|
||
return None
|
||
|
||
def _select_model(
|
||
self, provider: Provider, model: str | None, content_type: ContentType
|
||
) -> tuple[str | None, bool]:
|
||
"""Select the best model for the request, with vision fallback.
|
||
|
||
Returns:
|
||
Tuple of (selected_model, is_fallback_model).
|
||
"""
|
||
selected_model = model or provider.get_default_model()
|
||
is_fallback = False
|
||
|
||
if content_type != ContentType.TEXT and selected_model:
|
||
if provider.type == "ollama" and self._mm_manager:
|
||
from infrastructure.models.multimodal import ModelCapability
|
||
|
||
if content_type == ContentType.VISION:
|
||
supports = self._mm_manager.model_supports(
|
||
selected_model, ModelCapability.VISION
|
||
)
|
||
if not supports:
|
||
fallback = self._get_fallback_model(provider, selected_model, content_type)
|
||
if fallback:
|
||
logger.info(
|
||
"Model %s doesn't support vision, falling back to %s",
|
||
selected_model,
|
||
fallback,
|
||
)
|
||
selected_model = fallback
|
||
is_fallback = True
|
||
else:
|
||
logger.warning(
|
||
"No vision-capable model found on %s, trying anyway",
|
||
provider.name,
|
||
)
|
||
|
||
return selected_model, is_fallback
|
||
|
||
async def _attempt_with_retry(
|
||
self,
|
||
provider: Provider,
|
||
messages: list[dict],
|
||
model: str | None,
|
||
temperature: float,
|
||
max_tokens: int | None,
|
||
content_type: ContentType,
|
||
) -> dict:
|
||
"""Try a provider with retries, returning the result dict.
|
||
|
||
Raises:
|
||
RuntimeError: If all retry attempts fail.
|
||
Returns error strings collected during retries via the exception message.
|
||
"""
|
||
errors: list[str] = []
|
||
for attempt in range(self.config.max_retries_per_provider):
|
||
try:
|
||
return await self._try_provider(
|
||
provider=provider,
|
||
messages=messages,
|
||
model=model,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
content_type=content_type,
|
||
)
|
||
except Exception as exc:
|
||
error_msg = str(exc)
|
||
logger.warning(
|
||
"Provider %s attempt %d failed: %s",
|
||
provider.name,
|
||
attempt + 1,
|
||
error_msg,
|
||
)
|
||
errors.append(f"{provider.name}: {error_msg}")
|
||
|
||
if attempt < self.config.max_retries_per_provider - 1:
|
||
await asyncio.sleep(self.config.retry_delay_seconds)
|
||
|
||
raise RuntimeError("; ".join(errors))
|
||
|
||
def _filter_providers(self, cascade_tier: str | None) -> list["Provider"]:
|
||
"""Return the provider list filtered by tier.
|
||
|
||
Raises:
|
||
RuntimeError: If a tier is specified but no matching providers exist.
|
||
"""
|
||
if cascade_tier == "frontier_required":
|
||
providers = [p for p in self.providers if p.type == "anthropic"]
|
||
if not providers:
|
||
raise RuntimeError("No Anthropic provider configured for 'frontier_required' tier.")
|
||
return providers
|
||
if cascade_tier:
|
||
providers = [p for p in self.providers if p.tier == cascade_tier]
|
||
if not providers:
|
||
raise RuntimeError(f"No providers found for tier: {cascade_tier}")
|
||
return providers
|
||
return self.providers
|
||
|
||
async def _try_single_provider(
|
||
self,
|
||
provider: "Provider",
|
||
messages: list[dict],
|
||
model: str | None,
|
||
temperature: float,
|
||
max_tokens: int | None,
|
||
content_type: ContentType,
|
||
errors: list[str],
|
||
) -> dict | None:
|
||
"""Attempt one provider, returning a result dict on success or None on failure.
|
||
|
||
On failure the error string is appended to *errors* and the provider's
|
||
failure metrics are updated so the caller can move on to the next provider.
|
||
"""
|
||
if not self._is_provider_available(provider):
|
||
return None
|
||
|
||
# Metabolic protocol: skip cloud providers when quota is low
|
||
if provider.type in ("anthropic", "openai", "grok"):
|
||
if not self._quota_allows_cloud(provider):
|
||
logger.info(
|
||
"Metabolic protocol: skipping cloud provider %s (quota too low)",
|
||
provider.name,
|
||
)
|
||
return None
|
||
|
||
selected_model, is_fallback_model = self._select_model(provider, model, content_type)
|
||
|
||
try:
|
||
result = await self._attempt_with_retry(
|
||
provider, messages, selected_model, temperature, max_tokens, content_type
|
||
)
|
||
except RuntimeError as exc:
|
||
errors.append(str(exc))
|
||
self._record_failure(provider)
|
||
return None
|
||
|
||
self._record_success(provider, result.get("latency_ms", 0))
|
||
return {
|
||
"content": result["content"],
|
||
"provider": provider.name,
|
||
"model": result.get("model", selected_model or provider.get_default_model()),
|
||
"latency_ms": result.get("latency_ms", 0),
|
||
"is_fallback_model": is_fallback_model,
|
||
}
|
||
|
||
def _get_model_for_complexity(
|
||
self, provider: Provider, complexity: "TaskComplexity"
|
||
) -> str | None:
|
||
"""Return the best model on *provider* for the given complexity tier.
|
||
|
||
Checks fallback chains first (routine / complex), then falls back to
|
||
any model with the matching capability tag, then the provider default.
|
||
"""
|
||
from infrastructure.router.classifier import TaskComplexity
|
||
|
||
chain_key = "routine" if complexity == TaskComplexity.SIMPLE else "complex"
|
||
|
||
# Walk the capability fallback chain — first model present on this provider wins
|
||
for model_name in self.config.fallback_chains.get(chain_key, []):
|
||
if any(m["name"] == model_name for m in provider.models):
|
||
return model_name
|
||
|
||
# Direct capability lookup — only return if a model explicitly has the tag
|
||
# (do not use get_model_with_capability here as it falls back to the default)
|
||
cap_model = next(
|
||
(m["name"] for m in provider.models if chain_key in m.get("capabilities", [])),
|
||
None,
|
||
)
|
||
if cap_model:
|
||
return cap_model
|
||
|
||
return None # Caller will use provider default
|
||
|
||
async def complete(
|
||
self,
|
||
messages: list[dict],
|
||
model: str | None = None,
|
||
temperature: float = 0.7,
|
||
max_tokens: int | None = None,
|
||
cascade_tier: str | None = None,
|
||
complexity_hint: str | None = None,
|
||
) -> dict:
|
||
"""Complete a chat conversation with automatic failover.
|
||
|
||
Multi-modal support:
|
||
- Automatically detects if messages contain images
|
||
- Falls back to vision-capable models when needed
|
||
- Supports image URLs, paths, and base64 encoding
|
||
|
||
Complexity-based routing (issue #1065):
|
||
- ``complexity_hint="simple"`` -> routes to Qwen3-8B (low-latency)
|
||
- ``complexity_hint="complex"`` -> routes to Qwen3-14B (quality)
|
||
- ``complexity_hint=None`` (default) -> auto-classifies from messages
|
||
|
||
Args:
|
||
messages: List of message dicts with role and content
|
||
model: Preferred model (tries this first; complexity routing is
|
||
skipped when an explicit model is given)
|
||
temperature: Sampling temperature
|
||
max_tokens: Maximum tokens to generate
|
||
cascade_tier: If specified, filters providers by this tier.
|
||
- "frontier_required": Uses only Anthropic provider for top-tier models.
|
||
complexity_hint: "simple", "complex", or None (auto-detect).
|
||
|
||
Returns:
|
||
Dict with content, provider_used, model, latency_ms,
|
||
is_fallback_model, and complexity fields.
|
||
|
||
Raises:
|
||
RuntimeError: If all providers fail
|
||
"""
|
||
from infrastructure.router.classifier import TaskComplexity, classify_task
|
||
|
||
content_type = self._detect_content_type(messages)
|
||
if content_type != ContentType.TEXT:
|
||
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
|
||
|
||
# Resolve task complexity
|
||
# Skip complexity routing when caller explicitly specifies a model.
|
||
complexity: TaskComplexity | None = None
|
||
if model is None:
|
||
if complexity_hint is not None:
|
||
try:
|
||
complexity = TaskComplexity(complexity_hint.lower())
|
||
except ValueError:
|
||
logger.warning("Unknown complexity_hint %r, auto-classifying", complexity_hint)
|
||
complexity = classify_task(messages)
|
||
else:
|
||
complexity = classify_task(messages)
|
||
logger.debug("Task complexity: %s", complexity.value)
|
||
|
||
errors: list[str] = []
|
||
providers = self._filter_providers(cascade_tier)
|
||
|
||
for provider in providers:
|
||
if not self._is_provider_available(provider):
|
||
continue
|
||
|
||
# Metabolic protocol: skip cloud providers when quota is low
|
||
if provider.type in ("anthropic", "openai", "grok"):
|
||
if not self._quota_allows_cloud(provider):
|
||
logger.info(
|
||
"Metabolic protocol: skipping cloud provider %s (quota too low)",
|
||
provider.name,
|
||
)
|
||
continue
|
||
|
||
# Complexity-based model selection (only when no explicit model)
|
||
effective_model = model
|
||
if effective_model is None and complexity is not None:
|
||
effective_model = self._get_model_for_complexity(provider, complexity)
|
||
if effective_model:
|
||
logger.debug(
|
||
"Complexity routing [%s]: %s → %s",
|
||
complexity.value,
|
||
provider.name,
|
||
effective_model,
|
||
)
|
||
|
||
selected_model, is_fallback_model = self._select_model(
|
||
provider, effective_model, content_type
|
||
)
|
||
|
||
try:
|
||
result = await self._attempt_with_retry(
|
||
provider,
|
||
messages,
|
||
selected_model,
|
||
temperature,
|
||
max_tokens,
|
||
content_type,
|
||
)
|
||
except RuntimeError as exc:
|
||
errors.append(str(exc))
|
||
self._record_failure(provider)
|
||
continue
|
||
|
||
self._record_success(provider, result.get("latency_ms", 0))
|
||
return {
|
||
"content": result["content"],
|
||
"provider": provider.name,
|
||
"model": result.get("model", selected_model or provider.get_default_model()),
|
||
"latency_ms": result.get("latency_ms", 0),
|
||
"is_fallback_model": is_fallback_model,
|
||
"complexity": complexity.value if complexity is not None else None,
|
||
}
|
||
|
||
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
|
||
|
||
def reload_config(self) -> dict:
|
||
"""Hot-reload providers.yaml, preserving runtime state.
|
||
|
||
Re-reads the config file, rebuilds the provider list, and
|
||
preserves circuit breaker state and metrics for providers
|
||
that still exist after reload.
|
||
|
||
Returns:
|
||
Summary dict with added/removed/preserved counts.
|
||
"""
|
||
# Snapshot current runtime state keyed by provider name
|
||
old_state: dict[
|
||
str, tuple[ProviderMetrics, CircuitState, float | None, int, ProviderStatus]
|
||
] = {}
|
||
for p in self.providers:
|
||
old_state[p.name] = (
|
||
p.metrics,
|
||
p.circuit_state,
|
||
p.circuit_opened_at,
|
||
p.half_open_calls,
|
||
p.status,
|
||
)
|
||
|
||
old_names = set(old_state.keys())
|
||
|
||
# Reload from disk
|
||
self.providers = []
|
||
self._load_config()
|
||
|
||
# Restore preserved state
|
||
new_names = {p.name for p in self.providers}
|
||
preserved = 0
|
||
for p in self.providers:
|
||
if p.name in old_state:
|
||
metrics, circuit, opened_at, half_open, status = old_state[p.name]
|
||
p.metrics = metrics
|
||
p.circuit_state = circuit
|
||
p.circuit_opened_at = opened_at
|
||
p.half_open_calls = half_open
|
||
p.status = status
|
||
preserved += 1
|
||
|
||
added = new_names - old_names
|
||
removed = old_names - new_names
|
||
|
||
logger.info(
|
||
"Config reloaded: %d providers (%d preserved, %d added, %d removed)",
|
||
len(self.providers),
|
||
preserved,
|
||
len(added),
|
||
len(removed),
|
||
)
|
||
|
||
return {
|
||
"total_providers": len(self.providers),
|
||
"preserved": preserved,
|
||
"added": sorted(added),
|
||
"removed": sorted(removed),
|
||
}
|
||
|
||
def get_metrics(self) -> dict:
|
||
"""Get metrics for all providers."""
|
||
return {
|
||
"providers": [
|
||
{
|
||
"name": p.name,
|
||
"type": p.type,
|
||
"status": p.status.value,
|
||
"circuit_state": p.circuit_state.value,
|
||
"metrics": {
|
||
"total_requests": p.metrics.total_requests,
|
||
"successful": p.metrics.successful_requests,
|
||
"failed": p.metrics.failed_requests,
|
||
"error_rate": round(p.metrics.error_rate, 3),
|
||
"avg_latency_ms": round(p.metrics.avg_latency_ms, 2),
|
||
},
|
||
}
|
||
for p in self.providers
|
||
]
|
||
}
|
||
|
||
def get_status(self) -> dict:
|
||
"""Get current router status."""
|
||
healthy = sum(1 for p in self.providers if p.status == ProviderStatus.HEALTHY)
|
||
|
||
return {
|
||
"total_providers": len(self.providers),
|
||
"healthy_providers": healthy,
|
||
"degraded_providers": sum(
|
||
1 for p in self.providers if p.status == ProviderStatus.DEGRADED
|
||
),
|
||
"unhealthy_providers": sum(
|
||
1 for p in self.providers if p.status == ProviderStatus.UNHEALTHY
|
||
),
|
||
"providers": [
|
||
{
|
||
"name": p.name,
|
||
"type": p.type,
|
||
"status": p.status.value,
|
||
"priority": p.priority,
|
||
"default_model": p.get_default_model(),
|
||
}
|
||
for p in self.providers
|
||
],
|
||
}
|
||
|
||
async def generate_with_image(
|
||
self,
|
||
prompt: str,
|
||
image_path: str,
|
||
model: str | None = None,
|
||
temperature: float = 0.7,
|
||
) -> dict:
|
||
"""Convenience method for vision requests.
|
||
|
||
Args:
|
||
prompt: Text prompt about the image
|
||
image_path: Path to image file
|
||
model: Vision-capable model (auto-selected if not provided)
|
||
temperature: Sampling temperature
|
||
|
||
Returns:
|
||
Response dict with content and metadata
|
||
"""
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": prompt,
|
||
"images": [image_path],
|
||
}
|
||
]
|
||
return await self.complete(
|
||
messages=messages,
|
||
model=model,
|
||
temperature=temperature,
|
||
)
|
||
|
||
|
||
# Module-level singleton
|
||
cascade_router: CascadeRouter | None = None
|
||
|
||
|
||
def get_router() -> CascadeRouter:
|
||
"""Get or create the cascade router singleton."""
|
||
global cascade_router
|
||
if cascade_router is None:
|
||
cascade_router = CascadeRouter()
|
||
return cascade_router
|