WIP: Claude Code progress on #1342

Automated salvage commit — agent session ended (exit 124).
Work in progress, may need continuation.
This commit is contained in:
Alexander Whitestone
2026-03-23 23:03:51 -04:00
parent 3349948f7f
commit ab4b2f938d
12 changed files with 1108 additions and 1015 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,123 @@
"""Config loading helpers for the Cascade LLM Router.
Parses providers.yaml, expands env vars, and checks provider availability.
"""
from __future__ import annotations
import logging
from infrastructure.router.models import Provider, RouterConfig
logger = logging.getLogger(__name__)
try:
import yaml
except ImportError:
yaml = None # type: ignore
try:
import requests
except ImportError:
requests = None # type: ignore
def expand_env_vars(content: str) -> str:
"""Expand ${VAR} syntax in YAML content.
Uses os.environ directly (not settings) because this is a generic
YAML config loader that must expand arbitrary variable references.
"""
import os
import re
def replace_var(match: "re.Match[str]") -> str:
var_name = match.group(1)
return os.environ.get(var_name, match.group(0))
return re.sub(r"\$\{(\w+)\}", replace_var, content)
def parse_router_config(data: dict) -> RouterConfig:
"""Build a RouterConfig from parsed YAML data."""
cascade = data.get("cascade", {})
cb = cascade.get("circuit_breaker", {})
multimodal = data.get("multimodal", {})
return RouterConfig(
timeout_seconds=cascade.get("timeout_seconds", 30),
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
retry_delay_seconds=cascade.get("retry_delay_seconds", 1),
circuit_breaker_failure_threshold=cb.get("failure_threshold", 5),
circuit_breaker_recovery_timeout=cb.get("recovery_timeout", 60),
circuit_breaker_half_open_max_calls=cb.get("half_open_max_calls", 2),
auto_pull_models=multimodal.get("auto_pull", True),
fallback_chains=data.get("fallback_chains", {}),
)
def load_providers(data: dict) -> list[Provider]:
"""Load and filter providers from parsed YAML data (unsorted)."""
providers: list[Provider] = []
for p_data in data.get("providers", []):
if not p_data.get("enabled", False):
continue
provider = Provider(
name=p_data["name"],
type=p_data["type"],
enabled=p_data.get("enabled", True),
priority=p_data.get("priority", 99),
tier=p_data.get("tier"),
url=p_data.get("url"),
api_key=p_data.get("api_key"),
base_url=p_data.get("base_url"),
models=p_data.get("models", []),
)
if check_provider_available(provider):
providers.append(provider)
else:
logger.warning("Provider %s not available, skipping", provider.name)
return providers
def check_provider_available(provider: Provider) -> bool:
"""Check if a provider is actually available."""
from config import settings
if provider.type == "ollama":
# Check if Ollama is running
if requests is None:
# Can't check without requests, assume available
return True
try:
url = provider.url or settings.ollama_url
response = requests.get(f"{url}/api/tags", timeout=5)
return response.status_code == 200
except Exception as exc:
logger.debug("Ollama provider check error: %s", exc)
return False
elif provider.type == "vllm_mlx":
# Check if local vllm-mlx server is running (OpenAI-compatible)
if requests is None:
return True
try:
base_url = provider.base_url or provider.url or "http://localhost:8000"
# Strip /v1 suffix — health endpoint is at the root
server_root = base_url.rstrip("/")
if server_root.endswith("/v1"):
server_root = server_root[:-3]
response = requests.get(f"{server_root}/health", timeout=5)
return response.status_code == 200
except Exception as exc:
logger.debug("vllm-mlx provider check error: %s", exc)
return False
elif provider.type in ("openai", "anthropic", "grok"):
# Check if API key is set
return provider.api_key is not None and provider.api_key != ""
return True

View File

