diff --git a/cli.py b/cli.py index b09a17355..4d8146bb6 100644 --- a/cli.py +++ b/cli.py @@ -1182,9 +1182,13 @@ class HermesCLI: self._provider_require_params = pr.get("require_parameters", False) self._provider_data_collection = pr.get("data_collection") - # Fallback model config β€” tried when primary provider fails after retries - fb = CLI_CONFIG.get("fallback_model") or {} - self._fallback_model = fb if fb.get("provider") and fb.get("model") else None + # Fallback provider chain β€” tried in order when primary fails after retries. + # Supports new list format (fallback_providers) and legacy single-dict (fallback_model). + fb = CLI_CONFIG.get("fallback_providers") or CLI_CONFIG.get("fallback_model") or [] + # Normalize legacy single-dict to a one-element list + if isinstance(fb, dict): + fb = [fb] if fb.get("provider") and fb.get("model") else [] + self._fallback_model = fb # Optional cheap-vs-strong routing for simple turns self._smart_model_routing = CLI_CONFIG.get("smart_model_routing", {}) or {} diff --git a/gateway/run.py b/gateway/run.py index b335e42af..07eaa84d1 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -919,11 +919,12 @@ class GatewayRunner: return {} @staticmethod - def _load_fallback_model() -> dict | None: - """Load fallback model config from config.yaml. + def _load_fallback_model() -> list | dict | None: + """Load fallback provider chain from config.yaml. - Returns a dict with 'provider' and 'model' keys, or None if - not configured / both fields empty. + Returns a list of provider dicts (``fallback_providers``), a single + dict (legacy ``fallback_model``), or None if not configured. + AIAgent.__init__ normalizes both formats into a chain. """ try: import yaml as _y @@ -931,8 +932,8 @@ class GatewayRunner: if cfg_path.exists(): with open(cfg_path, encoding="utf-8") as _f: cfg = _y.safe_load(_f) or {} - fb = cfg.get("fallback_model", {}) or {} - if fb.get("provider") and fb.get("model"): + fb = cfg.get("fallback_providers") or cfg.get("fallback_model") or None + if fb: return fb except Exception: pass diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 7486f34b9..3304a187e 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -135,6 +135,7 @@ def ensure_hermes_home(): DEFAULT_CONFIG = { "model": "anthropic/claude-opus-4.6", + "fallback_providers": [], "toolsets": ["hermes-cli"], "agent": { "max_turns": 90, diff --git a/run_agent.py b/run_agent.py index b3b8a1679..8674149ea 100644 --- a/run_agent.py +++ b/run_agent.py @@ -896,16 +896,30 @@ class AIAgent: except Exception as e: raise RuntimeError(f"Failed to initialize OpenAI client: {e}") - # Provider fallback β€” a single backup model/provider tried when the - # primary is exhausted (rate-limit, overload, connection failure). - # Config shape: {"provider": "openrouter", "model": "anthropic/claude-sonnet-4"} - self._fallback_model = fallback_model if isinstance(fallback_model, dict) else None + # Provider fallback chain β€” ordered list of backup providers tried + # when the primary is exhausted (rate-limit, overload, connection + # failure). Supports both legacy single-dict ``fallback_model`` and + # new list ``fallback_providers`` format. + if isinstance(fallback_model, list): + self._fallback_chain = [ + f for f in fallback_model + if isinstance(f, dict) and f.get("provider") and f.get("model") + ] + elif isinstance(fallback_model, dict) and fallback_model.get("provider") and fallback_model.get("model"): + self._fallback_chain = [fallback_model] + else: + self._fallback_chain = [] + self._fallback_index = 0 self._fallback_activated = False - if self._fallback_model: - fb_p = self._fallback_model.get("provider", "") - fb_m = self._fallback_model.get("model", "") - if fb_p and fb_m and not self.quiet_mode: - print(f"πŸ”„ Fallback model: {fb_m} ({fb_p})") + # Legacy attribute kept for backward compat (tests, external callers) + self._fallback_model = self._fallback_chain[0] if self._fallback_chain else None + if self._fallback_chain and not self.quiet_mode: + if len(self._fallback_chain) == 1: + fb = self._fallback_chain[0] + print(f"πŸ”„ Fallback model: {fb['model']} ({fb['provider']})") + else: + print(f"πŸ”„ Fallback chain ({len(self._fallback_chain)} providers): " + + " β†’ ".join(f"{f['model']} ({f['provider']})" for f in self._fallback_chain)) # Get available tools with filtering self.tools = get_tool_definitions( @@ -4318,25 +4332,26 @@ class AIAgent: # ── Provider fallback ────────────────────────────────────────────────── def _try_activate_fallback(self) -> bool: - """Switch to the configured fallback model/provider. + """Switch to the next fallback model/provider in the chain. - Called when the primary model is failing after retries. Swaps the + Called when the current model is failing after retries. Swaps the OpenAI client, model slug, and provider in-place so the retry loop - can continue with the new backend. One-shot: returns False if - already activated or not configured. + can continue with the new backend. Advances through the chain on + each call; returns False when exhausted. Uses the centralized provider router (resolve_provider_client) for auth resolution and client construction β€” no duplicated providerβ†’key mappings. """ - if self._fallback_activated or not self._fallback_model: + if self._fallback_index >= len(self._fallback_chain): return False - fb = self._fallback_model + fb = self._fallback_chain[self._fallback_index] + self._fallback_index += 1 fb_provider = (fb.get("provider") or "").strip().lower() fb_model = (fb.get("model") or "").strip() if not fb_provider or not fb_model: - return False + return self._try_activate_fallback() # skip invalid, try next # Use centralized router for client construction. # raw_codex=True because the main agent needs direct responses.stream() @@ -4349,7 +4364,7 @@ class AIAgent: logging.warning( "Fallback to %s failed: provider not configured", fb_provider) - return False + return self._try_activate_fallback() # try next in chain # Determine api_mode from provider / base URL fb_api_mode = "chat_completions" @@ -4424,8 +4439,8 @@ class AIAgent: ) return True except Exception as e: - logging.error("Failed to activate fallback model: %s", e) - return False + logging.error("Failed to activate fallback %s: %s", fb_model, e) + return self._try_activate_fallback() # try next in chain # ── End provider fallback ────────────────────────────────────────────── @@ -6528,9 +6543,9 @@ class AIAgent: # Eager fallback: empty/malformed responses are a common # rate-limit symptom. Switch to fallback immediately # rather than retrying with extended backoff. - if not self._fallback_activated: + if self._fallback_index < len(self._fallback_chain): self._emit_status("⚠️ Empty/malformed response β€” switching to fallback...") - if not self._fallback_activated and self._try_activate_fallback(): + if self._try_activate_fallback(): retry_count = 0 continue @@ -6993,7 +7008,7 @@ class AIAgent: or "usage limit" in error_msg or "quota" in error_msg ) - if is_rate_limited and not self._fallback_activated: + if is_rate_limited and self._fallback_index < len(self._fallback_chain): self._emit_status("⚠️ Rate limited β€” switching to fallback provider...") if self._try_activate_fallback(): retry_count = 0 diff --git a/tests/test_compressor_fallback_update.py b/tests/test_compressor_fallback_update.py index 570238b02..064fd9b67 100644 --- a/tests/test_compressor_fallback_update.py +++ b/tests/test_compressor_fallback_update.py @@ -25,6 +25,8 @@ def _make_agent_with_compressor() -> AIAgent: "provider": "openai", "model": "gpt-4o", } + agent._fallback_chain = [agent._fallback_model] + agent._fallback_index = 0 # Context compressor with primary model values compressor = ContextCompressor( diff --git a/tests/test_provider_fallback.py b/tests/test_provider_fallback.py new file mode 100644 index 000000000..2bb210955 --- /dev/null +++ b/tests/test_provider_fallback.py @@ -0,0 +1,156 @@ +"""Tests for ordered provider fallback chain (salvage of PR #1761). + +Extends the single-fallback tests in test_fallback_model.py to cover +the new list-based ``fallback_providers`` config format and chain +advancement through multiple providers. +""" + +from unittest.mock import MagicMock, patch + +from run_agent import AIAgent + + +def _make_agent(fallback_model=None): + """Create a minimal AIAgent with optional fallback config.""" + with ( + patch("run_agent.get_tool_definitions", return_value=[]), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + ): + agent = AIAgent( + api_key="test-key", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + fallback_model=fallback_model, + ) + agent.client = MagicMock() + return agent + + +def _mock_client(base_url="https://openrouter.ai/api/v1", api_key="fb-key"): + mock = MagicMock() + mock.base_url = base_url + mock.api_key = api_key + return mock + + +# ── Chain initialisation ────────────────────────────────────────────────── + + +class TestFallbackChainInit: + def test_no_fallback(self): + agent = _make_agent(fallback_model=None) + assert agent._fallback_chain == [] + assert agent._fallback_index == 0 + assert agent._fallback_model is None + + def test_single_dict_backwards_compat(self): + fb = {"provider": "openai", "model": "gpt-4o"} + agent = _make_agent(fallback_model=fb) + assert agent._fallback_chain == [fb] + assert agent._fallback_model == fb + + def test_list_of_providers(self): + fbs = [ + {"provider": "openai", "model": "gpt-4o"}, + {"provider": "zai", "model": "glm-4.7"}, + ] + agent = _make_agent(fallback_model=fbs) + assert len(agent._fallback_chain) == 2 + assert agent._fallback_model == fbs[0] + + def test_invalid_entries_filtered(self): + fbs = [ + {"provider": "openai", "model": "gpt-4o"}, + {"provider": "", "model": "glm-4.7"}, + {"provider": "zai"}, + "not-a-dict", + ] + agent = _make_agent(fallback_model=fbs) + assert len(agent._fallback_chain) == 1 + assert agent._fallback_chain[0]["provider"] == "openai" + + def test_empty_list(self): + agent = _make_agent(fallback_model=[]) + assert agent._fallback_chain == [] + assert agent._fallback_model is None + + def test_invalid_dict_no_provider(self): + agent = _make_agent(fallback_model={"model": "gpt-4o"}) + assert agent._fallback_chain == [] + + +# ── Chain advancement ───────────────────────────────────────────────────── + + +class TestFallbackChainAdvancement: + def test_exhausted_returns_false(self): + agent = _make_agent(fallback_model=None) + assert agent._try_activate_fallback() is False + + def test_advances_index(self): + fbs = [ + {"provider": "openai", "model": "gpt-4o"}, + {"provider": "zai", "model": "glm-4.7"}, + ] + agent = _make_agent(fallback_model=fbs) + with patch("agent.auxiliary_client.resolve_provider_client", + return_value=(_mock_client(), "gpt-4o")): + assert agent._try_activate_fallback() is True + assert agent._fallback_index == 1 + assert agent.model == "gpt-4o" + assert agent._fallback_activated is True + + def test_second_fallback_works(self): + fbs = [ + {"provider": "openai", "model": "gpt-4o"}, + {"provider": "zai", "model": "glm-4.7"}, + ] + agent = _make_agent(fallback_model=fbs) + with patch("agent.auxiliary_client.resolve_provider_client", + return_value=(_mock_client(), "resolved")): + assert agent._try_activate_fallback() is True + assert agent.model == "gpt-4o" + assert agent._try_activate_fallback() is True + assert agent.model == "glm-4.7" + assert agent._fallback_index == 2 + + def test_all_exhausted_returns_false(self): + fbs = [{"provider": "openai", "model": "gpt-4o"}] + agent = _make_agent(fallback_model=fbs) + with patch("agent.auxiliary_client.resolve_provider_client", + return_value=(_mock_client(), "gpt-4o")): + assert agent._try_activate_fallback() is True + assert agent._try_activate_fallback() is False + + def test_skips_unconfigured_provider_to_next(self): + """If resolve_provider_client returns None, skip to next in chain.""" + fbs = [ + {"provider": "broken", "model": "nope"}, + {"provider": "openai", "model": "gpt-4o"}, + ] + agent = _make_agent(fallback_model=fbs) + with patch("agent.auxiliary_client.resolve_provider_client") as mock_rpc: + mock_rpc.side_effect = [ + (None, None), # broken provider + (_mock_client(), "gpt-4o"), # fallback succeeds + ] + assert agent._try_activate_fallback() is True + assert agent.model == "gpt-4o" + assert agent._fallback_index == 2 + + def test_skips_provider_that_raises_to_next(self): + """If resolve_provider_client raises, skip to next in chain.""" + fbs = [ + {"provider": "broken", "model": "nope"}, + {"provider": "openai", "model": "gpt-4o"}, + ] + agent = _make_agent(fallback_model=fbs) + with patch("agent.auxiliary_client.resolve_provider_client") as mock_rpc: + mock_rpc.side_effect = [ + RuntimeError("auth failed"), + (_mock_client(), "gpt-4o"), + ] + assert agent._try_activate_fallback() is True + assert agent.model == "gpt-4o" diff --git a/tests/test_run_agent.py b/tests/test_run_agent.py index cbfe14f68..c42ee29f2 100644 --- a/tests/test_run_agent.py +++ b/tests/test_run_agent.py @@ -2507,6 +2507,8 @@ class TestFallbackAnthropicProvider: def test_fallback_to_anthropic_sets_api_mode(self, agent): agent._fallback_activated = False agent._fallback_model = {"provider": "anthropic", "model": "claude-sonnet-4-20250514"} + agent._fallback_chain = [agent._fallback_model] + agent._fallback_index = 0 mock_client = MagicMock() mock_client.base_url = "https://api.anthropic.com/v1" @@ -2528,6 +2530,8 @@ class TestFallbackAnthropicProvider: def test_fallback_to_anthropic_enables_prompt_caching(self, agent): agent._fallback_activated = False agent._fallback_model = {"provider": "anthropic", "model": "claude-sonnet-4-20250514"} + agent._fallback_chain = [agent._fallback_model] + agent._fallback_index = 0 mock_client = MagicMock() mock_client.base_url = "https://api.anthropic.com/v1" @@ -2545,6 +2549,8 @@ class TestFallbackAnthropicProvider: def test_fallback_to_openrouter_uses_openai_client(self, agent): agent._fallback_activated = False agent._fallback_model = {"provider": "openrouter", "model": "anthropic/claude-sonnet-4"} + agent._fallback_chain = [agent._fallback_model] + agent._fallback_index = 0 mock_client = MagicMock() mock_client.base_url = "https://openrouter.ai/api/v1" @@ -3238,6 +3244,8 @@ class TestFallbackSetsOAuthFlag: def test_fallback_to_anthropic_oauth_sets_flag(self, agent): agent._fallback_activated = False agent._fallback_model = {"provider": "anthropic", "model": "claude-sonnet-4-6"} + agent._fallback_chain = [agent._fallback_model] + agent._fallback_index = 0 mock_client = MagicMock() mock_client.base_url = "https://api.anthropic.com/v1" @@ -3259,6 +3267,8 @@ class TestFallbackSetsOAuthFlag: def test_fallback_to_anthropic_api_key_clears_flag(self, agent): agent._fallback_activated = False agent._fallback_model = {"provider": "anthropic", "model": "claude-sonnet-4-6"} + agent._fallback_chain = [agent._fallback_model] + agent._fallback_index = 0 mock_client = MagicMock() mock_client.base_url = "https://api.anthropic.com/v1"