This commit was merged in pull request #1468.
This commit is contained in:
@@ -10,6 +10,8 @@ models for image inputs and falls back through capability chains.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -28,6 +30,15 @@ try:
|
||||
except ImportError:
|
||||
requests = None # type: ignore
|
||||
|
||||
# Pre-compiled regex for env-var expansion (avoids re-compilation per call)
|
||||
_ENV_VAR_RE = re.compile(r"\$\{(\w+)\}")
|
||||
|
||||
# Constant tuples for content-type detection (avoids per-call allocation)
|
||||
_IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
|
||||
|
||||
# Constant set for cloud provider types (avoids per-call tuple creation)
|
||||
_CLOUD_PROVIDER_TYPES = frozenset(("anthropic", "openai", "grok"))
|
||||
|
||||
# Re-export data models so existing ``from …cascade import X`` keeps working.
|
||||
# Mixins
|
||||
from .health import HealthMixin
|
||||
@@ -156,20 +167,19 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
|
||||
|
||||
self.providers.sort(key=lambda p: p.priority)
|
||||
|
||||
def _expand_env_vars(self, content: str) -> str:
|
||||
@staticmethod
|
||||
def _expand_env_vars(content: str) -> str:
|
||||
"""Expand ${VAR} syntax in YAML content.
|
||||
|
||||
Uses os.environ directly (not settings) because this is a generic
|
||||
YAML config loader that must expand arbitrary variable references.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
|
||||
def replace_var(match: "re.Match[str]") -> str:
|
||||
var_name = match.group(1)
|
||||
return os.environ.get(var_name, match.group(0))
|
||||
|
||||
return re.sub(r"\$\{(\w+)\}", replace_var, content)
|
||||
return _ENV_VAR_RE.sub(replace_var, content)
|
||||
|
||||
def _check_provider_available(self, provider: Provider) -> bool:
|
||||
"""Check if a provider is actually available."""
|
||||
@@ -225,8 +235,7 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
|
||||
|
||||
# Check for image URLs in content
|
||||
if isinstance(content, str):
|
||||
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
|
||||
if any(ext in content.lower() for ext in image_extensions):
|
||||
if any(ext in content.lower() for ext in _IMAGE_EXTENSIONS):
|
||||
has_image = True
|
||||
if content.startswith("data:image/"):
|
||||
has_image = True
|
||||
@@ -395,7 +404,7 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
|
||||
return None
|
||||
|
||||
# Metabolic protocol: skip cloud providers when quota is low
|
||||
if provider.type in ("anthropic", "openai", "grok"):
|
||||
if provider.type in _CLOUD_PROVIDER_TYPES:
|
||||
if not self._quota_allows_cloud(provider):
|
||||
logger.info(
|
||||
"Metabolic protocol: skipping cloud provider %s (quota too low)",
|
||||
@@ -513,18 +522,6 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
|
||||
providers = self._filter_providers(cascade_tier)
|
||||
|
||||
for provider in providers:
|
||||
if not self._is_provider_available(provider):
|
||||
continue
|
||||
|
||||
# Metabolic protocol: skip cloud providers when quota is low
|
||||
if provider.type in ("anthropic", "openai", "grok"):
|
||||
if not self._quota_allows_cloud(provider):
|
||||
logger.info(
|
||||
"Metabolic protocol: skipping cloud provider %s (quota too low)",
|
||||
provider.name,
|
||||
)
|
||||
continue
|
||||
|
||||
# Complexity-based model selection (only when no explicit model)
|
||||
effective_model = model
|
||||
if effective_model is None and complexity is not None:
|
||||
@@ -537,33 +534,13 @@ class CascadeRouter(HealthMixin, ProviderCallsMixin):
|
||||
effective_model,
|
||||
)
|
||||
|
||||
selected_model, is_fallback_model = self._select_model(
|
||||
provider, effective_model, content_type
|
||||
result = await self._try_single_provider(
|
||||
provider, messages, effective_model, temperature,
|
||||
max_tokens, content_type, errors,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self._attempt_with_retry(
|
||||
provider,
|
||||
messages,
|
||||
selected_model,
|
||||
temperature,
|
||||
max_tokens,
|
||||
content_type,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
errors.append(str(exc))
|
||||
self._record_failure(provider)
|
||||
continue
|
||||
|
||||
self._record_success(provider, result.get("latency_ms", 0))
|
||||
return {
|
||||
"content": result["content"],
|
||||
"provider": provider.name,
|
||||
"model": result.get("model", selected_model or provider.get_default_model()),
|
||||
"latency_ms": result.get("latency_ms", 0),
|
||||
"is_fallback_model": is_fallback_model,
|
||||
"complexity": complexity.value if complexity is not None else None,
|
||||
}
|
||||
if result is not None:
|
||||
result["complexity"] = complexity.value if complexity is not None else None
|
||||
return result
|
||||
|
||||
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user