perf: optimize cascade router memory — deduplicate provider loop, hoist constants (#1376)
This commit is contained in:
@@ -10,6 +10,8 @@ models for image inputs and falls back through capability chains.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
@@ -28,6 +30,15 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
requests = None # type: ignore
|
requests = None # type: ignore
|
||||||
|
|
||||||
|
# Pre-compiled regex for env-var expansion (avoids re-compilation per call)
|
||||||
|
_ENV_VAR_RE = re.compile(r"\$\{(\w+)\}")
|
||||||
|
|
||||||
|
# Constant tuples for content-type detection (avoids per-call allocation)
|
||||||
|
_IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
|
||||||
|
|
||||||
|
# Constant set for cloud provider types (avoids per-call tuple creation)
|
||||||
|
_CLOUD_PROVIDER_TYPES = frozenset(("anthropic", "openai", "grok"))
|
||||||
|
|
||||||
# Re-export data models so existing ``from …cascade import X`` keeps working.
|
# Re-export data models so existing ``from …cascade import X`` keeps working.
|
||||||
# Mixins
|
# Mixins
|
||||||
from .health import HealthMixin
|
from .health import HealthMixin
|
||||||
@@ -156,20 +167,19 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
|
|||||||
|
|
||||||
self.providers.sort(key=lambda p: p.priority)
|
self.providers.sort(key=lambda p: p.priority)
|
||||||
|
|
||||||
def _expand_env_vars(self, content: str) -> str:
|
@staticmethod
|
||||||
|
def _expand_env_vars(content: str) -> str:
|
||||||
"""Expand ${VAR} syntax in YAML content.
|
"""Expand ${VAR} syntax in YAML content.
|
||||||
|
|
||||||
Uses os.environ directly (not settings) because this is a generic
|
Uses os.environ directly (not settings) because this is a generic
|
||||||
YAML config loader that must expand arbitrary variable references.
|
YAML config loader that must expand arbitrary variable references.
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import re
|
|
||||||
|
|
||||||
def replace_var(match: "re.Match[str]") -> str:
|
def replace_var(match: "re.Match[str]") -> str:
|
||||||
var_name = match.group(1)
|
var_name = match.group(1)
|
||||||
return os.environ.get(var_name, match.group(0))
|
return os.environ.get(var_name, match.group(0))
|
||||||
|
|
||||||
return re.sub(r"\$\{(\w+)\}", replace_var, content)
|
return _ENV_VAR_RE.sub(replace_var, content)
|
||||||
|
|
||||||
def _check_provider_available(self, provider: Provider) -> bool:
|
def _check_provider_available(self, provider: Provider) -> bool:
|
||||||
"""Check if a provider is actually available."""
|
"""Check if a provider is actually available."""
|
||||||
@@ -225,8 +235,7 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
|
|||||||
|
|
||||||
# Check for image URLs in content
|
# Check for image URLs in content
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
|
if any(ext in content.lower() for ext in _IMAGE_EXTENSIONS):
|
||||||
if any(ext in content.lower() for ext in image_extensions):
|
|
||||||
has_image = True
|
has_image = True
|
||||||
if content.startswith("data:image/"):
|
if content.startswith("data:image/"):
|
||||||
has_image = True
|
has_image = True
|
||||||
@@ -395,7 +404,7 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Metabolic protocol: skip cloud providers when quota is low
|
# Metabolic protocol: skip cloud providers when quota is low
|
||||||
if provider.type in ("anthropic", "openai", "grok"):
|
if provider.type in _CLOUD_PROVIDER_TYPES:
|
||||||
if not self._quota_allows_cloud(provider):
|
if not self._quota_allows_cloud(provider):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Metabolic protocol: skipping cloud provider %s (quota too low)",
|
"Metabolic protocol: skipping cloud provider %s (quota too low)",
|
||||||
@@ -513,18 +522,6 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
|
|||||||
providers = self._filter_providers(cascade_tier)
|
providers = self._filter_providers(cascade_tier)
|
||||||
|
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
if not self._is_provider_available(provider):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Metabolic protocol: skip cloud providers when quota is low
|
|
||||||
if provider.type in ("anthropic", "openai", "grok"):
|
|
||||||
if not self._quota_allows_cloud(provider):
|
|
||||||
logger.info(
|
|
||||||
"Metabolic protocol: skipping cloud provider %s (quota too low)",
|
|
||||||
provider.name,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Complexity-based model selection (only when no explicit model)
|
# Complexity-based model selection (only when no explicit model)
|
||||||
effective_model = model
|
effective_model = model
|
||||||
if effective_model is None and complexity is not None:
|
if effective_model is None and complexity is not None:
|
||||||
@@ -537,33 +534,13 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
|
|||||||
effective_model,
|
effective_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
selected_model, is_fallback_model = self._select_model(
|
result = await self._try_single_provider(
|
||||||
provider, effective_model, content_type
|
provider, messages, effective_model, temperature,
|
||||||
|
max_tokens, content_type, errors,
|
||||||
)
|
)
|
||||||
|
if result is not None:
|
||||||
try:
|
result["complexity"] = complexity.value if complexity is not None else None
|
||||||
result = await self._attempt_with_retry(
|
return result
|
||||||
provider,
|
|
||||||
messages,
|
|
||||||
selected_model,
|
|
||||||
temperature,
|
|
||||||
max_tokens,
|
|
||||||
content_type,
|
|
||||||
)
|
|
||||||
except RuntimeError as exc:
|
|
||||||
errors.append(str(exc))
|
|
||||||
self._record_failure(provider)
|
|
||||||
continue
|
|
||||||
|
|
||||||
self._record_success(provider, result.get("latency_ms", 0))
|
|
||||||
return {
|
|
||||||
"content": result["content"],
|
|
||||||
"provider": provider.name,
|
|
||||||
"model": result.get("model", selected_model or provider.get_default_model()),
|
|
||||||
"latency_ms": result.get("latency_ms", 0),
|
|
||||||
"is_fallback_model": is_fallback_model,
|
|
||||||
"complexity": complexity.value if complexity is not None else None,
|
|
||||||
}
|
|
||||||
|
|
||||||
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
|
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user