diff --git a/src/infrastructure/router/cascade.py b/src/infrastructure/router/cascade.py index b9c07514..4f7510ae 100644 --- a/src/infrastructure/router/cascade.py +++ b/src/infrastructure/router/cascade.py @@ -485,18 +485,26 @@ class CascadeRouter: def _quota_allows_cloud(self, provider: Provider) -> bool: """Check quota before routing to a cloud provider. - Uses the metabolic protocol: cloud calls are gated by 5-hour quota. + Uses the metabolic protocol via select_model(): cloud calls are only + allowed when the quota monitor recommends a cloud model (BURST tier). Returns True (allow cloud) if quota monitor is unavailable or returns None. """ if _quota_monitor is None: return True try: - # Map provider type to task_value heuristic - task_value = "high" # conservative default - status = _quota_monitor.check() - if status is None: - return True # No credentials — caller decides based on config - return _quota_monitor.should_use_cloud(task_value) + suggested = _quota_monitor.select_model("high") + # Cloud is allowed only when select_model recommends the cloud model + allows = suggested == "claude-sonnet-4-6" + if not allows: + status = _quota_monitor.check() + tier = status.recommended_tier.value if status else "unknown" + logger.info( + "Metabolic protocol: %s tier — downshifting %s to local (%s)", + tier, + provider.name, + suggested, + ) + return allows except Exception as exc: logger.warning("Quota check failed, allowing cloud: %s", exc) return True diff --git a/tests/infrastructure/test_router_cascade.py b/tests/infrastructure/test_router_cascade.py index fc55f6e6..ca881c6a 100644 --- a/tests/infrastructure/test_router_cascade.py +++ b/tests/infrastructure/test_router_cascade.py @@ -664,10 +664,10 @@ class TestVllmMlxProvider: ) router.providers = [provider] - # Quota monitor returns False (block cloud) — vllm_mlx should still be tried + # Quota monitor downshifts to local (ACTIVE tier) — vllm_mlx should still be tried with patch("infrastructure.router.cascade._quota_monitor") as mock_qm: - mock_qm.check.return_value = object() - mock_qm.should_use_cloud.return_value = False + mock_qm.select_model.return_value = "qwen3:14b" + mock_qm.check.return_value = None with patch.object(router, "_call_vllm_mlx") as mock_call: mock_call.return_value = { @@ -681,6 +681,115 @@ class TestVllmMlxProvider: assert result["content"] == "Local MLX response" +class TestMetabolicProtocol: + """Test metabolic protocol: cloud providers skip when quota is ACTIVE/RESTING.""" + + def _make_anthropic_provider(self) -> "Provider": + return Provider( + name="anthropic-primary", + type="anthropic", + enabled=True, + priority=1, + api_key="test-key", + models=[{"name": "claude-sonnet-4-6", "default": True}], + ) + + async def test_cloud_provider_allowed_in_burst_tier(self): + """BURST tier (quota healthy): cloud provider is tried.""" + router = CascadeRouter(config_path=Path("/nonexistent")) + router.providers = [self._make_anthropic_provider()] + + with patch("infrastructure.router.cascade._quota_monitor") as mock_qm: + # select_model returns cloud model → BURST tier + mock_qm.select_model.return_value = "claude-sonnet-4-6" + mock_qm.check.return_value = None + + with patch.object(router, "_call_anthropic") as mock_call: + mock_call.return_value = {"content": "Cloud response", "model": "claude-sonnet-4-6"} + result = await router.complete( + messages=[{"role": "user", "content": "hard question"}], + ) + + mock_call.assert_called_once() + assert result["content"] == "Cloud response" + + async def test_cloud_provider_skipped_in_active_tier(self): + """ACTIVE tier (5-hour >= 50%): cloud provider is skipped.""" + router = CascadeRouter(config_path=Path("/nonexistent")) + router.providers = [self._make_anthropic_provider()] + + with patch("infrastructure.router.cascade._quota_monitor") as mock_qm: + # select_model returns local 14B → ACTIVE tier + mock_qm.select_model.return_value = "qwen3:14b" + mock_qm.check.return_value = None + + with patch.object(router, "_call_anthropic") as mock_call: + with pytest.raises(RuntimeError, match="All providers failed"): + await router.complete( + messages=[{"role": "user", "content": "question"}], + ) + + mock_call.assert_not_called() + + async def test_cloud_provider_skipped_in_resting_tier(self): + """RESTING tier (7-day >= 80%): cloud provider is skipped.""" + router = CascadeRouter(config_path=Path("/nonexistent")) + router.providers = [self._make_anthropic_provider()] + + with patch("infrastructure.router.cascade._quota_monitor") as mock_qm: + # select_model returns local 8B → RESTING tier + mock_qm.select_model.return_value = "qwen3:8b" + mock_qm.check.return_value = None + + with patch.object(router, "_call_anthropic") as mock_call: + with pytest.raises(RuntimeError, match="All providers failed"): + await router.complete( + messages=[{"role": "user", "content": "simple question"}], + ) + + mock_call.assert_not_called() + + async def test_local_provider_always_tried_regardless_of_quota(self): + """Local (ollama/vllm_mlx) providers bypass the metabolic protocol.""" + router = CascadeRouter(config_path=Path("/nonexistent")) + provider = Provider( + name="ollama-local", + type="ollama", + enabled=True, + priority=1, + url="http://localhost:11434", + models=[{"name": "qwen3:14b", "default": True}], + ) + router.providers = [provider] + + with patch("infrastructure.router.cascade._quota_monitor") as mock_qm: + mock_qm.select_model.return_value = "qwen3:8b" # RESTING tier + + with patch.object(router, "_call_ollama") as mock_call: + mock_call.return_value = {"content": "Local response", "model": "qwen3:14b"} + result = await router.complete( + messages=[{"role": "user", "content": "hi"}], + ) + + mock_call.assert_called_once() + assert result["content"] == "Local response" + + async def test_no_quota_monitor_allows_cloud(self): + """When quota monitor is None (unavailable), cloud providers are allowed.""" + router = CascadeRouter(config_path=Path("/nonexistent")) + router.providers = [self._make_anthropic_provider()] + + with patch("infrastructure.router.cascade._quota_monitor", None): + with patch.object(router, "_call_anthropic") as mock_call: + mock_call.return_value = {"content": "Cloud response", "model": "claude-sonnet-4-6"} + result = await router.complete( + messages=[{"role": "user", "content": "question"}], + ) + + mock_call.assert_called_once() + assert result["content"] == "Cloud response" + + class TestCascadeRouterReload: """Test hot-reload of providers.yaml."""