perf: optimize cascade router memory — deduplicate provider loop, hoist constants (#1376)
Some checks failed
Tests / lint (pull_request) Failing after 28s
Tests / test (pull_request) Has been skipped

This commit is contained in:
Alexander Whitestone
2026-03-24 17:04:06 -04:00
parent e5373119cc
commit ebfb9fadd2

View File

@@ -10,6 +10,8 @@ models for image inputs and falls back through capability chains.
import asyncio import asyncio
import logging import logging
import os
import re
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@@ -28,6 +30,15 @@ try:
except ImportError: except ImportError:
requests = None # type: ignore requests = None # type: ignore
# Pre-compiled regex for env-var expansion (avoids re-compilation per call)
_ENV_VAR_RE = re.compile(r"\$\{(\w+)\}")
# Constant tuples for content-type detection (avoids per-call allocation)
_IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
# Constant set for cloud provider types (avoids per-call tuple creation)
_CLOUD_PROVIDER_TYPES = frozenset(("anthropic", "openai", "grok"))
# Re-export data models so existing ``from …cascade import X`` keeps working. # Re-export data models so existing ``from …cascade import X`` keeps working.
# Mixins # Mixins
from .health import HealthMixin from .health import HealthMixin
@@ -156,20 +167,19 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
self.providers.sort(key=lambda p: p.priority) self.providers.sort(key=lambda p: p.priority)
def _expand_env_vars(self, content: str) -> str: @staticmethod
def _expand_env_vars(content: str) -> str:
"""Expand ${VAR} syntax in YAML content. """Expand ${VAR} syntax in YAML content.
Uses os.environ directly (not settings) because this is a generic Uses os.environ directly (not settings) because this is a generic
YAML config loader that must expand arbitrary variable references. YAML config loader that must expand arbitrary variable references.
""" """
import os
import re
def replace_var(match: "re.Match[str]") -> str: def replace_var(match: "re.Match[str]") -> str:
var_name = match.group(1) var_name = match.group(1)
return os.environ.get(var_name, match.group(0)) return os.environ.get(var_name, match.group(0))
return re.sub(r"\$\{(\w+)\}", replace_var, content) return _ENV_VAR_RE.sub(replace_var, content)
def _check_provider_available(self, provider: Provider) -> bool: def _check_provider_available(self, provider: Provider) -> bool:
"""Check if a provider is actually available.""" """Check if a provider is actually available."""
@@ -225,8 +235,7 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
# Check for image URLs in content # Check for image URLs in content
if isinstance(content, str): if isinstance(content, str):
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp") if any(ext in content.lower() for ext in _IMAGE_EXTENSIONS):
if any(ext in content.lower() for ext in image_extensions):
has_image = True has_image = True
if content.startswith("data:image/"): if content.startswith("data:image/"):
has_image = True has_image = True
@@ -395,7 +404,7 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
return None return None
# Metabolic protocol: skip cloud providers when quota is low # Metabolic protocol: skip cloud providers when quota is low
if provider.type in ("anthropic", "openai", "grok"): if provider.type in _CLOUD_PROVIDER_TYPES:
if not self._quota_allows_cloud(provider): if not self._quota_allows_cloud(provider):
logger.info( logger.info(
"Metabolic protocol: skipping cloud provider %s (quota too low)", "Metabolic protocol: skipping cloud provider %s (quota too low)",
@@ -513,18 +522,6 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
providers = self._filter_providers(cascade_tier) providers = self._filter_providers(cascade_tier)
for provider in providers: for provider in providers:
if not self._is_provider_available(provider):
continue
# Metabolic protocol: skip cloud providers when quota is low
if provider.type in ("anthropic", "openai", "grok"):
if not self._quota_allows_cloud(provider):
logger.info(
"Metabolic protocol: skipping cloud provider %s (quota too low)",
provider.name,
)
continue
# Complexity-based model selection (only when no explicit model) # Complexity-based model selection (only when no explicit model)
effective_model = model effective_model = model
if effective_model is None and complexity is not None: if effective_model is None and complexity is not None:
@@ -537,33 +534,13 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
effective_model, effective_model,
) )
selected_model, is_fallback_model = self._select_model( result = await self._try_single_provider(
provider, effective_model, content_type provider, messages, effective_model, temperature,
max_tokens, content_type, errors,
) )
if result is not None:
try: result["complexity"] = complexity.value if complexity is not None else None
result = await self._attempt_with_retry( return result
provider,
messages,
selected_model,
temperature,
max_tokens,
content_type,
)
except RuntimeError as exc:
errors.append(str(exc))
self._record_failure(provider)
continue
self._record_success(provider, result.get("latency_ms", 0))
return {
"content": result["content"],
"provider": provider.name,
"model": result.get("model", selected_model or provider.get_default_model()),
"latency_ms": result.get("latency_ms", 0),
"is_fallback_model": is_fallback_model,
"complexity": complexity.value if complexity is not None else None,
}
raise RuntimeError(f"All providers failed: {'; '.join(errors)}") raise RuntimeError(f"All providers failed: {'; '.join(errors)}")