feat: per-turn primary runtime restoration and transport recovery (#4624)

Makes provider fallback turn-scoped in long-lived CLI sessions. Previously, a single transient failure pinned the session to the fallback provider for every subsequent turn.

- _primary_runtime dict snapshot at __init__ (model, provider, base_url, api_mode, client_kwargs, compressor state)
- _restore_primary_runtime() at top of run_conversation() — restores all state, resets fallback chain index
- _try_recover_primary_transport() — one extra recovery cycle (client rebuild + cooldown) for transient transport errors on direct endpoints before fallback
- Skipped for aggregator providers (OpenRouter, Nous)
- 25 tests

Inspired by #4612 (@betamod). Closes #4612.
This commit is contained in:
Teknium
2026-04-02 10:52:01 -07:00
committed by GitHub
parent 918d593544
commit 3186668799
2 changed files with 621 additions and 3 deletions

View File

@@ -1236,6 +1236,34 @@ class AIAgent:
else:
print(f"📊 Context limit: {self.context_compressor.context_length:,} tokens (auto-compression disabled)")
# Snapshot primary runtime for per-turn restoration. When fallback
# activates during a turn, the next turn restores these values so the
# preferred model gets a fresh attempt each time. Uses a single dict
# so new state fields are easy to add without N individual attributes.
_cc = self.context_compressor
self._primary_runtime = {
"model": self.model,
"provider": self.provider,
"base_url": self.base_url,
"api_mode": self.api_mode,
"api_key": getattr(self, "api_key", ""),
"client_kwargs": dict(self._client_kwargs),
"use_prompt_caching": self._use_prompt_caching,
# Compressor state that _try_activate_fallback() overwrites
"compressor_model": _cc.model,
"compressor_base_url": _cc.base_url,
"compressor_api_key": getattr(_cc, "api_key", ""),
"compressor_provider": _cc.provider,
"compressor_context_length": _cc.context_length,
"compressor_threshold_tokens": _cc.threshold_tokens,
}
if self.api_mode == "anthropic_messages":
self._primary_runtime.update({
"anthropic_api_key": self._anthropic_api_key,
"anthropic_base_url": self._anthropic_base_url,
"is_anthropic_oauth": self._is_anthropic_oauth,
})
def reset_session_state(self):
"""Reset all session-scoped token counters to 0 for a fresh session.
@@ -4770,6 +4798,156 @@ class AIAgent:
logging.error("Failed to activate fallback %s: %s", fb_model, e)
return self._try_activate_fallback() # try next in chain
# ── Per-turn primary restoration ─────────────────────────────────────
def _restore_primary_runtime(self) -> bool:
"""Restore the primary runtime at the start of a new turn.
In long-lived CLI sessions a single AIAgent instance spans multiple
turns. Without restoration, one transient failure pins the session
to the fallback provider for every subsequent turn. Calling this at
the top of ``run_conversation()`` makes fallback turn-scoped.
The gateway creates a fresh agent per message so this is a no-op
there (``_fallback_activated`` is always False at turn start).
"""
if not self._fallback_activated:
return False
rt = self._primary_runtime
try:
# ── Core runtime state ──
self.model = rt["model"]
self.provider = rt["provider"]
self.base_url = rt["base_url"] # setter updates _base_url_lower
self.api_mode = rt["api_mode"]
self.api_key = rt["api_key"]
self._client_kwargs = dict(rt["client_kwargs"])
self._use_prompt_caching = rt["use_prompt_caching"]
# ── Rebuild client for the primary provider ──
if self.api_mode == "anthropic_messages":
from agent.anthropic_adapter import build_anthropic_client
self._anthropic_api_key = rt["anthropic_api_key"]
self._anthropic_base_url = rt["anthropic_base_url"]
self._anthropic_client = build_anthropic_client(
rt["anthropic_api_key"], rt["anthropic_base_url"],
)
self._is_anthropic_oauth = rt["is_anthropic_oauth"]
self.client = None
else:
self.client = self._create_openai_client(
dict(rt["client_kwargs"]),
reason="restore_primary",
shared=True,
)
# ── Restore context compressor state ──
cc = self.context_compressor
cc.model = rt["compressor_model"]
cc.base_url = rt["compressor_base_url"]
cc.api_key = rt["compressor_api_key"]
cc.provider = rt["compressor_provider"]
cc.context_length = rt["compressor_context_length"]
cc.threshold_tokens = rt["compressor_threshold_tokens"]
# ── Reset fallback chain for the new turn ──
self._fallback_activated = False
self._fallback_index = 0
logging.info(
"Primary runtime restored for new turn: %s (%s)",
self.model, self.provider,
)
return True
except Exception as e:
logging.warning("Failed to restore primary runtime: %s", e)
return False
# Which error types indicate a transient transport failure worth
# one more attempt with a rebuilt client / connection pool.
_TRANSIENT_TRANSPORT_ERRORS = frozenset({
"ReadTimeout", "ConnectTimeout", "PoolTimeout",
"ConnectError", "RemoteProtocolError",
})
def _try_recover_primary_transport(
self, api_error: Exception, *, retry_count: int, max_retries: int,
) -> bool:
"""Attempt one extra primary-provider recovery cycle for transient transport failures.
After ``max_retries`` exhaust, rebuild the primary client (clearing
stale connection pools) and give it one more attempt before falling
back. This is most useful for direct endpoints (custom, Z.AI,
Anthropic, OpenAI, local models) where a TCP-level hiccup does not
mean the provider is down.
Skipped for proxy/aggregator providers (OpenRouter, Nous) which
already manage connection pools and retries server-side — if our
retries through them are exhausted, one more rebuilt client won't help.
"""
if self._fallback_activated:
return False
# Only for transient transport errors
error_type = type(api_error).__name__
if error_type not in self._TRANSIENT_TRANSPORT_ERRORS:
return False
# Skip for aggregator providers — they manage their own retry infra
if self._is_openrouter_url():
return False
provider_lower = (self.provider or "").strip().lower()
if provider_lower in ("nous", "nous-research"):
return False
try:
# Close existing client to release stale connections
if getattr(self, "client", None) is not None:
try:
self._close_openai_client(
self.client, reason="primary_recovery", shared=True,
)
except Exception:
pass
# Rebuild from primary snapshot
rt = self._primary_runtime
self._client_kwargs = dict(rt["client_kwargs"])
self.model = rt["model"]
self.provider = rt["provider"]
self.base_url = rt["base_url"]
self.api_mode = rt["api_mode"]
self.api_key = rt["api_key"]
if self.api_mode == "anthropic_messages":
from agent.anthropic_adapter import build_anthropic_client
self._anthropic_api_key = rt["anthropic_api_key"]
self._anthropic_base_url = rt["anthropic_base_url"]
self._anthropic_client = build_anthropic_client(
rt["anthropic_api_key"], rt["anthropic_base_url"],
)
self._is_anthropic_oauth = rt["is_anthropic_oauth"]
self.client = None
else:
self.client = self._create_openai_client(
dict(rt["client_kwargs"]),
reason="primary_recovery",
shared=True,
)
wait_time = min(3 + retry_count, 8)
self._vprint(
f"{self.log_prefix}🔁 Transient {error_type} on {self.provider}"
f"rebuilt client, waiting {wait_time}s before one last primary attempt.",
force=True,
)
time.sleep(wait_time)
return True
except Exception as e:
logging.warning("Primary transport recovery failed: %s", e)
return False
# ── End provider fallback ──────────────────────────────────────────────
@staticmethod
@@ -6408,6 +6586,11 @@ class AIAgent:
# Installed once, transparent when streams are healthy, prevents crash on write.
_install_safe_stdio()
# If the previous turn activated fallback, restore the primary
# runtime so this turn gets a fresh attempt with the preferred model.
# No-op when _fallback_activated is False (gateway, first turn, etc.).
self._restore_primary_runtime()
# Sanitize surrogate characters from user input. Clipboard paste from
# rich-text editors (Google Docs, Word, etc.) can inject lone surrogates
# that are invalid UTF-8 and crash JSON serialization in the OpenAI SDK.
@@ -6826,10 +7009,11 @@ class AIAgent:
api_start_time = time.time()
retry_count = 0
max_retries = 3
primary_recovery_attempted = False
max_compression_attempts = 3
codex_auth_retry_attempted = False
anthropic_auth_retry_attempted = False
nous_auth_retry_attempted = False
codex_auth_retry_attempted=False
anthropic_auth_retry_attempted=False
nous_auth_retry_attempted=False
has_retried_429 = False
restart_with_compressed_messages = False
restart_with_length_continuation = False
@@ -7664,6 +7848,16 @@ class AIAgent:
}
if retry_count >= max_retries:
# Before falling back, try rebuilding the primary
# client once for transient transport errors (stale
# connection pool, TCP reset). Only attempted once
# per API call block.
if not primary_recovery_attempted and self._try_recover_primary_transport(
api_error, retry_count=retry_count, max_retries=max_retries,
):
primary_recovery_attempted = True
retry_count = 0
continue
# Try fallback before giving up entirely
self._emit_status(f"⚠️ Max retries ({max_retries}) exhausted — trying fallback...")
if self._try_activate_fallback():

View File

@@ -0,0 +1,424 @@
"""Tests for per-turn primary runtime restoration and transport recovery.
Verifies that:
1. Fallback is turn-scoped: a new turn restores the primary model/provider
2. The fallback chain index resets so all fallbacks are available again
3. Context compressor state is restored alongside the runtime
4. Transient transport errors get one recovery cycle before fallback
5. Recovery is skipped for aggregator providers (OpenRouter, Nous)
6. Non-transport errors don't trigger recovery
"""
import time
from types import SimpleNamespace
from unittest.mock import MagicMock, patch, PropertyMock
import pytest
from run_agent import AIAgent
def _make_tool_defs(*names: str) -> list:
return [
{
"type": "function",
"function": {
"name": n,
"description": f"{n} tool",
"parameters": {"type": "object", "properties": {}},
},
}
for n in names
]
def _make_agent(fallback_model=None, provider="custom", base_url="https://my-llm.example.com/v1"):
"""Create a minimal AIAgent with optional fallback config."""
with (
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
):
agent = AIAgent(
api_key="test-key-12345678",
base_url=base_url,
provider=provider,
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
fallback_model=fallback_model,
)
agent.client = MagicMock()
return agent
def _mock_resolve(base_url="https://openrouter.ai/api/v1", api_key="fallback-key-1234"):
"""Helper to create a mock client for resolve_provider_client."""
mock_client = MagicMock()
mock_client.api_key = api_key
mock_client.base_url = base_url
return mock_client
# =============================================================================
# _primary_runtime snapshot
# =============================================================================
class TestPrimaryRuntimeSnapshot:
def test_snapshot_created_at_init(self):
agent = _make_agent()
assert hasattr(agent, "_primary_runtime")
rt = agent._primary_runtime
assert rt["model"] == agent.model
assert rt["provider"] == "custom"
assert rt["base_url"] == "https://my-llm.example.com/v1"
assert rt["api_mode"] == agent.api_mode
assert "client_kwargs" in rt
assert "compressor_context_length" in rt
def test_snapshot_includes_compressor_state(self):
agent = _make_agent()
rt = agent._primary_runtime
cc = agent.context_compressor
assert rt["compressor_model"] == cc.model
assert rt["compressor_provider"] == cc.provider
assert rt["compressor_context_length"] == cc.context_length
assert rt["compressor_threshold_tokens"] == cc.threshold_tokens
def test_snapshot_includes_anthropic_state_when_applicable(self):
"""Anthropic-mode agents should snapshot Anthropic-specific state."""
with (
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
):
agent = AIAgent(
api_key="sk-ant-test-12345678",
base_url="https://api.anthropic.com",
provider="anthropic",
api_mode="anthropic_messages",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
rt = agent._primary_runtime
assert "anthropic_api_key" in rt
assert "anthropic_base_url" in rt
assert "is_anthropic_oauth" in rt
def test_snapshot_omits_anthropic_for_openai_mode(self):
agent = _make_agent(provider="custom")
rt = agent._primary_runtime
assert "anthropic_api_key" not in rt
# =============================================================================
# _restore_primary_runtime()
# =============================================================================
class TestRestorePrimaryRuntime:
def test_noop_when_not_fallback(self):
agent = _make_agent()
assert agent._fallback_activated is False
assert agent._restore_primary_runtime() is False
def test_restores_model_and_provider(self):
agent = _make_agent(
fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"},
)
original_model = agent.model
original_provider = agent.provider
# Simulate fallback activation
mock_client = _mock_resolve()
with patch("agent.auxiliary_client.resolve_provider_client", return_value=(mock_client, None)):
agent._try_activate_fallback()
assert agent._fallback_activated is True
assert agent.model == "anthropic/claude-sonnet-4"
assert agent.provider == "openrouter"
# Restore should bring back the primary
with patch("run_agent.OpenAI", return_value=MagicMock()):
result = agent._restore_primary_runtime()
assert result is True
assert agent._fallback_activated is False
assert agent.model == original_model
assert agent.provider == original_provider
def test_resets_fallback_index(self):
"""After restore, the full fallback chain should be available again."""
agent = _make_agent(
fallback_model=[
{"provider": "openrouter", "model": "model-a"},
{"provider": "anthropic", "model": "model-b"},
],
)
# Advance through the chain
mock_client = _mock_resolve()
with patch("agent.auxiliary_client.resolve_provider_client", return_value=(mock_client, None)):
agent._try_activate_fallback()
assert agent._fallback_index == 1 # consumed one entry
with patch("run_agent.OpenAI", return_value=MagicMock()):
agent._restore_primary_runtime()
assert agent._fallback_index == 0 # reset for next turn
def test_restores_compressor_state(self):
agent = _make_agent(
fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"},
)
original_ctx_len = agent.context_compressor.context_length
original_threshold = agent.context_compressor.threshold_tokens
# Simulate fallback modifying compressor
mock_client = _mock_resolve()
with patch("agent.auxiliary_client.resolve_provider_client", return_value=(mock_client, None)):
agent._try_activate_fallback()
# Manually simulate compressor being changed (as _try_activate_fallback does)
agent.context_compressor.context_length = 32000
agent.context_compressor.threshold_tokens = 25600
with patch("run_agent.OpenAI", return_value=MagicMock()):
agent._restore_primary_runtime()
assert agent.context_compressor.context_length == original_ctx_len
assert agent.context_compressor.threshold_tokens == original_threshold
def test_restores_prompt_caching_flag(self):
agent = _make_agent()
original_caching = agent._use_prompt_caching
# Simulate fallback changing the caching flag
agent._fallback_activated = True
agent._use_prompt_caching = not original_caching
with patch("run_agent.OpenAI", return_value=MagicMock()):
agent._restore_primary_runtime()
assert agent._use_prompt_caching == original_caching
def test_restore_survives_exception(self):
"""If client rebuild fails, the method returns False gracefully."""
agent = _make_agent()
agent._fallback_activated = True
with patch("run_agent.OpenAI", side_effect=Exception("connection refused")):
result = agent._restore_primary_runtime()
assert result is False
# =============================================================================
# _try_recover_primary_transport()
# =============================================================================
def _make_transport_error(error_type="ReadTimeout"):
"""Create an exception whose type().__name__ matches the given name."""
cls = type(error_type, (Exception,), {})
return cls("connection timed out")
class TestTryRecoverPrimaryTransport:
def test_recovers_on_read_timeout(self):
agent = _make_agent(provider="custom")
error = _make_transport_error("ReadTimeout")
with patch("run_agent.OpenAI", return_value=MagicMock()), \
patch("time.sleep"):
result = agent._try_recover_primary_transport(
error, retry_count=3, max_retries=3,
)
assert result is True
def test_recovers_on_connect_timeout(self):
agent = _make_agent(provider="custom")
error = _make_transport_error("ConnectTimeout")
with patch("run_agent.OpenAI", return_value=MagicMock()), \
patch("time.sleep"):
result = agent._try_recover_primary_transport(
error, retry_count=3, max_retries=3,
)
assert result is True
def test_recovers_on_pool_timeout(self):
agent = _make_agent(provider="zai")
error = _make_transport_error("PoolTimeout")
with patch("run_agent.OpenAI", return_value=MagicMock()), \
patch("time.sleep"):
result = agent._try_recover_primary_transport(
error, retry_count=3, max_retries=3,
)
assert result is True
def test_skipped_when_already_on_fallback(self):
agent = _make_agent(provider="custom")
agent._fallback_activated = True
error = _make_transport_error("ReadTimeout")
result = agent._try_recover_primary_transport(
error, retry_count=3, max_retries=3,
)
assert result is False
def test_skipped_for_non_transport_error(self):
"""Non-transport errors (ValueError, APIError, etc.) skip recovery."""
agent = _make_agent(provider="custom")
error = ValueError("invalid model")
result = agent._try_recover_primary_transport(
error, retry_count=3, max_retries=3,
)
assert result is False
def test_skipped_for_openrouter(self):
agent = _make_agent(provider="openrouter", base_url="https://openrouter.ai/api/v1")
error = _make_transport_error("ReadTimeout")
result = agent._try_recover_primary_transport(
error, retry_count=3, max_retries=3,
)
assert result is False
def test_skipped_for_nous_provider(self):
agent = _make_agent(provider="nous", base_url="https://inference.nous.nousresearch.com/v1")
error = _make_transport_error("ReadTimeout")
result = agent._try_recover_primary_transport(
error, retry_count=3, max_retries=3,
)
assert result is False
def test_allowed_for_anthropic_direct(self):
"""Direct Anthropic endpoint should get recovery."""
agent = _make_agent(provider="anthropic", base_url="https://api.anthropic.com")
# For non-anthropic_messages api_mode, it will use OpenAI client
error = _make_transport_error("ConnectError")
with patch("run_agent.OpenAI", return_value=MagicMock()), \
patch("time.sleep"):
result = agent._try_recover_primary_transport(
error, retry_count=3, max_retries=3,
)
assert result is True
def test_allowed_for_ollama(self):
agent = _make_agent(provider="ollama", base_url="http://localhost:11434/v1")
error = _make_transport_error("ConnectTimeout")
with patch("run_agent.OpenAI", return_value=MagicMock()), \
patch("time.sleep"):
result = agent._try_recover_primary_transport(
error, retry_count=3, max_retries=3,
)
assert result is True
def test_wait_time_scales_with_retry_count(self):
agent = _make_agent(provider="custom")
error = _make_transport_error("ReadTimeout")
with patch("run_agent.OpenAI", return_value=MagicMock()), \
patch("time.sleep") as mock_sleep:
agent._try_recover_primary_transport(
error, retry_count=3, max_retries=3,
)
# wait_time = min(3 + retry_count, 8) = min(6, 8) = 6
mock_sleep.assert_called_once_with(6)
def test_wait_time_capped_at_8(self):
agent = _make_agent(provider="custom")
error = _make_transport_error("ReadTimeout")
with patch("run_agent.OpenAI", return_value=MagicMock()), \
patch("time.sleep") as mock_sleep:
agent._try_recover_primary_transport(
error, retry_count=10, max_retries=3,
)
# wait_time = min(3 + 10, 8) = 8
mock_sleep.assert_called_once_with(8)
def test_closes_existing_client_before_rebuild(self):
agent = _make_agent(provider="custom")
old_client = agent.client
error = _make_transport_error("ReadTimeout")
with patch("run_agent.OpenAI", return_value=MagicMock()), \
patch("time.sleep"), \
patch.object(agent, "_close_openai_client") as mock_close:
agent._try_recover_primary_transport(
error, retry_count=3, max_retries=3,
)
mock_close.assert_called_once_with(
old_client, reason="primary_recovery", shared=True,
)
def test_survives_rebuild_failure(self):
"""If client rebuild fails, returns False gracefully."""
agent = _make_agent(provider="custom")
error = _make_transport_error("ReadTimeout")
with patch("run_agent.OpenAI", side_effect=Exception("socket error")), \
patch("time.sleep"):
result = agent._try_recover_primary_transport(
error, retry_count=3, max_retries=3,
)
assert result is False
# =============================================================================
# Integration: restore_primary_runtime called from run_conversation
# =============================================================================
class TestRestoreInRunConversation:
"""Verify the hook in run_conversation() calls _restore_primary_runtime."""
def test_restore_called_at_turn_start(self):
agent = _make_agent()
agent._fallback_activated = True
with patch.object(agent, "_restore_primary_runtime", return_value=True) as mock_restore, \
patch.object(agent, "run_conversation", wraps=None) as _:
# We can't easily run the full conversation, but we can verify
# the method exists and is callable
agent._restore_primary_runtime()
mock_restore.assert_called_once()
def test_full_cycle_fallback_then_restore(self):
"""Simulate: turn 1 activates fallback, turn 2 restores primary."""
agent = _make_agent(
fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"},
provider="custom",
)
# Turn 1: activate fallback
mock_client = _mock_resolve()
with patch("agent.auxiliary_client.resolve_provider_client", return_value=(mock_client, None)):
assert agent._try_activate_fallback() is True
assert agent._fallback_activated is True
assert agent.model == "anthropic/claude-sonnet-4"
assert agent.provider == "openrouter"
assert agent._fallback_index == 1
# Turn 2: restore primary
with patch("run_agent.OpenAI", return_value=MagicMock()):
assert agent._restore_primary_runtime() is True
assert agent._fallback_activated is False
assert agent._fallback_index == 0
assert agent.provider == "custom"
assert agent.base_url == "https://my-llm.example.com/v1"