refactor: break up cascade.py complete() into helper functions
Extract _filter_providers() and _try_single_provider() from complete(), reducing it from 84 lines to under 25. Each helper has a single clear responsibility and is tested directly. - _filter_providers(cascade_tier): isolates tier-based provider filtering - _try_single_provider(...): encapsulates per-provider attempt logic (availability check, quota gate, model selection, retry, metrics) Fixes #1185 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -528,6 +528,71 @@ class CascadeRouter:
|
||||
|
||||
return True
|
||||
|
||||
def _filter_providers(self, cascade_tier: str | None) -> list["Provider"]:
|
||||
"""Return the provider list filtered by tier.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If a tier is specified but no matching providers exist.
|
||||
"""
|
||||
if cascade_tier == "frontier_required":
|
||||
providers = [p for p in self.providers if p.type == "anthropic"]
|
||||
if not providers:
|
||||
raise RuntimeError("No Anthropic provider configured for 'frontier_required' tier.")
|
||||
return providers
|
||||
if cascade_tier:
|
||||
providers = [p for p in self.providers if p.tier == cascade_tier]
|
||||
if not providers:
|
||||
raise RuntimeError(f"No providers found for tier: {cascade_tier}")
|
||||
return providers
|
||||
return self.providers
|
||||
|
||||
async def _try_single_provider(
|
||||
self,
|
||||
provider: "Provider",
|
||||
messages: list[dict],
|
||||
model: str | None,
|
||||
temperature: float,
|
||||
max_tokens: int | None,
|
||||
content_type: ContentType,
|
||||
errors: list[str],
|
||||
) -> dict | None:
|
||||
"""Attempt one provider, returning a result dict on success or None on failure.
|
||||
|
||||
On failure the error string is appended to *errors* and the provider's
|
||||
failure metrics are updated so the caller can move on to the next provider.
|
||||
"""
|
||||
if not self._is_provider_available(provider):
|
||||
return None
|
||||
|
||||
# Metabolic protocol: skip cloud providers when quota is low
|
||||
if provider.type in ("anthropic", "openai", "grok"):
|
||||
if not self._quota_allows_cloud(provider):
|
||||
logger.info(
|
||||
"Metabolic protocol: skipping cloud provider %s (quota too low)",
|
||||
provider.name,
|
||||
)
|
||||
return None
|
||||
|
||||
selected_model, is_fallback_model = self._select_model(provider, model, content_type)
|
||||
|
||||
try:
|
||||
result = await self._attempt_with_retry(
|
||||
provider, messages, selected_model, temperature, max_tokens, content_type
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
errors.append(str(exc))
|
||||
self._record_failure(provider)
|
||||
return None
|
||||
|
||||
self._record_success(provider, result.get("latency_ms", 0))
|
||||
return {
|
||||
"content": result["content"],
|
||||
"provider": provider.name,
|
||||
"model": result.get("model", selected_model or provider.get_default_model()),
|
||||
"latency_ms": result.get("latency_ms", 0),
|
||||
"is_fallback_model": is_fallback_model,
|
||||
}
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict],
|
||||
@@ -561,55 +626,15 @@ class CascadeRouter:
|
||||
if content_type != ContentType.TEXT:
|
||||
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
|
||||
|
||||
errors = []
|
||||
|
||||
providers = self.providers
|
||||
if cascade_tier == "frontier_required":
|
||||
providers = [p for p in self.providers if p.type == "anthropic"]
|
||||
if not providers:
|
||||
raise RuntimeError("No Anthropic provider configured for 'frontier_required' tier.")
|
||||
elif cascade_tier:
|
||||
providers = [p for p in self.providers if p.tier == cascade_tier]
|
||||
if not providers:
|
||||
raise RuntimeError(f"No providers found for tier: {cascade_tier}")
|
||||
errors: list[str] = []
|
||||
providers = self._filter_providers(cascade_tier)
|
||||
|
||||
for provider in providers:
|
||||
if not self._is_provider_available(provider):
|
||||
continue
|
||||
|
||||
# Metabolic protocol: skip cloud providers when quota is low
|
||||
if provider.type in ("anthropic", "openai", "grok"):
|
||||
if not self._quota_allows_cloud(provider):
|
||||
logger.info(
|
||||
"Metabolic protocol: skipping cloud provider %s (quota too low)",
|
||||
provider.name,
|
||||
)
|
||||
continue
|
||||
|
||||
selected_model, is_fallback_model = self._select_model(provider, model, content_type)
|
||||
|
||||
try:
|
||||
result = await self._attempt_with_retry(
|
||||
provider,
|
||||
messages,
|
||||
selected_model,
|
||||
temperature,
|
||||
max_tokens,
|
||||
content_type,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
errors.append(str(exc))
|
||||
self._record_failure(provider)
|
||||
continue
|
||||
|
||||
self._record_success(provider, result.get("latency_ms", 0))
|
||||
return {
|
||||
"content": result["content"],
|
||||
"provider": provider.name,
|
||||
"model": result.get("model", selected_model or provider.get_default_model()),
|
||||
"latency_ms": result.get("latency_ms", 0),
|
||||
"is_fallback_model": is_fallback_model,
|
||||
}
|
||||
result = await self._try_single_provider(
|
||||
provider, messages, model, temperature, max_tokens, content_type, errors
|
||||
)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
|
||||
|
||||
|
||||
@@ -1376,3 +1376,141 @@ class TestIsProviderAvailable:
|
||||
result = router._is_provider_available(provider)
|
||||
assert result is True
|
||||
assert provider.circuit_state == CircuitState.HALF_OPEN
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFilterProviders:
|
||||
"""Test _filter_providers helper extracted from complete()."""
|
||||
|
||||
def _router(self) -> CascadeRouter:
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router.providers = [
|
||||
Provider(
|
||||
name="anthropic-p",
|
||||
type="anthropic",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
api_key="key",
|
||||
tier="frontier",
|
||||
),
|
||||
Provider(
|
||||
name="ollama-p",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=2,
|
||||
tier="local",
|
||||
),
|
||||
]
|
||||
return router
|
||||
|
||||
def test_no_tier_returns_all_providers(self):
|
||||
router = self._router()
|
||||
result = router._filter_providers(None)
|
||||
assert result is router.providers
|
||||
|
||||
def test_frontier_required_returns_only_anthropic(self):
|
||||
router = self._router()
|
||||
result = router._filter_providers("frontier_required")
|
||||
assert len(result) == 1
|
||||
assert result[0].type == "anthropic"
|
||||
|
||||
def test_frontier_required_no_anthropic_raises(self):
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router.providers = [
|
||||
Provider(name="ollama-p", type="ollama", enabled=True, priority=1)
|
||||
]
|
||||
with pytest.raises(RuntimeError, match="No Anthropic provider configured"):
|
||||
router._filter_providers("frontier_required")
|
||||
|
||||
def test_named_tier_filters_by_tier(self):
|
||||
router = self._router()
|
||||
result = router._filter_providers("local")
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "ollama-p"
|
||||
|
||||
def test_named_tier_not_found_raises(self):
|
||||
router = self._router()
|
||||
with pytest.raises(RuntimeError, match="No providers found for tier"):
|
||||
router._filter_providers("nonexistent")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
class TestTrySingleProvider:
|
||||
"""Test _try_single_provider helper extracted from complete()."""
|
||||
|
||||
def _router(self) -> CascadeRouter:
|
||||
return CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
def _provider(self, name: str = "test", ptype: str = "ollama") -> Provider:
|
||||
return Provider(
|
||||
name=name,
|
||||
type=ptype,
|
||||
enabled=True,
|
||||
priority=1,
|
||||
models=[{"name": "llama3.2", "default": True}],
|
||||
)
|
||||
|
||||
async def test_unavailable_provider_returns_none(self):
|
||||
router = self._router()
|
||||
provider = self._provider()
|
||||
provider.enabled = False
|
||||
errors: list[str] = []
|
||||
result = await router._try_single_provider(
|
||||
provider, [], None, 0.7, None, ContentType.TEXT, errors
|
||||
)
|
||||
assert result is None
|
||||
assert errors == []
|
||||
|
||||
async def test_quota_blocked_cloud_provider_returns_none(self):
|
||||
router = self._router()
|
||||
provider = self._provider(ptype="anthropic")
|
||||
errors: list[str] = []
|
||||
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
||||
mock_qm.select_model.return_value = "qwen3:14b" # non-cloud → ACTIVE tier
|
||||
mock_qm.check.return_value = None
|
||||
result = await router._try_single_provider(
|
||||
provider, [], None, 0.7, None, ContentType.TEXT, errors
|
||||
)
|
||||
assert result is None
|
||||
assert errors == []
|
||||
|
||||
async def test_success_returns_result_dict(self):
|
||||
router = self._router()
|
||||
provider = self._provider()
|
||||
errors: list[str] = []
|
||||
with patch.object(router, "_call_ollama") as mock_call:
|
||||
mock_call.return_value = {"content": "hi", "model": "llama3.2"}
|
||||
result = await router._try_single_provider(
|
||||
provider,
|
||||
[{"role": "user", "content": "hi"}],
|
||||
None,
|
||||
0.7,
|
||||
None,
|
||||
ContentType.TEXT,
|
||||
errors,
|
||||
)
|
||||
assert result is not None
|
||||
assert result["content"] == "hi"
|
||||
assert result["provider"] == "test"
|
||||
assert errors == []
|
||||
|
||||
async def test_failure_appends_error_and_returns_none(self):
|
||||
router = self._router()
|
||||
provider = self._provider()
|
||||
errors: list[str] = []
|
||||
with patch.object(router, "_call_ollama") as mock_call:
|
||||
mock_call.side_effect = RuntimeError("boom")
|
||||
result = await router._try_single_provider(
|
||||
provider,
|
||||
[{"role": "user", "content": "hi"}],
|
||||
None,
|
||||
0.7,
|
||||
None,
|
||||
ContentType.TEXT,
|
||||
errors,
|
||||
)
|
||||
assert result is None
|
||||
assert len(errors) == 1
|
||||
assert "boom" in errors[0]
|
||||
assert provider.metrics.failed_requests == 1
|
||||
|
||||
Reference in New Issue
Block a user