From 0e5948632ddcf4a921d0a5b013ac177984352b68 Mon Sep 17 00:00:00 2001 From: "Claude (Opus 4.6)" Date: Mon, 23 Mar 2026 18:47:28 +0000 Subject: [PATCH] [claude] Add unit tests for cascade.py (#1138) (#1150) --- tests/infrastructure/test_router_cascade.py | 409 +++++++++++++++++++- 1 file changed, 408 insertions(+), 1 deletion(-) diff --git a/tests/infrastructure/test_router_cascade.py b/tests/infrastructure/test_router_cascade.py index ca881c6..5b539e9 100644 --- a/tests/infrastructure/test_router_cascade.py +++ b/tests/infrastructure/test_router_cascade.py @@ -2,7 +2,7 @@ import time from pathlib import Path -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest import yaml @@ -10,13 +10,16 @@ import yaml from infrastructure.router.cascade import ( CascadeRouter, CircuitState, + ContentType, Provider, ProviderMetrics, ProviderStatus, RouterConfig, + get_router, ) +@pytest.mark.unit class TestProviderMetrics: """Test provider metrics tracking.""" @@ -45,6 +48,7 @@ class TestProviderMetrics: assert metrics.error_rate == 0.3 +@pytest.mark.unit class TestProvider: """Test Provider dataclass.""" @@ -88,6 +92,7 @@ class TestProvider: assert provider.get_default_model() is None +@pytest.mark.unit class TestRouterConfig: """Test router configuration.""" @@ -100,6 +105,7 @@ class TestRouterConfig: assert config.circuit_breaker_failure_threshold == 5 +@pytest.mark.unit class TestCascadeRouterInit: """Test CascadeRouter initialization.""" @@ -158,6 +164,7 @@ class TestCascadeRouterInit: assert router.providers[0].api_key == "secret123" +@pytest.mark.unit class TestCascadeRouterMetrics: """Test metrics tracking.""" @@ -241,6 +248,7 @@ class TestCascadeRouterMetrics: assert provider.status == ProviderStatus.HEALTHY +@pytest.mark.unit class TestCascadeRouterGetMetrics: """Test get_metrics method.""" @@ -280,6 +288,7 @@ class TestCascadeRouterGetMetrics: assert p_metrics["metrics"]["avg_latency_ms"] == 200.0 +@pytest.mark.unit class TestCascadeRouterGetStatus: """Test get_status method.""" @@ -305,6 +314,7 @@ class TestCascadeRouterGetStatus: assert len(status["providers"]) == 1 +@pytest.mark.unit @pytest.mark.asyncio class TestCascadeRouterComplete: """Test complete method with failover.""" @@ -436,6 +446,7 @@ class TestCascadeRouterComplete: assert result["provider"] == "healthy" +@pytest.mark.unit class TestProviderAvailabilityCheck: """Test provider availability checking.""" @@ -577,6 +588,7 @@ class TestProviderAvailabilityCheck: mock_requests.get.assert_called_once_with("http://localhost:8000/health", timeout=5) +@pytest.mark.unit @pytest.mark.asyncio class TestVllmMlxProvider: """Test vllm-mlx provider integration.""" @@ -681,6 +693,8 @@ class TestVllmMlxProvider: assert result["content"] == "Local MLX response" +@pytest.mark.unit +@pytest.mark.asyncio class TestMetabolicProtocol: """Test metabolic protocol: cloud providers skip when quota is ACTIVE/RESTING.""" @@ -790,6 +804,7 @@ class TestMetabolicProtocol: assert result["content"] == "Cloud response" +@pytest.mark.unit class TestCascadeRouterReload: """Test hot-reload of providers.yaml.""" @@ -968,3 +983,395 @@ class TestCascadeRouterReload: assert router.providers[0].name == "low-priority" assert router.providers[1].name == "high-priority" + + +@pytest.mark.unit +class TestContentTypeDetection: + """Test _detect_content_type logic.""" + + def _router(self) -> CascadeRouter: + return CascadeRouter(config_path=Path("/nonexistent")) + + def test_text_only(self): + router = self._router() + msgs = [{"role": "user", "content": "Hello"}] + assert router._detect_content_type(msgs) == ContentType.TEXT + + def test_images_key_triggers_vision(self): + router = self._router() + msgs = [{"role": "user", "content": "Describe this", "images": ["pic.jpg"]}] + assert router._detect_content_type(msgs) == ContentType.VISION + + def test_image_extension_in_content_triggers_vision(self): + router = self._router() + msgs = [{"role": "user", "content": "Look at photo.png please"}] + assert router._detect_content_type(msgs) == ContentType.VISION + + def test_base64_data_uri_triggers_vision(self): + router = self._router() + msgs = [{"role": "user", "content": "data:image/jpeg;base64,/9j/4AA..."}] + assert router._detect_content_type(msgs) == ContentType.VISION + + def test_audio_key_triggers_audio(self): + router = self._router() + msgs = [{"role": "user", "content": "", "audio": b"bytes"}] + assert router._detect_content_type(msgs) == ContentType.AUDIO + + def test_image_and_audio_triggers_multimodal(self): + router = self._router() + msgs = [ + {"role": "user", "content": "check photo.jpg", "audio": b"bytes"}, + ] + assert router._detect_content_type(msgs) == ContentType.MULTIMODAL + + def test_list_content_image_url_type(self): + router = self._router() + msgs = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What?"}, + {"type": "image_url", "image_url": {"url": "http://example.com/a.jpg"}}, + ], + } + ] + assert router._detect_content_type(msgs) == ContentType.VISION + + def test_list_content_audio_type(self): + router = self._router() + msgs = [ + { + "role": "user", + "content": [ + {"type": "audio", "data": "base64..."}, + ], + } + ] + assert router._detect_content_type(msgs) == ContentType.AUDIO + + +@pytest.mark.unit +class TestTransformMessagesForOllama: + """Test _transform_messages_for_ollama.""" + + def _router(self) -> CascadeRouter: + return CascadeRouter(config_path=Path("/nonexistent")) + + def test_plain_text_message(self): + router = self._router() + result = router._transform_messages_for_ollama( + [{"role": "user", "content": "Hello"}] + ) + assert result == [{"role": "user", "content": "Hello"}] + + def test_base64_image_stripped(self): + router = self._router() + msgs = [ + { + "role": "user", + "content": "Describe", + "images": ["data:image/png;base64,abc123"], + } + ] + result = router._transform_messages_for_ollama(msgs) + assert result[0]["images"] == ["abc123"] + + def test_http_url_skipped(self): + router = self._router() + msgs = [ + { + "role": "user", + "content": "Describe", + "images": ["http://example.com/img.jpg"], + } + ] + result = router._transform_messages_for_ollama(msgs) + # URL is skipped — images list should be empty or absent + assert result[0].get("images", []) == [] + + def test_missing_local_file_skipped(self): + router = self._router() + msgs = [ + { + "role": "user", + "content": "Describe", + "images": ["/nonexistent/path/image.png"], + } + ] + result = router._transform_messages_for_ollama(msgs) + assert result[0].get("images", []) == [] + + +@pytest.mark.unit +class TestProviderCapabilityMethods: + """Test Provider.get_model_with_capability and model_has_capability.""" + + def _provider(self) -> Provider: + return Provider( + name="test", + type="ollama", + enabled=True, + priority=1, + models=[ + {"name": "llava:7b", "capabilities": ["vision"]}, + {"name": "llama3.2", "default": True}, + ], + ) + + def test_get_model_with_capability_found(self): + p = self._provider() + assert p.get_model_with_capability("vision") == "llava:7b" + + def test_get_model_with_capability_falls_back_to_default(self): + p = self._provider() + assert p.get_model_with_capability("audio") == "llama3.2" + + def test_model_has_capability_true(self): + p = self._provider() + assert p.model_has_capability("llava:7b", "vision") is True + + def test_model_has_capability_false(self): + p = self._provider() + assert p.model_has_capability("llama3.2", "vision") is False + + def test_model_has_capability_unknown_model(self): + p = self._provider() + assert p.model_has_capability("unknown-model", "vision") is False + + +@pytest.mark.unit +class TestGetFallbackModel: + """Test _get_fallback_model.""" + + def _router_with_provider(self) -> tuple[CascadeRouter, Provider]: + router = CascadeRouter(config_path=Path("/nonexistent")) + provider = Provider( + name="test", + type="ollama", + enabled=True, + priority=1, + models=[ + {"name": "llava:7b", "capabilities": ["vision"]}, + {"name": "llama3.2", "default": True}, + ], + ) + return router, provider + + def test_returns_vision_model(self): + router, provider = self._router_with_provider() + result = router._get_fallback_model(provider, "llama3.2", ContentType.VISION) + assert result == "llava:7b" + + def test_returns_none_if_no_capability(self): + router, provider = self._router_with_provider() + result = router._get_fallback_model(provider, "llama3.2", ContentType.AUDIO) + # No audio model; falls back to default which is same as original + assert result is None or result == "llama3.2" + + def test_text_content_returns_none(self): + router, provider = self._router_with_provider() + result = router._get_fallback_model(provider, "llama3.2", ContentType.TEXT) + assert result is None + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestCascadeTierFiltering: + """Test cascade_tier parameter in complete().""" + + def _make_router(self) -> CascadeRouter: + router = CascadeRouter(config_path=Path("/nonexistent")) + router.providers = [ + Provider( + name="anthropic-primary", + type="anthropic", + enabled=True, + priority=1, + api_key="test-key", + models=[{"name": "claude-sonnet-4-6", "default": True}], + ), + Provider( + name="ollama-local", + type="ollama", + enabled=True, + priority=2, + models=[{"name": "llama3.2", "default": True}], + ), + ] + return router + + async def test_frontier_required_uses_anthropic(self): + router = self._make_router() + with patch("infrastructure.router.cascade._quota_monitor", None): + with patch.object(router, "_call_anthropic") as mock_call: + mock_call.return_value = {"content": "frontier response", "model": "claude-sonnet-4-6"} + result = await router.complete( + messages=[{"role": "user", "content": "hi"}], + cascade_tier="frontier_required", + ) + assert result["provider"] == "anthropic-primary" + mock_call.assert_called_once() + + async def test_frontier_required_no_anthropic_raises(self): + router = CascadeRouter(config_path=Path("/nonexistent")) + router.providers = [ + Provider( + name="ollama-local", + type="ollama", + enabled=True, + priority=1, + models=[{"name": "llama3.2", "default": True}], + ) + ] + with pytest.raises(RuntimeError, match="No Anthropic provider configured"): + await router.complete( + messages=[{"role": "user", "content": "hi"}], + cascade_tier="frontier_required", + ) + + async def test_unknown_tier_raises(self): + router = self._make_router() + with pytest.raises(RuntimeError, match="No providers found for tier"): + await router.complete( + messages=[{"role": "user", "content": "hi"}], + cascade_tier="nonexistent_tier", + ) + + async def test_tier_filter_only_matching_providers(self): + router = CascadeRouter(config_path=Path("/nonexistent")) + router.providers = [ + Provider( + name="local-primary", + type="ollama", + enabled=True, + priority=1, + tier="local", + models=[{"name": "llama3.2", "default": True}], + ), + Provider( + name="cloud-secondary", + type="anthropic", + enabled=True, + priority=2, + tier="cloud", + api_key="key", + models=[{"name": "claude-sonnet-4-6", "default": True}], + ), + ] + with patch.object(router, "_call_ollama") as mock_call: + mock_call.return_value = {"content": "local response", "model": "llama3.2"} + result = await router.complete( + messages=[{"role": "user", "content": "hi"}], + cascade_tier="local", + ) + assert result["provider"] == "local-primary" + mock_call.assert_called_once() + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestGenerateWithImage: + """Test generate_with_image convenience method.""" + + async def test_delegates_to_complete(self): + router = CascadeRouter(config_path=Path("/nonexistent")) + router.providers = [ + Provider( + name="ollama-vision", + type="ollama", + enabled=True, + priority=1, + models=[{"name": "llava:7b", "capabilities": ["vision"], "default": True}], + ) + ] + + with patch.object(router, "_call_ollama") as mock_call: + mock_call.return_value = {"content": "A cat", "model": "llava:7b"} + result = await router.generate_with_image( + prompt="What is this?", + image_path="/tmp/cat.jpg", + model="llava:7b", + ) + + assert result["content"] == "A cat" + assert result["provider"] == "ollama-vision" + # complete() should have been called with images in messages + call_kwargs = mock_call.call_args + messages_passed = call_kwargs.kwargs.get("messages") or call_kwargs[1].get("messages") + assert messages_passed[0]["images"] == ["/tmp/cat.jpg"] + + +@pytest.mark.unit +class TestGetRouterSingleton: + """Test get_router() returns a singleton and creates CascadeRouter.""" + + def test_get_router_returns_cascade_router(self): + import infrastructure.router.cascade as cascade_module + + # Reset singleton to test creation + original = cascade_module.cascade_router + cascade_module.cascade_router = None + try: + router = get_router() + assert isinstance(router, CascadeRouter) + finally: + cascade_module.cascade_router = original + + def test_get_router_returns_same_instance(self): + import infrastructure.router.cascade as cascade_module + + original = cascade_module.cascade_router + cascade_module.cascade_router = None + try: + r1 = get_router() + r2 = get_router() + assert r1 is r2 + finally: + cascade_module.cascade_router = original + + +@pytest.mark.unit +class TestIsProviderAvailable: + """Test _is_provider_available with circuit breaker transitions.""" + + def _router(self) -> CascadeRouter: + return CascadeRouter(config_path=Path("/nonexistent")) + + def test_disabled_provider_not_available(self): + router = self._router() + provider = Provider(name="p", type="ollama", enabled=False, priority=1) + assert router._is_provider_available(provider) is False + + def test_healthy_provider_available(self): + router = self._router() + provider = Provider(name="p", type="ollama", enabled=True, priority=1) + assert router._is_provider_available(provider) is True + + def test_unhealthy_open_circuit_not_available(self): + router = self._router() + provider = Provider( + name="p", + type="ollama", + enabled=True, + priority=1, + status=ProviderStatus.UNHEALTHY, + circuit_state=CircuitState.OPEN, + circuit_opened_at=time.time(), # Just opened — not yet recoverable + ) + assert router._is_provider_available(provider) is False + + def test_unhealthy_after_timeout_transitions_to_half_open(self): + router = self._router() + router.config.circuit_breaker_recovery_timeout = 0 + provider = Provider( + name="p", + type="ollama", + enabled=True, + priority=1, + status=ProviderStatus.UNHEALTHY, + circuit_state=CircuitState.OPEN, + circuit_opened_at=time.time() - 10, # Long ago + ) + result = router._is_provider_available(provider) + assert result is True + assert provider.circuit_state == CircuitState.HALF_OPEN