This commit was merged in pull request #1448.
This commit is contained in:
@@ -9,12 +9,7 @@ models for image inputs and falls back through capability chains.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
@@ -33,148 +28,25 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
requests = None # type: ignore
|
requests = None # type: ignore
|
||||||
|
|
||||||
|
# Re-export data models so existing ``from …cascade import X`` keeps working.
|
||||||
|
from .models import ( # noqa: F401 – re-exports
|
||||||
|
CircuitState,
|
||||||
|
ContentType,
|
||||||
|
ModelCapability,
|
||||||
|
Provider,
|
||||||
|
ProviderMetrics,
|
||||||
|
ProviderStatus,
|
||||||
|
RouterConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mixins
|
||||||
|
from .health import HealthMixin
|
||||||
|
from .providers import ProviderCallsMixin
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Quota monitor — optional, degrades gracefully if unavailable
|
|
||||||
try:
|
|
||||||
from infrastructure.claude_quota import QuotaMonitor, get_quota_monitor
|
|
||||||
|
|
||||||
_quota_monitor: "QuotaMonitor | None" = get_quota_monitor()
|
class CascadeRouter(HealthMixin, ProviderCallsMixin):
|
||||||
except Exception as _exc: # pragma: no cover
|
|
||||||
logger.debug("Quota monitor not available: %s", _exc)
|
|
||||||
_quota_monitor = None
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderStatus(Enum):
|
|
||||||
"""Health status of a provider."""
|
|
||||||
|
|
||||||
HEALTHY = "healthy"
|
|
||||||
DEGRADED = "degraded" # Working but slow or occasional errors
|
|
||||||
UNHEALTHY = "unhealthy" # Circuit breaker open
|
|
||||||
DISABLED = "disabled"
|
|
||||||
|
|
||||||
|
|
||||||
class CircuitState(Enum):
|
|
||||||
"""Circuit breaker state."""
|
|
||||||
|
|
||||||
CLOSED = "closed" # Normal operation
|
|
||||||
OPEN = "open" # Failing, rejecting requests
|
|
||||||
HALF_OPEN = "half_open" # Testing if recovered
|
|
||||||
|
|
||||||
|
|
||||||
class ContentType(Enum):
|
|
||||||
"""Type of content in the request."""
|
|
||||||
|
|
||||||
TEXT = "text"
|
|
||||||
VISION = "vision" # Contains images
|
|
||||||
AUDIO = "audio" # Contains audio
|
|
||||||
MULTIMODAL = "multimodal" # Multiple content types
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ProviderMetrics:
|
|
||||||
"""Metrics for a single provider."""
|
|
||||||
|
|
||||||
total_requests: int = 0
|
|
||||||
successful_requests: int = 0
|
|
||||||
failed_requests: int = 0
|
|
||||||
total_latency_ms: float = 0.0
|
|
||||||
last_request_time: str | None = None
|
|
||||||
last_error_time: str | None = None
|
|
||||||
consecutive_failures: int = 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def avg_latency_ms(self) -> float:
|
|
||||||
if self.total_requests == 0:
|
|
||||||
return 0.0
|
|
||||||
return self.total_latency_ms / self.total_requests
|
|
||||||
|
|
||||||
@property
|
|
||||||
def error_rate(self) -> float:
|
|
||||||
if self.total_requests == 0:
|
|
||||||
return 0.0
|
|
||||||
return self.failed_requests / self.total_requests
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelCapability:
|
|
||||||
"""Capabilities a model supports."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
supports_vision: bool = False
|
|
||||||
supports_audio: bool = False
|
|
||||||
supports_tools: bool = False
|
|
||||||
supports_json: bool = False
|
|
||||||
supports_streaming: bool = True
|
|
||||||
context_window: int = 4096
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Provider:
|
|
||||||
"""LLM provider configuration and state."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
type: str # ollama, openai, anthropic
|
|
||||||
enabled: bool
|
|
||||||
priority: int
|
|
||||||
tier: str | None = None # e.g., "local", "standard_cloud", "frontier"
|
|
||||||
url: str | None = None
|
|
||||||
api_key: str | None = None
|
|
||||||
base_url: str | None = None
|
|
||||||
models: list[dict] = field(default_factory=list)
|
|
||||||
|
|
||||||
# Runtime state
|
|
||||||
status: ProviderStatus = ProviderStatus.HEALTHY
|
|
||||||
metrics: ProviderMetrics = field(default_factory=ProviderMetrics)
|
|
||||||
circuit_state: CircuitState = CircuitState.CLOSED
|
|
||||||
circuit_opened_at: float | None = None
|
|
||||||
half_open_calls: int = 0
|
|
||||||
|
|
||||||
def get_default_model(self) -> str | None:
|
|
||||||
"""Get the default model for this provider."""
|
|
||||||
for model in self.models:
|
|
||||||
if model.get("default"):
|
|
||||||
return model["name"]
|
|
||||||
if self.models:
|
|
||||||
return self.models[0]["name"]
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_model_with_capability(self, capability: str) -> str | None:
|
|
||||||
"""Get a model that supports the given capability."""
|
|
||||||
for model in self.models:
|
|
||||||
capabilities = model.get("capabilities", [])
|
|
||||||
if capability in capabilities:
|
|
||||||
return model["name"]
|
|
||||||
# Fall back to default
|
|
||||||
return self.get_default_model()
|
|
||||||
|
|
||||||
def model_has_capability(self, model_name: str, capability: str) -> bool:
|
|
||||||
"""Check if a specific model has a capability."""
|
|
||||||
for model in self.models:
|
|
||||||
if model["name"] == model_name:
|
|
||||||
capabilities = model.get("capabilities", [])
|
|
||||||
return capability in capabilities
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RouterConfig:
|
|
||||||
"""Cascade router configuration."""
|
|
||||||
|
|
||||||
timeout_seconds: int = 30
|
|
||||||
max_retries_per_provider: int = 2
|
|
||||||
retry_delay_seconds: int = 1
|
|
||||||
circuit_breaker_failure_threshold: int = 5
|
|
||||||
circuit_breaker_recovery_timeout: int = 60
|
|
||||||
circuit_breaker_half_open_max_calls: int = 2
|
|
||||||
cost_tracking_enabled: bool = True
|
|
||||||
budget_daily_usd: float = 10.0
|
|
||||||
# Multi-modal settings
|
|
||||||
auto_pull_models: bool = True
|
|
||||||
fallback_chains: dict = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class CascadeRouter:
|
|
||||||
"""Routes LLM requests with automatic failover.
|
"""Routes LLM requests with automatic failover.
|
||||||
|
|
||||||
Now with multi-modal support:
|
Now with multi-modal support:
|
||||||
@@ -487,50 +359,6 @@ class CascadeRouter:
|
|||||||
|
|
||||||
raise RuntimeError("; ".join(errors))
|
raise RuntimeError("; ".join(errors))
|
||||||
|
|
||||||
def _quota_allows_cloud(self, provider: Provider) -> bool:
|
|
||||||
"""Check quota before routing to a cloud provider.
|
|
||||||
|
|
||||||
Uses the metabolic protocol via select_model(): cloud calls are only
|
|
||||||
allowed when the quota monitor recommends a cloud model (BURST tier).
|
|
||||||
Returns True (allow cloud) if quota monitor is unavailable or returns None.
|
|
||||||
"""
|
|
||||||
if _quota_monitor is None:
|
|
||||||
return True
|
|
||||||
try:
|
|
||||||
suggested = _quota_monitor.select_model("high")
|
|
||||||
# Cloud is allowed only when select_model recommends the cloud model
|
|
||||||
allows = suggested == "claude-sonnet-4-6"
|
|
||||||
if not allows:
|
|
||||||
status = _quota_monitor.check()
|
|
||||||
tier = status.recommended_tier.value if status else "unknown"
|
|
||||||
logger.info(
|
|
||||||
"Metabolic protocol: %s tier — downshifting %s to local (%s)",
|
|
||||||
tier,
|
|
||||||
provider.name,
|
|
||||||
suggested,
|
|
||||||
)
|
|
||||||
return allows
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("Quota check failed, allowing cloud: %s", exc)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _is_provider_available(self, provider: Provider) -> bool:
|
|
||||||
"""Check if a provider should be tried (enabled + circuit breaker)."""
|
|
||||||
if not provider.enabled:
|
|
||||||
logger.debug("Skipping %s (disabled)", provider.name)
|
|
||||||
return False
|
|
||||||
|
|
||||||
if provider.status == ProviderStatus.UNHEALTHY:
|
|
||||||
if self._can_close_circuit(provider):
|
|
||||||
provider.circuit_state = CircuitState.HALF_OPEN
|
|
||||||
provider.half_open_calls = 0
|
|
||||||
logger.info("Circuit breaker half-open for %s", provider.name)
|
|
||||||
else:
|
|
||||||
logger.debug("Skipping %s (circuit open)", provider.name)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _filter_providers(self, cascade_tier: str | None) -> list["Provider"]:
|
def _filter_providers(self, cascade_tier: str | None) -> list["Provider"]:
|
||||||
"""Return the provider list filtered by tier.
|
"""Return the provider list filtered by tier.
|
||||||
|
|
||||||
@@ -641,9 +469,9 @@ class CascadeRouter:
|
|||||||
- Supports image URLs, paths, and base64 encoding
|
- Supports image URLs, paths, and base64 encoding
|
||||||
|
|
||||||
Complexity-based routing (issue #1065):
|
Complexity-based routing (issue #1065):
|
||||||
- ``complexity_hint="simple"`` → routes to Qwen3-8B (low-latency)
|
- ``complexity_hint="simple"`` -> routes to Qwen3-8B (low-latency)
|
||||||
- ``complexity_hint="complex"`` → routes to Qwen3-14B (quality)
|
- ``complexity_hint="complex"`` -> routes to Qwen3-14B (quality)
|
||||||
- ``complexity_hint=None`` (default) → auto-classifies from messages
|
- ``complexity_hint=None`` (default) -> auto-classifies from messages
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of message dicts with role and content
|
messages: List of message dicts with role and content
|
||||||
@@ -668,7 +496,7 @@ class CascadeRouter:
|
|||||||
if content_type != ContentType.TEXT:
|
if content_type != ContentType.TEXT:
|
||||||
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
|
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
|
||||||
|
|
||||||
# Resolve task complexity ─────────────────────────────────────────────
|
# Resolve task complexity
|
||||||
# Skip complexity routing when caller explicitly specifies a model.
|
# Skip complexity routing when caller explicitly specifies a model.
|
||||||
complexity: TaskComplexity | None = None
|
complexity: TaskComplexity | None = None
|
||||||
if model is None:
|
if model is None:
|
||||||
@@ -698,7 +526,7 @@ class CascadeRouter:
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Complexity-based model selection (only when no explicit model) ──
|
# Complexity-based model selection (only when no explicit model)
|
||||||
effective_model = model
|
effective_model = model
|
||||||
if effective_model is None and complexity is not None:
|
if effective_model is None and complexity is not None:
|
||||||
effective_model = self._get_model_for_complexity(provider, complexity)
|
effective_model = self._get_model_for_complexity(provider, complexity)
|
||||||
@@ -740,357 +568,6 @@ class CascadeRouter:
|
|||||||
|
|
||||||
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
|
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
|
||||||
|
|
||||||
async def _try_provider(
|
|
||||||
self,
|
|
||||||
provider: Provider,
|
|
||||||
messages: list[dict],
|
|
||||||
model: str,
|
|
||||||
temperature: float,
|
|
||||||
max_tokens: int | None,
|
|
||||||
content_type: ContentType = ContentType.TEXT,
|
|
||||||
) -> dict:
|
|
||||||
"""Try a single provider request."""
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
if provider.type == "ollama":
|
|
||||||
result = await self._call_ollama(
|
|
||||||
provider=provider,
|
|
||||||
messages=messages,
|
|
||||||
model=model or provider.get_default_model(),
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
content_type=content_type,
|
|
||||||
)
|
|
||||||
elif provider.type == "openai":
|
|
||||||
result = await self._call_openai(
|
|
||||||
provider=provider,
|
|
||||||
messages=messages,
|
|
||||||
model=model or provider.get_default_model(),
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
)
|
|
||||||
elif provider.type == "anthropic":
|
|
||||||
result = await self._call_anthropic(
|
|
||||||
provider=provider,
|
|
||||||
messages=messages,
|
|
||||||
model=model or provider.get_default_model(),
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
)
|
|
||||||
elif provider.type == "grok":
|
|
||||||
result = await self._call_grok(
|
|
||||||
provider=provider,
|
|
||||||
messages=messages,
|
|
||||||
model=model or provider.get_default_model(),
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
)
|
|
||||||
elif provider.type == "vllm_mlx":
|
|
||||||
result = await self._call_vllm_mlx(
|
|
||||||
provider=provider,
|
|
||||||
messages=messages,
|
|
||||||
model=model or provider.get_default_model(),
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown provider type: {provider.type}")
|
|
||||||
|
|
||||||
latency_ms = (time.time() - start_time) * 1000
|
|
||||||
result["latency_ms"] = latency_ms
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def _call_ollama(
|
|
||||||
self,
|
|
||||||
provider: Provider,
|
|
||||||
messages: list[dict],
|
|
||||||
model: str,
|
|
||||||
temperature: float,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
content_type: ContentType = ContentType.TEXT,
|
|
||||||
) -> dict:
|
|
||||||
"""Call Ollama API with multi-modal support."""
|
|
||||||
import aiohttp
|
|
||||||
|
|
||||||
url = f"{provider.url or settings.ollama_url}/api/chat"
|
|
||||||
|
|
||||||
# Transform messages for Ollama format (including images)
|
|
||||||
transformed_messages = self._transform_messages_for_ollama(messages)
|
|
||||||
|
|
||||||
options = {"temperature": temperature}
|
|
||||||
if max_tokens:
|
|
||||||
options["num_predict"] = max_tokens
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"model": model,
|
|
||||||
"messages": transformed_messages,
|
|
||||||
"stream": False,
|
|
||||||
"options": options,
|
|
||||||
}
|
|
||||||
|
|
||||||
timeout = aiohttp.ClientTimeout(total=self.config.timeout_seconds)
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
||||||
async with session.post(url, json=payload) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
text = await response.text()
|
|
||||||
raise RuntimeError(f"Ollama error {response.status}: {text}")
|
|
||||||
|
|
||||||
data = await response.json()
|
|
||||||
return {
|
|
||||||
"content": data["message"]["content"],
|
|
||||||
"model": model,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]:
|
|
||||||
"""Transform messages to Ollama format, handling images."""
|
|
||||||
transformed = []
|
|
||||||
|
|
||||||
for msg in messages:
|
|
||||||
new_msg = {
|
|
||||||
"role": msg.get("role", "user"),
|
|
||||||
"content": msg.get("content", ""),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Handle images
|
|
||||||
images = msg.get("images", [])
|
|
||||||
if images:
|
|
||||||
new_msg["images"] = []
|
|
||||||
for img in images:
|
|
||||||
if isinstance(img, str):
|
|
||||||
if img.startswith("data:image/"):
|
|
||||||
# Base64 encoded image
|
|
||||||
new_msg["images"].append(img.split(",")[1])
|
|
||||||
elif img.startswith("http://") or img.startswith("https://"):
|
|
||||||
# URL - would need to download, skip for now
|
|
||||||
logger.warning("Image URLs not yet supported, skipping: %s", img)
|
|
||||||
elif Path(img).exists():
|
|
||||||
# Local file path - read and encode
|
|
||||||
try:
|
|
||||||
with open(img, "rb") as f:
|
|
||||||
img_data = base64.b64encode(f.read()).decode()
|
|
||||||
new_msg["images"].append(img_data)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("Failed to read image %s: %s", img, exc)
|
|
||||||
|
|
||||||
transformed.append(new_msg)
|
|
||||||
|
|
||||||
return transformed
|
|
||||||
|
|
||||||
async def _call_openai(
|
|
||||||
self,
|
|
||||||
provider: Provider,
|
|
||||||
messages: list[dict],
|
|
||||||
model: str,
|
|
||||||
temperature: float,
|
|
||||||
max_tokens: int | None,
|
|
||||||
) -> dict:
|
|
||||||
"""Call OpenAI API."""
|
|
||||||
import openai
|
|
||||||
|
|
||||||
client = openai.AsyncOpenAI(
|
|
||||||
api_key=provider.api_key,
|
|
||||||
base_url=provider.base_url,
|
|
||||||
timeout=self.config.timeout_seconds,
|
|
||||||
)
|
|
||||||
|
|
||||||
kwargs = {
|
|
||||||
"model": model,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": temperature,
|
|
||||||
}
|
|
||||||
if max_tokens:
|
|
||||||
kwargs["max_tokens"] = max_tokens
|
|
||||||
|
|
||||||
response = await client.chat.completions.create(**kwargs)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"content": response.choices[0].message.content,
|
|
||||||
"model": response.model,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _call_anthropic(
|
|
||||||
self,
|
|
||||||
provider: Provider,
|
|
||||||
messages: list[dict],
|
|
||||||
model: str,
|
|
||||||
temperature: float,
|
|
||||||
max_tokens: int | None,
|
|
||||||
) -> dict:
|
|
||||||
"""Call Anthropic API."""
|
|
||||||
import anthropic
|
|
||||||
|
|
||||||
client = anthropic.AsyncAnthropic(
|
|
||||||
api_key=provider.api_key,
|
|
||||||
timeout=self.config.timeout_seconds,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert messages to Anthropic format
|
|
||||||
system_msg = None
|
|
||||||
conversation = []
|
|
||||||
for msg in messages:
|
|
||||||
if msg["role"] == "system":
|
|
||||||
system_msg = msg["content"]
|
|
||||||
else:
|
|
||||||
conversation.append(
|
|
||||||
{
|
|
||||||
"role": msg["role"],
|
|
||||||
"content": msg["content"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
kwargs = {
|
|
||||||
"model": model,
|
|
||||||
"messages": conversation,
|
|
||||||
"temperature": temperature,
|
|
||||||
"max_tokens": max_tokens or 1024,
|
|
||||||
}
|
|
||||||
if system_msg:
|
|
||||||
kwargs["system"] = system_msg
|
|
||||||
|
|
||||||
response = await client.messages.create(**kwargs)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"content": response.content[0].text,
|
|
||||||
"model": response.model,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _call_grok(
|
|
||||||
self,
|
|
||||||
provider: Provider,
|
|
||||||
messages: list[dict],
|
|
||||||
model: str,
|
|
||||||
temperature: float,
|
|
||||||
max_tokens: int | None,
|
|
||||||
) -> dict:
|
|
||||||
"""Call xAI Grok API via OpenAI-compatible SDK."""
|
|
||||||
import httpx
|
|
||||||
import openai
|
|
||||||
|
|
||||||
client = openai.AsyncOpenAI(
|
|
||||||
api_key=provider.api_key,
|
|
||||||
base_url=provider.base_url or settings.xai_base_url,
|
|
||||||
timeout=httpx.Timeout(300.0),
|
|
||||||
)
|
|
||||||
|
|
||||||
kwargs = {
|
|
||||||
"model": model,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": temperature,
|
|
||||||
}
|
|
||||||
if max_tokens:
|
|
||||||
kwargs["max_tokens"] = max_tokens
|
|
||||||
|
|
||||||
response = await client.chat.completions.create(**kwargs)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"content": response.choices[0].message.content,
|
|
||||||
"model": response.model,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _call_vllm_mlx(
|
|
||||||
self,
|
|
||||||
provider: Provider,
|
|
||||||
messages: list[dict],
|
|
||||||
model: str,
|
|
||||||
temperature: float,
|
|
||||||
max_tokens: int | None,
|
|
||||||
) -> dict:
|
|
||||||
"""Call vllm-mlx via its OpenAI-compatible API.
|
|
||||||
|
|
||||||
vllm-mlx exposes the same /v1/chat/completions endpoint as OpenAI,
|
|
||||||
so we reuse the OpenAI client pointed at the local server.
|
|
||||||
No API key is required for local deployments.
|
|
||||||
"""
|
|
||||||
import openai
|
|
||||||
|
|
||||||
base_url = provider.base_url or provider.url or "http://localhost:8000"
|
|
||||||
# Ensure the base_url ends with /v1 as expected by the OpenAI client
|
|
||||||
if not base_url.rstrip("/").endswith("/v1"):
|
|
||||||
base_url = base_url.rstrip("/") + "/v1"
|
|
||||||
|
|
||||||
client = openai.AsyncOpenAI(
|
|
||||||
api_key=provider.api_key or "no-key-required",
|
|
||||||
base_url=base_url,
|
|
||||||
timeout=self.config.timeout_seconds,
|
|
||||||
)
|
|
||||||
|
|
||||||
kwargs: dict = {
|
|
||||||
"model": model,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": temperature,
|
|
||||||
}
|
|
||||||
if max_tokens:
|
|
||||||
kwargs["max_tokens"] = max_tokens
|
|
||||||
|
|
||||||
response = await client.chat.completions.create(**kwargs)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"content": response.choices[0].message.content,
|
|
||||||
"model": response.model,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _record_success(self, provider: Provider, latency_ms: float) -> None:
|
|
||||||
"""Record a successful request."""
|
|
||||||
provider.metrics.total_requests += 1
|
|
||||||
provider.metrics.successful_requests += 1
|
|
||||||
provider.metrics.total_latency_ms += latency_ms
|
|
||||||
provider.metrics.last_request_time = datetime.now(UTC).isoformat()
|
|
||||||
provider.metrics.consecutive_failures = 0
|
|
||||||
|
|
||||||
# Close circuit breaker if half-open
|
|
||||||
if provider.circuit_state == CircuitState.HALF_OPEN:
|
|
||||||
provider.half_open_calls += 1
|
|
||||||
if provider.half_open_calls >= self.config.circuit_breaker_half_open_max_calls:
|
|
||||||
self._close_circuit(provider)
|
|
||||||
|
|
||||||
# Update status based on error rate
|
|
||||||
if provider.metrics.error_rate < 0.1:
|
|
||||||
provider.status = ProviderStatus.HEALTHY
|
|
||||||
elif provider.metrics.error_rate < 0.3:
|
|
||||||
provider.status = ProviderStatus.DEGRADED
|
|
||||||
|
|
||||||
def _record_failure(self, provider: Provider) -> None:
|
|
||||||
"""Record a failed request."""
|
|
||||||
provider.metrics.total_requests += 1
|
|
||||||
provider.metrics.failed_requests += 1
|
|
||||||
provider.metrics.last_error_time = datetime.now(UTC).isoformat()
|
|
||||||
provider.metrics.consecutive_failures += 1
|
|
||||||
|
|
||||||
# Check if we should open circuit breaker
|
|
||||||
if provider.metrics.consecutive_failures >= self.config.circuit_breaker_failure_threshold:
|
|
||||||
self._open_circuit(provider)
|
|
||||||
|
|
||||||
# Update status
|
|
||||||
if provider.metrics.error_rate > 0.3:
|
|
||||||
provider.status = ProviderStatus.DEGRADED
|
|
||||||
if provider.metrics.error_rate > 0.5:
|
|
||||||
provider.status = ProviderStatus.UNHEALTHY
|
|
||||||
|
|
||||||
def _open_circuit(self, provider: Provider) -> None:
|
|
||||||
"""Open the circuit breaker for a provider."""
|
|
||||||
provider.circuit_state = CircuitState.OPEN
|
|
||||||
provider.circuit_opened_at = time.time()
|
|
||||||
provider.status = ProviderStatus.UNHEALTHY
|
|
||||||
logger.warning("Circuit breaker OPEN for %s", provider.name)
|
|
||||||
|
|
||||||
def _can_close_circuit(self, provider: Provider) -> bool:
|
|
||||||
"""Check if circuit breaker can transition to half-open."""
|
|
||||||
if provider.circuit_opened_at is None:
|
|
||||||
return False
|
|
||||||
elapsed = time.time() - provider.circuit_opened_at
|
|
||||||
return elapsed >= self.config.circuit_breaker_recovery_timeout
|
|
||||||
|
|
||||||
def _close_circuit(self, provider: Provider) -> None:
|
|
||||||
"""Close the circuit breaker (provider healthy again)."""
|
|
||||||
provider.circuit_state = CircuitState.CLOSED
|
|
||||||
provider.circuit_opened_at = None
|
|
||||||
provider.half_open_calls = 0
|
|
||||||
provider.metrics.consecutive_failures = 0
|
|
||||||
provider.status = ProviderStatus.HEALTHY
|
|
||||||
logger.info("Circuit breaker CLOSED for %s", provider.name)
|
|
||||||
|
|
||||||
def reload_config(self) -> dict:
|
def reload_config(self) -> dict:
|
||||||
"""Hot-reload providers.yaml, preserving runtime state.
|
"""Hot-reload providers.yaml, preserving runtime state.
|
||||||
|
|
||||||
|
|||||||
137
src/infrastructure/router/health.py
Normal file
137
src/infrastructure/router/health.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""Health monitoring and circuit breaker mixin for the Cascade Router.
|
||||||
|
|
||||||
|
Provides failure tracking, circuit breaker state transitions,
|
||||||
|
and quota-based cloud provider gating.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from .models import CircuitState, Provider, ProviderMetrics, ProviderStatus
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Quota monitor — optional, degrades gracefully if unavailable
|
||||||
|
try:
|
||||||
|
from infrastructure.claude_quota import QuotaMonitor, get_quota_monitor
|
||||||
|
|
||||||
|
_quota_monitor: "QuotaMonitor | None" = get_quota_monitor()
|
||||||
|
except Exception as _exc: # pragma: no cover
|
||||||
|
logger.debug("Quota monitor not available: %s", _exc)
|
||||||
|
_quota_monitor = None
|
||||||
|
|
||||||
|
|
||||||
|
class HealthMixin:
|
||||||
|
"""Mixin providing health tracking, circuit breaker, and quota checks.
|
||||||
|
|
||||||
|
Expects the consuming class to have:
|
||||||
|
- self.config: RouterConfig
|
||||||
|
- self.providers: list[Provider]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _record_success(self, provider: Provider, latency_ms: float) -> None:
|
||||||
|
"""Record a successful request."""
|
||||||
|
provider.metrics.total_requests += 1
|
||||||
|
provider.metrics.successful_requests += 1
|
||||||
|
provider.metrics.total_latency_ms += latency_ms
|
||||||
|
provider.metrics.last_request_time = datetime.now(UTC).isoformat()
|
||||||
|
provider.metrics.consecutive_failures = 0
|
||||||
|
|
||||||
|
# Close circuit breaker if half-open
|
||||||
|
if provider.circuit_state == CircuitState.HALF_OPEN:
|
||||||
|
provider.half_open_calls += 1
|
||||||
|
if provider.half_open_calls >= self.config.circuit_breaker_half_open_max_calls:
|
||||||
|
self._close_circuit(provider)
|
||||||
|
|
||||||
|
# Update status based on error rate
|
||||||
|
if provider.metrics.error_rate < 0.1:
|
||||||
|
provider.status = ProviderStatus.HEALTHY
|
||||||
|
elif provider.metrics.error_rate < 0.3:
|
||||||
|
provider.status = ProviderStatus.DEGRADED
|
||||||
|
|
||||||
|
def _record_failure(self, provider: Provider) -> None:
|
||||||
|
"""Record a failed request."""
|
||||||
|
provider.metrics.total_requests += 1
|
||||||
|
provider.metrics.failed_requests += 1
|
||||||
|
provider.metrics.last_error_time = datetime.now(UTC).isoformat()
|
||||||
|
provider.metrics.consecutive_failures += 1
|
||||||
|
|
||||||
|
# Check if we should open circuit breaker
|
||||||
|
if provider.metrics.consecutive_failures >= self.config.circuit_breaker_failure_threshold:
|
||||||
|
self._open_circuit(provider)
|
||||||
|
|
||||||
|
# Update status
|
||||||
|
if provider.metrics.error_rate > 0.3:
|
||||||
|
provider.status = ProviderStatus.DEGRADED
|
||||||
|
if provider.metrics.error_rate > 0.5:
|
||||||
|
provider.status = ProviderStatus.UNHEALTHY
|
||||||
|
|
||||||
|
def _open_circuit(self, provider: Provider) -> None:
|
||||||
|
"""Open the circuit breaker for a provider."""
|
||||||
|
provider.circuit_state = CircuitState.OPEN
|
||||||
|
provider.circuit_opened_at = time.time()
|
||||||
|
provider.status = ProviderStatus.UNHEALTHY
|
||||||
|
logger.warning("Circuit breaker OPEN for %s", provider.name)
|
||||||
|
|
||||||
|
def _can_close_circuit(self, provider: Provider) -> bool:
|
||||||
|
"""Check if circuit breaker can transition to half-open."""
|
||||||
|
if provider.circuit_opened_at is None:
|
||||||
|
return False
|
||||||
|
elapsed = time.time() - provider.circuit_opened_at
|
||||||
|
return elapsed >= self.config.circuit_breaker_recovery_timeout
|
||||||
|
|
||||||
|
def _close_circuit(self, provider: Provider) -> None:
|
||||||
|
"""Close the circuit breaker (provider healthy again)."""
|
||||||
|
provider.circuit_state = CircuitState.CLOSED
|
||||||
|
provider.circuit_opened_at = None
|
||||||
|
provider.half_open_calls = 0
|
||||||
|
provider.metrics.consecutive_failures = 0
|
||||||
|
provider.status = ProviderStatus.HEALTHY
|
||||||
|
logger.info("Circuit breaker CLOSED for %s", provider.name)
|
||||||
|
|
||||||
|
def _is_provider_available(self, provider: Provider) -> bool:
|
||||||
|
"""Check if a provider should be tried (enabled + circuit breaker)."""
|
||||||
|
if not provider.enabled:
|
||||||
|
logger.debug("Skipping %s (disabled)", provider.name)
|
||||||
|
return False
|
||||||
|
|
||||||
|
if provider.status == ProviderStatus.UNHEALTHY:
|
||||||
|
if self._can_close_circuit(provider):
|
||||||
|
provider.circuit_state = CircuitState.HALF_OPEN
|
||||||
|
provider.half_open_calls = 0
|
||||||
|
logger.info("Circuit breaker half-open for %s", provider.name)
|
||||||
|
else:
|
||||||
|
logger.debug("Skipping %s (circuit open)", provider.name)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _quota_allows_cloud(self, provider: Provider) -> bool:
|
||||||
|
"""Check quota before routing to a cloud provider.
|
||||||
|
|
||||||
|
Uses the metabolic protocol via select_model(): cloud calls are only
|
||||||
|
allowed when the quota monitor recommends a cloud model (BURST tier).
|
||||||
|
Returns True (allow cloud) if quota monitor is unavailable or returns None.
|
||||||
|
"""
|
||||||
|
if _quota_monitor is None:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
suggested = _quota_monitor.select_model("high")
|
||||||
|
# Cloud is allowed only when select_model recommends the cloud model
|
||||||
|
allows = suggested == "claude-sonnet-4-6"
|
||||||
|
if not allows:
|
||||||
|
status = _quota_monitor.check()
|
||||||
|
tier = status.recommended_tier.value if status else "unknown"
|
||||||
|
logger.info(
|
||||||
|
"Metabolic protocol: %s tier — downshifting %s to local (%s)",
|
||||||
|
tier,
|
||||||
|
provider.name,
|
||||||
|
suggested,
|
||||||
|
)
|
||||||
|
return allows
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Quota check failed, allowing cloud: %s", exc)
|
||||||
|
return True
|
||||||
138
src/infrastructure/router/models.py
Normal file
138
src/infrastructure/router/models.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
"""Data models for the Cascade LLM Router.
|
||||||
|
|
||||||
|
Enums, dataclasses, and configuration objects shared across router modules.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderStatus(Enum):
|
||||||
|
"""Health status of a provider."""
|
||||||
|
|
||||||
|
HEALTHY = "healthy"
|
||||||
|
DEGRADED = "degraded" # Working but slow or occasional errors
|
||||||
|
UNHEALTHY = "unhealthy" # Circuit breaker open
|
||||||
|
DISABLED = "disabled"
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitState(Enum):
|
||||||
|
"""Circuit breaker state."""
|
||||||
|
|
||||||
|
CLOSED = "closed" # Normal operation
|
||||||
|
OPEN = "open" # Failing, rejecting requests
|
||||||
|
HALF_OPEN = "half_open" # Testing if recovered
|
||||||
|
|
||||||
|
|
||||||
|
class ContentType(Enum):
|
||||||
|
"""Type of content in the request."""
|
||||||
|
|
||||||
|
TEXT = "text"
|
||||||
|
VISION = "vision" # Contains images
|
||||||
|
AUDIO = "audio" # Contains audio
|
||||||
|
MULTIMODAL = "multimodal" # Multiple content types
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProviderMetrics:
|
||||||
|
"""Metrics for a single provider."""
|
||||||
|
|
||||||
|
total_requests: int = 0
|
||||||
|
successful_requests: int = 0
|
||||||
|
failed_requests: int = 0
|
||||||
|
total_latency_ms: float = 0.0
|
||||||
|
last_request_time: str | None = None
|
||||||
|
last_error_time: str | None = None
|
||||||
|
consecutive_failures: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avg_latency_ms(self) -> float:
|
||||||
|
if self.total_requests == 0:
|
||||||
|
return 0.0
|
||||||
|
return self.total_latency_ms / self.total_requests
|
||||||
|
|
||||||
|
@property
|
||||||
|
def error_rate(self) -> float:
|
||||||
|
if self.total_requests == 0:
|
||||||
|
return 0.0
|
||||||
|
return self.failed_requests / self.total_requests
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelCapability:
|
||||||
|
"""Capabilities a model supports."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
supports_vision: bool = False
|
||||||
|
supports_audio: bool = False
|
||||||
|
supports_tools: bool = False
|
||||||
|
supports_json: bool = False
|
||||||
|
supports_streaming: bool = True
|
||||||
|
context_window: int = 4096
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Provider:
|
||||||
|
"""LLM provider configuration and state."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
type: str # ollama, openai, anthropic
|
||||||
|
enabled: bool
|
||||||
|
priority: int
|
||||||
|
tier: str | None = None # e.g., "local", "standard_cloud", "frontier"
|
||||||
|
url: str | None = None
|
||||||
|
api_key: str | None = None
|
||||||
|
base_url: str | None = None
|
||||||
|
models: list[dict] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Runtime state
|
||||||
|
status: ProviderStatus = ProviderStatus.HEALTHY
|
||||||
|
metrics: ProviderMetrics = field(default_factory=ProviderMetrics)
|
||||||
|
circuit_state: CircuitState = CircuitState.CLOSED
|
||||||
|
circuit_opened_at: float | None = None
|
||||||
|
half_open_calls: int = 0
|
||||||
|
|
||||||
|
def get_default_model(self) -> str | None:
|
||||||
|
"""Get the default model for this provider."""
|
||||||
|
for model in self.models:
|
||||||
|
if model.get("default"):
|
||||||
|
return model["name"]
|
||||||
|
if self.models:
|
||||||
|
return self.models[0]["name"]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_model_with_capability(self, capability: str) -> str | None:
|
||||||
|
"""Get a model that supports the given capability."""
|
||||||
|
for model in self.models:
|
||||||
|
capabilities = model.get("capabilities", [])
|
||||||
|
if capability in capabilities:
|
||||||
|
return model["name"]
|
||||||
|
# Fall back to default
|
||||||
|
return self.get_default_model()
|
||||||
|
|
||||||
|
def model_has_capability(self, model_name: str, capability: str) -> bool:
|
||||||
|
"""Check if a specific model has a capability."""
|
||||||
|
for model in self.models:
|
||||||
|
if model["name"] == model_name:
|
||||||
|
capabilities = model.get("capabilities", [])
|
||||||
|
return capability in capabilities
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RouterConfig:
|
||||||
|
"""Cascade router configuration."""
|
||||||
|
|
||||||
|
timeout_seconds: int = 30
|
||||||
|
max_retries_per_provider: int = 2
|
||||||
|
retry_delay_seconds: int = 1
|
||||||
|
circuit_breaker_failure_threshold: int = 5
|
||||||
|
circuit_breaker_recovery_timeout: int = 60
|
||||||
|
circuit_breaker_half_open_max_calls: int = 2
|
||||||
|
cost_tracking_enabled: bool = True
|
||||||
|
budget_daily_usd: float = 10.0
|
||||||
|
# Multi-modal settings
|
||||||
|
auto_pull_models: bool = True
|
||||||
|
fallback_chains: dict = field(default_factory=dict)
|
||||||
318
src/infrastructure/router/providers.py
Normal file
318
src/infrastructure/router/providers.py
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
"""Provider API call mixin for the Cascade Router.
|
||||||
|
|
||||||
|
Contains methods for calling individual LLM provider APIs
|
||||||
|
(Ollama, OpenAI, Anthropic, Grok, vllm-mlx).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from config import settings
|
||||||
|
|
||||||
|
from .models import ContentType, Provider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderCallsMixin:
|
||||||
|
"""Mixin providing LLM provider API call methods.
|
||||||
|
|
||||||
|
Expects the consuming class to have:
|
||||||
|
- self.config: RouterConfig
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _try_provider(
|
||||||
|
self,
|
||||||
|
provider: Provider,
|
||||||
|
messages: list[dict],
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int | None,
|
||||||
|
content_type: ContentType = ContentType.TEXT,
|
||||||
|
) -> dict:
|
||||||
|
"""Try a single provider request."""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
if provider.type == "ollama":
|
||||||
|
result = await self._call_ollama(
|
||||||
|
provider=provider,
|
||||||
|
messages=messages,
|
||||||
|
model=model or provider.get_default_model(),
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
content_type=content_type,
|
||||||
|
)
|
||||||
|
elif provider.type == "openai":
|
||||||
|
result = await self._call_openai(
|
||||||
|
provider=provider,
|
||||||
|
messages=messages,
|
||||||
|
model=model or provider.get_default_model(),
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
elif provider.type == "anthropic":
|
||||||
|
result = await self._call_anthropic(
|
||||||
|
provider=provider,
|
||||||
|
messages=messages,
|
||||||
|
model=model or provider.get_default_model(),
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
elif provider.type == "grok":
|
||||||
|
result = await self._call_grok(
|
||||||
|
provider=provider,
|
||||||
|
messages=messages,
|
||||||
|
model=model or provider.get_default_model(),
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
elif provider.type == "vllm_mlx":
|
||||||
|
result = await self._call_vllm_mlx(
|
||||||
|
provider=provider,
|
||||||
|
messages=messages,
|
||||||
|
model=model or provider.get_default_model(),
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown provider type: {provider.type}")
|
||||||
|
|
||||||
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
result["latency_ms"] = latency_ms
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _call_ollama(
|
||||||
|
self,
|
||||||
|
provider: Provider,
|
||||||
|
messages: list[dict],
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
content_type: ContentType = ContentType.TEXT,
|
||||||
|
) -> dict:
|
||||||
|
"""Call Ollama API with multi-modal support."""
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
url = f"{provider.url or settings.ollama_url}/api/chat"
|
||||||
|
|
||||||
|
# Transform messages for Ollama format (including images)
|
||||||
|
transformed_messages = self._transform_messages_for_ollama(messages)
|
||||||
|
|
||||||
|
options: dict[str, Any] = {"temperature": temperature}
|
||||||
|
if max_tokens:
|
||||||
|
options["num_predict"] = max_tokens
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": transformed_messages,
|
||||||
|
"stream": False,
|
||||||
|
"options": options,
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout = aiohttp.ClientTimeout(total=self.config.timeout_seconds)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
async with session.post(url, json=payload) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
text = await response.text()
|
||||||
|
raise RuntimeError(f"Ollama error {response.status}: {text}")
|
||||||
|
|
||||||
|
data = await response.json()
|
||||||
|
return {
|
||||||
|
"content": data["message"]["content"],
|
||||||
|
"model": model,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]:
|
||||||
|
"""Transform messages to Ollama format, handling images."""
|
||||||
|
transformed = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
new_msg: dict[str, Any] = {
|
||||||
|
"role": msg.get("role", "user"),
|
||||||
|
"content": msg.get("content", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle images
|
||||||
|
images = msg.get("images", [])
|
||||||
|
if images:
|
||||||
|
new_msg["images"] = []
|
||||||
|
for img in images:
|
||||||
|
if isinstance(img, str):
|
||||||
|
if img.startswith("data:image/"):
|
||||||
|
# Base64 encoded image
|
||||||
|
new_msg["images"].append(img.split(",")[1])
|
||||||
|
elif img.startswith("http://") or img.startswith("https://"):
|
||||||
|
# URL - would need to download, skip for now
|
||||||
|
logger.warning("Image URLs not yet supported, skipping: %s", img)
|
||||||
|
elif Path(img).exists():
|
||||||
|
# Local file path - read and encode
|
||||||
|
try:
|
||||||
|
with open(img, "rb") as f:
|
||||||
|
img_data = base64.b64encode(f.read()).decode()
|
||||||
|
new_msg["images"].append(img_data)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to read image %s: %s", img, exc)
|
||||||
|
|
||||||
|
transformed.append(new_msg)
|
||||||
|
|
||||||
|
return transformed
|
||||||
|
|
||||||
|
async def _call_openai(
|
||||||
|
self,
|
||||||
|
provider: Provider,
|
||||||
|
messages: list[dict],
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int | None,
|
||||||
|
) -> dict:
|
||||||
|
"""Call OpenAI API."""
|
||||||
|
import openai
|
||||||
|
|
||||||
|
client = openai.AsyncOpenAI(
|
||||||
|
api_key=provider.api_key,
|
||||||
|
base_url=provider.base_url,
|
||||||
|
timeout=self.config.timeout_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": temperature,
|
||||||
|
}
|
||||||
|
if max_tokens:
|
||||||
|
kwargs["max_tokens"] = max_tokens
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(**kwargs)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": response.choices[0].message.content,
|
||||||
|
"model": response.model,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _call_anthropic(
|
||||||
|
self,
|
||||||
|
provider: Provider,
|
||||||
|
messages: list[dict],
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int | None,
|
||||||
|
) -> dict:
|
||||||
|
"""Call Anthropic API."""
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
client = anthropic.AsyncAnthropic(
|
||||||
|
api_key=provider.api_key,
|
||||||
|
timeout=self.config.timeout_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert messages to Anthropic format
|
||||||
|
system_msg = None
|
||||||
|
conversation = []
|
||||||
|
for msg in messages:
|
||||||
|
if msg["role"] == "system":
|
||||||
|
system_msg = msg["content"]
|
||||||
|
else:
|
||||||
|
conversation.append(
|
||||||
|
{
|
||||||
|
"role": msg["role"],
|
||||||
|
"content": msg["content"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"messages": conversation,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens or 1024,
|
||||||
|
}
|
||||||
|
if system_msg:
|
||||||
|
kwargs["system"] = system_msg
|
||||||
|
|
||||||
|
response = await client.messages.create(**kwargs)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": response.content[0].text,
|
||||||
|
"model": response.model,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _call_grok(
|
||||||
|
self,
|
||||||
|
provider: Provider,
|
||||||
|
messages: list[dict],
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int | None,
|
||||||
|
) -> dict:
|
||||||
|
"""Call xAI Grok API via OpenAI-compatible SDK."""
|
||||||
|
import httpx
|
||||||
|
import openai
|
||||||
|
|
||||||
|
client = openai.AsyncOpenAI(
|
||||||
|
api_key=provider.api_key,
|
||||||
|
base_url=provider.base_url or settings.xai_base_url,
|
||||||
|
timeout=httpx.Timeout(300.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": temperature,
|
||||||
|
}
|
||||||
|
if max_tokens:
|
||||||
|
kwargs["max_tokens"] = max_tokens
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(**kwargs)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": response.choices[0].message.content,
|
||||||
|
"model": response.model,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _call_vllm_mlx(
|
||||||
|
self,
|
||||||
|
provider: Provider,
|
||||||
|
messages: list[dict],
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int | None,
|
||||||
|
) -> dict:
|
||||||
|
"""Call vllm-mlx via its OpenAI-compatible API.
|
||||||
|
|
||||||
|
vllm-mlx exposes the same /v1/chat/completions endpoint as OpenAI,
|
||||||
|
so we reuse the OpenAI client pointed at the local server.
|
||||||
|
No API key is required for local deployments.
|
||||||
|
"""
|
||||||
|
import openai
|
||||||
|
|
||||||
|
base_url = provider.base_url or provider.url or "http://localhost:8000"
|
||||||
|
# Ensure the base_url ends with /v1 as expected by the OpenAI client
|
||||||
|
if not base_url.rstrip("/").endswith("/v1"):
|
||||||
|
base_url = base_url.rstrip("/") + "/v1"
|
||||||
|
|
||||||
|
client = openai.AsyncOpenAI(
|
||||||
|
api_key=provider.api_key or "no-key-required",
|
||||||
|
base_url=base_url,
|
||||||
|
timeout=self.config.timeout_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": temperature,
|
||||||
|
}
|
||||||
|
if max_tokens:
|
||||||
|
kwargs["max_tokens"] = max_tokens
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(**kwargs)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": response.choices[0].message.content,
|
||||||
|
"model": response.model,
|
||||||
|
}
|
||||||
@@ -677,7 +677,7 @@ class TestVllmMlxProvider:
|
|||||||
router.providers = [provider]
|
router.providers = [provider]
|
||||||
|
|
||||||
# Quota monitor downshifts to local (ACTIVE tier) — vllm_mlx should still be tried
|
# Quota monitor downshifts to local (ACTIVE tier) — vllm_mlx should still be tried
|
||||||
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
with patch("infrastructure.router.health._quota_monitor") as mock_qm:
|
||||||
mock_qm.select_model.return_value = "qwen3:14b"
|
mock_qm.select_model.return_value = "qwen3:14b"
|
||||||
mock_qm.check.return_value = None
|
mock_qm.check.return_value = None
|
||||||
|
|
||||||
@@ -713,7 +713,7 @@ class TestMetabolicProtocol:
|
|||||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||||
router.providers = [self._make_anthropic_provider()]
|
router.providers = [self._make_anthropic_provider()]
|
||||||
|
|
||||||
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
with patch("infrastructure.router.health._quota_monitor") as mock_qm:
|
||||||
# select_model returns cloud model → BURST tier
|
# select_model returns cloud model → BURST tier
|
||||||
mock_qm.select_model.return_value = "claude-sonnet-4-6"
|
mock_qm.select_model.return_value = "claude-sonnet-4-6"
|
||||||
mock_qm.check.return_value = None
|
mock_qm.check.return_value = None
|
||||||
@@ -732,7 +732,7 @@ class TestMetabolicProtocol:
|
|||||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||||
router.providers = [self._make_anthropic_provider()]
|
router.providers = [self._make_anthropic_provider()]
|
||||||
|
|
||||||
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
with patch("infrastructure.router.health._quota_monitor") as mock_qm:
|
||||||
# select_model returns local 14B → ACTIVE tier
|
# select_model returns local 14B → ACTIVE tier
|
||||||
mock_qm.select_model.return_value = "qwen3:14b"
|
mock_qm.select_model.return_value = "qwen3:14b"
|
||||||
mock_qm.check.return_value = None
|
mock_qm.check.return_value = None
|
||||||
@@ -750,7 +750,7 @@ class TestMetabolicProtocol:
|
|||||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||||
router.providers = [self._make_anthropic_provider()]
|
router.providers = [self._make_anthropic_provider()]
|
||||||
|
|
||||||
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
with patch("infrastructure.router.health._quota_monitor") as mock_qm:
|
||||||
# select_model returns local 8B → RESTING tier
|
# select_model returns local 8B → RESTING tier
|
||||||
mock_qm.select_model.return_value = "qwen3:8b"
|
mock_qm.select_model.return_value = "qwen3:8b"
|
||||||
mock_qm.check.return_value = None
|
mock_qm.check.return_value = None
|
||||||
@@ -776,7 +776,7 @@ class TestMetabolicProtocol:
|
|||||||
)
|
)
|
||||||
router.providers = [provider]
|
router.providers = [provider]
|
||||||
|
|
||||||
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
with patch("infrastructure.router.health._quota_monitor") as mock_qm:
|
||||||
mock_qm.select_model.return_value = "qwen3:8b" # RESTING tier
|
mock_qm.select_model.return_value = "qwen3:8b" # RESTING tier
|
||||||
|
|
||||||
with patch.object(router, "_call_ollama") as mock_call:
|
with patch.object(router, "_call_ollama") as mock_call:
|
||||||
@@ -793,7 +793,7 @@ class TestMetabolicProtocol:
|
|||||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||||
router.providers = [self._make_anthropic_provider()]
|
router.providers = [self._make_anthropic_provider()]
|
||||||
|
|
||||||
with patch("infrastructure.router.cascade._quota_monitor", None):
|
with patch("infrastructure.router.health._quota_monitor", None):
|
||||||
with patch.object(router, "_call_anthropic") as mock_call:
|
with patch.object(router, "_call_anthropic") as mock_call:
|
||||||
mock_call.return_value = {"content": "Cloud response", "model": "claude-sonnet-4-6"}
|
mock_call.return_value = {"content": "Cloud response", "model": "claude-sonnet-4-6"}
|
||||||
result = await router.complete(
|
result = await router.complete(
|
||||||
@@ -1200,7 +1200,7 @@ class TestCascadeTierFiltering:
|
|||||||
|
|
||||||
async def test_frontier_required_uses_anthropic(self):
|
async def test_frontier_required_uses_anthropic(self):
|
||||||
router = self._make_router()
|
router = self._make_router()
|
||||||
with patch("infrastructure.router.cascade._quota_monitor", None):
|
with patch("infrastructure.router.health._quota_monitor", None):
|
||||||
with patch.object(router, "_call_anthropic") as mock_call:
|
with patch.object(router, "_call_anthropic") as mock_call:
|
||||||
mock_call.return_value = {
|
mock_call.return_value = {
|
||||||
"content": "frontier response",
|
"content": "frontier response",
|
||||||
@@ -1464,7 +1464,7 @@ class TestTrySingleProvider:
|
|||||||
router = self._router()
|
router = self._router()
|
||||||
provider = self._provider(ptype="anthropic")
|
provider = self._provider(ptype="anthropic")
|
||||||
errors: list[str] = []
|
errors: list[str] = []
|
||||||
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
with patch("infrastructure.router.health._quota_monitor") as mock_qm:
|
||||||
mock_qm.select_model.return_value = "qwen3:14b" # non-cloud → ACTIVE tier
|
mock_qm.select_model.return_value = "qwen3:14b" # non-cloud → ACTIVE tier
|
||||||
mock_qm.check.return_value = None
|
mock_qm.check.return_value = None
|
||||||
result = await router._try_single_provider(
|
result = await router._try_single_provider(
|
||||||
|
|||||||
Reference in New Issue
Block a user