@@ -0,0 +1,129 @@
"""Content-type detection and model selection for the Cascade LLM Router."""
from __future__ import annotations
import logging
from typing import Any
from infrastructure.router.models import ContentType, Provider
logger = logging.getLogger(__name__)
def detect_content_type(messages: list[dict]) -> ContentType:
"""Detect the type of content in the messages.
Checks for images, audio, etc. in the message content.
"""
has_image = False
has_audio = False
for msg in messages:
content = msg.get("content", "")
# Check for image URLs/paths
if msg.get("images"):
has_image = True
# Check for image URLs in content
if isinstance(content, str):
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
if any(ext in content.lower() for ext in image_extensions):
has_image = True
if content.startswith("data:image/"):
has_image = True
# Check for audio
if msg.get("audio"):
has_audio = True
# Check for multimodal content structure
if isinstance(content, list):
for item in content:
if isinstance(item, dict):
if item.get("type") == "image_url":
has_image = True
elif item.get("type") == "audio":
has_audio = True
if has_image and has_audio:
return ContentType.MULTIMODAL
elif has_image:
return ContentType.VISION
elif has_audio:
return ContentType.AUDIO
return ContentType.TEXT
def get_fallback_model(
provider: Provider,
original_model: str,
content_type: ContentType,
fallback_chains: dict,
) -> str | None:
"""Get a fallback model for the given content type."""
# Map content type to capability
capability_map = {
ContentType.VISION: "vision",
ContentType.AUDIO: "audio",
ContentType.MULTIMODAL: "vision", # Vision models often do both
}
capability = capability_map.get(content_type)
if not capability:
return None
# Check provider's models for capability
fallback_model = provider.get_model_with_capability(capability)
if fallback_model and fallback_model != original_model:
return fallback_model
# Use fallback chains from config
fallback_chain = fallback_chains.get(capability, [])
for model_name in fallback_chain:
if provider.model_has_capability(model_name, capability):
return model_name
return None
def select_model(
provider: Provider,
model: str | None,
content_type: ContentType,
mm_manager: Any,
fallback_chains: dict,
) -> tuple[str | None, bool]:
"""Select the best model for the request, with vision fallback.
Returns:
Tuple of (selected_model, is_fallback_model).
"""
selected_model = model or provider.get_default_model()
is_fallback = False
if content_type != ContentType.TEXT and selected_model:
if provider.type == "ollama" and mm_manager:
from infrastructure.models.multimodal import ModelCapability
if content_type == ContentType.VISION:
supports = mm_manager.model_supports(selected_model, ModelCapability.VISION)
if not supports:
fallback = get_fallback_model(
provider, selected_model, content_type, fallback_chains
)
if fallback:
logger.info(
"Model %s doesn't support vision, falling back to %s",
selected_model,
fallback,
)
selected_model = fallback
is_fallback = True
else:
logger.warning(
"No vision-capable model found on %s, trying anyway",
provider.name,
)
return selected_model, is_fallback

View File

