[claude] refactor: break up cascade.py complete() (#1185) (#1190)

This commit is contained in:
2026-03-23 21:52:27 +00:00
parent c78922ccbc
commit 8f8061e224
2 changed files with 210 additions and 47 deletions

View File

@@ -528,6 +528,71 @@ class CascadeRouter:
return True
def _filter_providers(self, cascade_tier: str | None) -> list["Provider"]:
"""Return the provider list filtered by tier.
Raises:
RuntimeError: If a tier is specified but no matching providers exist.
"""
if cascade_tier == "frontier_required":
providers = [p for p in self.providers if p.type == "anthropic"]
if not providers:
raise RuntimeError("No Anthropic provider configured for 'frontier_required' tier.")
return providers
if cascade_tier:
providers = [p for p in self.providers if p.tier == cascade_tier]
if not providers:
raise RuntimeError(f"No providers found for tier: {cascade_tier}")
return providers
return self.providers
async def _try_single_provider(
self,
provider: "Provider",
messages: list[dict],
model: str | None,
temperature: float,
max_tokens: int | None,
content_type: ContentType,
errors: list[str],
) -> dict | None:
"""Attempt one provider, returning a result dict on success or None on failure.
On failure the error string is appended to *errors* and the provider's
failure metrics are updated so the caller can move on to the next provider.
"""
if not self._is_provider_available(provider):
return None
# Metabolic protocol: skip cloud providers when quota is low
if provider.type in ("anthropic", "openai", "grok"):
if not self._quota_allows_cloud(provider):
logger.info(
"Metabolic protocol: skipping cloud provider %s (quota too low)",
provider.name,
)
return None
selected_model, is_fallback_model = self._select_model(provider, model, content_type)
try:
result = await self._attempt_with_retry(
provider, messages, selected_model, temperature, max_tokens, content_type
)
except RuntimeError as exc:
errors.append(str(exc))
self._record_failure(provider)
return None
self._record_success(provider, result.get("latency_ms", 0))
return {
"content": result["content"],
"provider": provider.name,
"model": result.get("model", selected_model or provider.get_default_model()),
"latency_ms": result.get("latency_ms", 0),
"is_fallback_model": is_fallback_model,
}
async def complete(
self,
messages: list[dict],
@@ -561,55 +626,15 @@ class CascadeRouter:
if content_type != ContentType.TEXT:
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
errors = []
providers = self.providers
if cascade_tier == "frontier_required":
providers = [p for p in self.providers if p.type == "anthropic"]
if not providers:
raise RuntimeError("No Anthropic provider configured for 'frontier_required' tier.")
elif cascade_tier:
providers = [p for p in self.providers if p.tier == cascade_tier]
if not providers:
raise RuntimeError(f"No providers found for tier: {cascade_tier}")
errors: list[str] = []
providers = self._filter_providers(cascade_tier)
for provider in providers:
if not self._is_provider_available(provider):
continue
# Metabolic protocol: skip cloud providers when quota is low
if provider.type in ("anthropic", "openai", "grok"):
if not self._quota_allows_cloud(provider):
logger.info(
"Metabolic protocol: skipping cloud provider %s (quota too low)",
provider.name,
)
continue
selected_model, is_fallback_model = self._select_model(provider, model, content_type)
try:
result = await self._attempt_with_retry(
provider,
messages,
selected_model,
temperature,
max_tokens,
content_type,
)
except RuntimeError as exc:
errors.append(str(exc))
self._record_failure(provider)
continue
self._record_success(provider, result.get("latency_ms", 0))
return {
"content": result["content"],
"provider": provider.name,
"model": result.get("model", selected_model or provider.get_default_model()),
"latency_ms": result.get("latency_ms", 0),
"is_fallback_model": is_fallback_model,
}
result = await self._try_single_provider(
provider, messages, model, temperature, max_tokens, content_type, errors
)
if result is not None:
return result
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")