Files
Timmy-time-dashboard/src/infrastructure/router/cascade.py
Claude 17059bc0ea feat: add Grok (xAI) as opt-in premium backend with monetization
- Add GrokBackend class in src/timmy/backends.py with full sync/async
  support, health checks, usage stats, and cost estimation in sats
- Add consult_grok tool to Timmy's toolkit for proactive Grok queries
- Extend cascade router with Grok provider type for failover chain
- Add Grok Mode toggle card to Mission Control dashboard (HTMX live)
- Add "Ask Grok" button on chat input for direct Grok queries
- Add /grok/* routes: status, toggle, chat, stats endpoints
- Integrate Lightning invoice generation for Grok usage monetization
- Add GROK_ENABLED, XAI_API_KEY, GROK_DEFAULT_MODEL, GROK_MAX_SATS_PER_QUERY,
  GROK_FREE config settings via pydantic-settings
- Update .env.example and docker-compose.yml with Grok env vars
- Add 21 tests covering backend, tools, and route endpoints (all green)

Local-first ethos preserved: Grok is premium augmentation only,
disabled by default, and Lightning-payable when enabled.

https://claude.ai/code/session_01FygwN8wS8J6WGZ8FPb7XGV
2026-02-27 01:12:51 +00:00

613 lines
21 KiB
Python

"""Cascade LLM Router — Automatic failover between providers.
Routes requests through an ordered list of LLM providers,
automatically failing over on rate limits or errors.
Tracks metrics for latency, errors, and cost.
"""
import asyncio
import logging
import time
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Any, Optional
from pathlib import Path
try:
import yaml
except ImportError:
yaml = None # type: ignore
try:
import requests
except ImportError:
requests = None # type: ignore
logger = logging.getLogger(__name__)
class ProviderStatus(Enum):
"""Health status of a provider."""
HEALTHY = "healthy"
DEGRADED = "degraded" # Working but slow or occasional errors
UNHEALTHY = "unhealthy" # Circuit breaker open
DISABLED = "disabled"
class CircuitState(Enum):
"""Circuit breaker state."""
CLOSED = "closed" # Normal operation
OPEN = "open" # Failing, rejecting requests
HALF_OPEN = "half_open" # Testing if recovered
@dataclass
class ProviderMetrics:
"""Metrics for a single provider."""
total_requests: int = 0
successful_requests: int = 0
failed_requests: int = 0
total_latency_ms: float = 0.0
last_request_time: Optional[str] = None
last_error_time: Optional[str] = None
consecutive_failures: int = 0
@property
def avg_latency_ms(self) -> float:
if self.total_requests == 0:
return 0.0
return self.total_latency_ms / self.total_requests
@property
def error_rate(self) -> float:
if self.total_requests == 0:
return 0.0
return self.failed_requests / self.total_requests
@dataclass
class Provider:
"""LLM provider configuration and state."""
name: str
type: str # ollama, openai, anthropic, airllm
enabled: bool
priority: int
url: Optional[str] = None
api_key: Optional[str] = None
base_url: Optional[str] = None
models: list[dict] = field(default_factory=list)
# Runtime state
status: ProviderStatus = ProviderStatus.HEALTHY
metrics: ProviderMetrics = field(default_factory=ProviderMetrics)
circuit_state: CircuitState = CircuitState.CLOSED
circuit_opened_at: Optional[float] = None
half_open_calls: int = 0
def get_default_model(self) -> Optional[str]:
"""Get the default model for this provider."""
for model in self.models:
if model.get("default"):
return model["name"]
if self.models:
return self.models[0]["name"]
return None
@dataclass
class RouterConfig:
"""Cascade router configuration."""
timeout_seconds: int = 30
max_retries_per_provider: int = 2
retry_delay_seconds: int = 1
circuit_breaker_failure_threshold: int = 5
circuit_breaker_recovery_timeout: int = 60
circuit_breaker_half_open_max_calls: int = 2
cost_tracking_enabled: bool = True
budget_daily_usd: float = 10.0
class CascadeRouter:
"""Routes LLM requests with automatic failover.
Usage:
router = CascadeRouter()
response = await router.complete(
messages=[{"role": "user", "content": "Hello"}],
model="llama3.2"
)
# Check metrics
metrics = router.get_metrics()
"""
def __init__(self, config_path: Optional[Path] = None) -> None:
self.config_path = config_path or Path("config/providers.yaml")
self.providers: list[Provider] = []
self.config: RouterConfig = RouterConfig()
self._load_config()
logger.info("CascadeRouter initialized with %d providers", len(self.providers))
def _load_config(self) -> None:
"""Load configuration from YAML."""
if not self.config_path.exists():
logger.warning("Config not found: %s, using defaults", self.config_path)
return
try:
if yaml is None:
raise RuntimeError("PyYAML not installed")
content = self.config_path.read_text()
# Expand environment variables
content = self._expand_env_vars(content)
data = yaml.safe_load(content)
# Load cascade settings
cascade = data.get("cascade", {})
self.config = RouterConfig(
timeout_seconds=cascade.get("timeout_seconds", 30),
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
retry_delay_seconds=cascade.get("retry_delay_seconds", 1),
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get("failure_threshold", 5),
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get("recovery_timeout", 60),
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get("half_open_max_calls", 2),
)
# Load providers
for p_data in data.get("providers", []):
# Skip disabled providers
if not p_data.get("enabled", False):
continue
provider = Provider(
name=p_data["name"],
type=p_data["type"],
enabled=p_data.get("enabled", True),
priority=p_data.get("priority", 99),
url=p_data.get("url"),
api_key=p_data.get("api_key"),
base_url=p_data.get("base_url"),
models=p_data.get("models", []),
)
# Check if provider is actually available
if self._check_provider_available(provider):
self.providers.append(provider)
else:
logger.warning("Provider %s not available, skipping", provider.name)
# Sort by priority
self.providers.sort(key=lambda p: p.priority)
except Exception as exc:
logger.error("Failed to load config: %s", exc)
def _expand_env_vars(self, content: str) -> str:
"""Expand ${VAR} syntax in YAML content."""
import os
import re
def replace_var(match):
var_name = match.group(1)
return os.environ.get(var_name, match.group(0))
return re.sub(r"\$\{(\w+)\}", replace_var, content)
def _check_provider_available(self, provider: Provider) -> bool:
"""Check if a provider is actually available."""
if provider.type == "ollama":
# Check if Ollama is running
if requests is None:
# Can't check without requests, assume available
return True
try:
url = provider.url or "http://localhost:11434"
response = requests.get(f"{url}/api/tags", timeout=5)
return response.status_code == 200
except Exception:
return False
elif provider.type == "airllm":
# Check if airllm is installed
try:
import airllm
return True
except ImportError:
return False
elif provider.type in ("openai", "anthropic", "grok"):
# Check if API key is set
return provider.api_key is not None and provider.api_key != ""
return True
async def complete(
self,
messages: list[dict],
model: Optional[str] = None,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
) -> dict:
"""Complete a chat conversation with automatic failover.
Args:
messages: List of message dicts with role and content
model: Preferred model (tries this first, then provider defaults)
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
Returns:
Dict with content, provider_used, and metrics
Raises:
RuntimeError: If all providers fail
"""
errors = []
for provider in self.providers:
# Skip disabled providers
if not provider.enabled:
logger.debug("Skipping %s (disabled)", provider.name)
continue
# Skip unhealthy providers (circuit breaker)
if provider.status == ProviderStatus.UNHEALTHY:
# Check if circuit breaker can close
if self._can_close_circuit(provider):
provider.circuit_state = CircuitState.HALF_OPEN
provider.half_open_calls = 0
logger.info("Circuit breaker half-open for %s", provider.name)
else:
logger.debug("Skipping %s (circuit open)", provider.name)
continue
# Try this provider
for attempt in range(self.config.max_retries_per_provider):
try:
result = await self._try_provider(
provider=provider,
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
)
# Success! Update metrics and return
self._record_success(provider, result.get("latency_ms", 0))
return {
"content": result["content"],
"provider": provider.name,
"model": result.get("model", model or provider.get_default_model()),
"latency_ms": result.get("latency_ms", 0),
}
except Exception as exc:
error_msg = str(exc)
logger.warning(
"Provider %s attempt %d failed: %s",
provider.name, attempt + 1, error_msg
)
errors.append(f"{provider.name}: {error_msg}")
if attempt < self.config.max_retries_per_provider - 1:
await asyncio.sleep(self.config.retry_delay_seconds)
# All retries failed for this provider
self._record_failure(provider)
# All providers failed
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
async def _try_provider(
self,
provider: Provider,
messages: list[dict],
model: Optional[str],
temperature: float,
max_tokens: Optional[int],
) -> dict:
"""Try a single provider request."""
start_time = time.time()
if provider.type == "ollama":
result = await self._call_ollama(
provider=provider,
messages=messages,
model=model or provider.get_default_model(),
temperature=temperature,
)
elif provider.type == "openai":
result = await self._call_openai(
provider=provider,
messages=messages,
model=model or provider.get_default_model(),
temperature=temperature,
max_tokens=max_tokens,
)
elif provider.type == "anthropic":
result = await self._call_anthropic(
provider=provider,
messages=messages,
model=model or provider.get_default_model(),
temperature=temperature,
max_tokens=max_tokens,
)
elif provider.type == "grok":
result = await self._call_grok(
provider=provider,
messages=messages,
model=model or provider.get_default_model(),
temperature=temperature,
max_tokens=max_tokens,
)
else:
raise ValueError(f"Unknown provider type: {provider.type}")
latency_ms = (time.time() - start_time) * 1000
result["latency_ms"] = latency_ms
return result
async def _call_ollama(
self,
provider: Provider,
messages: list[dict],
model: str,
temperature: float,
) -> dict:
"""Call Ollama API."""
import aiohttp
url = f"{provider.url}/api/chat"
payload = {
"model": model,
"messages": messages,
"stream": False,
"options": {
"temperature": temperature,
},
}
timeout = aiohttp.ClientTimeout(total=self.config.timeout_seconds)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(url, json=payload) as response:
if response.status != 200:
text = await response.text()
raise RuntimeError(f"Ollama error {response.status}: {text}")
data = await response.json()
return {
"content": data["message"]["content"],
"model": model,
}
async def _call_openai(
self,
provider: Provider,
messages: list[dict],
model: str,
temperature: float,
max_tokens: Optional[int],
) -> dict:
"""Call OpenAI API."""
import openai
client = openai.AsyncOpenAI(
api_key=provider.api_key,
base_url=provider.base_url,
timeout=self.config.timeout_seconds,
)
kwargs = {
"model": model,
"messages": messages,
"temperature": temperature,
}
if max_tokens:
kwargs["max_tokens"] = max_tokens
response = await client.chat.completions.create(**kwargs)
return {
"content": response.choices[0].message.content,
"model": response.model,
}
async def _call_anthropic(
self,
provider: Provider,
messages: list[dict],
model: str,
temperature: float,
max_tokens: Optional[int],
) -> dict:
"""Call Anthropic API."""
import anthropic
client = anthropic.AsyncAnthropic(
api_key=provider.api_key,
timeout=self.config.timeout_seconds,
)
# Convert messages to Anthropic format
system_msg = None
conversation = []
for msg in messages:
if msg["role"] == "system":
system_msg = msg["content"]
else:
conversation.append({
"role": msg["role"],
"content": msg["content"],
})
kwargs = {
"model": model,
"messages": conversation,
"temperature": temperature,
"max_tokens": max_tokens or 1024,
}
if system_msg:
kwargs["system"] = system_msg
response = await client.messages.create(**kwargs)
return {
"content": response.content[0].text,
"model": response.model,
}
async def _call_grok(
self,
provider: Provider,
messages: list[dict],
model: str,
temperature: float,
max_tokens: Optional[int],
) -> dict:
"""Call xAI Grok API via OpenAI-compatible SDK."""
import httpx
import openai
client = openai.AsyncOpenAI(
api_key=provider.api_key,
base_url=provider.base_url or "https://api.x.ai/v1",
timeout=httpx.Timeout(300.0),
)
kwargs = {
"model": model,
"messages": messages,
"temperature": temperature,
}
if max_tokens:
kwargs["max_tokens"] = max_tokens
response = await client.chat.completions.create(**kwargs)
return {
"content": response.choices[0].message.content,
"model": response.model,
}
def _record_success(self, provider: Provider, latency_ms: float) -> None:
"""Record a successful request."""
provider.metrics.total_requests += 1
provider.metrics.successful_requests += 1
provider.metrics.total_latency_ms += latency_ms
provider.metrics.last_request_time = datetime.now(timezone.utc).isoformat()
provider.metrics.consecutive_failures = 0
# Close circuit breaker if half-open
if provider.circuit_state == CircuitState.HALF_OPEN:
provider.half_open_calls += 1
if provider.half_open_calls >= self.config.circuit_breaker_half_open_max_calls:
self._close_circuit(provider)
# Update status based on error rate
if provider.metrics.error_rate < 0.1:
provider.status = ProviderStatus.HEALTHY
elif provider.metrics.error_rate < 0.3:
provider.status = ProviderStatus.DEGRADED
def _record_failure(self, provider: Provider) -> None:
"""Record a failed request."""
provider.metrics.total_requests += 1
provider.metrics.failed_requests += 1
provider.metrics.last_error_time = datetime.now(timezone.utc).isoformat()
provider.metrics.consecutive_failures += 1
# Check if we should open circuit breaker
if provider.metrics.consecutive_failures >= self.config.circuit_breaker_failure_threshold:
self._open_circuit(provider)
# Update status
if provider.metrics.error_rate > 0.3:
provider.status = ProviderStatus.DEGRADED
if provider.metrics.error_rate > 0.5:
provider.status = ProviderStatus.UNHEALTHY
def _open_circuit(self, provider: Provider) -> None:
"""Open the circuit breaker for a provider."""
provider.circuit_state = CircuitState.OPEN
provider.circuit_opened_at = time.time()
provider.status = ProviderStatus.UNHEALTHY
logger.warning("Circuit breaker OPEN for %s", provider.name)
def _can_close_circuit(self, provider: Provider) -> bool:
"""Check if circuit breaker can transition to half-open."""
if provider.circuit_opened_at is None:
return False
elapsed = time.time() - provider.circuit_opened_at
return elapsed >= self.config.circuit_breaker_recovery_timeout
def _close_circuit(self, provider: Provider) -> None:
"""Close the circuit breaker (provider healthy again)."""
provider.circuit_state = CircuitState.CLOSED
provider.circuit_opened_at = None
provider.half_open_calls = 0
provider.metrics.consecutive_failures = 0
provider.status = ProviderStatus.HEALTHY
logger.info("Circuit breaker CLOSED for %s", provider.name)
def get_metrics(self) -> dict:
"""Get metrics for all providers."""
return {
"providers": [
{
"name": p.name,
"type": p.type,
"status": p.status.value,
"circuit_state": p.circuit_state.value,
"metrics": {
"total_requests": p.metrics.total_requests,
"successful": p.metrics.successful_requests,
"failed": p.metrics.failed_requests,
"error_rate": round(p.metrics.error_rate, 3),
"avg_latency_ms": round(p.metrics.avg_latency_ms, 2),
},
}
for p in self.providers
]
}
def get_status(self) -> dict:
"""Get current router status."""
healthy = sum(1 for p in self.providers if p.status == ProviderStatus.HEALTHY)
return {
"total_providers": len(self.providers),
"healthy_providers": healthy,
"degraded_providers": sum(1 for p in self.providers if p.status == ProviderStatus.DEGRADED),
"unhealthy_providers": sum(1 for p in self.providers if p.status == ProviderStatus.UNHEALTHY),
"providers": [
{
"name": p.name,
"type": p.type,
"status": p.status.value,
"priority": p.priority,
"default_model": p.get_default_model(),
}
for p in self.providers
],
}
# Module-level singleton
cascade_router: Optional[CascadeRouter] = None
def get_router() -> CascadeRouter:
"""Get or create the cascade router singleton."""
global cascade_router
if cascade_router is None:
cascade_router = CascadeRouter()
return cascade_router