@@ -0,0 +1,79 @@
"""Circuit-breaker and health tracking for the Cascade LLM Router.
Standalone functions that mutate Provider state in place.
"""
from __future__ import annotations
import logging
import time
from datetime import UTC, datetime
from infrastructure.router.models import CircuitState, Provider, ProviderStatus, RouterConfig
logger = logging.getLogger(__name__)
def record_success(provider: Provider, latency_ms: float, config: RouterConfig) -> None:
"""Record a successful request."""
provider.metrics.total_requests += 1
provider.metrics.successful_requests += 1
provider.metrics.total_latency_ms += latency_ms
provider.metrics.last_request_time = datetime.now(UTC).isoformat()
provider.metrics.consecutive_failures = 0
# Close circuit breaker if half-open
if provider.circuit_state == CircuitState.HALF_OPEN:
provider.half_open_calls += 1
if provider.half_open_calls >= config.circuit_breaker_half_open_max_calls:
close_circuit(provider)
# Update status based on error rate
if provider.metrics.error_rate < 0.1:
provider.status = ProviderStatus.HEALTHY
elif provider.metrics.error_rate < 0.3:
provider.status = ProviderStatus.DEGRADED
def record_failure(provider: Provider, config: RouterConfig) -> None:
"""Record a failed request."""
provider.metrics.total_requests += 1
provider.metrics.failed_requests += 1
provider.metrics.last_error_time = datetime.now(UTC).isoformat()
provider.metrics.consecutive_failures += 1
# Check if we should open circuit breaker
if provider.metrics.consecutive_failures >= config.circuit_breaker_failure_threshold:
open_circuit(provider)
# Update status
if provider.metrics.error_rate > 0.3:
provider.status = ProviderStatus.DEGRADED
if provider.metrics.error_rate > 0.5:
provider.status = ProviderStatus.UNHEALTHY
def open_circuit(provider: Provider) -> None:
"""Open the circuit breaker for a provider."""
provider.circuit_state = CircuitState.OPEN
provider.circuit_opened_at = time.time()
provider.status = ProviderStatus.UNHEALTHY
logger.warning("Circuit breaker OPEN for %s", provider.name)
def can_close_circuit(provider: Provider, config: RouterConfig) -> bool:
"""Check if circuit breaker can transition to half-open."""
if provider.circuit_opened_at is None:
return False
elapsed = time.time() - provider.circuit_opened_at
return elapsed >= config.circuit_breaker_recovery_timeout
def close_circuit(provider: Provider) -> None:
"""Close the circuit breaker (provider healthy again)."""
provider.circuit_state = CircuitState.CLOSED
provider.circuit_opened_at = None
provider.half_open_calls = 0
provider.metrics.consecutive_failures = 0
provider.status = ProviderStatus.HEALTHY
logger.info("Circuit breaker CLOSED for %s", provider.name)

View File

