diff --git a/src/infrastructure/router/api.py b/src/infrastructure/router/api.py index a6157c1..4175a8c 100644 --- a/src/infrastructure/router/api.py +++ b/src/infrastructure/router/api.py @@ -183,6 +183,22 @@ async def run_health_check( } +@router.post("/reload") +async def reload_config( + cascade: Annotated[CascadeRouter, Depends(get_cascade_router)], +) -> dict[str, Any]: + """Hot-reload providers.yaml without restart. + + Preserves circuit breaker state and metrics for existing providers. + """ + try: + result = cascade.reload_config() + return {"status": "ok", **result} + except Exception as exc: + logger.error("Config reload failed: %s", exc) + raise HTTPException(status_code=500, detail=f"Reload failed: {exc}") from exc + + @router.get("/config") async def get_config( cascade: Annotated[CascadeRouter, Depends(get_cascade_router)], diff --git a/src/infrastructure/router/cascade.py b/src/infrastructure/router/cascade.py index 3b5e23a..aacec03 100644 --- a/src/infrastructure/router/cascade.py +++ b/src/infrastructure/router/cascade.py @@ -815,6 +815,64 @@ class CascadeRouter: provider.status = ProviderStatus.HEALTHY logger.info("Circuit breaker CLOSED for %s", provider.name) + def reload_config(self) -> dict: + """Hot-reload providers.yaml, preserving runtime state. + + Re-reads the config file, rebuilds the provider list, and + preserves circuit breaker state and metrics for providers + that still exist after reload. + + Returns: + Summary dict with added/removed/preserved counts. + """ + # Snapshot current runtime state keyed by provider name + old_state: dict[str, tuple[ProviderMetrics, CircuitState, float | None, int, ProviderStatus]] = {} + for p in self.providers: + old_state[p.name] = ( + p.metrics, + p.circuit_state, + p.circuit_opened_at, + p.half_open_calls, + p.status, + ) + + old_names = set(old_state.keys()) + + # Reload from disk + self.providers = [] + self._load_config() + + # Restore preserved state + new_names = {p.name for p in self.providers} + preserved = 0 + for p in self.providers: + if p.name in old_state: + metrics, circuit, opened_at, half_open, status = old_state[p.name] + p.metrics = metrics + p.circuit_state = circuit + p.circuit_opened_at = opened_at + p.half_open_calls = half_open + p.status = status + preserved += 1 + + added = new_names - old_names + removed = old_names - new_names + + logger.info( + "Config reloaded: %d providers (%d preserved, %d added, %d removed)", + len(self.providers), + preserved, + len(added), + len(removed), + ) + + return { + "total_providers": len(self.providers), + "preserved": preserved, + "added": sorted(added), + "removed": sorted(removed), + } + def get_metrics(self) -> dict: """Get metrics for all providers.""" return { diff --git a/tests/infrastructure/test_router_cascade.py b/tests/infrastructure/test_router_cascade.py index 92ab705..d482d7a 100644 --- a/tests/infrastructure/test_router_cascade.py +++ b/tests/infrastructure/test_router_cascade.py @@ -516,3 +516,183 @@ class TestProviderAvailabilityCheck: with patch("importlib.util.find_spec", return_value=None): assert router._check_provider_available(provider) is False + + +class TestCascadeRouterReload: + """Test hot-reload of providers.yaml.""" + + def test_reload_preserves_metrics(self, tmp_path): + """Test that reload preserves metrics for existing providers.""" + config = { + "providers": [ + { + "name": "test-openai", + "type": "openai", + "enabled": True, + "priority": 1, + "api_key": "sk-test", + } + ], + } + config_path = tmp_path / "providers.yaml" + config_path.write_text(yaml.dump(config)) + + router = CascadeRouter(config_path=config_path) + assert len(router.providers) == 1 + + # Simulate some traffic + router._record_success(router.providers[0], 150.0) + router._record_success(router.providers[0], 250.0) + assert router.providers[0].metrics.total_requests == 2 + + # Reload + result = router.reload_config() + + assert result["total_providers"] == 1 + assert result["preserved"] == 1 + assert result["added"] == [] + assert result["removed"] == [] + # Metrics survived + assert router.providers[0].metrics.total_requests == 2 + assert router.providers[0].metrics.total_latency_ms == 400.0 + + def test_reload_preserves_circuit_breaker(self, tmp_path): + """Test that reload preserves circuit breaker state.""" + config = { + "cascade": {"circuit_breaker": {"failure_threshold": 2}}, + "providers": [ + { + "name": "test-openai", + "type": "openai", + "enabled": True, + "priority": 1, + "api_key": "sk-test", + } + ], + } + config_path = tmp_path / "providers.yaml" + config_path.write_text(yaml.dump(config)) + + router = CascadeRouter(config_path=config_path) + + # Open circuit breaker + for _ in range(2): + router._record_failure(router.providers[0]) + assert router.providers[0].circuit_state == CircuitState.OPEN + + # Reload + router.reload_config() + + # Circuit breaker state preserved + assert router.providers[0].circuit_state == CircuitState.OPEN + assert router.providers[0].status == ProviderStatus.UNHEALTHY + + def test_reload_detects_added_provider(self, tmp_path): + """Test that reload detects newly added providers.""" + config = { + "providers": [ + { + "name": "openai-1", + "type": "openai", + "enabled": True, + "priority": 1, + "api_key": "sk-test", + } + ], + } + config_path = tmp_path / "providers.yaml" + config_path.write_text(yaml.dump(config)) + + router = CascadeRouter(config_path=config_path) + assert len(router.providers) == 1 + + # Add a second provider to config + config["providers"].append( + { + "name": "anthropic-1", + "type": "anthropic", + "enabled": True, + "priority": 2, + "api_key": "sk-ant-test", + } + ) + config_path.write_text(yaml.dump(config)) + + result = router.reload_config() + + assert result["total_providers"] == 2 + assert result["preserved"] == 1 + assert result["added"] == ["anthropic-1"] + assert result["removed"] == [] + + def test_reload_detects_removed_provider(self, tmp_path): + """Test that reload detects removed providers.""" + config = { + "providers": [ + { + "name": "openai-1", + "type": "openai", + "enabled": True, + "priority": 1, + "api_key": "sk-test", + }, + { + "name": "anthropic-1", + "type": "anthropic", + "enabled": True, + "priority": 2, + "api_key": "sk-ant-test", + }, + ], + } + config_path = tmp_path / "providers.yaml" + config_path.write_text(yaml.dump(config)) + + router = CascadeRouter(config_path=config_path) + assert len(router.providers) == 2 + + # Remove anthropic + config["providers"] = [config["providers"][0]] + config_path.write_text(yaml.dump(config)) + + result = router.reload_config() + + assert result["total_providers"] == 1 + assert result["preserved"] == 1 + assert result["removed"] == ["anthropic-1"] + + def test_reload_re_sorts_by_priority(self, tmp_path): + """Test that providers are re-sorted by priority after reload.""" + config = { + "providers": [ + { + "name": "low-priority", + "type": "openai", + "enabled": True, + "priority": 10, + "api_key": "sk-test", + }, + { + "name": "high-priority", + "type": "openai", + "enabled": True, + "priority": 1, + "api_key": "sk-test2", + }, + ], + } + config_path = tmp_path / "providers.yaml" + config_path.write_text(yaml.dump(config)) + + router = CascadeRouter(config_path=config_path) + assert router.providers[0].name == "high-priority" + + # Swap priorities + config["providers"][0]["priority"] = 1 + config["providers"][1]["priority"] = 10 + config_path.write_text(yaml.dump(config)) + + router.reload_config() + + assert router.providers[0].name == "low-priority" + assert router.providers[1].name == "high-priority"