Remove the airllm provider block from providers.yaml, the airllm availability check from cascade.py, and associated tests. Renumber remaining provider priorities (OpenAI → 2, Anthropic → 3). Fixes #459 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
957 lines
32 KiB
Python
957 lines
32 KiB
Python
"""Cascade LLM Router — Automatic failover between providers.
|
|
|
|
Routes requests through an ordered list of LLM providers,
|
|
automatically failing over on rate limits or errors.
|
|
Tracks metrics for latency, errors, and cost.
|
|
|
|
Now with multi-modal support — automatically selects vision-capable
|
|
models for image inputs and falls back through capability chains.
|
|
"""
|
|
|
|
import asyncio
|
|
import base64
|
|
import logging
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from datetime import UTC, datetime
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
try:
|
|
import yaml
|
|
except ImportError:
|
|
yaml = None # type: ignore
|
|
|
|
try:
|
|
import requests
|
|
except ImportError:
|
|
requests = None # type: ignore
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ProviderStatus(Enum):
|
|
"""Health status of a provider."""
|
|
|
|
HEALTHY = "healthy"
|
|
DEGRADED = "degraded" # Working but slow or occasional errors
|
|
UNHEALTHY = "unhealthy" # Circuit breaker open
|
|
DISABLED = "disabled"
|
|
|
|
|
|
class CircuitState(Enum):
|
|
"""Circuit breaker state."""
|
|
|
|
CLOSED = "closed" # Normal operation
|
|
OPEN = "open" # Failing, rejecting requests
|
|
HALF_OPEN = "half_open" # Testing if recovered
|
|
|
|
|
|
class ContentType(Enum):
|
|
"""Type of content in the request."""
|
|
|
|
TEXT = "text"
|
|
VISION = "vision" # Contains images
|
|
AUDIO = "audio" # Contains audio
|
|
MULTIMODAL = "multimodal" # Multiple content types
|
|
|
|
|
|
@dataclass
|
|
class ProviderMetrics:
|
|
"""Metrics for a single provider."""
|
|
|
|
total_requests: int = 0
|
|
successful_requests: int = 0
|
|
failed_requests: int = 0
|
|
total_latency_ms: float = 0.0
|
|
last_request_time: str | None = None
|
|
last_error_time: str | None = None
|
|
consecutive_failures: int = 0
|
|
|
|
@property
|
|
def avg_latency_ms(self) -> float:
|
|
if self.total_requests == 0:
|
|
return 0.0
|
|
return self.total_latency_ms / self.total_requests
|
|
|
|
@property
|
|
def error_rate(self) -> float:
|
|
if self.total_requests == 0:
|
|
return 0.0
|
|
return self.failed_requests / self.total_requests
|
|
|
|
|
|
@dataclass
|
|
class ModelCapability:
|
|
"""Capabilities a model supports."""
|
|
|
|
name: str
|
|
supports_vision: bool = False
|
|
supports_audio: bool = False
|
|
supports_tools: bool = False
|
|
supports_json: bool = False
|
|
supports_streaming: bool = True
|
|
context_window: int = 4096
|
|
|
|
|
|
@dataclass
|
|
class Provider:
|
|
"""LLM provider configuration and state."""
|
|
|
|
name: str
|
|
type: str # ollama, openai, anthropic
|
|
enabled: bool
|
|
priority: int
|
|
url: str | None = None
|
|
api_key: str | None = None
|
|
base_url: str | None = None
|
|
models: list[dict] = field(default_factory=list)
|
|
|
|
# Runtime state
|
|
status: ProviderStatus = ProviderStatus.HEALTHY
|
|
metrics: ProviderMetrics = field(default_factory=ProviderMetrics)
|
|
circuit_state: CircuitState = CircuitState.CLOSED
|
|
circuit_opened_at: float | None = None
|
|
half_open_calls: int = 0
|
|
|
|
def get_default_model(self) -> str | None:
|
|
"""Get the default model for this provider."""
|
|
for model in self.models:
|
|
if model.get("default"):
|
|
return model["name"]
|
|
if self.models:
|
|
return self.models[0]["name"]
|
|
return None
|
|
|
|
def get_model_with_capability(self, capability: str) -> str | None:
|
|
"""Get a model that supports the given capability."""
|
|
for model in self.models:
|
|
capabilities = model.get("capabilities", [])
|
|
if capability in capabilities:
|
|
return model["name"]
|
|
# Fall back to default
|
|
return self.get_default_model()
|
|
|
|
def model_has_capability(self, model_name: str, capability: str) -> bool:
|
|
"""Check if a specific model has a capability."""
|
|
for model in self.models:
|
|
if model["name"] == model_name:
|
|
capabilities = model.get("capabilities", [])
|
|
return capability in capabilities
|
|
return False
|
|
|
|
|
|
@dataclass
|
|
class RouterConfig:
|
|
"""Cascade router configuration."""
|
|
|
|
timeout_seconds: int = 30
|
|
max_retries_per_provider: int = 2
|
|
retry_delay_seconds: int = 1
|
|
circuit_breaker_failure_threshold: int = 5
|
|
circuit_breaker_recovery_timeout: int = 60
|
|
circuit_breaker_half_open_max_calls: int = 2
|
|
cost_tracking_enabled: bool = True
|
|
budget_daily_usd: float = 10.0
|
|
# Multi-modal settings
|
|
auto_pull_models: bool = True
|
|
fallback_chains: dict = field(default_factory=dict)
|
|
|
|
|
|
class CascadeRouter:
|
|
"""Routes LLM requests with automatic failover.
|
|
|
|
Now with multi-modal support:
|
|
- Automatically detects content type (text, vision, audio)
|
|
- Selects appropriate models based on capabilities
|
|
- Falls back through capability-specific model chains
|
|
- Supports image URLs and base64 encoding
|
|
|
|
Usage:
|
|
router = CascadeRouter()
|
|
|
|
# Text request
|
|
response = await router.complete(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="llama3.2"
|
|
)
|
|
|
|
# Vision request (automatically detects and selects vision model)
|
|
response = await router.complete(
|
|
messages=[{
|
|
"role": "user",
|
|
"content": "What's in this image?",
|
|
"images": ["path/to/image.jpg"]
|
|
}],
|
|
model="llava:7b"
|
|
)
|
|
|
|
# Check metrics
|
|
metrics = router.get_metrics()
|
|
"""
|
|
|
|
def __init__(self, config_path: Path | None = None) -> None:
|
|
self.config_path = config_path or Path("config/providers.yaml")
|
|
self.providers: list[Provider] = []
|
|
self.config: RouterConfig = RouterConfig()
|
|
self._load_config()
|
|
|
|
# Initialize multi-modal manager if available
|
|
self._mm_manager: Any | None = None
|
|
try:
|
|
from infrastructure.models.multimodal import get_multimodal_manager
|
|
|
|
self._mm_manager = get_multimodal_manager()
|
|
except Exception as exc:
|
|
logger.debug("Multi-modal manager not available: %s", exc)
|
|
|
|
logger.info("CascadeRouter initialized with %d providers", len(self.providers))
|
|
|
|
def _load_config(self) -> None:
|
|
"""Load configuration from YAML."""
|
|
if not self.config_path.exists():
|
|
logger.warning("Config not found: %s, using defaults", self.config_path)
|
|
return
|
|
|
|
try:
|
|
if yaml is None:
|
|
raise RuntimeError("PyYAML not installed")
|
|
|
|
content = self.config_path.read_text()
|
|
# Expand environment variables
|
|
content = self._expand_env_vars(content)
|
|
data = yaml.safe_load(content)
|
|
|
|
# Load cascade settings
|
|
cascade = data.get("cascade", {})
|
|
|
|
# Load fallback chains
|
|
fallback_chains = data.get("fallback_chains", {})
|
|
|
|
# Load multi-modal settings
|
|
multimodal = data.get("multimodal", {})
|
|
|
|
self.config = RouterConfig(
|
|
timeout_seconds=cascade.get("timeout_seconds", 30),
|
|
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
|
|
retry_delay_seconds=cascade.get("retry_delay_seconds", 1),
|
|
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get(
|
|
"failure_threshold", 5
|
|
),
|
|
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get(
|
|
"recovery_timeout", 60
|
|
),
|
|
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get(
|
|
"half_open_max_calls", 2
|
|
),
|
|
auto_pull_models=multimodal.get("auto_pull", True),
|
|
fallback_chains=fallback_chains,
|
|
)
|
|
|
|
# Load providers
|
|
for p_data in data.get("providers", []):
|
|
# Skip disabled providers
|
|
if not p_data.get("enabled", False):
|
|
continue
|
|
|
|
provider = Provider(
|
|
name=p_data["name"],
|
|
type=p_data["type"],
|
|
enabled=p_data.get("enabled", True),
|
|
priority=p_data.get("priority", 99),
|
|
url=p_data.get("url"),
|
|
api_key=p_data.get("api_key"),
|
|
base_url=p_data.get("base_url"),
|
|
models=p_data.get("models", []),
|
|
)
|
|
|
|
# Check if provider is actually available
|
|
if self._check_provider_available(provider):
|
|
self.providers.append(provider)
|
|
else:
|
|
logger.warning("Provider %s not available, skipping", provider.name)
|
|
|
|
# Sort by priority
|
|
self.providers.sort(key=lambda p: p.priority)
|
|
|
|
except Exception as exc:
|
|
logger.error("Failed to load config: %s", exc)
|
|
|
|
def _expand_env_vars(self, content: str) -> str:
|
|
"""Expand ${VAR} syntax in YAML content.
|
|
|
|
Uses os.environ directly (not settings) because this is a generic
|
|
YAML config loader that must expand arbitrary variable references.
|
|
"""
|
|
import os
|
|
import re
|
|
|
|
def replace_var(match: "re.Match[str]") -> str:
|
|
var_name = match.group(1)
|
|
return os.environ.get(var_name, match.group(0))
|
|
|
|
return re.sub(r"\$\{(\w+)\}", replace_var, content)
|
|
|
|
def _check_provider_available(self, provider: Provider) -> bool:
|
|
"""Check if a provider is actually available."""
|
|
if provider.type == "ollama":
|
|
# Check if Ollama is running
|
|
if requests is None:
|
|
# Can't check without requests, assume available
|
|
return True
|
|
try:
|
|
url = provider.url or "http://localhost:11434"
|
|
response = requests.get(f"{url}/api/tags", timeout=5)
|
|
return response.status_code == 200
|
|
except Exception as exc:
|
|
logger.debug("Ollama provider check error: %s", exc)
|
|
return False
|
|
|
|
elif provider.type in ("openai", "anthropic", "grok"):
|
|
# Check if API key is set
|
|
return provider.api_key is not None and provider.api_key != ""
|
|
|
|
return True
|
|
|
|
def _detect_content_type(self, messages: list[dict]) -> ContentType:
|
|
"""Detect the type of content in the messages.
|
|
|
|
Checks for images, audio, etc. in the message content.
|
|
"""
|
|
has_image = False
|
|
has_audio = False
|
|
|
|
for msg in messages:
|
|
content = msg.get("content", "")
|
|
|
|
# Check for image URLs/paths
|
|
if msg.get("images"):
|
|
has_image = True
|
|
|
|
# Check for image URLs in content
|
|
if isinstance(content, str):
|
|
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
|
|
if any(ext in content.lower() for ext in image_extensions):
|
|
has_image = True
|
|
if content.startswith("data:image/"):
|
|
has_image = True
|
|
|
|
# Check for audio
|
|
if msg.get("audio"):
|
|
has_audio = True
|
|
|
|
# Check for multimodal content structure
|
|
if isinstance(content, list):
|
|
for item in content:
|
|
if isinstance(item, dict):
|
|
if item.get("type") == "image_url":
|
|
has_image = True
|
|
elif item.get("type") == "audio":
|
|
has_audio = True
|
|
|
|
if has_image and has_audio:
|
|
return ContentType.MULTIMODAL
|
|
elif has_image:
|
|
return ContentType.VISION
|
|
elif has_audio:
|
|
return ContentType.AUDIO
|
|
return ContentType.TEXT
|
|
|
|
def _get_fallback_model(
|
|
self, provider: Provider, original_model: str, content_type: ContentType
|
|
) -> str | None:
|
|
"""Get a fallback model for the given content type."""
|
|
# Map content type to capability
|
|
capability_map = {
|
|
ContentType.VISION: "vision",
|
|
ContentType.AUDIO: "audio",
|
|
ContentType.MULTIMODAL: "vision", # Vision models often do both
|
|
}
|
|
|
|
capability = capability_map.get(content_type)
|
|
if not capability:
|
|
return None
|
|
|
|
# Check provider's models for capability
|
|
fallback_model = provider.get_model_with_capability(capability)
|
|
if fallback_model and fallback_model != original_model:
|
|
return fallback_model
|
|
|
|
# Use fallback chains from config
|
|
fallback_chain = self.config.fallback_chains.get(capability, [])
|
|
for model_name in fallback_chain:
|
|
if provider.model_has_capability(model_name, capability):
|
|
return model_name
|
|
|
|
return None
|
|
|
|
async def complete(
|
|
self,
|
|
messages: list[dict],
|
|
model: str | None = None,
|
|
temperature: float = 0.7,
|
|
max_tokens: int | None = None,
|
|
) -> dict:
|
|
"""Complete a chat conversation with automatic failover.
|
|
|
|
Multi-modal support:
|
|
- Automatically detects if messages contain images
|
|
- Falls back to vision-capable models when needed
|
|
- Supports image URLs, paths, and base64 encoding
|
|
|
|
Args:
|
|
messages: List of message dicts with role and content
|
|
model: Preferred model (tries this first, then provider defaults)
|
|
temperature: Sampling temperature
|
|
max_tokens: Maximum tokens to generate
|
|
|
|
Returns:
|
|
Dict with content, provider_used, and metrics
|
|
|
|
Raises:
|
|
RuntimeError: If all providers fail
|
|
"""
|
|
# Detect content type for multi-modal routing
|
|
content_type = self._detect_content_type(messages)
|
|
if content_type != ContentType.TEXT:
|
|
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
|
|
|
|
errors = []
|
|
|
|
for provider in self.providers:
|
|
# Skip disabled providers
|
|
if not provider.enabled:
|
|
logger.debug("Skipping %s (disabled)", provider.name)
|
|
continue
|
|
|
|
# Skip unhealthy providers (circuit breaker)
|
|
if provider.status == ProviderStatus.UNHEALTHY:
|
|
# Check if circuit breaker can close
|
|
if self._can_close_circuit(provider):
|
|
provider.circuit_state = CircuitState.HALF_OPEN
|
|
provider.half_open_calls = 0
|
|
logger.info("Circuit breaker half-open for %s", provider.name)
|
|
else:
|
|
logger.debug("Skipping %s (circuit open)", provider.name)
|
|
continue
|
|
|
|
# Determine which model to use
|
|
selected_model = model or provider.get_default_model()
|
|
is_fallback_model = False
|
|
|
|
# For non-text content, check if model supports it
|
|
if content_type != ContentType.TEXT and selected_model:
|
|
if provider.type == "ollama" and self._mm_manager:
|
|
from infrastructure.models.multimodal import ModelCapability
|
|
|
|
# Check if selected model supports the required capability
|
|
if content_type == ContentType.VISION:
|
|
supports = self._mm_manager.model_supports(
|
|
selected_model, ModelCapability.VISION
|
|
)
|
|
if not supports:
|
|
# Find fallback model
|
|
fallback = self._get_fallback_model(
|
|
provider, selected_model, content_type
|
|
)
|
|
if fallback:
|
|
logger.info(
|
|
"Model %s doesn't support vision, falling back to %s",
|
|
selected_model,
|
|
fallback,
|
|
)
|
|
selected_model = fallback
|
|
is_fallback_model = True
|
|
else:
|
|
logger.warning(
|
|
"No vision-capable model found on %s, trying anyway",
|
|
provider.name,
|
|
)
|
|
|
|
# Try this provider
|
|
for attempt in range(self.config.max_retries_per_provider):
|
|
try:
|
|
result = await self._try_provider(
|
|
provider=provider,
|
|
messages=messages,
|
|
model=selected_model,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
content_type=content_type,
|
|
)
|
|
|
|
# Success! Update metrics and return
|
|
self._record_success(provider, result.get("latency_ms", 0))
|
|
return {
|
|
"content": result["content"],
|
|
"provider": provider.name,
|
|
"model": result.get(
|
|
"model", selected_model or provider.get_default_model()
|
|
),
|
|
"latency_ms": result.get("latency_ms", 0),
|
|
"is_fallback_model": is_fallback_model,
|
|
}
|
|
|
|
except Exception as exc:
|
|
error_msg = str(exc)
|
|
logger.warning(
|
|
"Provider %s attempt %d failed: %s", provider.name, attempt + 1, error_msg
|
|
)
|
|
errors.append(f"{provider.name}: {error_msg}")
|
|
|
|
if attempt < self.config.max_retries_per_provider - 1:
|
|
await asyncio.sleep(self.config.retry_delay_seconds)
|
|
|
|
# All retries failed for this provider
|
|
self._record_failure(provider)
|
|
|
|
# All providers failed
|
|
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
|
|
|
|
async def _try_provider(
|
|
self,
|
|
provider: Provider,
|
|
messages: list[dict],
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int | None,
|
|
content_type: ContentType = ContentType.TEXT,
|
|
) -> dict:
|
|
"""Try a single provider request."""
|
|
start_time = time.time()
|
|
|
|
if provider.type == "ollama":
|
|
result = await self._call_ollama(
|
|
provider=provider,
|
|
messages=messages,
|
|
model=model or provider.get_default_model(),
|
|
temperature=temperature,
|
|
content_type=content_type,
|
|
)
|
|
elif provider.type == "openai":
|
|
result = await self._call_openai(
|
|
provider=provider,
|
|
messages=messages,
|
|
model=model or provider.get_default_model(),
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
)
|
|
elif provider.type == "anthropic":
|
|
result = await self._call_anthropic(
|
|
provider=provider,
|
|
messages=messages,
|
|
model=model or provider.get_default_model(),
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
)
|
|
elif provider.type == "grok":
|
|
result = await self._call_grok(
|
|
provider=provider,
|
|
messages=messages,
|
|
model=model or provider.get_default_model(),
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown provider type: {provider.type}")
|
|
|
|
latency_ms = (time.time() - start_time) * 1000
|
|
result["latency_ms"] = latency_ms
|
|
|
|
return result
|
|
|
|
async def _call_ollama(
|
|
self,
|
|
provider: Provider,
|
|
messages: list[dict],
|
|
model: str,
|
|
temperature: float,
|
|
content_type: ContentType = ContentType.TEXT,
|
|
) -> dict:
|
|
"""Call Ollama API with multi-modal support."""
|
|
import aiohttp
|
|
|
|
url = f"{provider.url}/api/chat"
|
|
|
|
# Transform messages for Ollama format (including images)
|
|
transformed_messages = self._transform_messages_for_ollama(messages)
|
|
|
|
payload = {
|
|
"model": model,
|
|
"messages": transformed_messages,
|
|
"stream": False,
|
|
"options": {
|
|
"temperature": temperature,
|
|
},
|
|
}
|
|
|
|
timeout = aiohttp.ClientTimeout(total=self.config.timeout_seconds)
|
|
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
async with session.post(url, json=payload) as response:
|
|
if response.status != 200:
|
|
text = await response.text()
|
|
raise RuntimeError(f"Ollama error {response.status}: {text}")
|
|
|
|
data = await response.json()
|
|
return {
|
|
"content": data["message"]["content"],
|
|
"model": model,
|
|
}
|
|
|
|
def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]:
|
|
"""Transform messages to Ollama format, handling images."""
|
|
transformed = []
|
|
|
|
for msg in messages:
|
|
new_msg = {
|
|
"role": msg.get("role", "user"),
|
|
"content": msg.get("content", ""),
|
|
}
|
|
|
|
# Handle images
|
|
images = msg.get("images", [])
|
|
if images:
|
|
new_msg["images"] = []
|
|
for img in images:
|
|
if isinstance(img, str):
|
|
if img.startswith("data:image/"):
|
|
# Base64 encoded image
|
|
new_msg["images"].append(img.split(",")[1])
|
|
elif img.startswith("http://") or img.startswith("https://"):
|
|
# URL - would need to download, skip for now
|
|
logger.warning("Image URLs not yet supported, skipping: %s", img)
|
|
elif Path(img).exists():
|
|
# Local file path - read and encode
|
|
try:
|
|
with open(img, "rb") as f:
|
|
img_data = base64.b64encode(f.read()).decode()
|
|
new_msg["images"].append(img_data)
|
|
except Exception as exc:
|
|
logger.error("Failed to read image %s: %s", img, exc)
|
|
|
|
transformed.append(new_msg)
|
|
|
|
return transformed
|
|
|
|
async def _call_openai(
|
|
self,
|
|
provider: Provider,
|
|
messages: list[dict],
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int | None,
|
|
) -> dict:
|
|
"""Call OpenAI API."""
|
|
import openai
|
|
|
|
client = openai.AsyncOpenAI(
|
|
api_key=provider.api_key,
|
|
base_url=provider.base_url,
|
|
timeout=self.config.timeout_seconds,
|
|
)
|
|
|
|
kwargs = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
}
|
|
if max_tokens:
|
|
kwargs["max_tokens"] = max_tokens
|
|
|
|
response = await client.chat.completions.create(**kwargs)
|
|
|
|
return {
|
|
"content": response.choices[0].message.content,
|
|
"model": response.model,
|
|
}
|
|
|
|
async def _call_anthropic(
|
|
self,
|
|
provider: Provider,
|
|
messages: list[dict],
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int | None,
|
|
) -> dict:
|
|
"""Call Anthropic API."""
|
|
import anthropic
|
|
|
|
client = anthropic.AsyncAnthropic(
|
|
api_key=provider.api_key,
|
|
timeout=self.config.timeout_seconds,
|
|
)
|
|
|
|
# Convert messages to Anthropic format
|
|
system_msg = None
|
|
conversation = []
|
|
for msg in messages:
|
|
if msg["role"] == "system":
|
|
system_msg = msg["content"]
|
|
else:
|
|
conversation.append(
|
|
{
|
|
"role": msg["role"],
|
|
"content": msg["content"],
|
|
}
|
|
)
|
|
|
|
kwargs = {
|
|
"model": model,
|
|
"messages": conversation,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens or 1024,
|
|
}
|
|
if system_msg:
|
|
kwargs["system"] = system_msg
|
|
|
|
response = await client.messages.create(**kwargs)
|
|
|
|
return {
|
|
"content": response.content[0].text,
|
|
"model": response.model,
|
|
}
|
|
|
|
async def _call_grok(
|
|
self,
|
|
provider: Provider,
|
|
messages: list[dict],
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int | None,
|
|
) -> dict:
|
|
"""Call xAI Grok API via OpenAI-compatible SDK."""
|
|
import httpx
|
|
import openai
|
|
|
|
client = openai.AsyncOpenAI(
|
|
api_key=provider.api_key,
|
|
base_url=provider.base_url or "https://api.x.ai/v1",
|
|
timeout=httpx.Timeout(300.0),
|
|
)
|
|
|
|
kwargs = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
}
|
|
if max_tokens:
|
|
kwargs["max_tokens"] = max_tokens
|
|
|
|
response = await client.chat.completions.create(**kwargs)
|
|
|
|
return {
|
|
"content": response.choices[0].message.content,
|
|
"model": response.model,
|
|
}
|
|
|
|
def _record_success(self, provider: Provider, latency_ms: float) -> None:
|
|
"""Record a successful request."""
|
|
provider.metrics.total_requests += 1
|
|
provider.metrics.successful_requests += 1
|
|
provider.metrics.total_latency_ms += latency_ms
|
|
provider.metrics.last_request_time = datetime.now(UTC).isoformat()
|
|
provider.metrics.consecutive_failures = 0
|
|
|
|
# Close circuit breaker if half-open
|
|
if provider.circuit_state == CircuitState.HALF_OPEN:
|
|
provider.half_open_calls += 1
|
|
if provider.half_open_calls >= self.config.circuit_breaker_half_open_max_calls:
|
|
self._close_circuit(provider)
|
|
|
|
# Update status based on error rate
|
|
if provider.metrics.error_rate < 0.1:
|
|
provider.status = ProviderStatus.HEALTHY
|
|
elif provider.metrics.error_rate < 0.3:
|
|
provider.status = ProviderStatus.DEGRADED
|
|
|
|
def _record_failure(self, provider: Provider) -> None:
|
|
"""Record a failed request."""
|
|
provider.metrics.total_requests += 1
|
|
provider.metrics.failed_requests += 1
|
|
provider.metrics.last_error_time = datetime.now(UTC).isoformat()
|
|
provider.metrics.consecutive_failures += 1
|
|
|
|
# Check if we should open circuit breaker
|
|
if provider.metrics.consecutive_failures >= self.config.circuit_breaker_failure_threshold:
|
|
self._open_circuit(provider)
|
|
|
|
# Update status
|
|
if provider.metrics.error_rate > 0.3:
|
|
provider.status = ProviderStatus.DEGRADED
|
|
if provider.metrics.error_rate > 0.5:
|
|
provider.status = ProviderStatus.UNHEALTHY
|
|
|
|
def _open_circuit(self, provider: Provider) -> None:
|
|
"""Open the circuit breaker for a provider."""
|
|
provider.circuit_state = CircuitState.OPEN
|
|
provider.circuit_opened_at = time.time()
|
|
provider.status = ProviderStatus.UNHEALTHY
|
|
logger.warning("Circuit breaker OPEN for %s", provider.name)
|
|
|
|
def _can_close_circuit(self, provider: Provider) -> bool:
|
|
"""Check if circuit breaker can transition to half-open."""
|
|
if provider.circuit_opened_at is None:
|
|
return False
|
|
elapsed = time.time() - provider.circuit_opened_at
|
|
return elapsed >= self.config.circuit_breaker_recovery_timeout
|
|
|
|
def _close_circuit(self, provider: Provider) -> None:
|
|
"""Close the circuit breaker (provider healthy again)."""
|
|
provider.circuit_state = CircuitState.CLOSED
|
|
provider.circuit_opened_at = None
|
|
provider.half_open_calls = 0
|
|
provider.metrics.consecutive_failures = 0
|
|
provider.status = ProviderStatus.HEALTHY
|
|
logger.info("Circuit breaker CLOSED for %s", provider.name)
|
|
|
|
def reload_config(self) -> dict:
|
|
"""Hot-reload providers.yaml, preserving runtime state.
|
|
|
|
Re-reads the config file, rebuilds the provider list, and
|
|
preserves circuit breaker state and metrics for providers
|
|
that still exist after reload.
|
|
|
|
Returns:
|
|
Summary dict with added/removed/preserved counts.
|
|
"""
|
|
# Snapshot current runtime state keyed by provider name
|
|
old_state: dict[
|
|
str, tuple[ProviderMetrics, CircuitState, float | None, int, ProviderStatus]
|
|
] = {}
|
|
for p in self.providers:
|
|
old_state[p.name] = (
|
|
p.metrics,
|
|
p.circuit_state,
|
|
p.circuit_opened_at,
|
|
p.half_open_calls,
|
|
p.status,
|
|
)
|
|
|
|
old_names = set(old_state.keys())
|
|
|
|
# Reload from disk
|
|
self.providers = []
|
|
self._load_config()
|
|
|
|
# Restore preserved state
|
|
new_names = {p.name for p in self.providers}
|
|
preserved = 0
|
|
for p in self.providers:
|
|
if p.name in old_state:
|
|
metrics, circuit, opened_at, half_open, status = old_state[p.name]
|
|
p.metrics = metrics
|
|
p.circuit_state = circuit
|
|
p.circuit_opened_at = opened_at
|
|
p.half_open_calls = half_open
|
|
p.status = status
|
|
preserved += 1
|
|
|
|
added = new_names - old_names
|
|
removed = old_names - new_names
|
|
|
|
logger.info(
|
|
"Config reloaded: %d providers (%d preserved, %d added, %d removed)",
|
|
len(self.providers),
|
|
preserved,
|
|
len(added),
|
|
len(removed),
|
|
)
|
|
|
|
return {
|
|
"total_providers": len(self.providers),
|
|
"preserved": preserved,
|
|
"added": sorted(added),
|
|
"removed": sorted(removed),
|
|
}
|
|
|
|
def get_metrics(self) -> dict:
|
|
"""Get metrics for all providers."""
|
|
return {
|
|
"providers": [
|
|
{
|
|
"name": p.name,
|
|
"type": p.type,
|
|
"status": p.status.value,
|
|
"circuit_state": p.circuit_state.value,
|
|
"metrics": {
|
|
"total_requests": p.metrics.total_requests,
|
|
"successful": p.metrics.successful_requests,
|
|
"failed": p.metrics.failed_requests,
|
|
"error_rate": round(p.metrics.error_rate, 3),
|
|
"avg_latency_ms": round(p.metrics.avg_latency_ms, 2),
|
|
},
|
|
}
|
|
for p in self.providers
|
|
]
|
|
}
|
|
|
|
def get_status(self) -> dict:
|
|
"""Get current router status."""
|
|
healthy = sum(1 for p in self.providers if p.status == ProviderStatus.HEALTHY)
|
|
|
|
return {
|
|
"total_providers": len(self.providers),
|
|
"healthy_providers": healthy,
|
|
"degraded_providers": sum(
|
|
1 for p in self.providers if p.status == ProviderStatus.DEGRADED
|
|
),
|
|
"unhealthy_providers": sum(
|
|
1 for p in self.providers if p.status == ProviderStatus.UNHEALTHY
|
|
),
|
|
"providers": [
|
|
{
|
|
"name": p.name,
|
|
"type": p.type,
|
|
"status": p.status.value,
|
|
"priority": p.priority,
|
|
"default_model": p.get_default_model(),
|
|
}
|
|
for p in self.providers
|
|
],
|
|
}
|
|
|
|
async def generate_with_image(
|
|
self,
|
|
prompt: str,
|
|
image_path: str,
|
|
model: str | None = None,
|
|
temperature: float = 0.7,
|
|
) -> dict:
|
|
"""Convenience method for vision requests.
|
|
|
|
Args:
|
|
prompt: Text prompt about the image
|
|
image_path: Path to image file
|
|
model: Vision-capable model (auto-selected if not provided)
|
|
temperature: Sampling temperature
|
|
|
|
Returns:
|
|
Response dict with content and metadata
|
|
"""
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": prompt,
|
|
"images": [image_path],
|
|
}
|
|
]
|
|
return await self.complete(
|
|
messages=messages,
|
|
model=model,
|
|
temperature=temperature,
|
|
)
|
|
|
|
|
|
# Module-level singleton
|
|
cascade_router: CascadeRouter | None = None
|
|
|
|
|
|
def get_router() -> CascadeRouter:
|
|
"""Get or create the cascade router singleton."""
|
|
global cascade_router
|
|
if cascade_router is None:
|
|
cascade_router = CascadeRouter()
|
|
return cascade_router
|