@@ -0,0 +1,141 @@
"""Data models for the Cascade LLM Router.
Enums, dataclasses, and provider configuration shared across
router sub-modules.
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
class ProviderStatus(Enum):
"""Health status of a provider."""
HEALTHY = "healthy"
DEGRADED = "degraded" # Working but slow or occasional errors
UNHEALTHY = "unhealthy" # Circuit breaker open
DISABLED = "disabled"
class CircuitState(Enum):
"""Circuit breaker state."""
CLOSED = "closed" # Normal operation
OPEN = "open" # Failing, rejecting requests
HALF_OPEN = "half_open" # Testing if recovered
class ContentType(Enum):
"""Type of content in the request."""
TEXT = "text"
VISION = "vision" # Contains images
AUDIO = "audio" # Contains audio
MULTIMODAL = "multimodal" # Multiple content types
@dataclass
class ProviderMetrics:
"""Metrics for a single provider."""
total_requests: int = 0
successful_requests: int = 0
failed_requests: int = 0
total_latency_ms: float = 0.0
last_request_time: str | None = None
last_error_time: str | None = None
consecutive_failures: int = 0
@property
def avg_latency_ms(self) -> float:
if self.total_requests == 0:
return 0.0
return self.total_latency_ms / self.total_requests
@property
def error_rate(self) -> float:
if self.total_requests == 0:
return 0.0
return self.failed_requests / self.total_requests
@dataclass
class ModelCapability:
"""Capabilities a model supports."""
name: str
supports_vision: bool = False
supports_audio: bool = False
supports_tools: bool = False
supports_json: bool = False
supports_streaming: bool = True
context_window: int = 4096
@dataclass
class Provider:
"""LLM provider configuration and state."""
name: str
type: str # ollama, openai, anthropic
enabled: bool
priority: int
tier: str | None = None # e.g., "local", "standard_cloud", "frontier"
url: str | None = None
api_key: str | None = None
base_url: str | None = None
models: list[dict] = field(default_factory=list)
# Runtime state
status: ProviderStatus = ProviderStatus.HEALTHY
metrics: ProviderMetrics = field(default_factory=ProviderMetrics)
circuit_state: CircuitState = CircuitState.CLOSED
circuit_opened_at: float | None = None
half_open_calls: int = 0
def get_default_model(self) -> str | None:
"""Get the default model for this provider."""
for model in self.models:
if model.get("default"):
return model["name"]
if self.models:
return self.models[0]["name"]
return None
def get_model_with_capability(self, capability: str) -> str | None:
"""Get a model that supports the given capability."""
for model in self.models:
capabilities = model.get("capabilities", [])
if capability in capabilities:
return model["name"]
# Fall back to default
return self.get_default_model()
def model_has_capability(self, model_name: str, capability: str) -> bool:
"""Check if a specific model has a capability."""
for model in self.models:
if model["name"] == model_name:
capabilities = model.get("capabilities", [])
return capability in capabilities
return False
@dataclass
class RouterConfig:
"""Cascade router configuration."""
timeout_seconds: int = 30
max_retries_per_provider: int = 2
retry_delay_seconds: int = 1
circuit_breaker_failure_threshold: int = 5
circuit_breaker_recovery_timeout: int = 60
circuit_breaker_half_open_max_calls: int = 2
cost_tracking_enabled: bool = True
budget_daily_usd: float = 10.0
# Multi-modal settings
auto_pull_models: bool = True
fallback_chains: dict = field(default_factory=dict)

View File

@@ -0,0 +1 @@
# Provider implementations

View File

@@ -0,0 +1,56 @@
"""Anthropic provider implementation for the Cascade LLM Router."""
from __future__ import annotations
import logging
from infrastructure.router.models import Provider
logger = logging.getLogger(__name__)
async def call_anthropic(
provider: Provider,
messages: list[dict],
model: str,
temperature: float,
max_tokens: int | None,
timeout_seconds: int,
) -> dict:
"""Call Anthropic API."""
import anthropic
client = anthropic.AsyncAnthropic(
api_key=provider.api_key,
timeout=timeout_seconds,
)
# Convert messages to Anthropic format
system_msg = None
conversation = []
for msg in messages:
if msg["role"] == "system":
system_msg = msg["content"]
else:
conversation.append(
{
"role": msg["role"],
"content": msg["content"],
}
)
kwargs: dict = {
"model": model,
"messages": conversation,
"temperature": temperature,
"max_tokens": max_tokens or 1024,
}
if system_msg:
kwargs["system"] = system_msg
response = await client.messages.create(**kwargs)
return {
"content": response.content[0].text,
"model": response.model,
}

View File

@@ -0,0 +1,80 @@
"""Provider dispatch — routes a single request to the correct provider module."""
from __future__ import annotations
import time
from infrastructure.router.models import ContentType, Provider
async def call_provider(
provider: Provider,
messages: list[dict],
model: str,
temperature: float,
max_tokens: int | None,
timeout_seconds: int,
content_type: ContentType = ContentType.TEXT,
) -> dict:
"""Dispatch a request to the correct provider implementation.
Returns a result dict with ``content``, ``model``, and ``latency_ms`` keys.
Raises ValueError for unknown provider types.
"""
from infrastructure.router.providers import ollama as _ollama
from infrastructure.router.providers import openai_compat as _openai_compat
from infrastructure.router.providers import anthropic as _anthropic
from infrastructure.router.providers import grok as _grok
start_time = time.time()
if provider.type == "ollama":
result = await _ollama.call_ollama(
provider=provider,
messages=messages,
model=model or provider.get_default_model(),
temperature=temperature,
max_tokens=max_tokens,
content_type=content_type,
timeout_seconds=timeout_seconds,
)
elif provider.type == "openai":
result = await _openai_compat.call_openai(
provider=provider,
messages=messages,
model=model or provider.get_default_model(),
temperature=temperature,
max_tokens=max_tokens,
timeout_seconds=timeout_seconds,
)
elif provider.type == "anthropic":
result = await _anthropic.call_anthropic(
provider=provider,
messages=messages,
model=model or provider.get_default_model(),
temperature=temperature,
max_tokens=max_tokens,
timeout_seconds=timeout_seconds,
)
elif provider.type == "grok":
result = await _grok.call_grok(
provider=provider,
messages=messages,
model=model or provider.get_default_model(),
temperature=temperature,
max_tokens=max_tokens,
)
elif provider.type == "vllm_mlx":
result = await _openai_compat.call_vllm_mlx(
provider=provider,
messages=messages,
model=model or provider.get_default_model(),
temperature=temperature,
max_tokens=max_tokens,
timeout_seconds=timeout_seconds,
)
else:
raise ValueError(f"Unknown provider type: {provider.type}")
result["latency_ms"] = (time.time() - start_time) * 1000
return result

View File

@@ -0,0 +1,44 @@
"""Grok (xAI) provider implementation for the Cascade LLM Router."""
from __future__ import annotations
import logging
from infrastructure.router.models import Provider
logger = logging.getLogger(__name__)
async def call_grok(
provider: Provider,
messages: list[dict],
model: str,
temperature: float,
max_tokens: int | None,
) -> dict:
"""Call xAI Grok API via OpenAI-compatible SDK."""
import httpx
import openai
from config import settings
client = openai.AsyncOpenAI(
api_key=provider.api_key,
base_url=provider.base_url or settings.xai_base_url,
timeout=httpx.Timeout(300.0),
)
kwargs: dict = {
"model": model,
"messages": messages,
"temperature": temperature,
}
if max_tokens:
kwargs["max_tokens"] = max_tokens
response = await client.chat.completions.create(**kwargs)
return {
"content": response.choices[0].message.content,
"model": response.model,
}

View File

@@ -0,0 +1,92 @@
"""Ollama provider implementation for the Cascade LLM Router."""
from __future__ import annotations
import base64
import logging
from pathlib import Path
import aiohttp
from infrastructure.router.models import ContentType, Provider
logger = logging.getLogger(__name__)
async def call_ollama(
provider: Provider,
messages: list[dict],
model: str,
temperature: float,
max_tokens: int | None,
content_type: ContentType,
timeout_seconds: int,
) -> dict:
"""Call Ollama API with multi-modal support."""
from config import settings
url = f"{provider.url or settings.ollama_url}/api/chat"
# Transform messages for Ollama format (including images)
transformed_messages = transform_messages_for_ollama(messages)
options: dict = {"temperature": temperature}
if max_tokens:
options["num_predict"] = max_tokens
payload = {
"model": model,
"messages": transformed_messages,
"stream": False,
"options": options,
}
timeout = aiohttp.ClientTimeout(total=timeout_seconds)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(url, json=payload) as response:
if response.status != 200:
text = await response.text()
raise RuntimeError(f"Ollama error {response.status}: {text}")
data = await response.json()
return {
"content": data["message"]["content"],
"model": model,
}
def transform_messages_for_ollama(messages: list[dict]) -> list[dict]:
"""Transform messages to Ollama format, handling images."""
transformed = []
for msg in messages:
new_msg = {
"role": msg.get("role", "user"),
"content": msg.get("content", ""),
}
# Handle images
images = msg.get("images", [])
if images:
new_msg["images"] = []
for img in images:
if isinstance(img, str):
if img.startswith("data:image/"):
# Base64 encoded image
new_msg["images"].append(img.split(",")[1])
elif img.startswith("http://") or img.startswith("https://"):
# URL - would need to download, skip for now
logger.warning("Image URLs not yet supported, skipping: %s", img)
elif Path(img).exists():
# Local file path - read and encode
try:
with open(img, "rb") as f:
img_data = base64.b64encode(f.read()).decode()
new_msg["images"].append(img_data)
except Exception as exc:
logger.error("Failed to read image %s: %s", img, exc)
transformed.append(new_msg)
return transformed

View File

@@ -0,0 +1,88 @@
"""OpenAI-compatible provider implementations for the Cascade LLM Router.
Covers the ``openai`` and ``vllm_mlx`` provider types.
"""
from __future__ import annotations
import logging
from infrastructure.router.models import Provider
logger = logging.getLogger(__name__)
async def call_openai(
provider: Provider,
messages: list[dict],
model: str,
temperature: float,
max_tokens: int | None,
timeout_seconds: int,
) -> dict:
"""Call OpenAI API."""
import openai
client = openai.AsyncOpenAI(
api_key=provider.api_key,
base_url=provider.base_url,
timeout=timeout_seconds,
)
kwargs: dict = {
"model": model,
"messages": messages,
"temperature": temperature,
}
if max_tokens:
kwargs["max_tokens"] = max_tokens
response = await client.chat.completions.create(**kwargs)
return {
"content": response.choices[0].message.content,
"model": response.model,
}
async def call_vllm_mlx(
provider: Provider,
messages: list[dict],
model: str,
temperature: float,
max_tokens: int | None,
timeout_seconds: int,
) -> dict:
"""Call vllm-mlx via its OpenAI-compatible API.
vllm-mlx exposes the same /v1/chat/completions endpoint as OpenAI,
so we reuse the OpenAI client pointed at the local server.
No API key is required for local deployments.
"""
import openai
base_url = provider.base_url or provider.url or "http://localhost:8000"
# Ensure the base_url ends with /v1 as expected by the OpenAI client
if not base_url.rstrip("/").endswith("/v1"):
base_url = base_url.rstrip("/") + "/v1"
client = openai.AsyncOpenAI(
api_key=provider.api_key or "no-key-required",
base_url=base_url,
timeout=timeout_seconds,
)
kwargs: dict = {
"model": model,
"messages": messages,
"temperature": temperature,
}
if max_tokens:
kwargs["max_tokens"] = max_tokens
response = await client.chat.completions.create(**kwargs)
return {
"content": response.choices[0].message.content,
"model": response.model,
}

View File

@@ -0,0 +1,89 @@
"""Metrics, status, and config-reload helpers for the Cascade LLM Router."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from infrastructure.router.models import (
CircuitState,
Provider,
ProviderMetrics,
ProviderStatus,
)
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
def build_metrics(providers: list[Provider]) -> dict:
"""Build a metrics summary dict for all providers."""
return {
"providers": [
{
"name": p.name,
"type": p.type,
"status": p.status.value,
"circuit_state": p.circuit_state.value,
"metrics": {
"total_requests": p.metrics.total_requests,
"successful": p.metrics.successful_requests,
"failed": p.metrics.failed_requests,
"error_rate": round(p.metrics.error_rate, 3),
"avg_latency_ms": round(p.metrics.avg_latency_ms, 2),
},
}
for p in providers
]
}
def build_status(providers: list[Provider]) -> dict:
"""Build a status summary dict for all providers."""
healthy = sum(1 for p in providers if p.status == ProviderStatus.HEALTHY)
return {
"total_providers": len(providers),
"healthy_providers": healthy,
"degraded_providers": sum(1 for p in providers if p.status == ProviderStatus.DEGRADED),
"unhealthy_providers": sum(1 for p in providers if p.status == ProviderStatus.UNHEALTHY),
"providers": [
{
"name": p.name,
"type": p.type,
"status": p.status.value,
"priority": p.priority,
"default_model": p.get_default_model(),
}
for p in providers
],
}
def snapshot_provider_state(
providers: list[Provider],
) -> dict[str, tuple[ProviderMetrics, CircuitState, float | None, int, ProviderStatus]]:
"""Capture current runtime state keyed by provider name."""
return {
p.name: (p.metrics, p.circuit_state, p.circuit_opened_at, p.half_open_calls, p.status)
for p in providers
}
def restore_provider_state(
providers: list[Provider],
old_state: dict[str, tuple[ProviderMetrics, CircuitState, float | None, int, ProviderStatus]],
) -> int:
"""Restore saved runtime state to matching providers. Returns count of restored providers."""
preserved = 0
for p in providers:
if p.name in old_state:
metrics, circuit, opened_at, half_open, status = old_state[p.name]
p.metrics = metrics
p.circuit_state = circuit
p.circuit_opened_at = opened_at
p.half_open_calls = half_open
p.status = status
preserved += 1
return preserved