feat: per-turn primary runtime restoration and transport recovery (#4624)
Makes provider fallback turn-scoped in long-lived CLI sessions. Previously, a single transient failure pinned the session to the fallback provider for every subsequent turn. - _primary_runtime dict snapshot at __init__ (model, provider, base_url, api_mode, client_kwargs, compressor state) - _restore_primary_runtime() at top of run_conversation() — restores all state, resets fallback chain index - _try_recover_primary_transport() — one extra recovery cycle (client rebuild + cooldown) for transient transport errors on direct endpoints before fallback - Skipped for aggregator providers (OpenRouter, Nous) - 25 tests Inspired by #4612 (@betamod). Closes #4612.
This commit is contained in:
200
run_agent.py
200
run_agent.py
@@ -1236,6 +1236,34 @@ class AIAgent:
|
||||
else:
|
||||
print(f"📊 Context limit: {self.context_compressor.context_length:,} tokens (auto-compression disabled)")
|
||||
|
||||
# Snapshot primary runtime for per-turn restoration. When fallback
|
||||
# activates during a turn, the next turn restores these values so the
|
||||
# preferred model gets a fresh attempt each time. Uses a single dict
|
||||
# so new state fields are easy to add without N individual attributes.
|
||||
_cc = self.context_compressor
|
||||
self._primary_runtime = {
|
||||
"model": self.model,
|
||||
"provider": self.provider,
|
||||
"base_url": self.base_url,
|
||||
"api_mode": self.api_mode,
|
||||
"api_key": getattr(self, "api_key", ""),
|
||||
"client_kwargs": dict(self._client_kwargs),
|
||||
"use_prompt_caching": self._use_prompt_caching,
|
||||
# Compressor state that _try_activate_fallback() overwrites
|
||||
"compressor_model": _cc.model,
|
||||
"compressor_base_url": _cc.base_url,
|
||||
"compressor_api_key": getattr(_cc, "api_key", ""),
|
||||
"compressor_provider": _cc.provider,
|
||||
"compressor_context_length": _cc.context_length,
|
||||
"compressor_threshold_tokens": _cc.threshold_tokens,
|
||||
}
|
||||
if self.api_mode == "anthropic_messages":
|
||||
self._primary_runtime.update({
|
||||
"anthropic_api_key": self._anthropic_api_key,
|
||||
"anthropic_base_url": self._anthropic_base_url,
|
||||
"is_anthropic_oauth": self._is_anthropic_oauth,
|
||||
})
|
||||
|
||||
def reset_session_state(self):
|
||||
"""Reset all session-scoped token counters to 0 for a fresh session.
|
||||
|
||||
@@ -4770,6 +4798,156 @@ class AIAgent:
|
||||
logging.error("Failed to activate fallback %s: %s", fb_model, e)
|
||||
return self._try_activate_fallback() # try next in chain
|
||||
|
||||
# ── Per-turn primary restoration ─────────────────────────────────────
|
||||
|
||||
def _restore_primary_runtime(self) -> bool:
|
||||
"""Restore the primary runtime at the start of a new turn.
|
||||
|
||||
In long-lived CLI sessions a single AIAgent instance spans multiple
|
||||
turns. Without restoration, one transient failure pins the session
|
||||
to the fallback provider for every subsequent turn. Calling this at
|
||||
the top of ``run_conversation()`` makes fallback turn-scoped.
|
||||
|
||||
The gateway creates a fresh agent per message so this is a no-op
|
||||
there (``_fallback_activated`` is always False at turn start).
|
||||
"""
|
||||
if not self._fallback_activated:
|
||||
return False
|
||||
|
||||
rt = self._primary_runtime
|
||||
try:
|
||||
# ── Core runtime state ──
|
||||
self.model = rt["model"]
|
||||
self.provider = rt["provider"]
|
||||
self.base_url = rt["base_url"] # setter updates _base_url_lower
|
||||
self.api_mode = rt["api_mode"]
|
||||
self.api_key = rt["api_key"]
|
||||
self._client_kwargs = dict(rt["client_kwargs"])
|
||||
self._use_prompt_caching = rt["use_prompt_caching"]
|
||||
|
||||
# ── Rebuild client for the primary provider ──
|
||||
if self.api_mode == "anthropic_messages":
|
||||
from agent.anthropic_adapter import build_anthropic_client
|
||||
self._anthropic_api_key = rt["anthropic_api_key"]
|
||||
self._anthropic_base_url = rt["anthropic_base_url"]
|
||||
self._anthropic_client = build_anthropic_client(
|
||||
rt["anthropic_api_key"], rt["anthropic_base_url"],
|
||||
)
|
||||
self._is_anthropic_oauth = rt["is_anthropic_oauth"]
|
||||
self.client = None
|
||||
else:
|
||||
self.client = self._create_openai_client(
|
||||
dict(rt["client_kwargs"]),
|
||||
reason="restore_primary",
|
||||
shared=True,
|
||||
)
|
||||
|
||||
# ── Restore context compressor state ──
|
||||
cc = self.context_compressor
|
||||
cc.model = rt["compressor_model"]
|
||||
cc.base_url = rt["compressor_base_url"]
|
||||
cc.api_key = rt["compressor_api_key"]
|
||||
cc.provider = rt["compressor_provider"]
|
||||
cc.context_length = rt["compressor_context_length"]
|
||||
cc.threshold_tokens = rt["compressor_threshold_tokens"]
|
||||
|
||||
# ── Reset fallback chain for the new turn ──
|
||||
self._fallback_activated = False
|
||||
self._fallback_index = 0
|
||||
|
||||
logging.info(
|
||||
"Primary runtime restored for new turn: %s (%s)",
|
||||
self.model, self.provider,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.warning("Failed to restore primary runtime: %s", e)
|
||||
return False
|
||||
|
||||
# Which error types indicate a transient transport failure worth
|
||||
# one more attempt with a rebuilt client / connection pool.
|
||||
_TRANSIENT_TRANSPORT_ERRORS = frozenset({
|
||||
"ReadTimeout", "ConnectTimeout", "PoolTimeout",
|
||||
"ConnectError", "RemoteProtocolError",
|
||||
})
|
||||
|
||||
def _try_recover_primary_transport(
|
||||
self, api_error: Exception, *, retry_count: int, max_retries: int,
|
||||
) -> bool:
|
||||
"""Attempt one extra primary-provider recovery cycle for transient transport failures.
|
||||
|
||||
After ``max_retries`` exhaust, rebuild the primary client (clearing
|
||||
stale connection pools) and give it one more attempt before falling
|
||||
back. This is most useful for direct endpoints (custom, Z.AI,
|
||||
Anthropic, OpenAI, local models) where a TCP-level hiccup does not
|
||||
mean the provider is down.
|
||||
|
||||
Skipped for proxy/aggregator providers (OpenRouter, Nous) which
|
||||
already manage connection pools and retries server-side — if our
|
||||
retries through them are exhausted, one more rebuilt client won't help.
|
||||
"""
|
||||
if self._fallback_activated:
|
||||
return False
|
||||
|
||||
# Only for transient transport errors
|
||||
error_type = type(api_error).__name__
|
||||
if error_type not in self._TRANSIENT_TRANSPORT_ERRORS:
|
||||
return False
|
||||
|
||||
# Skip for aggregator providers — they manage their own retry infra
|
||||
if self._is_openrouter_url():
|
||||
return False
|
||||
provider_lower = (self.provider or "").strip().lower()
|
||||
if provider_lower in ("nous", "nous-research"):
|
||||
return False
|
||||
|
||||
try:
|
||||
# Close existing client to release stale connections
|
||||
if getattr(self, "client", None) is not None:
|
||||
try:
|
||||
self._close_openai_client(
|
||||
self.client, reason="primary_recovery", shared=True,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Rebuild from primary snapshot
|
||||
rt = self._primary_runtime
|
||||
self._client_kwargs = dict(rt["client_kwargs"])
|
||||
self.model = rt["model"]
|
||||
self.provider = rt["provider"]
|
||||
self.base_url = rt["base_url"]
|
||||
self.api_mode = rt["api_mode"]
|
||||
self.api_key = rt["api_key"]
|
||||
|
||||
if self.api_mode == "anthropic_messages":
|
||||
from agent.anthropic_adapter import build_anthropic_client
|
||||
self._anthropic_api_key = rt["anthropic_api_key"]
|
||||
self._anthropic_base_url = rt["anthropic_base_url"]
|
||||
self._anthropic_client = build_anthropic_client(
|
||||
rt["anthropic_api_key"], rt["anthropic_base_url"],
|
||||
)
|
||||
self._is_anthropic_oauth = rt["is_anthropic_oauth"]
|
||||
self.client = None
|
||||
else:
|
||||
self.client = self._create_openai_client(
|
||||
dict(rt["client_kwargs"]),
|
||||
reason="primary_recovery",
|
||||
shared=True,
|
||||
)
|
||||
|
||||
wait_time = min(3 + retry_count, 8)
|
||||
self._vprint(
|
||||
f"{self.log_prefix}🔁 Transient {error_type} on {self.provider} — "
|
||||
f"rebuilt client, waiting {wait_time}s before one last primary attempt.",
|
||||
force=True,
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.warning("Primary transport recovery failed: %s", e)
|
||||
return False
|
||||
|
||||
# ── End provider fallback ──────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
@@ -6408,6 +6586,11 @@ class AIAgent:
|
||||
# Installed once, transparent when streams are healthy, prevents crash on write.
|
||||
_install_safe_stdio()
|
||||
|
||||
# If the previous turn activated fallback, restore the primary
|
||||
# runtime so this turn gets a fresh attempt with the preferred model.
|
||||
# No-op when _fallback_activated is False (gateway, first turn, etc.).
|
||||
self._restore_primary_runtime()
|
||||
|
||||
# Sanitize surrogate characters from user input. Clipboard paste from
|
||||
# rich-text editors (Google Docs, Word, etc.) can inject lone surrogates
|
||||
# that are invalid UTF-8 and crash JSON serialization in the OpenAI SDK.
|
||||
@@ -6826,10 +7009,11 @@ class AIAgent:
|
||||
api_start_time = time.time()
|
||||
retry_count = 0
|
||||
max_retries = 3
|
||||
primary_recovery_attempted = False
|
||||
max_compression_attempts = 3
|
||||
codex_auth_retry_attempted = False
|
||||
anthropic_auth_retry_attempted = False
|
||||
nous_auth_retry_attempted = False
|
||||
codex_auth_retry_attempted=False
|
||||
anthropic_auth_retry_attempted=False
|
||||
nous_auth_retry_attempted=False
|
||||
has_retried_429 = False
|
||||
restart_with_compressed_messages = False
|
||||
restart_with_length_continuation = False
|
||||
@@ -7664,6 +7848,16 @@ class AIAgent:
|
||||
}
|
||||
|
||||
if retry_count >= max_retries:
|
||||
# Before falling back, try rebuilding the primary
|
||||
# client once for transient transport errors (stale
|
||||
# connection pool, TCP reset). Only attempted once
|
||||
# per API call block.
|
||||
if not primary_recovery_attempted and self._try_recover_primary_transport(
|
||||
api_error, retry_count=retry_count, max_retries=max_retries,
|
||||
):
|
||||
primary_recovery_attempted = True
|
||||
retry_count = 0
|
||||
continue
|
||||
# Try fallback before giving up entirely
|
||||
self._emit_status(f"⚠️ Max retries ({max_retries}) exhausted — trying fallback...")
|
||||
if self._try_activate_fallback():
|
||||
|
||||
424
tests/test_primary_runtime_restore.py
Normal file
424
tests/test_primary_runtime_restore.py
Normal file
@@ -0,0 +1,424 @@
|
||||
"""Tests for per-turn primary runtime restoration and transport recovery.
|
||||
|
||||
Verifies that:
|
||||
1. Fallback is turn-scoped: a new turn restores the primary model/provider
|
||||
2. The fallback chain index resets so all fallbacks are available again
|
||||
3. Context compressor state is restored alongside the runtime
|
||||
4. Transient transport errors get one recovery cycle before fallback
|
||||
5. Recovery is skipped for aggregator providers (OpenRouter, Nous)
|
||||
6. Non-transport errors don't trigger recovery
|
||||
"""
|
||||
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
def _make_tool_defs(*names: str) -> list:
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": n,
|
||||
"description": f"{n} tool",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
for n in names
|
||||
]
|
||||
|
||||
|
||||
def _make_agent(fallback_model=None, provider="custom", base_url="https://my-llm.example.com/v1"):
|
||||
"""Create a minimal AIAgent with optional fallback config."""
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-key-12345678",
|
||||
base_url=base_url,
|
||||
provider=provider,
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
fallback_model=fallback_model,
|
||||
)
|
||||
agent.client = MagicMock()
|
||||
return agent
|
||||
|
||||
|
||||
def _mock_resolve(base_url="https://openrouter.ai/api/v1", api_key="fallback-key-1234"):
|
||||
"""Helper to create a mock client for resolve_provider_client."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.api_key = api_key
|
||||
mock_client.base_url = base_url
|
||||
return mock_client
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _primary_runtime snapshot
|
||||
# =============================================================================
|
||||
|
||||
class TestPrimaryRuntimeSnapshot:
|
||||
def test_snapshot_created_at_init(self):
|
||||
agent = _make_agent()
|
||||
assert hasattr(agent, "_primary_runtime")
|
||||
rt = agent._primary_runtime
|
||||
assert rt["model"] == agent.model
|
||||
assert rt["provider"] == "custom"
|
||||
assert rt["base_url"] == "https://my-llm.example.com/v1"
|
||||
assert rt["api_mode"] == agent.api_mode
|
||||
assert "client_kwargs" in rt
|
||||
assert "compressor_context_length" in rt
|
||||
|
||||
def test_snapshot_includes_compressor_state(self):
|
||||
agent = _make_agent()
|
||||
rt = agent._primary_runtime
|
||||
cc = agent.context_compressor
|
||||
assert rt["compressor_model"] == cc.model
|
||||
assert rt["compressor_provider"] == cc.provider
|
||||
assert rt["compressor_context_length"] == cc.context_length
|
||||
assert rt["compressor_threshold_tokens"] == cc.threshold_tokens
|
||||
|
||||
def test_snapshot_includes_anthropic_state_when_applicable(self):
|
||||
"""Anthropic-mode agents should snapshot Anthropic-specific state."""
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="sk-ant-test-12345678",
|
||||
base_url="https://api.anthropic.com",
|
||||
provider="anthropic",
|
||||
api_mode="anthropic_messages",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
rt = agent._primary_runtime
|
||||
assert "anthropic_api_key" in rt
|
||||
assert "anthropic_base_url" in rt
|
||||
assert "is_anthropic_oauth" in rt
|
||||
|
||||
def test_snapshot_omits_anthropic_for_openai_mode(self):
|
||||
agent = _make_agent(provider="custom")
|
||||
rt = agent._primary_runtime
|
||||
assert "anthropic_api_key" not in rt
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _restore_primary_runtime()
|
||||
# =============================================================================
|
||||
|
||||
class TestRestorePrimaryRuntime:
|
||||
def test_noop_when_not_fallback(self):
|
||||
agent = _make_agent()
|
||||
assert agent._fallback_activated is False
|
||||
assert agent._restore_primary_runtime() is False
|
||||
|
||||
def test_restores_model_and_provider(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"},
|
||||
)
|
||||
original_model = agent.model
|
||||
original_provider = agent.provider
|
||||
|
||||
# Simulate fallback activation
|
||||
mock_client = _mock_resolve()
|
||||
with patch("agent.auxiliary_client.resolve_provider_client", return_value=(mock_client, None)):
|
||||
agent._try_activate_fallback()
|
||||
|
||||
assert agent._fallback_activated is True
|
||||
assert agent.model == "anthropic/claude-sonnet-4"
|
||||
assert agent.provider == "openrouter"
|
||||
|
||||
# Restore should bring back the primary
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()):
|
||||
result = agent._restore_primary_runtime()
|
||||
|
||||
assert result is True
|
||||
assert agent._fallback_activated is False
|
||||
assert agent.model == original_model
|
||||
assert agent.provider == original_provider
|
||||
|
||||
def test_resets_fallback_index(self):
|
||||
"""After restore, the full fallback chain should be available again."""
|
||||
agent = _make_agent(
|
||||
fallback_model=[
|
||||
{"provider": "openrouter", "model": "model-a"},
|
||||
{"provider": "anthropic", "model": "model-b"},
|
||||
],
|
||||
)
|
||||
# Advance through the chain
|
||||
mock_client = _mock_resolve()
|
||||
with patch("agent.auxiliary_client.resolve_provider_client", return_value=(mock_client, None)):
|
||||
agent._try_activate_fallback()
|
||||
|
||||
assert agent._fallback_index == 1 # consumed one entry
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()):
|
||||
agent._restore_primary_runtime()
|
||||
|
||||
assert agent._fallback_index == 0 # reset for next turn
|
||||
|
||||
def test_restores_compressor_state(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"},
|
||||
)
|
||||
original_ctx_len = agent.context_compressor.context_length
|
||||
original_threshold = agent.context_compressor.threshold_tokens
|
||||
|
||||
# Simulate fallback modifying compressor
|
||||
mock_client = _mock_resolve()
|
||||
with patch("agent.auxiliary_client.resolve_provider_client", return_value=(mock_client, None)):
|
||||
agent._try_activate_fallback()
|
||||
|
||||
# Manually simulate compressor being changed (as _try_activate_fallback does)
|
||||
agent.context_compressor.context_length = 32000
|
||||
agent.context_compressor.threshold_tokens = 25600
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()):
|
||||
agent._restore_primary_runtime()
|
||||
|
||||
assert agent.context_compressor.context_length == original_ctx_len
|
||||
assert agent.context_compressor.threshold_tokens == original_threshold
|
||||
|
||||
def test_restores_prompt_caching_flag(self):
|
||||
agent = _make_agent()
|
||||
original_caching = agent._use_prompt_caching
|
||||
|
||||
# Simulate fallback changing the caching flag
|
||||
agent._fallback_activated = True
|
||||
agent._use_prompt_caching = not original_caching
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()):
|
||||
agent._restore_primary_runtime()
|
||||
|
||||
assert agent._use_prompt_caching == original_caching
|
||||
|
||||
def test_restore_survives_exception(self):
|
||||
"""If client rebuild fails, the method returns False gracefully."""
|
||||
agent = _make_agent()
|
||||
agent._fallback_activated = True
|
||||
|
||||
with patch("run_agent.OpenAI", side_effect=Exception("connection refused")):
|
||||
result = agent._restore_primary_runtime()
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _try_recover_primary_transport()
|
||||
# =============================================================================
|
||||
|
||||
def _make_transport_error(error_type="ReadTimeout"):
|
||||
"""Create an exception whose type().__name__ matches the given name."""
|
||||
cls = type(error_type, (Exception,), {})
|
||||
return cls("connection timed out")
|
||||
|
||||
|
||||
class TestTryRecoverPrimaryTransport:
|
||||
|
||||
def test_recovers_on_read_timeout(self):
|
||||
agent = _make_agent(provider="custom")
|
||||
error = _make_transport_error("ReadTimeout")
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()), \
|
||||
patch("time.sleep"):
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_recovers_on_connect_timeout(self):
|
||||
agent = _make_agent(provider="custom")
|
||||
error = _make_transport_error("ConnectTimeout")
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()), \
|
||||
patch("time.sleep"):
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_recovers_on_pool_timeout(self):
|
||||
agent = _make_agent(provider="zai")
|
||||
error = _make_transport_error("PoolTimeout")
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()), \
|
||||
patch("time.sleep"):
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_skipped_when_already_on_fallback(self):
|
||||
agent = _make_agent(provider="custom")
|
||||
agent._fallback_activated = True
|
||||
error = _make_transport_error("ReadTimeout")
|
||||
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_skipped_for_non_transport_error(self):
|
||||
"""Non-transport errors (ValueError, APIError, etc.) skip recovery."""
|
||||
agent = _make_agent(provider="custom")
|
||||
error = ValueError("invalid model")
|
||||
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_skipped_for_openrouter(self):
|
||||
agent = _make_agent(provider="openrouter", base_url="https://openrouter.ai/api/v1")
|
||||
error = _make_transport_error("ReadTimeout")
|
||||
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_skipped_for_nous_provider(self):
|
||||
agent = _make_agent(provider="nous", base_url="https://inference.nous.nousresearch.com/v1")
|
||||
error = _make_transport_error("ReadTimeout")
|
||||
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_allowed_for_anthropic_direct(self):
|
||||
"""Direct Anthropic endpoint should get recovery."""
|
||||
agent = _make_agent(provider="anthropic", base_url="https://api.anthropic.com")
|
||||
# For non-anthropic_messages api_mode, it will use OpenAI client
|
||||
error = _make_transport_error("ConnectError")
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()), \
|
||||
patch("time.sleep"):
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_allowed_for_ollama(self):
|
||||
agent = _make_agent(provider="ollama", base_url="http://localhost:11434/v1")
|
||||
error = _make_transport_error("ConnectTimeout")
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()), \
|
||||
patch("time.sleep"):
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_wait_time_scales_with_retry_count(self):
|
||||
agent = _make_agent(provider="custom")
|
||||
error = _make_transport_error("ReadTimeout")
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()), \
|
||||
patch("time.sleep") as mock_sleep:
|
||||
agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
# wait_time = min(3 + retry_count, 8) = min(6, 8) = 6
|
||||
mock_sleep.assert_called_once_with(6)
|
||||
|
||||
def test_wait_time_capped_at_8(self):
|
||||
agent = _make_agent(provider="custom")
|
||||
error = _make_transport_error("ReadTimeout")
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()), \
|
||||
patch("time.sleep") as mock_sleep:
|
||||
agent._try_recover_primary_transport(
|
||||
error, retry_count=10, max_retries=3,
|
||||
)
|
||||
# wait_time = min(3 + 10, 8) = 8
|
||||
mock_sleep.assert_called_once_with(8)
|
||||
|
||||
def test_closes_existing_client_before_rebuild(self):
|
||||
agent = _make_agent(provider="custom")
|
||||
old_client = agent.client
|
||||
error = _make_transport_error("ReadTimeout")
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()), \
|
||||
patch("time.sleep"), \
|
||||
patch.object(agent, "_close_openai_client") as mock_close:
|
||||
agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
mock_close.assert_called_once_with(
|
||||
old_client, reason="primary_recovery", shared=True,
|
||||
)
|
||||
|
||||
def test_survives_rebuild_failure(self):
|
||||
"""If client rebuild fails, returns False gracefully."""
|
||||
agent = _make_agent(provider="custom")
|
||||
error = _make_transport_error("ReadTimeout")
|
||||
|
||||
with patch("run_agent.OpenAI", side_effect=Exception("socket error")), \
|
||||
patch("time.sleep"):
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration: restore_primary_runtime called from run_conversation
|
||||
# =============================================================================
|
||||
|
||||
class TestRestoreInRunConversation:
|
||||
"""Verify the hook in run_conversation() calls _restore_primary_runtime."""
|
||||
|
||||
def test_restore_called_at_turn_start(self):
|
||||
agent = _make_agent()
|
||||
agent._fallback_activated = True
|
||||
|
||||
with patch.object(agent, "_restore_primary_runtime", return_value=True) as mock_restore, \
|
||||
patch.object(agent, "run_conversation", wraps=None) as _:
|
||||
# We can't easily run the full conversation, but we can verify
|
||||
# the method exists and is callable
|
||||
agent._restore_primary_runtime()
|
||||
mock_restore.assert_called_once()
|
||||
|
||||
def test_full_cycle_fallback_then_restore(self):
|
||||
"""Simulate: turn 1 activates fallback, turn 2 restores primary."""
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"},
|
||||
provider="custom",
|
||||
)
|
||||
|
||||
# Turn 1: activate fallback
|
||||
mock_client = _mock_resolve()
|
||||
with patch("agent.auxiliary_client.resolve_provider_client", return_value=(mock_client, None)):
|
||||
assert agent._try_activate_fallback() is True
|
||||
|
||||
assert agent._fallback_activated is True
|
||||
assert agent.model == "anthropic/claude-sonnet-4"
|
||||
assert agent.provider == "openrouter"
|
||||
assert agent._fallback_index == 1
|
||||
|
||||
# Turn 2: restore primary
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()):
|
||||
assert agent._restore_primary_runtime() is True
|
||||
|
||||
assert agent._fallback_activated is False
|
||||
assert agent._fallback_index == 0
|
||||
assert agent.provider == "custom"
|
||||
assert agent.base_url == "https://my-llm.example.com/v1"
|
||||
Reference in New Issue
Block a user