1
0

refactor: break up CascadeRouter.complete() into focused helpers (#510)

Co-authored-by: Kimi Agent <kimi@timmy.local>
Co-committed-by: Kimi Agent <kimi@timmy.local>
This commit is contained in:
2026-03-19 19:24:36 -04:00
committed by Timmy Time
parent e4de539bf3
commit cd3dc5d989

View File

@@ -388,6 +388,101 @@ class CascadeRouter:
return None
def _select_model(
self, provider: Provider, model: str | None, content_type: ContentType
) -> tuple[str | None, bool]:
"""Select the best model for the request, with vision fallback.
Returns:
Tuple of (selected_model, is_fallback_model).
"""
selected_model = model or provider.get_default_model()
is_fallback = False
if content_type != ContentType.TEXT and selected_model:
if provider.type == "ollama" and self._mm_manager:
from infrastructure.models.multimodal import ModelCapability
if content_type == ContentType.VISION:
supports = self._mm_manager.model_supports(
selected_model, ModelCapability.VISION
)
if not supports:
fallback = self._get_fallback_model(provider, selected_model, content_type)
if fallback:
logger.info(
"Model %s doesn't support vision, falling back to %s",
selected_model,
fallback,
)
selected_model = fallback
is_fallback = True
else:
logger.warning(
"No vision-capable model found on %s, trying anyway",
provider.name,
)
return selected_model, is_fallback
async def _attempt_with_retry(
self,
provider: Provider,
messages: list[dict],
model: str | None,
temperature: float,
max_tokens: int | None,
content_type: ContentType,
) -> dict:
"""Try a provider with retries, returning the result dict.
Raises:
RuntimeError: If all retry attempts fail.
Returns error strings collected during retries via the exception message.
"""
errors: list[str] = []
for attempt in range(self.config.max_retries_per_provider):
try:
return await self._try_provider(
provider=provider,
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
content_type=content_type,
)
except Exception as exc:
error_msg = str(exc)
logger.warning(
"Provider %s attempt %d failed: %s",
provider.name,
attempt + 1,
error_msg,
)
errors.append(f"{provider.name}: {error_msg}")
if attempt < self.config.max_retries_per_provider - 1:
await asyncio.sleep(self.config.retry_delay_seconds)
raise RuntimeError("; ".join(errors))
def _is_provider_available(self, provider: Provider) -> bool:
"""Check if a provider should be tried (enabled + circuit breaker)."""
if not provider.enabled:
logger.debug("Skipping %s (disabled)", provider.name)
return False
if provider.status == ProviderStatus.UNHEALTHY:
if self._can_close_circuit(provider):
provider.circuit_state = CircuitState.HALF_OPEN
provider.half_open_calls = 0
logger.info("Circuit breaker half-open for %s", provider.name)
else:
logger.debug("Skipping %s (circuit open)", provider.name)
return False
return True
async def complete(
self,
messages: list[dict],
@@ -414,7 +509,6 @@ class CascadeRouter:
Raises:
RuntimeError: If all providers fail
"""
# Detect content type for multi-modal routing
content_type = self._detect_content_type(messages)
if content_type != ContentType.TEXT:
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
@@ -422,93 +516,34 @@ class CascadeRouter:
errors = []
for provider in self.providers:
# Skip disabled providers
if not provider.enabled:
logger.debug("Skipping %s (disabled)", provider.name)
if not self._is_provider_available(provider):
continue
# Skip unhealthy providers (circuit breaker)
if provider.status == ProviderStatus.UNHEALTHY:
# Check if circuit breaker can close
if self._can_close_circuit(provider):
provider.circuit_state = CircuitState.HALF_OPEN
provider.half_open_calls = 0
logger.info("Circuit breaker half-open for %s", provider.name)
else:
logger.debug("Skipping %s (circuit open)", provider.name)
continue
selected_model, is_fallback_model = self._select_model(provider, model, content_type)
# Determine which model to use
selected_model = model or provider.get_default_model()
is_fallback_model = False
try:
result = await self._attempt_with_retry(
provider,
messages,
selected_model,
temperature,
max_tokens,
content_type,
)
except RuntimeError as exc:
errors.append(str(exc))
self._record_failure(provider)
continue
# For non-text content, check if model supports it
if content_type != ContentType.TEXT and selected_model:
if provider.type == "ollama" and self._mm_manager:
from infrastructure.models.multimodal import ModelCapability
self._record_success(provider, result.get("latency_ms", 0))
return {
"content": result["content"],
"provider": provider.name,
"model": result.get("model", selected_model or provider.get_default_model()),
"latency_ms": result.get("latency_ms", 0),
"is_fallback_model": is_fallback_model,
}
# Check if selected model supports the required capability
if content_type == ContentType.VISION:
supports = self._mm_manager.model_supports(
selected_model, ModelCapability.VISION
)
if not supports:
# Find fallback model
fallback = self._get_fallback_model(
provider, selected_model, content_type
)
if fallback:
logger.info(
"Model %s doesn't support vision, falling back to %s",
selected_model,
fallback,
)
selected_model = fallback
is_fallback_model = True
else:
logger.warning(
"No vision-capable model found on %s, trying anyway",
provider.name,
)
# Try this provider
for attempt in range(self.config.max_retries_per_provider):
try:
result = await self._try_provider(
provider=provider,
messages=messages,
model=selected_model,
temperature=temperature,
max_tokens=max_tokens,
content_type=content_type,
)
# Success! Update metrics and return
self._record_success(provider, result.get("latency_ms", 0))
return {
"content": result["content"],
"provider": provider.name,
"model": result.get(
"model", selected_model or provider.get_default_model()
),
"latency_ms": result.get("latency_ms", 0),
"is_fallback_model": is_fallback_model,
}
except Exception as exc:
error_msg = str(exc)
logger.warning(
"Provider %s attempt %d failed: %s", provider.name, attempt + 1, error_msg
)
errors.append(f"{provider.name}: {error_msg}")
if attempt < self.config.max_retries_per_provider - 1:
await asyncio.sleep(self.config.retry_delay_seconds)
# All retries failed for this provider
self._record_failure(provider)
# All providers failed
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
async def _try_provider(