"""Tests for Cascade LLM Router.""" import time from pathlib import Path from unittest.mock import AsyncMock, patch import pytest import yaml from infrastructure.router.cascade import ( CascadeRouter, CircuitState, Provider, ProviderMetrics, ProviderStatus, RouterConfig, ) class TestProviderMetrics: """Test provider metrics tracking.""" def test_empty_metrics(self): """Test metrics with no requests.""" metrics = ProviderMetrics() assert metrics.total_requests == 0 assert metrics.avg_latency_ms == 0.0 assert metrics.error_rate == 0.0 def test_avg_latency_calculation(self): """Test average latency calculation.""" metrics = ProviderMetrics( total_requests=4, total_latency_ms=1000.0, # 4 requests, 1000ms total ) assert metrics.avg_latency_ms == 250.0 def test_error_rate_calculation(self): """Test error rate calculation.""" metrics = ProviderMetrics( total_requests=10, successful_requests=7, failed_requests=3, ) assert metrics.error_rate == 0.3 class TestProvider: """Test Provider dataclass.""" def test_get_default_model(self): """Test getting default model.""" provider = Provider( name="test", type="ollama", enabled=True, priority=1, models=[ {"name": "llama3", "default": True}, {"name": "mistral"}, ], ) assert provider.get_default_model() == "llama3" def test_get_default_model_no_default(self): """Test getting first model when no default set.""" provider = Provider( name="test", type="ollama", enabled=True, priority=1, models=[ {"name": "llama3"}, {"name": "mistral"}, ], ) assert provider.get_default_model() == "llama3" def test_get_default_model_empty(self): """Test with no models.""" provider = Provider( name="test", type="ollama", enabled=True, priority=1, models=[], ) assert provider.get_default_model() is None class TestRouterConfig: """Test router configuration.""" def test_default_config(self): """Test default configuration values.""" config = RouterConfig() assert config.timeout_seconds == 30 assert config.max_retries_per_provider == 2 assert config.retry_delay_seconds == 1 assert config.circuit_breaker_failure_threshold == 5 class TestCascadeRouterInit: """Test CascadeRouter initialization.""" def test_init_without_config(self, tmp_path): """Test initialization without config file.""" router = CascadeRouter(config_path=tmp_path / "nonexistent.yaml") assert len(router.providers) == 0 assert router.config.timeout_seconds == 30 def test_init_with_config(self, tmp_path): """Test initialization with config file.""" config = { "cascade": { "timeout_seconds": 60, "max_retries_per_provider": 3, }, "providers": [ { "name": "test-ollama", "type": "ollama", "enabled": False, # Disabled to avoid availability check "priority": 1, "url": "http://localhost:11434", } ], } config_path = tmp_path / "providers.yaml" config_path.write_text(yaml.dump(config)) router = CascadeRouter(config_path=config_path) assert router.config.timeout_seconds == 60 assert router.config.max_retries_per_provider == 3 assert len(router.providers) == 0 # Provider is disabled def test_env_var_expansion(self, tmp_path, monkeypatch): """Test environment variable expansion in config.""" monkeypatch.setenv("TEST_API_KEY", "secret123") config = { "cascade": {}, "providers": [ { "name": "test-openai", "type": "openai", "enabled": True, "priority": 1, "api_key": "${TEST_API_KEY}", } ], } config_path = tmp_path / "providers.yaml" config_path.write_text(yaml.dump(config)) router = CascadeRouter(config_path=config_path) assert len(router.providers) == 1 assert router.providers[0].api_key == "secret123" class TestCascadeRouterMetrics: """Test metrics tracking.""" def test_record_success(self): """Test recording successful request.""" provider = Provider(name="test", type="ollama", enabled=True, priority=1) router = CascadeRouter(config_path=Path("/nonexistent")) router._record_success(provider, 150.0) assert provider.metrics.total_requests == 1 assert provider.metrics.successful_requests == 1 assert provider.metrics.total_latency_ms == 150.0 assert provider.metrics.consecutive_failures == 0 def test_record_failure(self): """Test recording failed request.""" provider = Provider(name="test", type="ollama", enabled=True, priority=1) router = CascadeRouter(config_path=Path("/nonexistent")) router._record_failure(provider) assert provider.metrics.total_requests == 1 assert provider.metrics.failed_requests == 1 assert provider.metrics.consecutive_failures == 1 def test_circuit_breaker_opens(self): """Test circuit breaker opens after failures.""" provider = Provider(name="test", type="ollama", enabled=True, priority=1) router = CascadeRouter(config_path=Path("/nonexistent")) router.config.circuit_breaker_failure_threshold = 3 # Record 3 failures for _ in range(3): router._record_failure(provider) assert provider.circuit_state == CircuitState.OPEN assert provider.status == ProviderStatus.UNHEALTHY assert provider.circuit_opened_at is not None def test_circuit_breaker_can_close(self): """Test circuit breaker can transition to closed.""" provider = Provider(name="test", type="ollama", enabled=True, priority=1) router = CascadeRouter(config_path=Path("/nonexistent")) router.config.circuit_breaker_failure_threshold = 3 router.config.circuit_breaker_recovery_timeout = 0.1 # Open the circuit for _ in range(3): router._record_failure(provider) assert provider.circuit_state == CircuitState.OPEN # Wait for recovery timeout (reduced for faster tests) import time time.sleep(0.2) # Check if can close assert router._can_close_circuit(provider) is True def test_half_open_to_closed(self): """Test circuit breaker closes after successful test calls.""" provider = Provider(name="test", type="ollama", enabled=True, priority=1) router = CascadeRouter(config_path=Path("/nonexistent")) router.config.circuit_breaker_half_open_max_calls = 2 # Manually set to half-open provider.circuit_state = CircuitState.HALF_OPEN provider.half_open_calls = 0 # Record successful calls router._record_success(provider, 100.0) assert provider.circuit_state == CircuitState.HALF_OPEN # Still half-open router._record_success(provider, 100.0) assert provider.circuit_state == CircuitState.CLOSED # Now closed assert provider.status == ProviderStatus.HEALTHY class TestCascadeRouterGetMetrics: """Test get_metrics method.""" def test_get_metrics_empty(self): """Test getting metrics with no providers.""" router = CascadeRouter(config_path=Path("/nonexistent")) metrics = router.get_metrics() assert "providers" in metrics assert len(metrics["providers"]) == 0 def test_get_metrics_with_providers(self): """Test getting metrics with providers.""" router = CascadeRouter(config_path=Path("/nonexistent")) # Add a test provider provider = Provider( name="test", type="ollama", enabled=True, priority=1, ) provider.metrics.total_requests = 10 provider.metrics.successful_requests = 8 provider.metrics.failed_requests = 2 provider.metrics.total_latency_ms = 2000.0 router.providers = [provider] metrics = router.get_metrics() assert len(metrics["providers"]) == 1 p_metrics = metrics["providers"][0] assert p_metrics["name"] == "test" assert p_metrics["metrics"]["total_requests"] == 10 assert p_metrics["metrics"]["error_rate"] == 0.2 assert p_metrics["metrics"]["avg_latency_ms"] == 200.0 class TestCascadeRouterGetStatus: """Test get_status method.""" def test_get_status(self): """Test getting router status.""" router = CascadeRouter(config_path=Path("/nonexistent")) provider = Provider( name="test", type="ollama", enabled=True, priority=1, models=[{"name": "llama3", "default": True}], ) router.providers = [provider] status = router.get_status() assert status["total_providers"] == 1 assert status["healthy_providers"] == 1 assert status["degraded_providers"] == 0 assert status["unhealthy_providers"] == 0 assert len(status["providers"]) == 1 @pytest.mark.asyncio class TestCascadeRouterComplete: """Test complete method with failover.""" async def test_complete_with_ollama(self): """Test successful completion with Ollama.""" router = CascadeRouter(config_path=Path("/nonexistent")) provider = Provider( name="ollama-local", type="ollama", enabled=True, priority=1, url="http://localhost:11434", models=[{"name": "llama3.2", "default": True}], ) router.providers = [provider] # Mock the Ollama call with patch.object(router, "_call_ollama") as mock_call: mock_call.return_value = AsyncMock()() mock_call.return_value = { "content": "Hello, world!", "model": "llama3.2", } result = await router.complete( messages=[{"role": "user", "content": "Hi"}], ) assert result["content"] == "Hello, world!" assert result["provider"] == "ollama-local" assert result["model"] == "llama3.2" async def test_failover_to_second_provider(self): """Test failover when first provider fails.""" router = CascadeRouter(config_path=Path("/nonexistent")) provider1 = Provider( name="ollama-failing", type="ollama", enabled=True, priority=1, url="http://localhost:11434", models=[{"name": "llama3.2", "default": True}], ) provider2 = Provider( name="ollama-backup", type="ollama", enabled=True, priority=2, url="http://backup:11434", models=[{"name": "llama3.2", "default": True}], ) router.providers = [provider1, provider2] # First provider fails, second succeeds call_count = [0] async def side_effect(*args, **kwargs): call_count[0] += 1 # First 2 retries for provider1 fail, then provider2 succeeds if call_count[0] <= router.config.max_retries_per_provider: raise RuntimeError("Connection failed") return {"content": "Backup response", "model": "llama3.2"} with patch.object(router, "_call_ollama") as mock_call: mock_call.side_effect = side_effect result = await router.complete( messages=[{"role": "user", "content": "Hi"}], ) assert result["content"] == "Backup response" assert result["provider"] == "ollama-backup" async def test_all_providers_fail(self): """Test error when all providers fail.""" router = CascadeRouter(config_path=Path("/nonexistent")) provider = Provider( name="failing", type="ollama", enabled=True, priority=1, models=[{"name": "llama3.2", "default": True}], ) router.providers = [provider] with patch.object(router, "_call_ollama") as mock_call: mock_call.side_effect = RuntimeError("Always fails") with pytest.raises(RuntimeError) as exc_info: await router.complete(messages=[{"role": "user", "content": "Hi"}]) assert "All providers failed" in str(exc_info.value) async def test_skips_unhealthy_provider(self): """Test that unhealthy providers are skipped.""" router = CascadeRouter(config_path=Path("/nonexistent")) provider1 = Provider( name="unhealthy", type="ollama", enabled=True, priority=1, status=ProviderStatus.UNHEALTHY, circuit_state=CircuitState.OPEN, circuit_opened_at=time.time(), # Just opened models=[{"name": "llama3.2", "default": True}], ) provider2 = Provider( name="healthy", type="ollama", enabled=True, priority=2, models=[{"name": "llama3.2", "default": True}], ) router.providers = [provider1, provider2] with patch.object(router, "_call_ollama") as mock_call: mock_call.return_value = {"content": "Success", "model": "llama3.2"} result = await router.complete( messages=[{"role": "user", "content": "Hi"}], ) # Should use the healthy provider assert result["provider"] == "healthy" class TestProviderAvailabilityCheck: """Test provider availability checking.""" def test_check_ollama_without_requests(self): """Test Ollama returns True when requests not available (fallback).""" router = CascadeRouter(config_path=Path("/nonexistent")) provider = Provider( name="ollama", type="ollama", enabled=True, priority=1, url="http://localhost:11434", ) # When requests is None, assume available import infrastructure.router.cascade as cascade_module old_requests = cascade_module.requests cascade_module.requests = None try: assert router._check_provider_available(provider) is True finally: cascade_module.requests = old_requests def test_check_openai_with_key(self): """Test OpenAI with API key.""" router = CascadeRouter(config_path=Path("/nonexistent")) provider = Provider( name="openai", type="openai", enabled=True, priority=1, api_key="sk-test123", ) assert router._check_provider_available(provider) is True def test_check_openai_without_key(self): """Test OpenAI without API key.""" router = CascadeRouter(config_path=Path("/nonexistent")) provider = Provider( name="openai", type="openai", enabled=True, priority=1, api_key=None, ) assert router._check_provider_available(provider) is False def test_check_vllm_mlx_without_requests(self): """Test vllm-mlx returns True when requests not available (fallback).""" router = CascadeRouter(config_path=Path("/nonexistent")) provider = Provider( name="vllm-mlx-local", type="vllm_mlx", enabled=True, priority=2, base_url="http://localhost:8000/v1", ) import infrastructure.router.cascade as cascade_module old_requests = cascade_module.requests cascade_module.requests = None try: assert router._check_provider_available(provider) is True finally: cascade_module.requests = old_requests def test_check_vllm_mlx_server_healthy(self): """Test vllm-mlx when health check succeeds.""" from unittest.mock import MagicMock, patch router = CascadeRouter(config_path=Path("/nonexistent")) provider = Provider( name="vllm-mlx-local", type="vllm_mlx", enabled=True, priority=2, base_url="http://localhost:8000/v1", ) mock_response = MagicMock() mock_response.status_code = 200 with patch("infrastructure.router.cascade.requests") as mock_requests: mock_requests.get.return_value = mock_response result = router._check_provider_available(provider) assert result is True mock_requests.get.assert_called_once_with("http://localhost:8000/health", timeout=5) def test_check_vllm_mlx_server_down(self): """Test vllm-mlx when server is not running.""" from unittest.mock import patch router = CascadeRouter(config_path=Path("/nonexistent")) provider = Provider( name="vllm-mlx-local", type="vllm_mlx", enabled=True, priority=2, base_url="http://localhost:8000/v1", ) with patch("infrastructure.router.cascade.requests") as mock_requests: mock_requests.get.side_effect = ConnectionRefusedError("Connection refused") result = router._check_provider_available(provider) assert result is False def test_check_vllm_mlx_default_url(self): """Test vllm-mlx uses default localhost:8000 when no URL configured.""" from unittest.mock import MagicMock, patch router = CascadeRouter(config_path=Path("/nonexistent")) provider = Provider( name="vllm-mlx-local", type="vllm_mlx", enabled=True, priority=2, ) mock_response = MagicMock() mock_response.status_code = 200 with patch("infrastructure.router.cascade.requests") as mock_requests: mock_requests.get.return_value = mock_response router._check_provider_available(provider) mock_requests.get.assert_called_once_with("http://localhost:8000/health", timeout=5) @pytest.mark.asyncio class TestVllmMlxProvider: """Test vllm-mlx provider integration.""" async def test_complete_with_vllm_mlx(self): """Test successful completion via vllm-mlx.""" router = CascadeRouter(config_path=Path("/nonexistent")) provider = Provider( name="vllm-mlx-local", type="vllm_mlx", enabled=True, priority=2, base_url="http://localhost:8000/v1", models=[{"name": "Qwen/Qwen2.5-14B-Instruct-MLX", "default": True}], ) router.providers = [provider] with patch.object(router, "_call_vllm_mlx") as mock_call: mock_call.return_value = { "content": "MLX response", "model": "Qwen/Qwen2.5-14B-Instruct-MLX", } result = await router.complete( messages=[{"role": "user", "content": "Hi"}], ) assert result["content"] == "MLX response" assert result["provider"] == "vllm-mlx-local" assert result["model"] == "Qwen/Qwen2.5-14B-Instruct-MLX" async def test_vllm_mlx_base_url_normalization(self): """Test _call_vllm_mlx appends /v1 when missing.""" from unittest.mock import AsyncMock, MagicMock, patch router = CascadeRouter(config_path=Path("/nonexistent")) provider = Provider( name="vllm-mlx-local", type="vllm_mlx", enabled=True, priority=2, base_url="http://localhost:8000", # No /v1 models=[{"name": "qwen-mlx", "default": True}], ) mock_choice = MagicMock() mock_choice.message.content = "hello" mock_response = MagicMock() mock_response.choices = [mock_choice] mock_response.model = "qwen-mlx" async def fake_create(**kwargs): return mock_response with patch("openai.AsyncOpenAI") as mock_openai_cls: mock_client = MagicMock() mock_client.chat.completions.create = AsyncMock(side_effect=fake_create) mock_openai_cls.return_value = mock_client await router._call_vllm_mlx( provider=provider, messages=[{"role": "user", "content": "hi"}], model="qwen-mlx", temperature=0.7, max_tokens=None, ) call_kwargs = mock_openai_cls.call_args base_url_used = call_kwargs.kwargs.get("base_url") or call_kwargs[1].get("base_url") assert base_url_used.endswith("/v1") async def test_vllm_mlx_is_local_not_cloud(self): """Confirm vllm_mlx is not subject to metabolic protocol cloud skip.""" router = CascadeRouter(config_path=Path("/nonexistent")) provider = Provider( name="vllm-mlx-local", type="vllm_mlx", enabled=True, priority=2, base_url="http://localhost:8000/v1", models=[{"name": "qwen-mlx", "default": True}], ) router.providers = [provider] # Quota monitor downshifts to local (ACTIVE tier) — vllm_mlx should still be tried with patch("infrastructure.router.cascade._quota_monitor") as mock_qm: mock_qm.select_model.return_value = "qwen3:14b" mock_qm.check.return_value = None with patch.object(router, "_call_vllm_mlx") as mock_call: mock_call.return_value = { "content": "Local MLX response", "model": "qwen-mlx", } result = await router.complete( messages=[{"role": "user", "content": "hi"}], ) assert result["content"] == "Local MLX response" class TestMetabolicProtocol: """Test metabolic protocol: cloud providers skip when quota is ACTIVE/RESTING.""" def _make_anthropic_provider(self) -> "Provider": return Provider( name="anthropic-primary", type="anthropic", enabled=True, priority=1, api_key="test-key", models=[{"name": "claude-sonnet-4-6", "default": True}], ) async def test_cloud_provider_allowed_in_burst_tier(self): """BURST tier (quota healthy): cloud provider is tried.""" router = CascadeRouter(config_path=Path("/nonexistent")) router.providers = [self._make_anthropic_provider()] with patch("infrastructure.router.cascade._quota_monitor") as mock_qm: # select_model returns cloud model → BURST tier mock_qm.select_model.return_value = "claude-sonnet-4-6" mock_qm.check.return_value = None with patch.object(router, "_call_anthropic") as mock_call: mock_call.return_value = {"content": "Cloud response", "model": "claude-sonnet-4-6"} result = await router.complete( messages=[{"role": "user", "content": "hard question"}], ) mock_call.assert_called_once() assert result["content"] == "Cloud response" async def test_cloud_provider_skipped_in_active_tier(self): """ACTIVE tier (5-hour >= 50%): cloud provider is skipped.""" router = CascadeRouter(config_path=Path("/nonexistent")) router.providers = [self._make_anthropic_provider()] with patch("infrastructure.router.cascade._quota_monitor") as mock_qm: # select_model returns local 14B → ACTIVE tier mock_qm.select_model.return_value = "qwen3:14b" mock_qm.check.return_value = None with patch.object(router, "_call_anthropic") as mock_call: with pytest.raises(RuntimeError, match="All providers failed"): await router.complete( messages=[{"role": "user", "content": "question"}], ) mock_call.assert_not_called() async def test_cloud_provider_skipped_in_resting_tier(self): """RESTING tier (7-day >= 80%): cloud provider is skipped.""" router = CascadeRouter(config_path=Path("/nonexistent")) router.providers = [self._make_anthropic_provider()] with patch("infrastructure.router.cascade._quota_monitor") as mock_qm: # select_model returns local 8B → RESTING tier mock_qm.select_model.return_value = "qwen3:8b" mock_qm.check.return_value = None with patch.object(router, "_call_anthropic") as mock_call: with pytest.raises(RuntimeError, match="All providers failed"): await router.complete( messages=[{"role": "user", "content": "simple question"}], ) mock_call.assert_not_called() async def test_local_provider_always_tried_regardless_of_quota(self): """Local (ollama/vllm_mlx) providers bypass the metabolic protocol.""" router = CascadeRouter(config_path=Path("/nonexistent")) provider = Provider( name="ollama-local", type="ollama", enabled=True, priority=1, url="http://localhost:11434", models=[{"name": "qwen3:14b", "default": True}], ) router.providers = [provider] with patch("infrastructure.router.cascade._quota_monitor") as mock_qm: mock_qm.select_model.return_value = "qwen3:8b" # RESTING tier with patch.object(router, "_call_ollama") as mock_call: mock_call.return_value = {"content": "Local response", "model": "qwen3:14b"} result = await router.complete( messages=[{"role": "user", "content": "hi"}], ) mock_call.assert_called_once() assert result["content"] == "Local response" async def test_no_quota_monitor_allows_cloud(self): """When quota monitor is None (unavailable), cloud providers are allowed.""" router = CascadeRouter(config_path=Path("/nonexistent")) router.providers = [self._make_anthropic_provider()] with patch("infrastructure.router.cascade._quota_monitor", None): with patch.object(router, "_call_anthropic") as mock_call: mock_call.return_value = {"content": "Cloud response", "model": "claude-sonnet-4-6"} result = await router.complete( messages=[{"role": "user", "content": "question"}], ) mock_call.assert_called_once() assert result["content"] == "Cloud response" class TestCascadeRouterReload: """Test hot-reload of providers.yaml.""" def test_reload_preserves_metrics(self, tmp_path): """Test that reload preserves metrics for existing providers.""" config = { "providers": [ { "name": "test-openai", "type": "openai", "enabled": True, "priority": 1, "api_key": "sk-test", } ], } config_path = tmp_path / "providers.yaml" config_path.write_text(yaml.dump(config)) router = CascadeRouter(config_path=config_path) assert len(router.providers) == 1 # Simulate some traffic router._record_success(router.providers[0], 150.0) router._record_success(router.providers[0], 250.0) assert router.providers[0].metrics.total_requests == 2 # Reload result = router.reload_config() assert result["total_providers"] == 1 assert result["preserved"] == 1 assert result["added"] == [] assert result["removed"] == [] # Metrics survived assert router.providers[0].metrics.total_requests == 2 assert router.providers[0].metrics.total_latency_ms == 400.0 def test_reload_preserves_circuit_breaker(self, tmp_path): """Test that reload preserves circuit breaker state.""" config = { "cascade": {"circuit_breaker": {"failure_threshold": 2}}, "providers": [ { "name": "test-openai", "type": "openai", "enabled": True, "priority": 1, "api_key": "sk-test", } ], } config_path = tmp_path / "providers.yaml" config_path.write_text(yaml.dump(config)) router = CascadeRouter(config_path=config_path) # Open circuit breaker for _ in range(2): router._record_failure(router.providers[0]) assert router.providers[0].circuit_state == CircuitState.OPEN # Reload router.reload_config() # Circuit breaker state preserved assert router.providers[0].circuit_state == CircuitState.OPEN assert router.providers[0].status == ProviderStatus.UNHEALTHY def test_reload_detects_added_provider(self, tmp_path): """Test that reload detects newly added providers.""" config = { "providers": [ { "name": "openai-1", "type": "openai", "enabled": True, "priority": 1, "api_key": "sk-test", } ], } config_path = tmp_path / "providers.yaml" config_path.write_text(yaml.dump(config)) router = CascadeRouter(config_path=config_path) assert len(router.providers) == 1 # Add a second provider to config config["providers"].append( { "name": "anthropic-1", "type": "anthropic", "enabled": True, "priority": 2, "api_key": "sk-ant-test", } ) config_path.write_text(yaml.dump(config)) result = router.reload_config() assert result["total_providers"] == 2 assert result["preserved"] == 1 assert result["added"] == ["anthropic-1"] assert result["removed"] == [] def test_reload_detects_removed_provider(self, tmp_path): """Test that reload detects removed providers.""" config = { "providers": [ { "name": "openai-1", "type": "openai", "enabled": True, "priority": 1, "api_key": "sk-test", }, { "name": "anthropic-1", "type": "anthropic", "enabled": True, "priority": 2, "api_key": "sk-ant-test", }, ], } config_path = tmp_path / "providers.yaml" config_path.write_text(yaml.dump(config)) router = CascadeRouter(config_path=config_path) assert len(router.providers) == 2 # Remove anthropic config["providers"] = [config["providers"][0]] config_path.write_text(yaml.dump(config)) result = router.reload_config() assert result["total_providers"] == 1 assert result["preserved"] == 1 assert result["removed"] == ["anthropic-1"] def test_reload_re_sorts_by_priority(self, tmp_path): """Test that providers are re-sorted by priority after reload.""" config = { "providers": [ { "name": "low-priority", "type": "openai", "enabled": True, "priority": 10, "api_key": "sk-test", }, { "name": "high-priority", "type": "openai", "enabled": True, "priority": 1, "api_key": "sk-test2", }, ], } config_path = tmp_path / "providers.yaml" config_path.write_text(yaml.dump(config)) router = CascadeRouter(config_path=config_path) assert router.providers[0].name == "high-priority" # Swap priorities config["providers"][0]["priority"] = 1 config["providers"][1]["priority"] = 10 config_path.write_text(yaml.dump(config)) router.reload_config() assert router.providers[0].name == "low-priority" assert router.providers[1].name == "high-priority"