diff --git a/src/infrastructure/router/cascade.py b/src/infrastructure/router/cascade.py index 306b157f..ca50264a 100644 --- a/src/infrastructure/router/cascade.py +++ b/src/infrastructure/router/cascade.py @@ -1,38 +1,37 @@ -"""Cascade LLM Router — Automatic failover between providers. +"""Cascade LLM Router — Automatic failover between LLM providers. -Routes requests through an ordered list of LLM providers, -automatically failing over on rate limits or errors. -Tracks metrics for latency, errors, and cost. - -Now with multi-modal support — automatically selects vision-capable -models for image inputs and falls back through capability chains. +Supports multi-modal content (vision/audio), circuit-breaker health tracking, +complexity-based model routing, and the metabolic quota protocol. """ import asyncio -import base64 import logging -import time -from dataclasses import dataclass, field -from datetime import UTC, datetime -from enum import Enum from pathlib import Path from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from infrastructure.router.classifier import TaskComplexity -from config import settings +# Re-export all public names so existing imports remain unchanged. +from infrastructure.router.models import ( # noqa: F401 + CircuitState, + ContentType, + ModelCapability, + Provider, + ProviderMetrics, + ProviderStatus, + RouterConfig, +) +from infrastructure.router import health as _health +from infrastructure.router import content as _content +from infrastructure.router import config_loader as _config_loader +from infrastructure.router import reporting as _reporting try: import yaml except ImportError: yaml = None # type: ignore -try: - import requests -except ImportError: - requests = None # type: ignore - logger = logging.getLogger(__name__) # Quota monitor — optional, degrades gracefully if unavailable @@ -45,165 +44,25 @@ except Exception as _exc: # pragma: no cover _quota_monitor = None -class ProviderStatus(Enum): - """Health status of a provider.""" +def _resolve_complexity( + messages: list[dict], complexity_hint: str | None +) -> "TaskComplexity": + """Resolve task complexity from a hint string or by auto-classifying messages.""" + from infrastructure.router.classifier import TaskComplexity, classify_task - HEALTHY = "healthy" - DEGRADED = "degraded" # Working but slow or occasional errors - UNHEALTHY = "unhealthy" # Circuit breaker open - DISABLED = "disabled" - - -class CircuitState(Enum): - """Circuit breaker state.""" - - CLOSED = "closed" # Normal operation - OPEN = "open" # Failing, rejecting requests - HALF_OPEN = "half_open" # Testing if recovered - - -class ContentType(Enum): - """Type of content in the request.""" - - TEXT = "text" - VISION = "vision" # Contains images - AUDIO = "audio" # Contains audio - MULTIMODAL = "multimodal" # Multiple content types - - -@dataclass -class ProviderMetrics: - """Metrics for a single provider.""" - - total_requests: int = 0 - successful_requests: int = 0 - failed_requests: int = 0 - total_latency_ms: float = 0.0 - last_request_time: str | None = None - last_error_time: str | None = None - consecutive_failures: int = 0 - - @property - def avg_latency_ms(self) -> float: - if self.total_requests == 0: - return 0.0 - return self.total_latency_ms / self.total_requests - - @property - def error_rate(self) -> float: - if self.total_requests == 0: - return 0.0 - return self.failed_requests / self.total_requests - - -@dataclass -class ModelCapability: - """Capabilities a model supports.""" - - name: str - supports_vision: bool = False - supports_audio: bool = False - supports_tools: bool = False - supports_json: bool = False - supports_streaming: bool = True - context_window: int = 4096 - - -@dataclass -class Provider: - """LLM provider configuration and state.""" - - name: str - type: str # ollama, openai, anthropic - enabled: bool - priority: int - tier: str | None = None # e.g., "local", "standard_cloud", "frontier" - url: str | None = None - api_key: str | None = None - base_url: str | None = None - models: list[dict] = field(default_factory=list) - - # Runtime state - status: ProviderStatus = ProviderStatus.HEALTHY - metrics: ProviderMetrics = field(default_factory=ProviderMetrics) - circuit_state: CircuitState = CircuitState.CLOSED - circuit_opened_at: float | None = None - half_open_calls: int = 0 - - def get_default_model(self) -> str | None: - """Get the default model for this provider.""" - for model in self.models: - if model.get("default"): - return model["name"] - if self.models: - return self.models[0]["name"] - return None - - def get_model_with_capability(self, capability: str) -> str | None: - """Get a model that supports the given capability.""" - for model in self.models: - capabilities = model.get("capabilities", []) - if capability in capabilities: - return model["name"] - # Fall back to default - return self.get_default_model() - - def model_has_capability(self, model_name: str, capability: str) -> bool: - """Check if a specific model has a capability.""" - for model in self.models: - if model["name"] == model_name: - capabilities = model.get("capabilities", []) - return capability in capabilities - return False - - -@dataclass -class RouterConfig: - """Cascade router configuration.""" - - timeout_seconds: int = 30 - max_retries_per_provider: int = 2 - retry_delay_seconds: int = 1 - circuit_breaker_failure_threshold: int = 5 - circuit_breaker_recovery_timeout: int = 60 - circuit_breaker_half_open_max_calls: int = 2 - cost_tracking_enabled: bool = True - budget_daily_usd: float = 10.0 - # Multi-modal settings - auto_pull_models: bool = True - fallback_chains: dict = field(default_factory=dict) + if complexity_hint is not None: + try: + return TaskComplexity(complexity_hint.lower()) + except ValueError: + logger.warning("Unknown complexity_hint %r, auto-classifying", complexity_hint) + return classify_task(messages) class CascadeRouter: - """Routes LLM requests with automatic failover. + """Routes LLM requests with automatic failover and multi-modal support. - Now with multi-modal support: - - Automatically detects content type (text, vision, audio) - - Selects appropriate models based on capabilities - - Falls back through capability-specific model chains - - Supports image URLs and base64 encoding - - Usage: - router = CascadeRouter() - - # Text request - response = await router.complete( - messages=[{"role": "user", "content": "Hello"}], - model="llama3.2" - ) - - # Vision request (automatically detects and selects vision model) - response = await router.complete( - messages=[{ - "role": "user", - "content": "What's in this image?", - "images": ["path/to/image.jpg"] - }], - model="llava:7b" - ) - - # Check metrics - metrics = router.get_metrics() + Detects content type (text/vision/audio), selects models by capability, + falls back through capability chains, and tracks circuit-breaker health. """ def __init__(self, config_path: Path | None = None) -> None: @@ -234,217 +93,92 @@ class CascadeRouter: raise RuntimeError("PyYAML not installed") content = self.config_path.read_text() - content = self._expand_env_vars(content) + content = _config_loader.expand_env_vars(content) data = yaml.safe_load(content) - self.config = self._parse_router_config(data) - self._load_providers(data) + self.config = _config_loader.parse_router_config(data) + providers = _config_loader.load_providers(data) + providers.sort(key=lambda p: p.priority) + self.providers = providers except Exception as exc: logger.error("Failed to load config: %s", exc) - def _parse_router_config(self, data: dict) -> RouterConfig: - """Build a RouterConfig from parsed YAML data.""" - cascade = data.get("cascade", {}) - cb = cascade.get("circuit_breaker", {}) - multimodal = data.get("multimodal", {}) - - return RouterConfig( - timeout_seconds=cascade.get("timeout_seconds", 30), - max_retries_per_provider=cascade.get("max_retries_per_provider", 2), - retry_delay_seconds=cascade.get("retry_delay_seconds", 1), - circuit_breaker_failure_threshold=cb.get("failure_threshold", 5), - circuit_breaker_recovery_timeout=cb.get("recovery_timeout", 60), - circuit_breaker_half_open_max_calls=cb.get("half_open_max_calls", 2), - auto_pull_models=multimodal.get("auto_pull", True), - fallback_chains=data.get("fallback_chains", {}), - ) - - def _load_providers(self, data: dict) -> None: - """Load, filter, and sort providers from parsed YAML data.""" - for p_data in data.get("providers", []): - if not p_data.get("enabled", False): - continue - - provider = Provider( - name=p_data["name"], - type=p_data["type"], - enabled=p_data.get("enabled", True), - priority=p_data.get("priority", 99), - tier=p_data.get("tier"), - url=p_data.get("url"), - api_key=p_data.get("api_key"), - base_url=p_data.get("base_url"), - models=p_data.get("models", []), - ) - - if self._check_provider_available(provider): - self.providers.append(provider) - else: - logger.warning("Provider %s not available, skipping", provider.name) - - self.providers.sort(key=lambda p: p.priority) - - def _expand_env_vars(self, content: str) -> str: - """Expand ${VAR} syntax in YAML content. - - Uses os.environ directly (not settings) because this is a generic - YAML config loader that must expand arbitrary variable references. - """ - import os - import re - - def replace_var(match: "re.Match[str]") -> str: - var_name = match.group(1) - return os.environ.get(var_name, match.group(0)) - - return re.sub(r"\$\{(\w+)\}", replace_var, content) - def _check_provider_available(self, provider: Provider) -> bool: - """Check if a provider is actually available.""" - if provider.type == "ollama": - # Check if Ollama is running - if requests is None: - # Can't check without requests, assume available - return True - try: - url = provider.url or settings.ollama_url - response = requests.get(f"{url}/api/tags", timeout=5) - return response.status_code == 200 - except Exception as exc: - logger.debug("Ollama provider check error: %s", exc) - return False - - elif provider.type == "vllm_mlx": - # Check if local vllm-mlx server is running (OpenAI-compatible) - if requests is None: - return True - try: - base_url = provider.base_url or provider.url or "http://localhost:8000" - # Strip /v1 suffix — health endpoint is at the root - server_root = base_url.rstrip("/") - if server_root.endswith("/v1"): - server_root = server_root[:-3] - response = requests.get(f"{server_root}/health", timeout=5) - return response.status_code == 200 - except Exception as exc: - logger.debug("vllm-mlx provider check error: %s", exc) - return False - - elif provider.type in ("openai", "anthropic", "grok"): - # Check if API key is set - return provider.api_key is not None and provider.api_key != "" - - return True + return _config_loader.check_provider_available(provider) def _detect_content_type(self, messages: list[dict]) -> ContentType: - """Detect the type of content in the messages. - - Checks for images, audio, etc. in the message content. - """ - has_image = False - has_audio = False - - for msg in messages: - content = msg.get("content", "") - - # Check for image URLs/paths - if msg.get("images"): - has_image = True - - # Check for image URLs in content - if isinstance(content, str): - image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp") - if any(ext in content.lower() for ext in image_extensions): - has_image = True - if content.startswith("data:image/"): - has_image = True - - # Check for audio - if msg.get("audio"): - has_audio = True - - # Check for multimodal content structure - if isinstance(content, list): - for item in content: - if isinstance(item, dict): - if item.get("type") == "image_url": - has_image = True - elif item.get("type") == "audio": - has_audio = True - - if has_image and has_audio: - return ContentType.MULTIMODAL - elif has_image: - return ContentType.VISION - elif has_audio: - return ContentType.AUDIO - return ContentType.TEXT + return _content.detect_content_type(messages) def _get_fallback_model( self, provider: Provider, original_model: str, content_type: ContentType ) -> str | None: - """Get a fallback model for the given content type.""" - # Map content type to capability - capability_map = { - ContentType.VISION: "vision", - ContentType.AUDIO: "audio", - ContentType.MULTIMODAL: "vision", # Vision models often do both - } - - capability = capability_map.get(content_type) - if not capability: - return None - - # Check provider's models for capability - fallback_model = provider.get_model_with_capability(capability) - if fallback_model and fallback_model != original_model: - return fallback_model - - # Use fallback chains from config - fallback_chain = self.config.fallback_chains.get(capability, []) - for model_name in fallback_chain: - if provider.model_has_capability(model_name, capability): - return model_name - - return None + return _content.get_fallback_model( + provider, original_model, content_type, self.config.fallback_chains + ) def _select_model( self, provider: Provider, model: str | None, content_type: ContentType ) -> tuple[str | None, bool]: - """Select the best model for the request, with vision fallback. + return _content.select_model( + provider, model, content_type, self._mm_manager, self.config.fallback_chains + ) - Returns: - Tuple of (selected_model, is_fallback_model). - """ - selected_model = model or provider.get_default_model() - is_fallback = False + def _record_success(self, provider: Provider, latency_ms: float) -> None: + _health.record_success(provider, latency_ms, self.config) - if content_type != ContentType.TEXT and selected_model: - if provider.type == "ollama" and self._mm_manager: - from infrastructure.models.multimodal import ModelCapability + def _record_failure(self, provider: Provider) -> None: + _health.record_failure(provider, self.config) - if content_type == ContentType.VISION: - supports = self._mm_manager.model_supports( - selected_model, ModelCapability.VISION - ) - if not supports: - fallback = self._get_fallback_model(provider, selected_model, content_type) - if fallback: - logger.info( - "Model %s doesn't support vision, falling back to %s", - selected_model, - fallback, - ) - selected_model = fallback - is_fallback = True - else: - logger.warning( - "No vision-capable model found on %s, trying anyway", - provider.name, - ) + def _can_close_circuit(self, provider: Provider) -> bool: + return _health.can_close_circuit(provider, self.config) - return selected_model, is_fallback + def _quota_allows_cloud(self, provider: Provider) -> bool: + """Return True when quota allows a cloud provider call (metabolic protocol).""" + if _quota_monitor is None: + return True + try: + suggested = _quota_monitor.select_model("high") + allows = suggested == "claude-sonnet-4-6" + if not allows: + status = _quota_monitor.check() + tier = status.recommended_tier.value if status else "unknown" + logger.info("Metabolic protocol: %s tier — downshifting %s (%s)", tier, provider.name, suggested) + return allows + except Exception as exc: + logger.warning("Quota check failed, allowing cloud: %s", exc) + return True + + def _is_provider_available(self, provider: Provider) -> bool: + """Check if a provider should be tried (enabled + circuit breaker).""" + if not provider.enabled: + logger.debug("Skipping %s (disabled)", provider.name) + return False + + if provider.status == ProviderStatus.UNHEALTHY: + if self._can_close_circuit(provider): + provider.circuit_state = CircuitState.HALF_OPEN + provider.half_open_calls = 0 + logger.info("Circuit breaker half-open for %s", provider.name) + else: + logger.debug("Skipping %s (circuit open)", provider.name) + return False + + return True + + def _filter_providers(self, cascade_tier: str | None) -> list["Provider"]: + """Return providers filtered by tier. Raises RuntimeError if tier has no providers.""" + if cascade_tier == "frontier_required": + ps = [p for p in self.providers if p.type == "anthropic"] + if not ps: + raise RuntimeError("No Anthropic provider configured for 'frontier_required' tier.") + return ps + if cascade_tier: + ps = [p for p in self.providers if p.tier == cascade_tier] + if not ps: + raise RuntimeError(f"No providers found for tier: {cascade_tier}") + return ps + return self.providers async def _attempt_with_retry( self, @@ -455,12 +189,7 @@ class CascadeRouter: max_tokens: int | None, content_type: ContentType, ) -> dict: - """Try a provider with retries, returning the result dict. - - Raises: - RuntimeError: If all retry attempts fail. - Returns error strings collected during retries via the exception message. - """ + """Try a provider with retries. Raises RuntimeError if all attempts fail.""" errors: list[str] = [] for attempt in range(self.config.max_retries_per_provider): try: @@ -487,67 +216,94 @@ class CascadeRouter: raise RuntimeError("; ".join(errors)) - def _quota_allows_cloud(self, provider: Provider) -> bool: - """Check quota before routing to a cloud provider. + async def _try_provider( + self, + provider: Provider, + messages: list[dict], + model: str, + temperature: float, + max_tokens: int | None, + content_type: ContentType = ContentType.TEXT, + ) -> dict: + """Dispatch a single request to the correct provider implementation.""" + from infrastructure.router.providers.dispatch import call_provider - Uses the metabolic protocol via select_model(): cloud calls are only - allowed when the quota monitor recommends a cloud model (BURST tier). - Returns True (allow cloud) if quota monitor is unavailable or returns None. + return await call_provider( + provider=provider, + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + timeout_seconds=self.config.timeout_seconds, + content_type=content_type, + ) + + def _get_model_for_complexity( + self, provider: Provider, complexity: "TaskComplexity" + ) -> str | None: + """Return the best model on *provider* for the given complexity tier.""" + from infrastructure.router.classifier import TaskComplexity + + chain_key = "routine" if complexity == TaskComplexity.SIMPLE else "complex" + # Fallback chain: first model present on this provider wins + for model_name in self.config.fallback_chains.get(chain_key, []): + if any(m["name"] == model_name for m in provider.models): + return model_name + return next( + (m["name"] for m in provider.models if chain_key in m.get("capabilities", [])), + None, + ) + + async def complete( # noqa: PLR0912 + self, + messages: list[dict], + model: str | None = None, + temperature: float = 0.7, + max_tokens: int | None = None, + cascade_tier: str | None = None, + complexity_hint: str | None = None, + ) -> dict: + """Complete a chat conversation with automatic failover. + + Supports complexity-based routing and multi-modal content. + Raises RuntimeError if all providers fail. """ - if _quota_monitor is None: - return True - try: - suggested = _quota_monitor.select_model("high") - # Cloud is allowed only when select_model recommends the cloud model - allows = suggested == "claude-sonnet-4-6" - if not allows: - status = _quota_monitor.check() - tier = status.recommended_tier.value if status else "unknown" - logger.info( - "Metabolic protocol: %s tier — downshifting %s to local (%s)", - tier, - provider.name, - suggested, - ) - return allows - except Exception as exc: - logger.warning("Quota check failed, allowing cloud: %s", exc) - return True + from infrastructure.router.classifier import TaskComplexity - def _is_provider_available(self, provider: Provider) -> bool: - """Check if a provider should be tried (enabled + circuit breaker).""" - if not provider.enabled: - logger.debug("Skipping %s (disabled)", provider.name) - return False + content_type = self._detect_content_type(messages) + if content_type != ContentType.TEXT: + logger.debug("Detected %s content, selecting appropriate model", content_type.value) - if provider.status == ProviderStatus.UNHEALTHY: - if self._can_close_circuit(provider): - provider.circuit_state = CircuitState.HALF_OPEN - provider.half_open_calls = 0 - logger.info("Circuit breaker half-open for %s", provider.name) - else: - logger.debug("Skipping %s (circuit open)", provider.name) - return False + # Resolve task complexity (skipped when caller provides an explicit model) + complexity: TaskComplexity | None = None + if model is None: + complexity = _resolve_complexity(messages, complexity_hint) + logger.debug("Task complexity: %s", complexity.value) - return True + errors: list[str] = [] + providers = self._filter_providers(cascade_tier) - def _filter_providers(self, cascade_tier: str | None) -> list["Provider"]: - """Return the provider list filtered by tier. + for provider in providers: + # Complexity-based model selection (only when no explicit model) + effective_model = model + if effective_model is None and complexity is not None: + effective_model = self._get_model_for_complexity(provider, complexity) + if effective_model: + logger.debug( + "Complexity routing [%s]: %s → %s", + complexity.value, + provider.name, + effective_model, + ) - Raises: - RuntimeError: If a tier is specified but no matching providers exist. - """ - if cascade_tier == "frontier_required": - providers = [p for p in self.providers if p.type == "anthropic"] - if not providers: - raise RuntimeError("No Anthropic provider configured for 'frontier_required' tier.") - return providers - if cascade_tier: - providers = [p for p in self.providers if p.tier == cascade_tier] - if not providers: - raise RuntimeError(f"No providers found for tier: {cascade_tier}") - return providers - return self.providers + result = await self._try_single_provider( + provider, messages, effective_model, temperature, max_tokens, content_type, errors + ) + if result is not None: + result["complexity"] = complexity.value if complexity is not None else None + return result + + raise RuntimeError(f"All providers failed: {'; '.join(errors)}") async def _try_single_provider( self, @@ -559,11 +315,7 @@ class CascadeRouter: content_type: ContentType, errors: list[str], ) -> dict | None: - """Attempt one provider, returning a result dict on success or None on failure. - - On failure the error string is appended to *errors* and the provider's - failure metrics are updated so the caller can move on to the next provider. - """ + """Try one provider; append to errors and return None on failure.""" if not self._is_provider_available(provider): return None @@ -596,553 +348,20 @@ class CascadeRouter: "is_fallback_model": is_fallback_model, } - def _get_model_for_complexity( - self, provider: Provider, complexity: "TaskComplexity" - ) -> str | None: - """Return the best model on *provider* for the given complexity tier. - - Checks fallback chains first (routine / complex), then falls back to - any model with the matching capability tag, then the provider default. - """ - from infrastructure.router.classifier import TaskComplexity - - chain_key = "routine" if complexity == TaskComplexity.SIMPLE else "complex" - - # Walk the capability fallback chain — first model present on this provider wins - for model_name in self.config.fallback_chains.get(chain_key, []): - if any(m["name"] == model_name for m in provider.models): - return model_name - - # Direct capability lookup — only return if a model explicitly has the tag - # (do not use get_model_with_capability here as it falls back to the default) - cap_model = next( - (m["name"] for m in provider.models if chain_key in m.get("capabilities", [])), - None, - ) - if cap_model: - return cap_model - - return None # Caller will use provider default - - async def complete( - self, - messages: list[dict], - model: str | None = None, - temperature: float = 0.7, - max_tokens: int | None = None, - cascade_tier: str | None = None, - complexity_hint: str | None = None, - ) -> dict: - """Complete a chat conversation with automatic failover. - - Multi-modal support: - - Automatically detects if messages contain images - - Falls back to vision-capable models when needed - - Supports image URLs, paths, and base64 encoding - - Complexity-based routing (issue #1065): - - ``complexity_hint="simple"`` → routes to Qwen3-8B (low-latency) - - ``complexity_hint="complex"`` → routes to Qwen3-14B (quality) - - ``complexity_hint=None`` (default) → auto-classifies from messages - - Args: - messages: List of message dicts with role and content - model: Preferred model (tries this first; complexity routing is - skipped when an explicit model is given) - temperature: Sampling temperature - max_tokens: Maximum tokens to generate - cascade_tier: If specified, filters providers by this tier. - - "frontier_required": Uses only Anthropic provider for top-tier models. - complexity_hint: "simple", "complex", or None (auto-detect). - - Returns: - Dict with content, provider_used, model, latency_ms, - is_fallback_model, and complexity fields. - - Raises: - RuntimeError: If all providers fail - """ - from infrastructure.router.classifier import TaskComplexity, classify_task - - content_type = self._detect_content_type(messages) - if content_type != ContentType.TEXT: - logger.debug("Detected %s content, selecting appropriate model", content_type.value) - - # Resolve task complexity ───────────────────────────────────────────── - # Skip complexity routing when caller explicitly specifies a model. - complexity: TaskComplexity | None = None - if model is None: - if complexity_hint is not None: - try: - complexity = TaskComplexity(complexity_hint.lower()) - except ValueError: - logger.warning("Unknown complexity_hint %r, auto-classifying", complexity_hint) - complexity = classify_task(messages) - else: - complexity = classify_task(messages) - logger.debug("Task complexity: %s", complexity.value) - - errors: list[str] = [] - providers = self._filter_providers(cascade_tier) - - for provider in providers: - if not self._is_provider_available(provider): - continue - - # Metabolic protocol: skip cloud providers when quota is low - if provider.type in ("anthropic", "openai", "grok"): - if not self._quota_allows_cloud(provider): - logger.info( - "Metabolic protocol: skipping cloud provider %s (quota too low)", - provider.name, - ) - continue - - # Complexity-based model selection (only when no explicit model) ── - effective_model = model - if effective_model is None and complexity is not None: - effective_model = self._get_model_for_complexity(provider, complexity) - if effective_model: - logger.debug( - "Complexity routing [%s]: %s → %s", - complexity.value, - provider.name, - effective_model, - ) - - selected_model, is_fallback_model = self._select_model( - provider, effective_model, content_type - ) - - try: - result = await self._attempt_with_retry( - provider, - messages, - selected_model, - temperature, - max_tokens, - content_type, - ) - except RuntimeError as exc: - errors.append(str(exc)) - self._record_failure(provider) - continue - - self._record_success(provider, result.get("latency_ms", 0)) - return { - "content": result["content"], - "provider": provider.name, - "model": result.get("model", selected_model or provider.get_default_model()), - "latency_ms": result.get("latency_ms", 0), - "is_fallback_model": is_fallback_model, - "complexity": complexity.value if complexity is not None else None, - } - - raise RuntimeError(f"All providers failed: {'; '.join(errors)}") - - async def _try_provider( - self, - provider: Provider, - messages: list[dict], - model: str, - temperature: float, - max_tokens: int | None, - content_type: ContentType = ContentType.TEXT, - ) -> dict: - """Try a single provider request.""" - start_time = time.time() - - if provider.type == "ollama": - result = await self._call_ollama( - provider=provider, - messages=messages, - model=model or provider.get_default_model(), - temperature=temperature, - max_tokens=max_tokens, - content_type=content_type, - ) - elif provider.type == "openai": - result = await self._call_openai( - provider=provider, - messages=messages, - model=model or provider.get_default_model(), - temperature=temperature, - max_tokens=max_tokens, - ) - elif provider.type == "anthropic": - result = await self._call_anthropic( - provider=provider, - messages=messages, - model=model or provider.get_default_model(), - temperature=temperature, - max_tokens=max_tokens, - ) - elif provider.type == "grok": - result = await self._call_grok( - provider=provider, - messages=messages, - model=model or provider.get_default_model(), - temperature=temperature, - max_tokens=max_tokens, - ) - elif provider.type == "vllm_mlx": - result = await self._call_vllm_mlx( - provider=provider, - messages=messages, - model=model or provider.get_default_model(), - temperature=temperature, - max_tokens=max_tokens, - ) - else: - raise ValueError(f"Unknown provider type: {provider.type}") - - latency_ms = (time.time() - start_time) * 1000 - result["latency_ms"] = latency_ms - - return result - - async def _call_ollama( - self, - provider: Provider, - messages: list[dict], - model: str, - temperature: float, - max_tokens: int | None = None, - content_type: ContentType = ContentType.TEXT, - ) -> dict: - """Call Ollama API with multi-modal support.""" - import aiohttp - - url = f"{provider.url or settings.ollama_url}/api/chat" - - # Transform messages for Ollama format (including images) - transformed_messages = self._transform_messages_for_ollama(messages) - - options = {"temperature": temperature} - if max_tokens: - options["num_predict"] = max_tokens - - payload = { - "model": model, - "messages": transformed_messages, - "stream": False, - "options": options, - } - - timeout = aiohttp.ClientTimeout(total=self.config.timeout_seconds) - - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.post(url, json=payload) as response: - if response.status != 200: - text = await response.text() - raise RuntimeError(f"Ollama error {response.status}: {text}") - - data = await response.json() - return { - "content": data["message"]["content"], - "model": model, - } - - def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]: - """Transform messages to Ollama format, handling images.""" - transformed = [] - - for msg in messages: - new_msg = { - "role": msg.get("role", "user"), - "content": msg.get("content", ""), - } - - # Handle images - images = msg.get("images", []) - if images: - new_msg["images"] = [] - for img in images: - if isinstance(img, str): - if img.startswith("data:image/"): - # Base64 encoded image - new_msg["images"].append(img.split(",")[1]) - elif img.startswith("http://") or img.startswith("https://"): - # URL - would need to download, skip for now - logger.warning("Image URLs not yet supported, skipping: %s", img) - elif Path(img).exists(): - # Local file path - read and encode - try: - with open(img, "rb") as f: - img_data = base64.b64encode(f.read()).decode() - new_msg["images"].append(img_data) - except Exception as exc: - logger.error("Failed to read image %s: %s", img, exc) - - transformed.append(new_msg) - - return transformed - - async def _call_openai( - self, - provider: Provider, - messages: list[dict], - model: str, - temperature: float, - max_tokens: int | None, - ) -> dict: - """Call OpenAI API.""" - import openai - - client = openai.AsyncOpenAI( - api_key=provider.api_key, - base_url=provider.base_url, - timeout=self.config.timeout_seconds, - ) - - kwargs = { - "model": model, - "messages": messages, - "temperature": temperature, - } - if max_tokens: - kwargs["max_tokens"] = max_tokens - - response = await client.chat.completions.create(**kwargs) - - return { - "content": response.choices[0].message.content, - "model": response.model, - } - - async def _call_anthropic( - self, - provider: Provider, - messages: list[dict], - model: str, - temperature: float, - max_tokens: int | None, - ) -> dict: - """Call Anthropic API.""" - import anthropic - - client = anthropic.AsyncAnthropic( - api_key=provider.api_key, - timeout=self.config.timeout_seconds, - ) - - # Convert messages to Anthropic format - system_msg = None - conversation = [] - for msg in messages: - if msg["role"] == "system": - system_msg = msg["content"] - else: - conversation.append( - { - "role": msg["role"], - "content": msg["content"], - } - ) - - kwargs = { - "model": model, - "messages": conversation, - "temperature": temperature, - "max_tokens": max_tokens or 1024, - } - if system_msg: - kwargs["system"] = system_msg - - response = await client.messages.create(**kwargs) - - return { - "content": response.content[0].text, - "model": response.model, - } - - async def _call_grok( - self, - provider: Provider, - messages: list[dict], - model: str, - temperature: float, - max_tokens: int | None, - ) -> dict: - """Call xAI Grok API via OpenAI-compatible SDK.""" - import httpx - import openai - - client = openai.AsyncOpenAI( - api_key=provider.api_key, - base_url=provider.base_url or settings.xai_base_url, - timeout=httpx.Timeout(300.0), - ) - - kwargs = { - "model": model, - "messages": messages, - "temperature": temperature, - } - if max_tokens: - kwargs["max_tokens"] = max_tokens - - response = await client.chat.completions.create(**kwargs) - - return { - "content": response.choices[0].message.content, - "model": response.model, - } - - async def _call_vllm_mlx( - self, - provider: Provider, - messages: list[dict], - model: str, - temperature: float, - max_tokens: int | None, - ) -> dict: - """Call vllm-mlx via its OpenAI-compatible API. - - vllm-mlx exposes the same /v1/chat/completions endpoint as OpenAI, - so we reuse the OpenAI client pointed at the local server. - No API key is required for local deployments. - """ - import openai - - base_url = provider.base_url or provider.url or "http://localhost:8000" - # Ensure the base_url ends with /v1 as expected by the OpenAI client - if not base_url.rstrip("/").endswith("/v1"): - base_url = base_url.rstrip("/") + "/v1" - - client = openai.AsyncOpenAI( - api_key=provider.api_key or "no-key-required", - base_url=base_url, - timeout=self.config.timeout_seconds, - ) - - kwargs: dict = { - "model": model, - "messages": messages, - "temperature": temperature, - } - if max_tokens: - kwargs["max_tokens"] = max_tokens - - response = await client.chat.completions.create(**kwargs) - - return { - "content": response.choices[0].message.content, - "model": response.model, - } - - def _record_success(self, provider: Provider, latency_ms: float) -> None: - """Record a successful request.""" - provider.metrics.total_requests += 1 - provider.metrics.successful_requests += 1 - provider.metrics.total_latency_ms += latency_ms - provider.metrics.last_request_time = datetime.now(UTC).isoformat() - provider.metrics.consecutive_failures = 0 - - # Close circuit breaker if half-open - if provider.circuit_state == CircuitState.HALF_OPEN: - provider.half_open_calls += 1 - if provider.half_open_calls >= self.config.circuit_breaker_half_open_max_calls: - self._close_circuit(provider) - - # Update status based on error rate - if provider.metrics.error_rate < 0.1: - provider.status = ProviderStatus.HEALTHY - elif provider.metrics.error_rate < 0.3: - provider.status = ProviderStatus.DEGRADED - - def _record_failure(self, provider: Provider) -> None: - """Record a failed request.""" - provider.metrics.total_requests += 1 - provider.metrics.failed_requests += 1 - provider.metrics.last_error_time = datetime.now(UTC).isoformat() - provider.metrics.consecutive_failures += 1 - - # Check if we should open circuit breaker - if provider.metrics.consecutive_failures >= self.config.circuit_breaker_failure_threshold: - self._open_circuit(provider) - - # Update status - if provider.metrics.error_rate > 0.3: - provider.status = ProviderStatus.DEGRADED - if provider.metrics.error_rate > 0.5: - provider.status = ProviderStatus.UNHEALTHY - - def _open_circuit(self, provider: Provider) -> None: - """Open the circuit breaker for a provider.""" - provider.circuit_state = CircuitState.OPEN - provider.circuit_opened_at = time.time() - provider.status = ProviderStatus.UNHEALTHY - logger.warning("Circuit breaker OPEN for %s", provider.name) - - def _can_close_circuit(self, provider: Provider) -> bool: - """Check if circuit breaker can transition to half-open.""" - if provider.circuit_opened_at is None: - return False - elapsed = time.time() - provider.circuit_opened_at - return elapsed >= self.config.circuit_breaker_recovery_timeout - - def _close_circuit(self, provider: Provider) -> None: - """Close the circuit breaker (provider healthy again).""" - provider.circuit_state = CircuitState.CLOSED - provider.circuit_opened_at = None - provider.half_open_calls = 0 - provider.metrics.consecutive_failures = 0 - provider.status = ProviderStatus.HEALTHY - logger.info("Circuit breaker CLOSED for %s", provider.name) - def reload_config(self) -> dict: - """Hot-reload providers.yaml, preserving runtime state. - - Re-reads the config file, rebuilds the provider list, and - preserves circuit breaker state and metrics for providers - that still exist after reload. - - Returns: - Summary dict with added/removed/preserved counts. - """ - # Snapshot current runtime state keyed by provider name - old_state: dict[ - str, tuple[ProviderMetrics, CircuitState, float | None, int, ProviderStatus] - ] = {} - for p in self.providers: - old_state[p.name] = ( - p.metrics, - p.circuit_state, - p.circuit_opened_at, - p.half_open_calls, - p.status, - ) - + """Re-read providers.yaml and preserve runtime circuit/metrics state.""" + old_state = _reporting.snapshot_provider_state(self.providers) old_names = set(old_state.keys()) - # Reload from disk self.providers = [] self._load_config() - # Restore preserved state + preserved = _reporting.restore_provider_state(self.providers, old_state) new_names = {p.name for p in self.providers} - preserved = 0 - for p in self.providers: - if p.name in old_state: - metrics, circuit, opened_at, half_open, status = old_state[p.name] - p.metrics = metrics - p.circuit_state = circuit - p.circuit_opened_at = opened_at - p.half_open_calls = half_open - p.status = status - preserved += 1 - added = new_names - old_names removed = old_names - new_names - logger.info( - "Config reloaded: %d providers (%d preserved, %d added, %d removed)", - len(self.providers), - preserved, - len(added), - len(removed), - ) + logger.info("Config reloaded: %d providers (%d preserved, +%d, -%d)", len(self.providers), preserved, len(added), len(removed)) return { "total_providers": len(self.providers), @@ -1153,49 +372,11 @@ class CascadeRouter: def get_metrics(self) -> dict: """Get metrics for all providers.""" - return { - "providers": [ - { - "name": p.name, - "type": p.type, - "status": p.status.value, - "circuit_state": p.circuit_state.value, - "metrics": { - "total_requests": p.metrics.total_requests, - "successful": p.metrics.successful_requests, - "failed": p.metrics.failed_requests, - "error_rate": round(p.metrics.error_rate, 3), - "avg_latency_ms": round(p.metrics.avg_latency_ms, 2), - }, - } - for p in self.providers - ] - } + return _reporting.build_metrics(self.providers) def get_status(self) -> dict: """Get current router status.""" - healthy = sum(1 for p in self.providers if p.status == ProviderStatus.HEALTHY) - - return { - "total_providers": len(self.providers), - "healthy_providers": healthy, - "degraded_providers": sum( - 1 for p in self.providers if p.status == ProviderStatus.DEGRADED - ), - "unhealthy_providers": sum( - 1 for p in self.providers if p.status == ProviderStatus.UNHEALTHY - ), - "providers": [ - { - "name": p.name, - "type": p.type, - "status": p.status.value, - "priority": p.priority, - "default_model": p.get_default_model(), - } - for p in self.providers - ], - } + return _reporting.build_status(self.providers) async def generate_with_image( self, @@ -1204,17 +385,7 @@ class CascadeRouter: model: str | None = None, temperature: float = 0.7, ) -> dict: - """Convenience method for vision requests. - - Args: - prompt: Text prompt about the image - image_path: Path to image file - model: Vision-capable model (auto-selected if not provided) - temperature: Sampling temperature - - Returns: - Response dict with content and metadata - """ + """Convenience wrapper for single-image vision requests.""" messages = [ { "role": "user", diff --git a/src/infrastructure/router/config_loader.py b/src/infrastructure/router/config_loader.py new file mode 100644 index 00000000..f6dbafb6 --- /dev/null +++ b/src/infrastructure/router/config_loader.py @@ -0,0 +1,123 @@ +"""Config loading helpers for the Cascade LLM Router. + +Parses providers.yaml, expands env vars, and checks provider availability. +""" + +from __future__ import annotations + +import logging + +from infrastructure.router.models import Provider, RouterConfig + +logger = logging.getLogger(__name__) + +try: + import yaml +except ImportError: + yaml = None # type: ignore + +try: + import requests +except ImportError: + requests = None # type: ignore + + +def expand_env_vars(content: str) -> str: + """Expand ${VAR} syntax in YAML content. + + Uses os.environ directly (not settings) because this is a generic + YAML config loader that must expand arbitrary variable references. + """ + import os + import re + + def replace_var(match: "re.Match[str]") -> str: + var_name = match.group(1) + return os.environ.get(var_name, match.group(0)) + + return re.sub(r"\$\{(\w+)\}", replace_var, content) + + +def parse_router_config(data: dict) -> RouterConfig: + """Build a RouterConfig from parsed YAML data.""" + cascade = data.get("cascade", {}) + cb = cascade.get("circuit_breaker", {}) + multimodal = data.get("multimodal", {}) + + return RouterConfig( + timeout_seconds=cascade.get("timeout_seconds", 30), + max_retries_per_provider=cascade.get("max_retries_per_provider", 2), + retry_delay_seconds=cascade.get("retry_delay_seconds", 1), + circuit_breaker_failure_threshold=cb.get("failure_threshold", 5), + circuit_breaker_recovery_timeout=cb.get("recovery_timeout", 60), + circuit_breaker_half_open_max_calls=cb.get("half_open_max_calls", 2), + auto_pull_models=multimodal.get("auto_pull", True), + fallback_chains=data.get("fallback_chains", {}), + ) + + +def load_providers(data: dict) -> list[Provider]: + """Load and filter providers from parsed YAML data (unsorted).""" + providers: list[Provider] = [] + for p_data in data.get("providers", []): + if not p_data.get("enabled", False): + continue + + provider = Provider( + name=p_data["name"], + type=p_data["type"], + enabled=p_data.get("enabled", True), + priority=p_data.get("priority", 99), + tier=p_data.get("tier"), + url=p_data.get("url"), + api_key=p_data.get("api_key"), + base_url=p_data.get("base_url"), + models=p_data.get("models", []), + ) + + if check_provider_available(provider): + providers.append(provider) + else: + logger.warning("Provider %s not available, skipping", provider.name) + + return providers + + +def check_provider_available(provider: Provider) -> bool: + """Check if a provider is actually available.""" + from config import settings + + if provider.type == "ollama": + # Check if Ollama is running + if requests is None: + # Can't check without requests, assume available + return True + try: + url = provider.url or settings.ollama_url + response = requests.get(f"{url}/api/tags", timeout=5) + return response.status_code == 200 + except Exception as exc: + logger.debug("Ollama provider check error: %s", exc) + return False + + elif provider.type == "vllm_mlx": + # Check if local vllm-mlx server is running (OpenAI-compatible) + if requests is None: + return True + try: + base_url = provider.base_url or provider.url or "http://localhost:8000" + # Strip /v1 suffix — health endpoint is at the root + server_root = base_url.rstrip("/") + if server_root.endswith("/v1"): + server_root = server_root[:-3] + response = requests.get(f"{server_root}/health", timeout=5) + return response.status_code == 200 + except Exception as exc: + logger.debug("vllm-mlx provider check error: %s", exc) + return False + + elif provider.type in ("openai", "anthropic", "grok"): + # Check if API key is set + return provider.api_key is not None and provider.api_key != "" + + return True diff --git a/src/infrastructure/router/content.py b/src/infrastructure/router/content.py new file mode 100644 index 00000000..9cda7393 --- /dev/null +++ b/src/infrastructure/router/content.py @@ -0,0 +1,129 @@ +"""Content-type detection and model selection for the Cascade LLM Router.""" + +from __future__ import annotations + +import logging +from typing import Any + +from infrastructure.router.models import ContentType, Provider + +logger = logging.getLogger(__name__) + + +def detect_content_type(messages: list[dict]) -> ContentType: + """Detect the type of content in the messages. + + Checks for images, audio, etc. in the message content. + """ + has_image = False + has_audio = False + + for msg in messages: + content = msg.get("content", "") + + # Check for image URLs/paths + if msg.get("images"): + has_image = True + + # Check for image URLs in content + if isinstance(content, str): + image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp") + if any(ext in content.lower() for ext in image_extensions): + has_image = True + if content.startswith("data:image/"): + has_image = True + + # Check for audio + if msg.get("audio"): + has_audio = True + + # Check for multimodal content structure + if isinstance(content, list): + for item in content: + if isinstance(item, dict): + if item.get("type") == "image_url": + has_image = True + elif item.get("type") == "audio": + has_audio = True + + if has_image and has_audio: + return ContentType.MULTIMODAL + elif has_image: + return ContentType.VISION + elif has_audio: + return ContentType.AUDIO + return ContentType.TEXT + + +def get_fallback_model( + provider: Provider, + original_model: str, + content_type: ContentType, + fallback_chains: dict, +) -> str | None: + """Get a fallback model for the given content type.""" + # Map content type to capability + capability_map = { + ContentType.VISION: "vision", + ContentType.AUDIO: "audio", + ContentType.MULTIMODAL: "vision", # Vision models often do both + } + + capability = capability_map.get(content_type) + if not capability: + return None + + # Check provider's models for capability + fallback_model = provider.get_model_with_capability(capability) + if fallback_model and fallback_model != original_model: + return fallback_model + + # Use fallback chains from config + fallback_chain = fallback_chains.get(capability, []) + for model_name in fallback_chain: + if provider.model_has_capability(model_name, capability): + return model_name + + return None + + +def select_model( + provider: Provider, + model: str | None, + content_type: ContentType, + mm_manager: Any, + fallback_chains: dict, +) -> tuple[str | None, bool]: + """Select the best model for the request, with vision fallback. + + Returns: + Tuple of (selected_model, is_fallback_model). + """ + selected_model = model or provider.get_default_model() + is_fallback = False + + if content_type != ContentType.TEXT and selected_model: + if provider.type == "ollama" and mm_manager: + from infrastructure.models.multimodal import ModelCapability + + if content_type == ContentType.VISION: + supports = mm_manager.model_supports(selected_model, ModelCapability.VISION) + if not supports: + fallback = get_fallback_model( + provider, selected_model, content_type, fallback_chains + ) + if fallback: + logger.info( + "Model %s doesn't support vision, falling back to %s", + selected_model, + fallback, + ) + selected_model = fallback + is_fallback = True + else: + logger.warning( + "No vision-capable model found on %s, trying anyway", + provider.name, + ) + + return selected_model, is_fallback diff --git a/src/infrastructure/router/health.py b/src/infrastructure/router/health.py new file mode 100644 index 00000000..1a0184ce --- /dev/null +++ b/src/infrastructure/router/health.py @@ -0,0 +1,79 @@ +"""Circuit-breaker and health tracking for the Cascade LLM Router. + +Standalone functions that mutate Provider state in place. +""" + +from __future__ import annotations + +import logging +import time +from datetime import UTC, datetime + +from infrastructure.router.models import CircuitState, Provider, ProviderStatus, RouterConfig + +logger = logging.getLogger(__name__) + + +def record_success(provider: Provider, latency_ms: float, config: RouterConfig) -> None: + """Record a successful request.""" + provider.metrics.total_requests += 1 + provider.metrics.successful_requests += 1 + provider.metrics.total_latency_ms += latency_ms + provider.metrics.last_request_time = datetime.now(UTC).isoformat() + provider.metrics.consecutive_failures = 0 + + # Close circuit breaker if half-open + if provider.circuit_state == CircuitState.HALF_OPEN: + provider.half_open_calls += 1 + if provider.half_open_calls >= config.circuit_breaker_half_open_max_calls: + close_circuit(provider) + + # Update status based on error rate + if provider.metrics.error_rate < 0.1: + provider.status = ProviderStatus.HEALTHY + elif provider.metrics.error_rate < 0.3: + provider.status = ProviderStatus.DEGRADED + + +def record_failure(provider: Provider, config: RouterConfig) -> None: + """Record a failed request.""" + provider.metrics.total_requests += 1 + provider.metrics.failed_requests += 1 + provider.metrics.last_error_time = datetime.now(UTC).isoformat() + provider.metrics.consecutive_failures += 1 + + # Check if we should open circuit breaker + if provider.metrics.consecutive_failures >= config.circuit_breaker_failure_threshold: + open_circuit(provider) + + # Update status + if provider.metrics.error_rate > 0.3: + provider.status = ProviderStatus.DEGRADED + if provider.metrics.error_rate > 0.5: + provider.status = ProviderStatus.UNHEALTHY + + +def open_circuit(provider: Provider) -> None: + """Open the circuit breaker for a provider.""" + provider.circuit_state = CircuitState.OPEN + provider.circuit_opened_at = time.time() + provider.status = ProviderStatus.UNHEALTHY + logger.warning("Circuit breaker OPEN for %s", provider.name) + + +def can_close_circuit(provider: Provider, config: RouterConfig) -> bool: + """Check if circuit breaker can transition to half-open.""" + if provider.circuit_opened_at is None: + return False + elapsed = time.time() - provider.circuit_opened_at + return elapsed >= config.circuit_breaker_recovery_timeout + + +def close_circuit(provider: Provider) -> None: + """Close the circuit breaker (provider healthy again).""" + provider.circuit_state = CircuitState.CLOSED + provider.circuit_opened_at = None + provider.half_open_calls = 0 + provider.metrics.consecutive_failures = 0 + provider.status = ProviderStatus.HEALTHY + logger.info("Circuit breaker CLOSED for %s", provider.name) diff --git a/src/infrastructure/router/models.py b/src/infrastructure/router/models.py new file mode 100644 index 00000000..2b912152 --- /dev/null +++ b/src/infrastructure/router/models.py @@ -0,0 +1,141 @@ +"""Data models for the Cascade LLM Router. + +Enums, dataclasses, and provider configuration shared across +router sub-modules. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from datetime import UTC, datetime +from enum import Enum + + +class ProviderStatus(Enum): + """Health status of a provider.""" + + HEALTHY = "healthy" + DEGRADED = "degraded" # Working but slow or occasional errors + UNHEALTHY = "unhealthy" # Circuit breaker open + DISABLED = "disabled" + + +class CircuitState(Enum): + """Circuit breaker state.""" + + CLOSED = "closed" # Normal operation + OPEN = "open" # Failing, rejecting requests + HALF_OPEN = "half_open" # Testing if recovered + + +class ContentType(Enum): + """Type of content in the request.""" + + TEXT = "text" + VISION = "vision" # Contains images + AUDIO = "audio" # Contains audio + MULTIMODAL = "multimodal" # Multiple content types + + +@dataclass +class ProviderMetrics: + """Metrics for a single provider.""" + + total_requests: int = 0 + successful_requests: int = 0 + failed_requests: int = 0 + total_latency_ms: float = 0.0 + last_request_time: str | None = None + last_error_time: str | None = None + consecutive_failures: int = 0 + + @property + def avg_latency_ms(self) -> float: + if self.total_requests == 0: + return 0.0 + return self.total_latency_ms / self.total_requests + + @property + def error_rate(self) -> float: + if self.total_requests == 0: + return 0.0 + return self.failed_requests / self.total_requests + + +@dataclass +class ModelCapability: + """Capabilities a model supports.""" + + name: str + supports_vision: bool = False + supports_audio: bool = False + supports_tools: bool = False + supports_json: bool = False + supports_streaming: bool = True + context_window: int = 4096 + + +@dataclass +class Provider: + """LLM provider configuration and state.""" + + name: str + type: str # ollama, openai, anthropic + enabled: bool + priority: int + tier: str | None = None # e.g., "local", "standard_cloud", "frontier" + url: str | None = None + api_key: str | None = None + base_url: str | None = None + models: list[dict] = field(default_factory=list) + + # Runtime state + status: ProviderStatus = ProviderStatus.HEALTHY + metrics: ProviderMetrics = field(default_factory=ProviderMetrics) + circuit_state: CircuitState = CircuitState.CLOSED + circuit_opened_at: float | None = None + half_open_calls: int = 0 + + def get_default_model(self) -> str | None: + """Get the default model for this provider.""" + for model in self.models: + if model.get("default"): + return model["name"] + if self.models: + return self.models[0]["name"] + return None + + def get_model_with_capability(self, capability: str) -> str | None: + """Get a model that supports the given capability.""" + for model in self.models: + capabilities = model.get("capabilities", []) + if capability in capabilities: + return model["name"] + # Fall back to default + return self.get_default_model() + + def model_has_capability(self, model_name: str, capability: str) -> bool: + """Check if a specific model has a capability.""" + for model in self.models: + if model["name"] == model_name: + capabilities = model.get("capabilities", []) + return capability in capabilities + return False + + +@dataclass +class RouterConfig: + """Cascade router configuration.""" + + timeout_seconds: int = 30 + max_retries_per_provider: int = 2 + retry_delay_seconds: int = 1 + circuit_breaker_failure_threshold: int = 5 + circuit_breaker_recovery_timeout: int = 60 + circuit_breaker_half_open_max_calls: int = 2 + cost_tracking_enabled: bool = True + budget_daily_usd: float = 10.0 + # Multi-modal settings + auto_pull_models: bool = True + fallback_chains: dict = field(default_factory=dict) diff --git a/src/infrastructure/router/providers/__init__.py b/src/infrastructure/router/providers/__init__.py new file mode 100644 index 00000000..2f4a9996 --- /dev/null +++ b/src/infrastructure/router/providers/__init__.py @@ -0,0 +1 @@ +# Provider implementations diff --git a/src/infrastructure/router/providers/anthropic.py b/src/infrastructure/router/providers/anthropic.py new file mode 100644 index 00000000..0f432dcc --- /dev/null +++ b/src/infrastructure/router/providers/anthropic.py @@ -0,0 +1,56 @@ +"""Anthropic provider implementation for the Cascade LLM Router.""" + +from __future__ import annotations + +import logging + +from infrastructure.router.models import Provider + +logger = logging.getLogger(__name__) + + +async def call_anthropic( + provider: Provider, + messages: list[dict], + model: str, + temperature: float, + max_tokens: int | None, + timeout_seconds: int, +) -> dict: + """Call Anthropic API.""" + import anthropic + + client = anthropic.AsyncAnthropic( + api_key=provider.api_key, + timeout=timeout_seconds, + ) + + # Convert messages to Anthropic format + system_msg = None + conversation = [] + for msg in messages: + if msg["role"] == "system": + system_msg = msg["content"] + else: + conversation.append( + { + "role": msg["role"], + "content": msg["content"], + } + ) + + kwargs: dict = { + "model": model, + "messages": conversation, + "temperature": temperature, + "max_tokens": max_tokens or 1024, + } + if system_msg: + kwargs["system"] = system_msg + + response = await client.messages.create(**kwargs) + + return { + "content": response.content[0].text, + "model": response.model, + } diff --git a/src/infrastructure/router/providers/dispatch.py b/src/infrastructure/router/providers/dispatch.py new file mode 100644 index 00000000..cbea465c --- /dev/null +++ b/src/infrastructure/router/providers/dispatch.py @@ -0,0 +1,80 @@ +"""Provider dispatch — routes a single request to the correct provider module.""" + +from __future__ import annotations + +import time + +from infrastructure.router.models import ContentType, Provider + + +async def call_provider( + provider: Provider, + messages: list[dict], + model: str, + temperature: float, + max_tokens: int | None, + timeout_seconds: int, + content_type: ContentType = ContentType.TEXT, +) -> dict: + """Dispatch a request to the correct provider implementation. + + Returns a result dict with ``content``, ``model``, and ``latency_ms`` keys. + Raises ValueError for unknown provider types. + """ + from infrastructure.router.providers import ollama as _ollama + from infrastructure.router.providers import openai_compat as _openai_compat + from infrastructure.router.providers import anthropic as _anthropic + from infrastructure.router.providers import grok as _grok + + start_time = time.time() + + if provider.type == "ollama": + result = await _ollama.call_ollama( + provider=provider, + messages=messages, + model=model or provider.get_default_model(), + temperature=temperature, + max_tokens=max_tokens, + content_type=content_type, + timeout_seconds=timeout_seconds, + ) + elif provider.type == "openai": + result = await _openai_compat.call_openai( + provider=provider, + messages=messages, + model=model or provider.get_default_model(), + temperature=temperature, + max_tokens=max_tokens, + timeout_seconds=timeout_seconds, + ) + elif provider.type == "anthropic": + result = await _anthropic.call_anthropic( + provider=provider, + messages=messages, + model=model or provider.get_default_model(), + temperature=temperature, + max_tokens=max_tokens, + timeout_seconds=timeout_seconds, + ) + elif provider.type == "grok": + result = await _grok.call_grok( + provider=provider, + messages=messages, + model=model or provider.get_default_model(), + temperature=temperature, + max_tokens=max_tokens, + ) + elif provider.type == "vllm_mlx": + result = await _openai_compat.call_vllm_mlx( + provider=provider, + messages=messages, + model=model or provider.get_default_model(), + temperature=temperature, + max_tokens=max_tokens, + timeout_seconds=timeout_seconds, + ) + else: + raise ValueError(f"Unknown provider type: {provider.type}") + + result["latency_ms"] = (time.time() - start_time) * 1000 + return result diff --git a/src/infrastructure/router/providers/grok.py b/src/infrastructure/router/providers/grok.py new file mode 100644 index 00000000..ebc7abf9 --- /dev/null +++ b/src/infrastructure/router/providers/grok.py @@ -0,0 +1,44 @@ +"""Grok (xAI) provider implementation for the Cascade LLM Router.""" + +from __future__ import annotations + +import logging + +from infrastructure.router.models import Provider + +logger = logging.getLogger(__name__) + + +async def call_grok( + provider: Provider, + messages: list[dict], + model: str, + temperature: float, + max_tokens: int | None, +) -> dict: + """Call xAI Grok API via OpenAI-compatible SDK.""" + import httpx + import openai + + from config import settings + + client = openai.AsyncOpenAI( + api_key=provider.api_key, + base_url=provider.base_url or settings.xai_base_url, + timeout=httpx.Timeout(300.0), + ) + + kwargs: dict = { + "model": model, + "messages": messages, + "temperature": temperature, + } + if max_tokens: + kwargs["max_tokens"] = max_tokens + + response = await client.chat.completions.create(**kwargs) + + return { + "content": response.choices[0].message.content, + "model": response.model, + } diff --git a/src/infrastructure/router/providers/ollama.py b/src/infrastructure/router/providers/ollama.py new file mode 100644 index 00000000..e3900737 --- /dev/null +++ b/src/infrastructure/router/providers/ollama.py @@ -0,0 +1,92 @@ +"""Ollama provider implementation for the Cascade LLM Router.""" + +from __future__ import annotations + +import base64 +import logging +from pathlib import Path + +import aiohttp + +from infrastructure.router.models import ContentType, Provider + +logger = logging.getLogger(__name__) + + +async def call_ollama( + provider: Provider, + messages: list[dict], + model: str, + temperature: float, + max_tokens: int | None, + content_type: ContentType, + timeout_seconds: int, +) -> dict: + """Call Ollama API with multi-modal support.""" + from config import settings + + url = f"{provider.url or settings.ollama_url}/api/chat" + + # Transform messages for Ollama format (including images) + transformed_messages = transform_messages_for_ollama(messages) + + options: dict = {"temperature": temperature} + if max_tokens: + options["num_predict"] = max_tokens + + payload = { + "model": model, + "messages": transformed_messages, + "stream": False, + "options": options, + } + + timeout = aiohttp.ClientTimeout(total=timeout_seconds) + + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(url, json=payload) as response: + if response.status != 200: + text = await response.text() + raise RuntimeError(f"Ollama error {response.status}: {text}") + + data = await response.json() + return { + "content": data["message"]["content"], + "model": model, + } + + +def transform_messages_for_ollama(messages: list[dict]) -> list[dict]: + """Transform messages to Ollama format, handling images.""" + transformed = [] + + for msg in messages: + new_msg = { + "role": msg.get("role", "user"), + "content": msg.get("content", ""), + } + + # Handle images + images = msg.get("images", []) + if images: + new_msg["images"] = [] + for img in images: + if isinstance(img, str): + if img.startswith("data:image/"): + # Base64 encoded image + new_msg["images"].append(img.split(",")[1]) + elif img.startswith("http://") or img.startswith("https://"): + # URL - would need to download, skip for now + logger.warning("Image URLs not yet supported, skipping: %s", img) + elif Path(img).exists(): + # Local file path - read and encode + try: + with open(img, "rb") as f: + img_data = base64.b64encode(f.read()).decode() + new_msg["images"].append(img_data) + except Exception as exc: + logger.error("Failed to read image %s: %s", img, exc) + + transformed.append(new_msg) + + return transformed diff --git a/src/infrastructure/router/providers/openai_compat.py b/src/infrastructure/router/providers/openai_compat.py new file mode 100644 index 00000000..11434488 --- /dev/null +++ b/src/infrastructure/router/providers/openai_compat.py @@ -0,0 +1,88 @@ +"""OpenAI-compatible provider implementations for the Cascade LLM Router. + +Covers the ``openai`` and ``vllm_mlx`` provider types. +""" + +from __future__ import annotations + +import logging + +from infrastructure.router.models import Provider + +logger = logging.getLogger(__name__) + + +async def call_openai( + provider: Provider, + messages: list[dict], + model: str, + temperature: float, + max_tokens: int | None, + timeout_seconds: int, +) -> dict: + """Call OpenAI API.""" + import openai + + client = openai.AsyncOpenAI( + api_key=provider.api_key, + base_url=provider.base_url, + timeout=timeout_seconds, + ) + + kwargs: dict = { + "model": model, + "messages": messages, + "temperature": temperature, + } + if max_tokens: + kwargs["max_tokens"] = max_tokens + + response = await client.chat.completions.create(**kwargs) + + return { + "content": response.choices[0].message.content, + "model": response.model, + } + + +async def call_vllm_mlx( + provider: Provider, + messages: list[dict], + model: str, + temperature: float, + max_tokens: int | None, + timeout_seconds: int, +) -> dict: + """Call vllm-mlx via its OpenAI-compatible API. + + vllm-mlx exposes the same /v1/chat/completions endpoint as OpenAI, + so we reuse the OpenAI client pointed at the local server. + No API key is required for local deployments. + """ + import openai + + base_url = provider.base_url or provider.url or "http://localhost:8000" + # Ensure the base_url ends with /v1 as expected by the OpenAI client + if not base_url.rstrip("/").endswith("/v1"): + base_url = base_url.rstrip("/") + "/v1" + + client = openai.AsyncOpenAI( + api_key=provider.api_key or "no-key-required", + base_url=base_url, + timeout=timeout_seconds, + ) + + kwargs: dict = { + "model": model, + "messages": messages, + "temperature": temperature, + } + if max_tokens: + kwargs["max_tokens"] = max_tokens + + response = await client.chat.completions.create(**kwargs) + + return { + "content": response.choices[0].message.content, + "model": response.model, + } diff --git a/src/infrastructure/router/reporting.py b/src/infrastructure/router/reporting.py new file mode 100644 index 00000000..e9040697 --- /dev/null +++ b/src/infrastructure/router/reporting.py @@ -0,0 +1,89 @@ +"""Metrics, status, and config-reload helpers for the Cascade LLM Router.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from infrastructure.router.models import ( + CircuitState, + Provider, + ProviderMetrics, + ProviderStatus, +) + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + + +def build_metrics(providers: list[Provider]) -> dict: + """Build a metrics summary dict for all providers.""" + return { + "providers": [ + { + "name": p.name, + "type": p.type, + "status": p.status.value, + "circuit_state": p.circuit_state.value, + "metrics": { + "total_requests": p.metrics.total_requests, + "successful": p.metrics.successful_requests, + "failed": p.metrics.failed_requests, + "error_rate": round(p.metrics.error_rate, 3), + "avg_latency_ms": round(p.metrics.avg_latency_ms, 2), + }, + } + for p in providers + ] + } + + +def build_status(providers: list[Provider]) -> dict: + """Build a status summary dict for all providers.""" + healthy = sum(1 for p in providers if p.status == ProviderStatus.HEALTHY) + return { + "total_providers": len(providers), + "healthy_providers": healthy, + "degraded_providers": sum(1 for p in providers if p.status == ProviderStatus.DEGRADED), + "unhealthy_providers": sum(1 for p in providers if p.status == ProviderStatus.UNHEALTHY), + "providers": [ + { + "name": p.name, + "type": p.type, + "status": p.status.value, + "priority": p.priority, + "default_model": p.get_default_model(), + } + for p in providers + ], + } + + +def snapshot_provider_state( + providers: list[Provider], +) -> dict[str, tuple[ProviderMetrics, CircuitState, float | None, int, ProviderStatus]]: + """Capture current runtime state keyed by provider name.""" + return { + p.name: (p.metrics, p.circuit_state, p.circuit_opened_at, p.half_open_calls, p.status) + for p in providers + } + + +def restore_provider_state( + providers: list[Provider], + old_state: dict[str, tuple[ProviderMetrics, CircuitState, float | None, int, ProviderStatus]], +) -> int: + """Restore saved runtime state to matching providers. Returns count of restored providers.""" + preserved = 0 + for p in providers: + if p.name in old_state: + metrics, circuit, opened_at, half_open, status = old_state[p.name] + p.metrics = metrics + p.circuit_state = circuit + p.circuit_opened_at = opened_at + p.half_open_calls = half_open + p.status = status + preserved += 1 + return preserved