import sys import threading import types from types import SimpleNamespace import httpx import pytest from openai import APIConnectionError sys.modules.setdefault("fire", types.SimpleNamespace(Fire=lambda *a, **k: None)) sys.modules.setdefault("firecrawl", types.SimpleNamespace(Firecrawl=object)) sys.modules.setdefault("fal_client", types.SimpleNamespace()) import run_agent class FakeRequestClient: def __init__(self, responder): self._responder = responder self._client = SimpleNamespace(is_closed=False) self.chat = SimpleNamespace( completions=SimpleNamespace(create=self._create) ) self.responses = SimpleNamespace() self.close_calls = 0 def _create(self, **kwargs): return self._responder(**kwargs) def close(self): self.close_calls += 1 self._client.is_closed = True class FakeSharedClient(FakeRequestClient): pass class OpenAIFactory: def __init__(self, clients): self._clients = list(clients) self.calls = [] def __call__(self, **kwargs): self.calls.append(dict(kwargs)) if not self._clients: raise AssertionError("OpenAI factory exhausted") return self._clients.pop(0) def _build_agent(shared_client=None): agent = run_agent.AIAgent.__new__(run_agent.AIAgent) agent.api_mode = "chat_completions" agent.provider = "openai-codex" agent.base_url = "https://chatgpt.com/backend-api/codex" agent.model = "gpt-5-codex" agent.log_prefix = "" agent.quiet_mode = True agent._interrupt_requested = False agent._interrupt_message = None agent._client_lock = threading.RLock() agent._client_kwargs = {"api_key": "***", "base_url": agent.base_url} agent.client = shared_client or FakeSharedClient(lambda **kwargs: {"shared": True}) agent.stream_delta_callback = None agent._stream_callback = None agent.reasoning_callback = None return agent def _connection_error(): return APIConnectionError( message="Connection error.", request=httpx.Request("POST", "https://example.com/v1/chat/completions"), ) def test_retry_after_api_connection_error_recreates_request_client(monkeypatch): first_request = FakeRequestClient(lambda **kwargs: (_ for _ in ()).throw(_connection_error())) second_request = FakeRequestClient(lambda **kwargs: {"ok": True}) factory = OpenAIFactory([first_request, second_request]) monkeypatch.setattr(run_agent, "OpenAI", factory) agent = _build_agent() with pytest.raises(APIConnectionError): agent._interruptible_api_call({"model": agent.model, "messages": []}) result = agent._interruptible_api_call({"model": agent.model, "messages": []}) assert result == {"ok": True} assert len(factory.calls) == 2 assert first_request.close_calls >= 1 assert second_request.close_calls >= 1 def test_closed_shared_client_is_recreated_before_request(monkeypatch): stale_shared = FakeSharedClient(lambda **kwargs: (_ for _ in ()).throw(AssertionError("stale shared client used"))) stale_shared._client.is_closed = True replacement_shared = FakeSharedClient(lambda **kwargs: {"replacement": True}) request_client = FakeRequestClient(lambda **kwargs: {"ok": "fresh-request-client"}) factory = OpenAIFactory([replacement_shared, request_client]) monkeypatch.setattr(run_agent, "OpenAI", factory) agent = _build_agent(shared_client=stale_shared) result = agent._interruptible_api_call({"model": agent.model, "messages": []}) assert result == {"ok": "fresh-request-client"} assert agent.client is replacement_shared assert stale_shared.close_calls >= 1 assert replacement_shared.close_calls == 0 assert len(factory.calls) == 2 def test_concurrent_requests_do_not_break_each_other_when_one_client_closes(monkeypatch): first_started = threading.Event() first_closed = threading.Event() def first_responder(**kwargs): first_started.set() first_client.close() first_closed.set() raise _connection_error() def second_responder(**kwargs): assert first_started.wait(timeout=2) assert first_closed.wait(timeout=2) return {"ok": "second"} first_client = FakeRequestClient(first_responder) second_client = FakeRequestClient(second_responder) factory = OpenAIFactory([first_client, second_client]) monkeypatch.setattr(run_agent, "OpenAI", factory) agent = _build_agent() results = {} def run_call(name): try: results[name] = agent._interruptible_api_call({"model": agent.model, "messages": []}) except Exception as exc: # noqa: BLE001 - asserting exact type below results[name] = exc thread_one = threading.Thread(target=run_call, args=("first",), daemon=True) thread_two = threading.Thread(target=run_call, args=("second",), daemon=True) thread_one.start() thread_two.start() thread_one.join(timeout=5) thread_two.join(timeout=5) values = list(results.values()) assert sum(isinstance(value, APIConnectionError) for value in values) == 1 assert values.count({"ok": "second"}) == 1 assert len(factory.calls) == 2 def test_streaming_call_recreates_closed_shared_client_before_request(monkeypatch): chunks = iter([ SimpleNamespace( model="gpt-5-codex", choices=[SimpleNamespace(delta=SimpleNamespace(content="Hello", tool_calls=None), finish_reason=None)], ), SimpleNamespace( model="gpt-5-codex", choices=[SimpleNamespace(delta=SimpleNamespace(content=" world", tool_calls=None), finish_reason="stop")], ), ]) stale_shared = FakeSharedClient(lambda **kwargs: (_ for _ in ()).throw(AssertionError("stale shared client used"))) stale_shared._client.is_closed = True replacement_shared = FakeSharedClient(lambda **kwargs: {"replacement": True}) request_client = FakeRequestClient(lambda **kwargs: chunks) factory = OpenAIFactory([replacement_shared, request_client]) monkeypatch.setattr(run_agent, "OpenAI", factory) agent = _build_agent(shared_client=stale_shared) agent.stream_delta_callback = lambda _delta: None # Force chat_completions mode so the streaming path uses # chat.completions.create(stream=True) instead of Codex responses.stream() agent.api_mode = "chat_completions" response = agent._interruptible_streaming_api_call({"model": agent.model, "messages": []}) assert response.choices[0].message.content == "Hello world" assert agent.client is replacement_shared assert stale_shared.close_calls >= 1 assert request_client.close_calls >= 1 assert len(factory.calls) == 2