diff --git a/src/infrastructure/router/cascade.py b/src/infrastructure/router/cascade.py index 306b157f..c34ce17c 100644 --- a/src/infrastructure/router/cascade.py +++ b/src/infrastructure/router/cascade.py @@ -9,12 +9,7 @@ models for image inputs and falls back through capability chains. """ import asyncio -import base64 import logging -import time -from dataclasses import dataclass, field -from datetime import UTC, datetime -from enum import Enum from pathlib import Path from typing import TYPE_CHECKING, Any @@ -33,148 +28,25 @@ try: except ImportError: requests = None # type: ignore +# Re-export data models so existing ``from …cascade import X`` keeps working. +from .models import ( # noqa: F401 – re-exports + CircuitState, + ContentType, + ModelCapability, + Provider, + ProviderMetrics, + ProviderStatus, + RouterConfig, +) + +# Mixins +from .health import HealthMixin +from .providers import ProviderCallsMixin + logger = logging.getLogger(__name__) -# Quota monitor — optional, degrades gracefully if unavailable -try: - from infrastructure.claude_quota import QuotaMonitor, get_quota_monitor - _quota_monitor: "QuotaMonitor | None" = get_quota_monitor() -except Exception as _exc: # pragma: no cover - logger.debug("Quota monitor not available: %s", _exc) - _quota_monitor = None - - -class ProviderStatus(Enum): - """Health status of a provider.""" - - HEALTHY = "healthy" - DEGRADED = "degraded" # Working but slow or occasional errors - UNHEALTHY = "unhealthy" # Circuit breaker open - DISABLED = "disabled" - - -class CircuitState(Enum): - """Circuit breaker state.""" - - CLOSED = "closed" # Normal operation - OPEN = "open" # Failing, rejecting requests - HALF_OPEN = "half_open" # Testing if recovered - - -class ContentType(Enum): - """Type of content in the request.""" - - TEXT = "text" - VISION = "vision" # Contains images - AUDIO = "audio" # Contains audio - MULTIMODAL = "multimodal" # Multiple content types - - -@dataclass -class ProviderMetrics: - """Metrics for a single provider.""" - - total_requests: int = 0 - successful_requests: int = 0 - failed_requests: int = 0 - total_latency_ms: float = 0.0 - last_request_time: str | None = None - last_error_time: str | None = None - consecutive_failures: int = 0 - - @property - def avg_latency_ms(self) -> float: - if self.total_requests == 0: - return 0.0 - return self.total_latency_ms / self.total_requests - - @property - def error_rate(self) -> float: - if self.total_requests == 0: - return 0.0 - return self.failed_requests / self.total_requests - - -@dataclass -class ModelCapability: - """Capabilities a model supports.""" - - name: str - supports_vision: bool = False - supports_audio: bool = False - supports_tools: bool = False - supports_json: bool = False - supports_streaming: bool = True - context_window: int = 4096 - - -@dataclass -class Provider: - """LLM provider configuration and state.""" - - name: str - type: str # ollama, openai, anthropic - enabled: bool - priority: int - tier: str | None = None # e.g., "local", "standard_cloud", "frontier" - url: str | None = None - api_key: str | None = None - base_url: str | None = None - models: list[dict] = field(default_factory=list) - - # Runtime state - status: ProviderStatus = ProviderStatus.HEALTHY - metrics: ProviderMetrics = field(default_factory=ProviderMetrics) - circuit_state: CircuitState = CircuitState.CLOSED - circuit_opened_at: float | None = None - half_open_calls: int = 0 - - def get_default_model(self) -> str | None: - """Get the default model for this provider.""" - for model in self.models: - if model.get("default"): - return model["name"] - if self.models: - return self.models[0]["name"] - return None - - def get_model_with_capability(self, capability: str) -> str | None: - """Get a model that supports the given capability.""" - for model in self.models: - capabilities = model.get("capabilities", []) - if capability in capabilities: - return model["name"] - # Fall back to default - return self.get_default_model() - - def model_has_capability(self, model_name: str, capability: str) -> bool: - """Check if a specific model has a capability.""" - for model in self.models: - if model["name"] == model_name: - capabilities = model.get("capabilities", []) - return capability in capabilities - return False - - -@dataclass -class RouterConfig: - """Cascade router configuration.""" - - timeout_seconds: int = 30 - max_retries_per_provider: int = 2 - retry_delay_seconds: int = 1 - circuit_breaker_failure_threshold: int = 5 - circuit_breaker_recovery_timeout: int = 60 - circuit_breaker_half_open_max_calls: int = 2 - cost_tracking_enabled: bool = True - budget_daily_usd: float = 10.0 - # Multi-modal settings - auto_pull_models: bool = True - fallback_chains: dict = field(default_factory=dict) - - -class CascadeRouter: +class CascadeRouter(HealthMixin, ProviderCallsMixin): """Routes LLM requests with automatic failover. Now with multi-modal support: @@ -487,50 +359,6 @@ class CascadeRouter: raise RuntimeError("; ".join(errors)) - def _quota_allows_cloud(self, provider: Provider) -> bool: - """Check quota before routing to a cloud provider. - - Uses the metabolic protocol via select_model(): cloud calls are only - allowed when the quota monitor recommends a cloud model (BURST tier). - Returns True (allow cloud) if quota monitor is unavailable or returns None. - """ - if _quota_monitor is None: - return True - try: - suggested = _quota_monitor.select_model("high") - # Cloud is allowed only when select_model recommends the cloud model - allows = suggested == "claude-sonnet-4-6" - if not allows: - status = _quota_monitor.check() - tier = status.recommended_tier.value if status else "unknown" - logger.info( - "Metabolic protocol: %s tier — downshifting %s to local (%s)", - tier, - provider.name, - suggested, - ) - return allows - except Exception as exc: - logger.warning("Quota check failed, allowing cloud: %s", exc) - return True - - def _is_provider_available(self, provider: Provider) -> bool: - """Check if a provider should be tried (enabled + circuit breaker).""" - if not provider.enabled: - logger.debug("Skipping %s (disabled)", provider.name) - return False - - if provider.status == ProviderStatus.UNHEALTHY: - if self._can_close_circuit(provider): - provider.circuit_state = CircuitState.HALF_OPEN - provider.half_open_calls = 0 - logger.info("Circuit breaker half-open for %s", provider.name) - else: - logger.debug("Skipping %s (circuit open)", provider.name) - return False - - return True - def _filter_providers(self, cascade_tier: str | None) -> list["Provider"]: """Return the provider list filtered by tier. @@ -641,9 +469,9 @@ class CascadeRouter: - Supports image URLs, paths, and base64 encoding Complexity-based routing (issue #1065): - - ``complexity_hint="simple"`` → routes to Qwen3-8B (low-latency) - - ``complexity_hint="complex"`` → routes to Qwen3-14B (quality) - - ``complexity_hint=None`` (default) → auto-classifies from messages + - ``complexity_hint="simple"`` -> routes to Qwen3-8B (low-latency) + - ``complexity_hint="complex"`` -> routes to Qwen3-14B (quality) + - ``complexity_hint=None`` (default) -> auto-classifies from messages Args: messages: List of message dicts with role and content @@ -668,7 +496,7 @@ class CascadeRouter: if content_type != ContentType.TEXT: logger.debug("Detected %s content, selecting appropriate model", content_type.value) - # Resolve task complexity ───────────────────────────────────────────── + # Resolve task complexity # Skip complexity routing when caller explicitly specifies a model. complexity: TaskComplexity | None = None if model is None: @@ -698,7 +526,7 @@ class CascadeRouter: ) continue - # Complexity-based model selection (only when no explicit model) ── + # Complexity-based model selection (only when no explicit model) effective_model = model if effective_model is None and complexity is not None: effective_model = self._get_model_for_complexity(provider, complexity) @@ -740,357 +568,6 @@ class CascadeRouter: raise RuntimeError(f"All providers failed: {'; '.join(errors)}") - async def _try_provider( - self, - provider: Provider, - messages: list[dict], - model: str, - temperature: float, - max_tokens: int | None, - content_type: ContentType = ContentType.TEXT, - ) -> dict: - """Try a single provider request.""" - start_time = time.time() - - if provider.type == "ollama": - result = await self._call_ollama( - provider=provider, - messages=messages, - model=model or provider.get_default_model(), - temperature=temperature, - max_tokens=max_tokens, - content_type=content_type, - ) - elif provider.type == "openai": - result = await self._call_openai( - provider=provider, - messages=messages, - model=model or provider.get_default_model(), - temperature=temperature, - max_tokens=max_tokens, - ) - elif provider.type == "anthropic": - result = await self._call_anthropic( - provider=provider, - messages=messages, - model=model or provider.get_default_model(), - temperature=temperature, - max_tokens=max_tokens, - ) - elif provider.type == "grok": - result = await self._call_grok( - provider=provider, - messages=messages, - model=model or provider.get_default_model(), - temperature=temperature, - max_tokens=max_tokens, - ) - elif provider.type == "vllm_mlx": - result = await self._call_vllm_mlx( - provider=provider, - messages=messages, - model=model or provider.get_default_model(), - temperature=temperature, - max_tokens=max_tokens, - ) - else: - raise ValueError(f"Unknown provider type: {provider.type}") - - latency_ms = (time.time() - start_time) * 1000 - result["latency_ms"] = latency_ms - - return result - - async def _call_ollama( - self, - provider: Provider, - messages: list[dict], - model: str, - temperature: float, - max_tokens: int | None = None, - content_type: ContentType = ContentType.TEXT, - ) -> dict: - """Call Ollama API with multi-modal support.""" - import aiohttp - - url = f"{provider.url or settings.ollama_url}/api/chat" - - # Transform messages for Ollama format (including images) - transformed_messages = self._transform_messages_for_ollama(messages) - - options = {"temperature": temperature} - if max_tokens: - options["num_predict"] = max_tokens - - payload = { - "model": model, - "messages": transformed_messages, - "stream": False, - "options": options, - } - - timeout = aiohttp.ClientTimeout(total=self.config.timeout_seconds) - - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.post(url, json=payload) as response: - if response.status != 200: - text = await response.text() - raise RuntimeError(f"Ollama error {response.status}: {text}") - - data = await response.json() - return { - "content": data["message"]["content"], - "model": model, - } - - def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]: - """Transform messages to Ollama format, handling images.""" - transformed = [] - - for msg in messages: - new_msg = { - "role": msg.get("role", "user"), - "content": msg.get("content", ""), - } - - # Handle images - images = msg.get("images", []) - if images: - new_msg["images"] = [] - for img in images: - if isinstance(img, str): - if img.startswith("data:image/"): - # Base64 encoded image - new_msg["images"].append(img.split(",")[1]) - elif img.startswith("http://") or img.startswith("https://"): - # URL - would need to download, skip for now - logger.warning("Image URLs not yet supported, skipping: %s", img) - elif Path(img).exists(): - # Local file path - read and encode - try: - with open(img, "rb") as f: - img_data = base64.b64encode(f.read()).decode() - new_msg["images"].append(img_data) - except Exception as exc: - logger.error("Failed to read image %s: %s", img, exc) - - transformed.append(new_msg) - - return transformed - - async def _call_openai( - self, - provider: Provider, - messages: list[dict], - model: str, - temperature: float, - max_tokens: int | None, - ) -> dict: - """Call OpenAI API.""" - import openai - - client = openai.AsyncOpenAI( - api_key=provider.api_key, - base_url=provider.base_url, - timeout=self.config.timeout_seconds, - ) - - kwargs = { - "model": model, - "messages": messages, - "temperature": temperature, - } - if max_tokens: - kwargs["max_tokens"] = max_tokens - - response = await client.chat.completions.create(**kwargs) - - return { - "content": response.choices[0].message.content, - "model": response.model, - } - - async def _call_anthropic( - self, - provider: Provider, - messages: list[dict], - model: str, - temperature: float, - max_tokens: int | None, - ) -> dict: - """Call Anthropic API.""" - import anthropic - - client = anthropic.AsyncAnthropic( - api_key=provider.api_key, - timeout=self.config.timeout_seconds, - ) - - # Convert messages to Anthropic format - system_msg = None - conversation = [] - for msg in messages: - if msg["role"] == "system": - system_msg = msg["content"] - else: - conversation.append( - { - "role": msg["role"], - "content": msg["content"], - } - ) - - kwargs = { - "model": model, - "messages": conversation, - "temperature": temperature, - "max_tokens": max_tokens or 1024, - } - if system_msg: - kwargs["system"] = system_msg - - response = await client.messages.create(**kwargs) - - return { - "content": response.content[0].text, - "model": response.model, - } - - async def _call_grok( - self, - provider: Provider, - messages: list[dict], - model: str, - temperature: float, - max_tokens: int | None, - ) -> dict: - """Call xAI Grok API via OpenAI-compatible SDK.""" - import httpx - import openai - - client = openai.AsyncOpenAI( - api_key=provider.api_key, - base_url=provider.base_url or settings.xai_base_url, - timeout=httpx.Timeout(300.0), - ) - - kwargs = { - "model": model, - "messages": messages, - "temperature": temperature, - } - if max_tokens: - kwargs["max_tokens"] = max_tokens - - response = await client.chat.completions.create(**kwargs) - - return { - "content": response.choices[0].message.content, - "model": response.model, - } - - async def _call_vllm_mlx( - self, - provider: Provider, - messages: list[dict], - model: str, - temperature: float, - max_tokens: int | None, - ) -> dict: - """Call vllm-mlx via its OpenAI-compatible API. - - vllm-mlx exposes the same /v1/chat/completions endpoint as OpenAI, - so we reuse the OpenAI client pointed at the local server. - No API key is required for local deployments. - """ - import openai - - base_url = provider.base_url or provider.url or "http://localhost:8000" - # Ensure the base_url ends with /v1 as expected by the OpenAI client - if not base_url.rstrip("/").endswith("/v1"): - base_url = base_url.rstrip("/") + "/v1" - - client = openai.AsyncOpenAI( - api_key=provider.api_key or "no-key-required", - base_url=base_url, - timeout=self.config.timeout_seconds, - ) - - kwargs: dict = { - "model": model, - "messages": messages, - "temperature": temperature, - } - if max_tokens: - kwargs["max_tokens"] = max_tokens - - response = await client.chat.completions.create(**kwargs) - - return { - "content": response.choices[0].message.content, - "model": response.model, - } - - def _record_success(self, provider: Provider, latency_ms: float) -> None: - """Record a successful request.""" - provider.metrics.total_requests += 1 - provider.metrics.successful_requests += 1 - provider.metrics.total_latency_ms += latency_ms - provider.metrics.last_request_time = datetime.now(UTC).isoformat() - provider.metrics.consecutive_failures = 0 - - # Close circuit breaker if half-open - if provider.circuit_state == CircuitState.HALF_OPEN: - provider.half_open_calls += 1 - if provider.half_open_calls >= self.config.circuit_breaker_half_open_max_calls: - self._close_circuit(provider) - - # Update status based on error rate - if provider.metrics.error_rate < 0.1: - provider.status = ProviderStatus.HEALTHY - elif provider.metrics.error_rate < 0.3: - provider.status = ProviderStatus.DEGRADED - - def _record_failure(self, provider: Provider) -> None: - """Record a failed request.""" - provider.metrics.total_requests += 1 - provider.metrics.failed_requests += 1 - provider.metrics.last_error_time = datetime.now(UTC).isoformat() - provider.metrics.consecutive_failures += 1 - - # Check if we should open circuit breaker - if provider.metrics.consecutive_failures >= self.config.circuit_breaker_failure_threshold: - self._open_circuit(provider) - - # Update status - if provider.metrics.error_rate > 0.3: - provider.status = ProviderStatus.DEGRADED - if provider.metrics.error_rate > 0.5: - provider.status = ProviderStatus.UNHEALTHY - - def _open_circuit(self, provider: Provider) -> None: - """Open the circuit breaker for a provider.""" - provider.circuit_state = CircuitState.OPEN - provider.circuit_opened_at = time.time() - provider.status = ProviderStatus.UNHEALTHY - logger.warning("Circuit breaker OPEN for %s", provider.name) - - def _can_close_circuit(self, provider: Provider) -> bool: - """Check if circuit breaker can transition to half-open.""" - if provider.circuit_opened_at is None: - return False - elapsed = time.time() - provider.circuit_opened_at - return elapsed >= self.config.circuit_breaker_recovery_timeout - - def _close_circuit(self, provider: Provider) -> None: - """Close the circuit breaker (provider healthy again).""" - provider.circuit_state = CircuitState.CLOSED - provider.circuit_opened_at = None - provider.half_open_calls = 0 - provider.metrics.consecutive_failures = 0 - provider.status = ProviderStatus.HEALTHY - logger.info("Circuit breaker CLOSED for %s", provider.name) - def reload_config(self) -> dict: """Hot-reload providers.yaml, preserving runtime state. diff --git a/src/infrastructure/router/health.py b/src/infrastructure/router/health.py new file mode 100644 index 00000000..7b77318f --- /dev/null +++ b/src/infrastructure/router/health.py @@ -0,0 +1,137 @@ +"""Health monitoring and circuit breaker mixin for the Cascade Router. + +Provides failure tracking, circuit breaker state transitions, +and quota-based cloud provider gating. +""" + +from __future__ import annotations + +import logging +import time +from datetime import UTC, datetime + +from .models import CircuitState, Provider, ProviderMetrics, ProviderStatus + +logger = logging.getLogger(__name__) + +# Quota monitor — optional, degrades gracefully if unavailable +try: + from infrastructure.claude_quota import QuotaMonitor, get_quota_monitor + + _quota_monitor: "QuotaMonitor | None" = get_quota_monitor() +except Exception as _exc: # pragma: no cover + logger.debug("Quota monitor not available: %s", _exc) + _quota_monitor = None + + +class HealthMixin: + """Mixin providing health tracking, circuit breaker, and quota checks. + + Expects the consuming class to have: + - self.config: RouterConfig + - self.providers: list[Provider] + """ + + def _record_success(self, provider: Provider, latency_ms: float) -> None: + """Record a successful request.""" + provider.metrics.total_requests += 1 + provider.metrics.successful_requests += 1 + provider.metrics.total_latency_ms += latency_ms + provider.metrics.last_request_time = datetime.now(UTC).isoformat() + provider.metrics.consecutive_failures = 0 + + # Close circuit breaker if half-open + if provider.circuit_state == CircuitState.HALF_OPEN: + provider.half_open_calls += 1 + if provider.half_open_calls >= self.config.circuit_breaker_half_open_max_calls: + self._close_circuit(provider) + + # Update status based on error rate + if provider.metrics.error_rate < 0.1: + provider.status = ProviderStatus.HEALTHY + elif provider.metrics.error_rate < 0.3: + provider.status = ProviderStatus.DEGRADED + + def _record_failure(self, provider: Provider) -> None: + """Record a failed request.""" + provider.metrics.total_requests += 1 + provider.metrics.failed_requests += 1 + provider.metrics.last_error_time = datetime.now(UTC).isoformat() + provider.metrics.consecutive_failures += 1 + + # Check if we should open circuit breaker + if provider.metrics.consecutive_failures >= self.config.circuit_breaker_failure_threshold: + self._open_circuit(provider) + + # Update status + if provider.metrics.error_rate > 0.3: + provider.status = ProviderStatus.DEGRADED + if provider.metrics.error_rate > 0.5: + provider.status = ProviderStatus.UNHEALTHY + + def _open_circuit(self, provider: Provider) -> None: + """Open the circuit breaker for a provider.""" + provider.circuit_state = CircuitState.OPEN + provider.circuit_opened_at = time.time() + provider.status = ProviderStatus.UNHEALTHY + logger.warning("Circuit breaker OPEN for %s", provider.name) + + def _can_close_circuit(self, provider: Provider) -> bool: + """Check if circuit breaker can transition to half-open.""" + if provider.circuit_opened_at is None: + return False + elapsed = time.time() - provider.circuit_opened_at + return elapsed >= self.config.circuit_breaker_recovery_timeout + + def _close_circuit(self, provider: Provider) -> None: + """Close the circuit breaker (provider healthy again).""" + provider.circuit_state = CircuitState.CLOSED + provider.circuit_opened_at = None + provider.half_open_calls = 0 + provider.metrics.consecutive_failures = 0 + provider.status = ProviderStatus.HEALTHY + logger.info("Circuit breaker CLOSED for %s", provider.name) + + def _is_provider_available(self, provider: Provider) -> bool: + """Check if a provider should be tried (enabled + circuit breaker).""" + if not provider.enabled: + logger.debug("Skipping %s (disabled)", provider.name) + return False + + if provider.status == ProviderStatus.UNHEALTHY: + if self._can_close_circuit(provider): + provider.circuit_state = CircuitState.HALF_OPEN + provider.half_open_calls = 0 + logger.info("Circuit breaker half-open for %s", provider.name) + else: + logger.debug("Skipping %s (circuit open)", provider.name) + return False + + return True + + def _quota_allows_cloud(self, provider: Provider) -> bool: + """Check quota before routing to a cloud provider. + + Uses the metabolic protocol via select_model(): cloud calls are only + allowed when the quota monitor recommends a cloud model (BURST tier). + Returns True (allow cloud) if quota monitor is unavailable or returns None. + """ + if _quota_monitor is None: + return True + try: + suggested = _quota_monitor.select_model("high") + # Cloud is allowed only when select_model recommends the cloud model + allows = suggested == "claude-sonnet-4-6" + if not allows: + status = _quota_monitor.check() + tier = status.recommended_tier.value if status else "unknown" + logger.info( + "Metabolic protocol: %s tier — downshifting %s to local (%s)", + tier, + provider.name, + suggested, + ) + return allows + except Exception as exc: + logger.warning("Quota check failed, allowing cloud: %s", exc) + return True diff --git a/src/infrastructure/router/models.py b/src/infrastructure/router/models.py new file mode 100644 index 00000000..9acbecf9 --- /dev/null +++ b/src/infrastructure/router/models.py @@ -0,0 +1,138 @@ +"""Data models for the Cascade LLM Router. + +Enums, dataclasses, and configuration objects shared across router modules. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum + + +class ProviderStatus(Enum): + """Health status of a provider.""" + + HEALTHY = "healthy" + DEGRADED = "degraded" # Working but slow or occasional errors + UNHEALTHY = "unhealthy" # Circuit breaker open + DISABLED = "disabled" + + +class CircuitState(Enum): + """Circuit breaker state.""" + + CLOSED = "closed" # Normal operation + OPEN = "open" # Failing, rejecting requests + HALF_OPEN = "half_open" # Testing if recovered + + +class ContentType(Enum): + """Type of content in the request.""" + + TEXT = "text" + VISION = "vision" # Contains images + AUDIO = "audio" # Contains audio + MULTIMODAL = "multimodal" # Multiple content types + + +@dataclass +class ProviderMetrics: + """Metrics for a single provider.""" + + total_requests: int = 0 + successful_requests: int = 0 + failed_requests: int = 0 + total_latency_ms: float = 0.0 + last_request_time: str | None = None + last_error_time: str | None = None + consecutive_failures: int = 0 + + @property + def avg_latency_ms(self) -> float: + if self.total_requests == 0: + return 0.0 + return self.total_latency_ms / self.total_requests + + @property + def error_rate(self) -> float: + if self.total_requests == 0: + return 0.0 + return self.failed_requests / self.total_requests + + +@dataclass +class ModelCapability: + """Capabilities a model supports.""" + + name: str + supports_vision: bool = False + supports_audio: bool = False + supports_tools: bool = False + supports_json: bool = False + supports_streaming: bool = True + context_window: int = 4096 + + +@dataclass +class Provider: + """LLM provider configuration and state.""" + + name: str + type: str # ollama, openai, anthropic + enabled: bool + priority: int + tier: str | None = None # e.g., "local", "standard_cloud", "frontier" + url: str | None = None + api_key: str | None = None + base_url: str | None = None + models: list[dict] = field(default_factory=list) + + # Runtime state + status: ProviderStatus = ProviderStatus.HEALTHY + metrics: ProviderMetrics = field(default_factory=ProviderMetrics) + circuit_state: CircuitState = CircuitState.CLOSED + circuit_opened_at: float | None = None + half_open_calls: int = 0 + + def get_default_model(self) -> str | None: + """Get the default model for this provider.""" + for model in self.models: + if model.get("default"): + return model["name"] + if self.models: + return self.models[0]["name"] + return None + + def get_model_with_capability(self, capability: str) -> str | None: + """Get a model that supports the given capability.""" + for model in self.models: + capabilities = model.get("capabilities", []) + if capability in capabilities: + return model["name"] + # Fall back to default + return self.get_default_model() + + def model_has_capability(self, model_name: str, capability: str) -> bool: + """Check if a specific model has a capability.""" + for model in self.models: + if model["name"] == model_name: + capabilities = model.get("capabilities", []) + return capability in capabilities + return False + + +@dataclass +class RouterConfig: + """Cascade router configuration.""" + + timeout_seconds: int = 30 + max_retries_per_provider: int = 2 + retry_delay_seconds: int = 1 + circuit_breaker_failure_threshold: int = 5 + circuit_breaker_recovery_timeout: int = 60 + circuit_breaker_half_open_max_calls: int = 2 + cost_tracking_enabled: bool = True + budget_daily_usd: float = 10.0 + # Multi-modal settings + auto_pull_models: bool = True + fallback_chains: dict = field(default_factory=dict) diff --git a/src/infrastructure/router/providers.py b/src/infrastructure/router/providers.py new file mode 100644 index 00000000..ed3965b0 --- /dev/null +++ b/src/infrastructure/router/providers.py @@ -0,0 +1,318 @@ +"""Provider API call mixin for the Cascade Router. + +Contains methods for calling individual LLM provider APIs +(Ollama, OpenAI, Anthropic, Grok, vllm-mlx). +""" + +from __future__ import annotations + +import base64 +import logging +import time +from pathlib import Path +from typing import Any + +from config import settings + +from .models import ContentType, Provider + +logger = logging.getLogger(__name__) + + +class ProviderCallsMixin: + """Mixin providing LLM provider API call methods. + + Expects the consuming class to have: + - self.config: RouterConfig + """ + + async def _try_provider( + self, + provider: Provider, + messages: list[dict], + model: str, + temperature: float, + max_tokens: int | None, + content_type: ContentType = ContentType.TEXT, + ) -> dict: + """Try a single provider request.""" + start_time = time.time() + + if provider.type == "ollama": + result = await self._call_ollama( + provider=provider, + messages=messages, + model=model or provider.get_default_model(), + temperature=temperature, + max_tokens=max_tokens, + content_type=content_type, + ) + elif provider.type == "openai": + result = await self._call_openai( + provider=provider, + messages=messages, + model=model or provider.get_default_model(), + temperature=temperature, + max_tokens=max_tokens, + ) + elif provider.type == "anthropic": + result = await self._call_anthropic( + provider=provider, + messages=messages, + model=model or provider.get_default_model(), + temperature=temperature, + max_tokens=max_tokens, + ) + elif provider.type == "grok": + result = await self._call_grok( + provider=provider, + messages=messages, + model=model or provider.get_default_model(), + temperature=temperature, + max_tokens=max_tokens, + ) + elif provider.type == "vllm_mlx": + result = await self._call_vllm_mlx( + provider=provider, + messages=messages, + model=model or provider.get_default_model(), + temperature=temperature, + max_tokens=max_tokens, + ) + else: + raise ValueError(f"Unknown provider type: {provider.type}") + + latency_ms = (time.time() - start_time) * 1000 + result["latency_ms"] = latency_ms + + return result + + async def _call_ollama( + self, + provider: Provider, + messages: list[dict], + model: str, + temperature: float, + max_tokens: int | None = None, + content_type: ContentType = ContentType.TEXT, + ) -> dict: + """Call Ollama API with multi-modal support.""" + import aiohttp + + url = f"{provider.url or settings.ollama_url}/api/chat" + + # Transform messages for Ollama format (including images) + transformed_messages = self._transform_messages_for_ollama(messages) + + options: dict[str, Any] = {"temperature": temperature} + if max_tokens: + options["num_predict"] = max_tokens + + payload = { + "model": model, + "messages": transformed_messages, + "stream": False, + "options": options, + } + + timeout = aiohttp.ClientTimeout(total=self.config.timeout_seconds) + + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(url, json=payload) as response: + if response.status != 200: + text = await response.text() + raise RuntimeError(f"Ollama error {response.status}: {text}") + + data = await response.json() + return { + "content": data["message"]["content"], + "model": model, + } + + def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]: + """Transform messages to Ollama format, handling images.""" + transformed = [] + + for msg in messages: + new_msg: dict[str, Any] = { + "role": msg.get("role", "user"), + "content": msg.get("content", ""), + } + + # Handle images + images = msg.get("images", []) + if images: + new_msg["images"] = [] + for img in images: + if isinstance(img, str): + if img.startswith("data:image/"): + # Base64 encoded image + new_msg["images"].append(img.split(",")[1]) + elif img.startswith("http://") or img.startswith("https://"): + # URL - would need to download, skip for now + logger.warning("Image URLs not yet supported, skipping: %s", img) + elif Path(img).exists(): + # Local file path - read and encode + try: + with open(img, "rb") as f: + img_data = base64.b64encode(f.read()).decode() + new_msg["images"].append(img_data) + except Exception as exc: + logger.error("Failed to read image %s: %s", img, exc) + + transformed.append(new_msg) + + return transformed + + async def _call_openai( + self, + provider: Provider, + messages: list[dict], + model: str, + temperature: float, + max_tokens: int | None, + ) -> dict: + """Call OpenAI API.""" + import openai + + client = openai.AsyncOpenAI( + api_key=provider.api_key, + base_url=provider.base_url, + timeout=self.config.timeout_seconds, + ) + + kwargs: dict[str, Any] = { + "model": model, + "messages": messages, + "temperature": temperature, + } + if max_tokens: + kwargs["max_tokens"] = max_tokens + + response = await client.chat.completions.create(**kwargs) + + return { + "content": response.choices[0].message.content, + "model": response.model, + } + + async def _call_anthropic( + self, + provider: Provider, + messages: list[dict], + model: str, + temperature: float, + max_tokens: int | None, + ) -> dict: + """Call Anthropic API.""" + import anthropic + + client = anthropic.AsyncAnthropic( + api_key=provider.api_key, + timeout=self.config.timeout_seconds, + ) + + # Convert messages to Anthropic format + system_msg = None + conversation = [] + for msg in messages: + if msg["role"] == "system": + system_msg = msg["content"] + else: + conversation.append( + { + "role": msg["role"], + "content": msg["content"], + } + ) + + kwargs: dict[str, Any] = { + "model": model, + "messages": conversation, + "temperature": temperature, + "max_tokens": max_tokens or 1024, + } + if system_msg: + kwargs["system"] = system_msg + + response = await client.messages.create(**kwargs) + + return { + "content": response.content[0].text, + "model": response.model, + } + + async def _call_grok( + self, + provider: Provider, + messages: list[dict], + model: str, + temperature: float, + max_tokens: int | None, + ) -> dict: + """Call xAI Grok API via OpenAI-compatible SDK.""" + import httpx + import openai + + client = openai.AsyncOpenAI( + api_key=provider.api_key, + base_url=provider.base_url or settings.xai_base_url, + timeout=httpx.Timeout(300.0), + ) + + kwargs: dict[str, Any] = { + "model": model, + "messages": messages, + "temperature": temperature, + } + if max_tokens: + kwargs["max_tokens"] = max_tokens + + response = await client.chat.completions.create(**kwargs) + + return { + "content": response.choices[0].message.content, + "model": response.model, + } + + async def _call_vllm_mlx( + self, + provider: Provider, + messages: list[dict], + model: str, + temperature: float, + max_tokens: int | None, + ) -> dict: + """Call vllm-mlx via its OpenAI-compatible API. + + vllm-mlx exposes the same /v1/chat/completions endpoint as OpenAI, + so we reuse the OpenAI client pointed at the local server. + No API key is required for local deployments. + """ + import openai + + base_url = provider.base_url or provider.url or "http://localhost:8000" + # Ensure the base_url ends with /v1 as expected by the OpenAI client + if not base_url.rstrip("/").endswith("/v1"): + base_url = base_url.rstrip("/") + "/v1" + + client = openai.AsyncOpenAI( + api_key=provider.api_key or "no-key-required", + base_url=base_url, + timeout=self.config.timeout_seconds, + ) + + kwargs: dict[str, Any] = { + "model": model, + "messages": messages, + "temperature": temperature, + } + if max_tokens: + kwargs["max_tokens"] = max_tokens + + response = await client.chat.completions.create(**kwargs) + + return { + "content": response.choices[0].message.content, + "model": response.model, + } diff --git a/tests/infrastructure/test_router_cascade.py b/tests/infrastructure/test_router_cascade.py index 5d2c7788..9df41062 100644 --- a/tests/infrastructure/test_router_cascade.py +++ b/tests/infrastructure/test_router_cascade.py @@ -677,7 +677,7 @@ class TestVllmMlxProvider: router.providers = [provider] # Quota monitor downshifts to local (ACTIVE tier) — vllm_mlx should still be tried - with patch("infrastructure.router.cascade._quota_monitor") as mock_qm: + with patch("infrastructure.router.health._quota_monitor") as mock_qm: mock_qm.select_model.return_value = "qwen3:14b" mock_qm.check.return_value = None @@ -713,7 +713,7 @@ class TestMetabolicProtocol: router = CascadeRouter(config_path=Path("/nonexistent")) router.providers = [self._make_anthropic_provider()] - with patch("infrastructure.router.cascade._quota_monitor") as mock_qm: + with patch("infrastructure.router.health._quota_monitor") as mock_qm: # select_model returns cloud model → BURST tier mock_qm.select_model.return_value = "claude-sonnet-4-6" mock_qm.check.return_value = None @@ -732,7 +732,7 @@ class TestMetabolicProtocol: router = CascadeRouter(config_path=Path("/nonexistent")) router.providers = [self._make_anthropic_provider()] - with patch("infrastructure.router.cascade._quota_monitor") as mock_qm: + with patch("infrastructure.router.health._quota_monitor") as mock_qm: # select_model returns local 14B → ACTIVE tier mock_qm.select_model.return_value = "qwen3:14b" mock_qm.check.return_value = None @@ -750,7 +750,7 @@ class TestMetabolicProtocol: router = CascadeRouter(config_path=Path("/nonexistent")) router.providers = [self._make_anthropic_provider()] - with patch("infrastructure.router.cascade._quota_monitor") as mock_qm: + with patch("infrastructure.router.health._quota_monitor") as mock_qm: # select_model returns local 8B → RESTING tier mock_qm.select_model.return_value = "qwen3:8b" mock_qm.check.return_value = None @@ -776,7 +776,7 @@ class TestMetabolicProtocol: ) router.providers = [provider] - with patch("infrastructure.router.cascade._quota_monitor") as mock_qm: + with patch("infrastructure.router.health._quota_monitor") as mock_qm: mock_qm.select_model.return_value = "qwen3:8b" # RESTING tier with patch.object(router, "_call_ollama") as mock_call: @@ -793,7 +793,7 @@ class TestMetabolicProtocol: router = CascadeRouter(config_path=Path("/nonexistent")) router.providers = [self._make_anthropic_provider()] - with patch("infrastructure.router.cascade._quota_monitor", None): + with patch("infrastructure.router.health._quota_monitor", None): with patch.object(router, "_call_anthropic") as mock_call: mock_call.return_value = {"content": "Cloud response", "model": "claude-sonnet-4-6"} result = await router.complete( @@ -1200,7 +1200,7 @@ class TestCascadeTierFiltering: async def test_frontier_required_uses_anthropic(self): router = self._make_router() - with patch("infrastructure.router.cascade._quota_monitor", None): + with patch("infrastructure.router.health._quota_monitor", None): with patch.object(router, "_call_anthropic") as mock_call: mock_call.return_value = { "content": "frontier response", @@ -1464,7 +1464,7 @@ class TestTrySingleProvider: router = self._router() provider = self._provider(ptype="anthropic") errors: list[str] = [] - with patch("infrastructure.router.cascade._quota_monitor") as mock_qm: + with patch("infrastructure.router.health._quota_monitor") as mock_qm: mock_qm.select_model.return_value = "qwen3:14b" # non-cloud → ACTIVE tier mock_qm.check.return_value = None result = await router._try_single_provider(