From e2d662257ab8b2fda5afe2d5b54610acc69accad Mon Sep 17 00:00:00 2001 From: kimi Date: Mon, 23 Mar 2026 17:55:30 -0400 Subject: [PATCH] refactor: break up cascade.py::complete() into helper methods (#1185) Extract _get_providers_for_tier() and _try_single_provider() to reduce the complexity of complete(). The method was 84 lines; now the main logic is clearer and each helper has a single responsibility. - _get_providers_for_tier(): Filters providers by cascade_tier - _try_single_provider(): Attempts a single provider with metabolic protocol Fixes #1185 --- src/infrastructure/router/cascade.py | 105 +++++++++++++++++---------- 1 file changed, 68 insertions(+), 37 deletions(-) diff --git a/src/infrastructure/router/cascade.py b/src/infrastructure/router/cascade.py index 1cb9747a..bef9d19e 100644 --- a/src/infrastructure/router/cascade.py +++ b/src/infrastructure/router/cascade.py @@ -528,6 +528,68 @@ class CascadeRouter: return True + def _get_providers_for_tier(self, cascade_tier: str | None) -> list[Provider]: + """Filter providers by tier, returning eligible providers.""" + if cascade_tier == "frontier_required": + providers = [p for p in self.providers if p.type == "anthropic"] + if not providers: + raise RuntimeError("No Anthropic provider configured for 'frontier_required' tier.") + return providers + elif cascade_tier: + providers = [p for p in self.providers if p.tier == cascade_tier] + if not providers: + raise RuntimeError(f"No providers found for tier: {cascade_tier}") + return providers + return self.providers + + async def _try_single_provider( + self, + provider: Provider, + messages: list[dict], + model: str | None, + temperature: float, + max_tokens: int | None, + content_type: ContentType, + ) -> dict | None: + """Attempt a single provider request. + + Returns: + Response dict on success, None if provider should be skipped. + Raises: + RuntimeError: If the provider attempt fails. + """ + if not self._is_provider_available(provider): + return None + + # Metabolic protocol: skip cloud providers when quota is low + if provider.type in ("anthropic", "openai", "grok"): + if not self._quota_allows_cloud(provider): + logger.info( + "Metabolic protocol: skipping cloud provider %s (quota too low)", + provider.name, + ) + return None + + selected_model, is_fallback_model = self._select_model(provider, model, content_type) + + result = await self._attempt_with_retry( + provider, + messages, + selected_model, + temperature, + max_tokens, + content_type, + ) + + self._record_success(provider, result.get("latency_ms", 0)) + return { + "content": result["content"], + "provider": provider.name, + "model": result.get("model", selected_model or provider.get_default_model()), + "latency_ms": result.get("latency_ms", 0), + "is_fallback_model": is_fallback_model, + } + async def complete( self, messages: list[dict], @@ -561,55 +623,24 @@ class CascadeRouter: if content_type != ContentType.TEXT: logger.debug("Detected %s content, selecting appropriate model", content_type.value) - errors = [] - - providers = self.providers - if cascade_tier == "frontier_required": - providers = [p for p in self.providers if p.type == "anthropic"] - if not providers: - raise RuntimeError("No Anthropic provider configured for 'frontier_required' tier.") - elif cascade_tier: - providers = [p for p in self.providers if p.tier == cascade_tier] - if not providers: - raise RuntimeError(f"No providers found for tier: {cascade_tier}") + errors: list[str] = [] + providers = self._get_providers_for_tier(cascade_tier) for provider in providers: - if not self._is_provider_available(provider): - continue - - # Metabolic protocol: skip cloud providers when quota is low - if provider.type in ("anthropic", "openai", "grok"): - if not self._quota_allows_cloud(provider): - logger.info( - "Metabolic protocol: skipping cloud provider %s (quota too low)", - provider.name, - ) - continue - - selected_model, is_fallback_model = self._select_model(provider, model, content_type) - try: - result = await self._attempt_with_retry( + result = await self._try_single_provider( provider, messages, - selected_model, + model, temperature, max_tokens, content_type, ) + if result is not None: + return result except RuntimeError as exc: errors.append(str(exc)) self._record_failure(provider) - continue - - self._record_success(provider, result.get("latency_ms", 0)) - return { - "content": result["content"], - "provider": provider.name, - "model": result.get("model", selected_model or provider.get_default_model()), - "latency_ms": result.get("latency_ms", 0), - "is_fallback_model": is_fallback_model, - } raise RuntimeError(f"All providers failed: {'; '.join(errors)}") -- 2.43.0