perf: optimize cascade router memory — deduplicate provider loop, hoist constants (#1376)
Some checks failed
Tests / lint (pull_request) Failing after 28s
Tests / test (pull_request) Has been skipped

This commit is contained in:
Alexander Whitestone
2026-03-24 17:04:06 -04:00
parent e5373119cc
commit ebfb9fadd2

View File

@@ -10,6 +10,8 @@ models for image inputs and falls back through capability chains.
import asyncio
import logging
import os
import re
from pathlib import Path
from typing import TYPE_CHECKING, Any
@@ -28,6 +30,15 @@ try:
except ImportError:
requests = None # type: ignore
# Pre-compiled regex for env-var expansion (avoids re-compilation per call)
_ENV_VAR_RE = re.compile(r"\$\{(\w+)\}")
# Constant tuples for content-type detection (avoids per-call allocation)
_IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
# Constant set for cloud provider types (avoids per-call tuple creation)
_CLOUD_PROVIDER_TYPES = frozenset(("anthropic", "openai", "grok"))
# Re-export data models so existing ``from …cascade import X`` keeps working.
# Mixins
from .health import HealthMixin
@@ -156,20 +167,19 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
self.providers.sort(key=lambda p: p.priority)
def _expand_env_vars(self, content: str) -> str:
@staticmethod
def _expand_env_vars(content: str) -> str:
"""Expand ${VAR} syntax in YAML content.
Uses os.environ directly (not settings) because this is a generic
YAML config loader that must expand arbitrary variable references.
"""
import os
import re
def replace_var(match: "re.Match[str]") -> str:
var_name = match.group(1)
return os.environ.get(var_name, match.group(0))
return re.sub(r"\$\{(\w+)\}", replace_var, content)
return _ENV_VAR_RE.sub(replace_var, content)
def _check_provider_available(self, provider: Provider) -> bool:
"""Check if a provider is actually available."""
@@ -225,8 +235,7 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
# Check for image URLs in content
if isinstance(content, str):
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
if any(ext in content.lower() for ext in image_extensions):
if any(ext in content.lower() for ext in _IMAGE_EXTENSIONS):
has_image = True
if content.startswith("data:image/"):
has_image = True
@@ -395,7 +404,7 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
return None
# Metabolic protocol: skip cloud providers when quota is low
if provider.type in ("anthropic", "openai", "grok"):
if provider.type in _CLOUD_PROVIDER_TYPES:
if not self._quota_allows_cloud(provider):
logger.info(
"Metabolic protocol: skipping cloud provider %s (quota too low)",
@@ -513,18 +522,6 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
providers = self._filter_providers(cascade_tier)
for provider in providers:
if not self._is_provider_available(provider):
continue
# Metabolic protocol: skip cloud providers when quota is low
if provider.type in ("anthropic", "openai", "grok"):
if not self._quota_allows_cloud(provider):
logger.info(
"Metabolic protocol: skipping cloud provider %s (quota too low)",
provider.name,
)
continue
# Complexity-based model selection (only when no explicit model)
effective_model = model
if effective_model is None and complexity is not None:
@@ -537,33 +534,13 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
effective_model,
)
selected_model, is_fallback_model = self._select_model(
provider, effective_model, content_type
result = await self._try_single_provider(
provider, messages, effective_model, temperature,
max_tokens, content_type, errors,
)
try:
result = await self._attempt_with_retry(
provider,
messages,
selected_model,
temperature,
max_tokens,
content_type,
)
except RuntimeError as exc:
errors.append(str(exc))
self._record_failure(provider)
continue
self._record_success(provider, result.get("latency_ms", 0))
return {
"content": result["content"],
"provider": provider.name,
"model": result.get("model", selected_model or provider.get_default_model()),
"latency_ms": result.get("latency_ms", 0),
"is_fallback_model": is_fallback_model,
"complexity": complexity.value if complexity is not None else None,
}
if result is not None:
result["complexity"] = complexity.value if complexity is not None else None
return result
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")