Extends the single fallback_model mechanism into an ordered chain.
When the primary model fails, Hermes tries each fallback provider in
sequence until one succeeds or the chain is exhausted.
Config format (new):
fallback_providers:
- provider: openrouter
model: anthropic/claude-sonnet-4
- provider: openai
model: gpt-4o
Legacy single-dict fallback_model format still works unchanged.
Key fix vs original PR: the call sites in the retry loop now use
_fallback_index < len(_fallback_chain) instead of the old one-shot
_fallback_activated guard, so the chain actually advances through
all configured providers.
Changes:
- run_agent.py: _fallback_chain list + _fallback_index replaces
one-shot _fallback_model; _try_activate_fallback() advances
through chain; failed provider resolution skips to next entry;
call sites updated to allow chain advancement
- cli.py: reads fallback_providers with legacy fallback_model compat
- gateway/run.py: same
- hermes_cli/config.py: fallback_providers: [] in DEFAULT_CONFIG
- tests: 12 new chain tests + 6 existing test fixtures updated
Co-authored-by: uzaylisak <uzaylisak@users.noreply.github.com>
This commit is contained in:
10
cli.py
10
cli.py
@@ -1182,9 +1182,13 @@ class HermesCLI:
|
||||
self._provider_require_params = pr.get("require_parameters", False)
|
||||
self._provider_data_collection = pr.get("data_collection")
|
||||
|
||||
# Fallback model config — tried when primary provider fails after retries
|
||||
fb = CLI_CONFIG.get("fallback_model") or {}
|
||||
self._fallback_model = fb if fb.get("provider") and fb.get("model") else None
|
||||
# Fallback provider chain — tried in order when primary fails after retries.
|
||||
# Supports new list format (fallback_providers) and legacy single-dict (fallback_model).
|
||||
fb = CLI_CONFIG.get("fallback_providers") or CLI_CONFIG.get("fallback_model") or []
|
||||
# Normalize legacy single-dict to a one-element list
|
||||
if isinstance(fb, dict):
|
||||
fb = [fb] if fb.get("provider") and fb.get("model") else []
|
||||
self._fallback_model = fb
|
||||
|
||||
# Optional cheap-vs-strong routing for simple turns
|
||||
self._smart_model_routing = CLI_CONFIG.get("smart_model_routing", {}) or {}
|
||||
|
||||
@@ -919,11 +919,12 @@ class GatewayRunner:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _load_fallback_model() -> dict | None:
|
||||
"""Load fallback model config from config.yaml.
|
||||
def _load_fallback_model() -> list | dict | None:
|
||||
"""Load fallback provider chain from config.yaml.
|
||||
|
||||
Returns a dict with 'provider' and 'model' keys, or None if
|
||||
not configured / both fields empty.
|
||||
Returns a list of provider dicts (``fallback_providers``), a single
|
||||
dict (legacy ``fallback_model``), or None if not configured.
|
||||
AIAgent.__init__ normalizes both formats into a chain.
|
||||
"""
|
||||
try:
|
||||
import yaml as _y
|
||||
@@ -931,8 +932,8 @@ class GatewayRunner:
|
||||
if cfg_path.exists():
|
||||
with open(cfg_path, encoding="utf-8") as _f:
|
||||
cfg = _y.safe_load(_f) or {}
|
||||
fb = cfg.get("fallback_model", {}) or {}
|
||||
if fb.get("provider") and fb.get("model"):
|
||||
fb = cfg.get("fallback_providers") or cfg.get("fallback_model") or None
|
||||
if fb:
|
||||
return fb
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -135,6 +135,7 @@ def ensure_hermes_home():
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"model": "anthropic/claude-opus-4.6",
|
||||
"fallback_providers": [],
|
||||
"toolsets": ["hermes-cli"],
|
||||
"agent": {
|
||||
"max_turns": 90,
|
||||
|
||||
59
run_agent.py
59
run_agent.py
@@ -896,16 +896,30 @@ class AIAgent:
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize OpenAI client: {e}")
|
||||
|
||||
# Provider fallback — a single backup model/provider tried when the
|
||||
# primary is exhausted (rate-limit, overload, connection failure).
|
||||
# Config shape: {"provider": "openrouter", "model": "anthropic/claude-sonnet-4"}
|
||||
self._fallback_model = fallback_model if isinstance(fallback_model, dict) else None
|
||||
# Provider fallback chain — ordered list of backup providers tried
|
||||
# when the primary is exhausted (rate-limit, overload, connection
|
||||
# failure). Supports both legacy single-dict ``fallback_model`` and
|
||||
# new list ``fallback_providers`` format.
|
||||
if isinstance(fallback_model, list):
|
||||
self._fallback_chain = [
|
||||
f for f in fallback_model
|
||||
if isinstance(f, dict) and f.get("provider") and f.get("model")
|
||||
]
|
||||
elif isinstance(fallback_model, dict) and fallback_model.get("provider") and fallback_model.get("model"):
|
||||
self._fallback_chain = [fallback_model]
|
||||
else:
|
||||
self._fallback_chain = []
|
||||
self._fallback_index = 0
|
||||
self._fallback_activated = False
|
||||
if self._fallback_model:
|
||||
fb_p = self._fallback_model.get("provider", "")
|
||||
fb_m = self._fallback_model.get("model", "")
|
||||
if fb_p and fb_m and not self.quiet_mode:
|
||||
print(f"🔄 Fallback model: {fb_m} ({fb_p})")
|
||||
# Legacy attribute kept for backward compat (tests, external callers)
|
||||
self._fallback_model = self._fallback_chain[0] if self._fallback_chain else None
|
||||
if self._fallback_chain and not self.quiet_mode:
|
||||
if len(self._fallback_chain) == 1:
|
||||
fb = self._fallback_chain[0]
|
||||
print(f"🔄 Fallback model: {fb['model']} ({fb['provider']})")
|
||||
else:
|
||||
print(f"🔄 Fallback chain ({len(self._fallback_chain)} providers): " +
|
||||
" → ".join(f"{f['model']} ({f['provider']})" for f in self._fallback_chain))
|
||||
|
||||
# Get available tools with filtering
|
||||
self.tools = get_tool_definitions(
|
||||
@@ -4318,25 +4332,26 @@ class AIAgent:
|
||||
# ── Provider fallback ──────────────────────────────────────────────────
|
||||
|
||||
def _try_activate_fallback(self) -> bool:
|
||||
"""Switch to the configured fallback model/provider.
|
||||
"""Switch to the next fallback model/provider in the chain.
|
||||
|
||||
Called when the primary model is failing after retries. Swaps the
|
||||
Called when the current model is failing after retries. Swaps the
|
||||
OpenAI client, model slug, and provider in-place so the retry loop
|
||||
can continue with the new backend. One-shot: returns False if
|
||||
already activated or not configured.
|
||||
can continue with the new backend. Advances through the chain on
|
||||
each call; returns False when exhausted.
|
||||
|
||||
Uses the centralized provider router (resolve_provider_client) for
|
||||
auth resolution and client construction — no duplicated provider→key
|
||||
mappings.
|
||||
"""
|
||||
if self._fallback_activated or not self._fallback_model:
|
||||
if self._fallback_index >= len(self._fallback_chain):
|
||||
return False
|
||||
|
||||
fb = self._fallback_model
|
||||
fb = self._fallback_chain[self._fallback_index]
|
||||
self._fallback_index += 1
|
||||
fb_provider = (fb.get("provider") or "").strip().lower()
|
||||
fb_model = (fb.get("model") or "").strip()
|
||||
if not fb_provider or not fb_model:
|
||||
return False
|
||||
return self._try_activate_fallback() # skip invalid, try next
|
||||
|
||||
# Use centralized router for client construction.
|
||||
# raw_codex=True because the main agent needs direct responses.stream()
|
||||
@@ -4349,7 +4364,7 @@ class AIAgent:
|
||||
logging.warning(
|
||||
"Fallback to %s failed: provider not configured",
|
||||
fb_provider)
|
||||
return False
|
||||
return self._try_activate_fallback() # try next in chain
|
||||
|
||||
# Determine api_mode from provider / base URL
|
||||
fb_api_mode = "chat_completions"
|
||||
@@ -4424,8 +4439,8 @@ class AIAgent:
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error("Failed to activate fallback model: %s", e)
|
||||
return False
|
||||
logging.error("Failed to activate fallback %s: %s", fb_model, e)
|
||||
return self._try_activate_fallback() # try next in chain
|
||||
|
||||
# ── End provider fallback ──────────────────────────────────────────────
|
||||
|
||||
@@ -6528,9 +6543,9 @@ class AIAgent:
|
||||
# Eager fallback: empty/malformed responses are a common
|
||||
# rate-limit symptom. Switch to fallback immediately
|
||||
# rather than retrying with extended backoff.
|
||||
if not self._fallback_activated:
|
||||
if self._fallback_index < len(self._fallback_chain):
|
||||
self._emit_status("⚠️ Empty/malformed response — switching to fallback...")
|
||||
if not self._fallback_activated and self._try_activate_fallback():
|
||||
if self._try_activate_fallback():
|
||||
retry_count = 0
|
||||
continue
|
||||
|
||||
@@ -6993,7 +7008,7 @@ class AIAgent:
|
||||
or "usage limit" in error_msg
|
||||
or "quota" in error_msg
|
||||
)
|
||||
if is_rate_limited and not self._fallback_activated:
|
||||
if is_rate_limited and self._fallback_index < len(self._fallback_chain):
|
||||
self._emit_status("⚠️ Rate limited — switching to fallback provider...")
|
||||
if self._try_activate_fallback():
|
||||
retry_count = 0
|
||||
|
||||
@@ -25,6 +25,8 @@ def _make_agent_with_compressor() -> AIAgent:
|
||||
"provider": "openai",
|
||||
"model": "gpt-4o",
|
||||
}
|
||||
agent._fallback_chain = [agent._fallback_model]
|
||||
agent._fallback_index = 0
|
||||
|
||||
# Context compressor with primary model values
|
||||
compressor = ContextCompressor(
|
||||
|
||||
156
tests/test_provider_fallback.py
Normal file
156
tests/test_provider_fallback.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Tests for ordered provider fallback chain (salvage of PR #1761).
|
||||
|
||||
Extends the single-fallback tests in test_fallback_model.py to cover
|
||||
the new list-based ``fallback_providers`` config format and chain
|
||||
advancement through multiple providers.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
def _make_agent(fallback_model=None):
|
||||
"""Create a minimal AIAgent with optional fallback config."""
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=[]),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
fallback_model=fallback_model,
|
||||
)
|
||||
agent.client = MagicMock()
|
||||
return agent
|
||||
|
||||
|
||||
def _mock_client(base_url="https://openrouter.ai/api/v1", api_key="fb-key"):
|
||||
mock = MagicMock()
|
||||
mock.base_url = base_url
|
||||
mock.api_key = api_key
|
||||
return mock
|
||||
|
||||
|
||||
# ── Chain initialisation ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFallbackChainInit:
|
||||
def test_no_fallback(self):
|
||||
agent = _make_agent(fallback_model=None)
|
||||
assert agent._fallback_chain == []
|
||||
assert agent._fallback_index == 0
|
||||
assert agent._fallback_model is None
|
||||
|
||||
def test_single_dict_backwards_compat(self):
|
||||
fb = {"provider": "openai", "model": "gpt-4o"}
|
||||
agent = _make_agent(fallback_model=fb)
|
||||
assert agent._fallback_chain == [fb]
|
||||
assert agent._fallback_model == fb
|
||||
|
||||
def test_list_of_providers(self):
|
||||
fbs = [
|
||||
{"provider": "openai", "model": "gpt-4o"},
|
||||
{"provider": "zai", "model": "glm-4.7"},
|
||||
]
|
||||
agent = _make_agent(fallback_model=fbs)
|
||||
assert len(agent._fallback_chain) == 2
|
||||
assert agent._fallback_model == fbs[0]
|
||||
|
||||
def test_invalid_entries_filtered(self):
|
||||
fbs = [
|
||||
{"provider": "openai", "model": "gpt-4o"},
|
||||
{"provider": "", "model": "glm-4.7"},
|
||||
{"provider": "zai"},
|
||||
"not-a-dict",
|
||||
]
|
||||
agent = _make_agent(fallback_model=fbs)
|
||||
assert len(agent._fallback_chain) == 1
|
||||
assert agent._fallback_chain[0]["provider"] == "openai"
|
||||
|
||||
def test_empty_list(self):
|
||||
agent = _make_agent(fallback_model=[])
|
||||
assert agent._fallback_chain == []
|
||||
assert agent._fallback_model is None
|
||||
|
||||
def test_invalid_dict_no_provider(self):
|
||||
agent = _make_agent(fallback_model={"model": "gpt-4o"})
|
||||
assert agent._fallback_chain == []
|
||||
|
||||
|
||||
# ── Chain advancement ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFallbackChainAdvancement:
|
||||
def test_exhausted_returns_false(self):
|
||||
agent = _make_agent(fallback_model=None)
|
||||
assert agent._try_activate_fallback() is False
|
||||
|
||||
def test_advances_index(self):
|
||||
fbs = [
|
||||
{"provider": "openai", "model": "gpt-4o"},
|
||||
{"provider": "zai", "model": "glm-4.7"},
|
||||
]
|
||||
agent = _make_agent(fallback_model=fbs)
|
||||
with patch("agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(_mock_client(), "gpt-4o")):
|
||||
assert agent._try_activate_fallback() is True
|
||||
assert agent._fallback_index == 1
|
||||
assert agent.model == "gpt-4o"
|
||||
assert agent._fallback_activated is True
|
||||
|
||||
def test_second_fallback_works(self):
|
||||
fbs = [
|
||||
{"provider": "openai", "model": "gpt-4o"},
|
||||
{"provider": "zai", "model": "glm-4.7"},
|
||||
]
|
||||
agent = _make_agent(fallback_model=fbs)
|
||||
with patch("agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(_mock_client(), "resolved")):
|
||||
assert agent._try_activate_fallback() is True
|
||||
assert agent.model == "gpt-4o"
|
||||
assert agent._try_activate_fallback() is True
|
||||
assert agent.model == "glm-4.7"
|
||||
assert agent._fallback_index == 2
|
||||
|
||||
def test_all_exhausted_returns_false(self):
|
||||
fbs = [{"provider": "openai", "model": "gpt-4o"}]
|
||||
agent = _make_agent(fallback_model=fbs)
|
||||
with patch("agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(_mock_client(), "gpt-4o")):
|
||||
assert agent._try_activate_fallback() is True
|
||||
assert agent._try_activate_fallback() is False
|
||||
|
||||
def test_skips_unconfigured_provider_to_next(self):
|
||||
"""If resolve_provider_client returns None, skip to next in chain."""
|
||||
fbs = [
|
||||
{"provider": "broken", "model": "nope"},
|
||||
{"provider": "openai", "model": "gpt-4o"},
|
||||
]
|
||||
agent = _make_agent(fallback_model=fbs)
|
||||
with patch("agent.auxiliary_client.resolve_provider_client") as mock_rpc:
|
||||
mock_rpc.side_effect = [
|
||||
(None, None), # broken provider
|
||||
(_mock_client(), "gpt-4o"), # fallback succeeds
|
||||
]
|
||||
assert agent._try_activate_fallback() is True
|
||||
assert agent.model == "gpt-4o"
|
||||
assert agent._fallback_index == 2
|
||||
|
||||
def test_skips_provider_that_raises_to_next(self):
|
||||
"""If resolve_provider_client raises, skip to next in chain."""
|
||||
fbs = [
|
||||
{"provider": "broken", "model": "nope"},
|
||||
{"provider": "openai", "model": "gpt-4o"},
|
||||
]
|
||||
agent = _make_agent(fallback_model=fbs)
|
||||
with patch("agent.auxiliary_client.resolve_provider_client") as mock_rpc:
|
||||
mock_rpc.side_effect = [
|
||||
RuntimeError("auth failed"),
|
||||
(_mock_client(), "gpt-4o"),
|
||||
]
|
||||
assert agent._try_activate_fallback() is True
|
||||
assert agent.model == "gpt-4o"
|
||||
@@ -2507,6 +2507,8 @@ class TestFallbackAnthropicProvider:
|
||||
def test_fallback_to_anthropic_sets_api_mode(self, agent):
|
||||
agent._fallback_activated = False
|
||||
agent._fallback_model = {"provider": "anthropic", "model": "claude-sonnet-4-20250514"}
|
||||
agent._fallback_chain = [agent._fallback_model]
|
||||
agent._fallback_index = 0
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "https://api.anthropic.com/v1"
|
||||
@@ -2528,6 +2530,8 @@ class TestFallbackAnthropicProvider:
|
||||
def test_fallback_to_anthropic_enables_prompt_caching(self, agent):
|
||||
agent._fallback_activated = False
|
||||
agent._fallback_model = {"provider": "anthropic", "model": "claude-sonnet-4-20250514"}
|
||||
agent._fallback_chain = [agent._fallback_model]
|
||||
agent._fallback_index = 0
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "https://api.anthropic.com/v1"
|
||||
@@ -2545,6 +2549,8 @@ class TestFallbackAnthropicProvider:
|
||||
def test_fallback_to_openrouter_uses_openai_client(self, agent):
|
||||
agent._fallback_activated = False
|
||||
agent._fallback_model = {"provider": "openrouter", "model": "anthropic/claude-sonnet-4"}
|
||||
agent._fallback_chain = [agent._fallback_model]
|
||||
agent._fallback_index = 0
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "https://openrouter.ai/api/v1"
|
||||
@@ -3238,6 +3244,8 @@ class TestFallbackSetsOAuthFlag:
|
||||
def test_fallback_to_anthropic_oauth_sets_flag(self, agent):
|
||||
agent._fallback_activated = False
|
||||
agent._fallback_model = {"provider": "anthropic", "model": "claude-sonnet-4-6"}
|
||||
agent._fallback_chain = [agent._fallback_model]
|
||||
agent._fallback_index = 0
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "https://api.anthropic.com/v1"
|
||||
@@ -3259,6 +3267,8 @@ class TestFallbackSetsOAuthFlag:
|
||||
def test_fallback_to_anthropic_api_key_clears_flag(self, agent):
|
||||
agent._fallback_activated = False
|
||||
agent._fallback_model = {"provider": "anthropic", "model": "claude-sonnet-4-6"}
|
||||
agent._fallback_chain = [agent._fallback_model]
|
||||
agent._fallback_index = 0
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "https://api.anthropic.com/v1"
|
||||
|
||||
Reference in New Issue
Block a user