From cd3dc5d98969d6c33806ab36f82bacda7dcae773 Mon Sep 17 00:00:00 2001 From: Kimi Agent Date: Thu, 19 Mar 2026 19:24:36 -0400 Subject: [PATCH] refactor: break up CascadeRouter.complete() into focused helpers (#510) Co-authored-by: Kimi Agent Co-committed-by: Kimi Agent --- src/infrastructure/router/cascade.py | 201 ++++++++++++++++----------- 1 file changed, 118 insertions(+), 83 deletions(-) diff --git a/src/infrastructure/router/cascade.py b/src/infrastructure/router/cascade.py index d5aba23e..40b1304b 100644 --- a/src/infrastructure/router/cascade.py +++ b/src/infrastructure/router/cascade.py @@ -388,6 +388,101 @@ class CascadeRouter: return None + def _select_model( + self, provider: Provider, model: str | None, content_type: ContentType + ) -> tuple[str | None, bool]: + """Select the best model for the request, with vision fallback. + + Returns: + Tuple of (selected_model, is_fallback_model). + """ + selected_model = model or provider.get_default_model() + is_fallback = False + + if content_type != ContentType.TEXT and selected_model: + if provider.type == "ollama" and self._mm_manager: + from infrastructure.models.multimodal import ModelCapability + + if content_type == ContentType.VISION: + supports = self._mm_manager.model_supports( + selected_model, ModelCapability.VISION + ) + if not supports: + fallback = self._get_fallback_model(provider, selected_model, content_type) + if fallback: + logger.info( + "Model %s doesn't support vision, falling back to %s", + selected_model, + fallback, + ) + selected_model = fallback + is_fallback = True + else: + logger.warning( + "No vision-capable model found on %s, trying anyway", + provider.name, + ) + + return selected_model, is_fallback + + async def _attempt_with_retry( + self, + provider: Provider, + messages: list[dict], + model: str | None, + temperature: float, + max_tokens: int | None, + content_type: ContentType, + ) -> dict: + """Try a provider with retries, returning the result dict. + + Raises: + RuntimeError: If all retry attempts fail. + Returns error strings collected during retries via the exception message. + """ + errors: list[str] = [] + for attempt in range(self.config.max_retries_per_provider): + try: + return await self._try_provider( + provider=provider, + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + content_type=content_type, + ) + except Exception as exc: + error_msg = str(exc) + logger.warning( + "Provider %s attempt %d failed: %s", + provider.name, + attempt + 1, + error_msg, + ) + errors.append(f"{provider.name}: {error_msg}") + + if attempt < self.config.max_retries_per_provider - 1: + await asyncio.sleep(self.config.retry_delay_seconds) + + raise RuntimeError("; ".join(errors)) + + def _is_provider_available(self, provider: Provider) -> bool: + """Check if a provider should be tried (enabled + circuit breaker).""" + if not provider.enabled: + logger.debug("Skipping %s (disabled)", provider.name) + return False + + if provider.status == ProviderStatus.UNHEALTHY: + if self._can_close_circuit(provider): + provider.circuit_state = CircuitState.HALF_OPEN + provider.half_open_calls = 0 + logger.info("Circuit breaker half-open for %s", provider.name) + else: + logger.debug("Skipping %s (circuit open)", provider.name) + return False + + return True + async def complete( self, messages: list[dict], @@ -414,7 +509,6 @@ class CascadeRouter: Raises: RuntimeError: If all providers fail """ - # Detect content type for multi-modal routing content_type = self._detect_content_type(messages) if content_type != ContentType.TEXT: logger.debug("Detected %s content, selecting appropriate model", content_type.value) @@ -422,93 +516,34 @@ class CascadeRouter: errors = [] for provider in self.providers: - # Skip disabled providers - if not provider.enabled: - logger.debug("Skipping %s (disabled)", provider.name) + if not self._is_provider_available(provider): continue - # Skip unhealthy providers (circuit breaker) - if provider.status == ProviderStatus.UNHEALTHY: - # Check if circuit breaker can close - if self._can_close_circuit(provider): - provider.circuit_state = CircuitState.HALF_OPEN - provider.half_open_calls = 0 - logger.info("Circuit breaker half-open for %s", provider.name) - else: - logger.debug("Skipping %s (circuit open)", provider.name) - continue + selected_model, is_fallback_model = self._select_model(provider, model, content_type) - # Determine which model to use - selected_model = model or provider.get_default_model() - is_fallback_model = False + try: + result = await self._attempt_with_retry( + provider, + messages, + selected_model, + temperature, + max_tokens, + content_type, + ) + except RuntimeError as exc: + errors.append(str(exc)) + self._record_failure(provider) + continue - # For non-text content, check if model supports it - if content_type != ContentType.TEXT and selected_model: - if provider.type == "ollama" and self._mm_manager: - from infrastructure.models.multimodal import ModelCapability + self._record_success(provider, result.get("latency_ms", 0)) + return { + "content": result["content"], + "provider": provider.name, + "model": result.get("model", selected_model or provider.get_default_model()), + "latency_ms": result.get("latency_ms", 0), + "is_fallback_model": is_fallback_model, + } - # Check if selected model supports the required capability - if content_type == ContentType.VISION: - supports = self._mm_manager.model_supports( - selected_model, ModelCapability.VISION - ) - if not supports: - # Find fallback model - fallback = self._get_fallback_model( - provider, selected_model, content_type - ) - if fallback: - logger.info( - "Model %s doesn't support vision, falling back to %s", - selected_model, - fallback, - ) - selected_model = fallback - is_fallback_model = True - else: - logger.warning( - "No vision-capable model found on %s, trying anyway", - provider.name, - ) - - # Try this provider - for attempt in range(self.config.max_retries_per_provider): - try: - result = await self._try_provider( - provider=provider, - messages=messages, - model=selected_model, - temperature=temperature, - max_tokens=max_tokens, - content_type=content_type, - ) - - # Success! Update metrics and return - self._record_success(provider, result.get("latency_ms", 0)) - return { - "content": result["content"], - "provider": provider.name, - "model": result.get( - "model", selected_model or provider.get_default_model() - ), - "latency_ms": result.get("latency_ms", 0), - "is_fallback_model": is_fallback_model, - } - - except Exception as exc: - error_msg = str(exc) - logger.warning( - "Provider %s attempt %d failed: %s", provider.name, attempt + 1, error_msg - ) - errors.append(f"{provider.name}: {error_msg}") - - if attempt < self.config.max_retries_per_provider - 1: - await asyncio.sleep(self.config.retry_delay_seconds) - - # All retries failed for this provider - self._record_failure(provider) - - # All providers failed raise RuntimeError(f"All providers failed: {'; '.join(errors)}") async def _try_provider(