diff --git a/src/infrastructure/router/cascade.py b/src/infrastructure/router/cascade.py index 1cb9747a..be85939f 100644 --- a/src/infrastructure/router/cascade.py +++ b/src/infrastructure/router/cascade.py @@ -528,6 +528,71 @@ class CascadeRouter: return True + def _filter_providers(self, cascade_tier: str | None) -> list["Provider"]: + """Return the provider list filtered by tier. + + Raises: + RuntimeError: If a tier is specified but no matching providers exist. + """ + if cascade_tier == "frontier_required": + providers = [p for p in self.providers if p.type == "anthropic"] + if not providers: + raise RuntimeError("No Anthropic provider configured for 'frontier_required' tier.") + return providers + if cascade_tier: + providers = [p for p in self.providers if p.tier == cascade_tier] + if not providers: + raise RuntimeError(f"No providers found for tier: {cascade_tier}") + return providers + return self.providers + + async def _try_single_provider( + self, + provider: "Provider", + messages: list[dict], + model: str | None, + temperature: float, + max_tokens: int | None, + content_type: ContentType, + errors: list[str], + ) -> dict | None: + """Attempt one provider, returning a result dict on success or None on failure. + + On failure the error string is appended to *errors* and the provider's + failure metrics are updated so the caller can move on to the next provider. + """ + if not self._is_provider_available(provider): + return None + + # Metabolic protocol: skip cloud providers when quota is low + if provider.type in ("anthropic", "openai", "grok"): + if not self._quota_allows_cloud(provider): + logger.info( + "Metabolic protocol: skipping cloud provider %s (quota too low)", + provider.name, + ) + return None + + selected_model, is_fallback_model = self._select_model(provider, model, content_type) + + try: + result = await self._attempt_with_retry( + provider, messages, selected_model, temperature, max_tokens, content_type + ) + except RuntimeError as exc: + errors.append(str(exc)) + self._record_failure(provider) + return None + + self._record_success(provider, result.get("latency_ms", 0)) + return { + "content": result["content"], + "provider": provider.name, + "model": result.get("model", selected_model or provider.get_default_model()), + "latency_ms": result.get("latency_ms", 0), + "is_fallback_model": is_fallback_model, + } + async def complete( self, messages: list[dict], @@ -561,55 +626,15 @@ class CascadeRouter: if content_type != ContentType.TEXT: logger.debug("Detected %s content, selecting appropriate model", content_type.value) - errors = [] - - providers = self.providers - if cascade_tier == "frontier_required": - providers = [p for p in self.providers if p.type == "anthropic"] - if not providers: - raise RuntimeError("No Anthropic provider configured for 'frontier_required' tier.") - elif cascade_tier: - providers = [p for p in self.providers if p.tier == cascade_tier] - if not providers: - raise RuntimeError(f"No providers found for tier: {cascade_tier}") + errors: list[str] = [] + providers = self._filter_providers(cascade_tier) for provider in providers: - if not self._is_provider_available(provider): - continue - - # Metabolic protocol: skip cloud providers when quota is low - if provider.type in ("anthropic", "openai", "grok"): - if not self._quota_allows_cloud(provider): - logger.info( - "Metabolic protocol: skipping cloud provider %s (quota too low)", - provider.name, - ) - continue - - selected_model, is_fallback_model = self._select_model(provider, model, content_type) - - try: - result = await self._attempt_with_retry( - provider, - messages, - selected_model, - temperature, - max_tokens, - content_type, - ) - except RuntimeError as exc: - errors.append(str(exc)) - self._record_failure(provider) - continue - - self._record_success(provider, result.get("latency_ms", 0)) - return { - "content": result["content"], - "provider": provider.name, - "model": result.get("model", selected_model or provider.get_default_model()), - "latency_ms": result.get("latency_ms", 0), - "is_fallback_model": is_fallback_model, - } + result = await self._try_single_provider( + provider, messages, model, temperature, max_tokens, content_type, errors + ) + if result is not None: + return result raise RuntimeError(f"All providers failed: {'; '.join(errors)}") diff --git a/tests/infrastructure/test_router_cascade.py b/tests/infrastructure/test_router_cascade.py index 10aa9ac0..aaa79bbc 100644 --- a/tests/infrastructure/test_router_cascade.py +++ b/tests/infrastructure/test_router_cascade.py @@ -1376,3 +1376,141 @@ class TestIsProviderAvailable: result = router._is_provider_available(provider) assert result is True assert provider.circuit_state == CircuitState.HALF_OPEN + + +@pytest.mark.unit +class TestFilterProviders: + """Test _filter_providers helper extracted from complete().""" + + def _router(self) -> CascadeRouter: + router = CascadeRouter(config_path=Path("/nonexistent")) + router.providers = [ + Provider( + name="anthropic-p", + type="anthropic", + enabled=True, + priority=1, + api_key="key", + tier="frontier", + ), + Provider( + name="ollama-p", + type="ollama", + enabled=True, + priority=2, + tier="local", + ), + ] + return router + + def test_no_tier_returns_all_providers(self): + router = self._router() + result = router._filter_providers(None) + assert result is router.providers + + def test_frontier_required_returns_only_anthropic(self): + router = self._router() + result = router._filter_providers("frontier_required") + assert len(result) == 1 + assert result[0].type == "anthropic" + + def test_frontier_required_no_anthropic_raises(self): + router = CascadeRouter(config_path=Path("/nonexistent")) + router.providers = [ + Provider(name="ollama-p", type="ollama", enabled=True, priority=1) + ] + with pytest.raises(RuntimeError, match="No Anthropic provider configured"): + router._filter_providers("frontier_required") + + def test_named_tier_filters_by_tier(self): + router = self._router() + result = router._filter_providers("local") + assert len(result) == 1 + assert result[0].name == "ollama-p" + + def test_named_tier_not_found_raises(self): + router = self._router() + with pytest.raises(RuntimeError, match="No providers found for tier"): + router._filter_providers("nonexistent") + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestTrySingleProvider: + """Test _try_single_provider helper extracted from complete().""" + + def _router(self) -> CascadeRouter: + return CascadeRouter(config_path=Path("/nonexistent")) + + def _provider(self, name: str = "test", ptype: str = "ollama") -> Provider: + return Provider( + name=name, + type=ptype, + enabled=True, + priority=1, + models=[{"name": "llama3.2", "default": True}], + ) + + async def test_unavailable_provider_returns_none(self): + router = self._router() + provider = self._provider() + provider.enabled = False + errors: list[str] = [] + result = await router._try_single_provider( + provider, [], None, 0.7, None, ContentType.TEXT, errors + ) + assert result is None + assert errors == [] + + async def test_quota_blocked_cloud_provider_returns_none(self): + router = self._router() + provider = self._provider(ptype="anthropic") + errors: list[str] = [] + with patch("infrastructure.router.cascade._quota_monitor") as mock_qm: + mock_qm.select_model.return_value = "qwen3:14b" # non-cloud → ACTIVE tier + mock_qm.check.return_value = None + result = await router._try_single_provider( + provider, [], None, 0.7, None, ContentType.TEXT, errors + ) + assert result is None + assert errors == [] + + async def test_success_returns_result_dict(self): + router = self._router() + provider = self._provider() + errors: list[str] = [] + with patch.object(router, "_call_ollama") as mock_call: + mock_call.return_value = {"content": "hi", "model": "llama3.2"} + result = await router._try_single_provider( + provider, + [{"role": "user", "content": "hi"}], + None, + 0.7, + None, + ContentType.TEXT, + errors, + ) + assert result is not None + assert result["content"] == "hi" + assert result["provider"] == "test" + assert errors == [] + + async def test_failure_appends_error_and_returns_none(self): + router = self._router() + provider = self._provider() + errors: list[str] = [] + with patch.object(router, "_call_ollama") as mock_call: + mock_call.side_effect = RuntimeError("boom") + result = await router._try_single_provider( + provider, + [{"role": "user", "content": "hi"}], + None, + 0.7, + None, + ContentType.TEXT, + errors, + ) + assert result is None + assert len(errors) == 1 + assert "boom" in errors[0] + assert provider.metrics.failed_requests == 1