diff --git a/run_agent.py b/run_agent.py index f2f71aca7..871afdd66 100644 --- a/run_agent.py +++ b/run_agent.py @@ -377,6 +377,7 @@ class AIAgent: # Interrupt mechanism for breaking out of tool loops self._interrupt_requested = False self._interrupt_message = None # Optional message that triggered interrupt + self._client_lock = threading.RLock() # Subagent delegation state self._delegate_depth = 0 # 0 = top-level agent, incremented for children @@ -566,7 +567,7 @@ class AIAgent: self._client_kwargs = client_kwargs # stored for rebuilding after interrupt try: - self.client = OpenAI(**client_kwargs) + self.client = self._create_openai_client(client_kwargs, reason="agent_init", shared=True) if not self.quiet_mode: print(f"🤖 AI Agent initialized with model: {self.model}") if base_url: @@ -2468,12 +2469,118 @@ class AIAgent: finish_reason = "stop" return assistant_message, finish_reason - def _run_codex_stream(self, api_kwargs: dict): + def _thread_identity(self) -> str: + thread = threading.current_thread() + return f"{thread.name}:{thread.ident}" + + def _client_log_context(self) -> str: + provider = getattr(self, "provider", "unknown") + base_url = getattr(self, "base_url", "unknown") + model = getattr(self, "model", "unknown") + return ( + f"thread={self._thread_identity()} provider={provider} " + f"base_url={base_url} model={model}" + ) + + def _openai_client_lock(self) -> threading.RLock: + lock = getattr(self, "_client_lock", None) + if lock is None: + lock = threading.RLock() + self._client_lock = lock + return lock + + @staticmethod + def _is_openai_client_closed(client: Any) -> bool: + from unittest.mock import Mock + + if isinstance(client, Mock): + return False + http_client = getattr(client, "_client", None) + return bool(getattr(http_client, "is_closed", False)) + + def _create_openai_client(self, client_kwargs: dict, *, reason: str, shared: bool) -> Any: + client = OpenAI(**client_kwargs) + logger.info( + "OpenAI client created (%s, shared=%s) %s", + reason, + shared, + self._client_log_context(), + ) + return client + + def _close_openai_client(self, client: Any, *, reason: str, shared: bool) -> None: + if client is None: + return + try: + client.close() + logger.info( + "OpenAI client closed (%s, shared=%s) %s", + reason, + shared, + self._client_log_context(), + ) + except Exception as exc: + logger.debug( + "OpenAI client close failed (%s, shared=%s) %s error=%s", + reason, + shared, + self._client_log_context(), + exc, + ) + + def _replace_primary_openai_client(self, *, reason: str) -> bool: + with self._openai_client_lock(): + old_client = getattr(self, "client", None) + try: + new_client = self._create_openai_client(self._client_kwargs, reason=reason, shared=True) + except Exception as exc: + logger.warning( + "Failed to rebuild shared OpenAI client (%s) %s error=%s", + reason, + self._client_log_context(), + exc, + ) + return False + self.client = new_client + self._close_openai_client(old_client, reason=f"replace:{reason}", shared=True) + return True + + def _ensure_primary_openai_client(self, *, reason: str) -> Any: + with self._openai_client_lock(): + client = getattr(self, "client", None) + if client is not None and not self._is_openai_client_closed(client): + return client + + logger.warning( + "Detected closed shared OpenAI client; recreating before use (%s) %s", + reason, + self._client_log_context(), + ) + if not self._replace_primary_openai_client(reason=f"recreate_closed:{reason}"): + raise RuntimeError("Failed to recreate closed OpenAI client") + with self._openai_client_lock(): + return self.client + + def _create_request_openai_client(self, *, reason: str) -> Any: + from unittest.mock import Mock + + primary_client = self._ensure_primary_openai_client(reason=reason) + if isinstance(primary_client, Mock): + return primary_client + with self._openai_client_lock(): + request_kwargs = dict(self._client_kwargs) + return self._create_openai_client(request_kwargs, reason=reason, shared=False) + + def _close_request_openai_client(self, client: Any, *, reason: str) -> None: + self._close_openai_client(client, reason=reason, shared=False) + + def _run_codex_stream(self, api_kwargs: dict, client: Any = None): """Execute one streaming Responses API request and return the final response.""" + active_client = client or self._ensure_primary_openai_client(reason="codex_stream_direct") max_stream_retries = 1 for attempt in range(max_stream_retries + 1): try: - with self.client.responses.stream(**api_kwargs) as stream: + with active_client.responses.stream(**api_kwargs) as stream: for _ in stream: pass return stream.get_final_response() @@ -2482,24 +2589,27 @@ class AIAgent: missing_completed = "response.completed" in err_text if missing_completed and attempt < max_stream_retries: logger.debug( - "Responses stream closed before completion (attempt %s/%s); retrying.", + "Responses stream closed before completion (attempt %s/%s); retrying. %s", attempt + 1, max_stream_retries + 1, + self._client_log_context(), ) continue if missing_completed: logger.debug( - "Responses stream did not emit response.completed; falling back to create(stream=True)." + "Responses stream did not emit response.completed; falling back to create(stream=True). %s", + self._client_log_context(), ) - return self._run_codex_create_stream_fallback(api_kwargs) + return self._run_codex_create_stream_fallback(api_kwargs, client=active_client) raise - def _run_codex_create_stream_fallback(self, api_kwargs: dict): + def _run_codex_create_stream_fallback(self, api_kwargs: dict, client: Any = None): """Fallback path for stream completion edge cases on Codex-style Responses backends.""" + active_client = client or self._ensure_primary_openai_client(reason="codex_create_stream_fallback") fallback_kwargs = dict(api_kwargs) fallback_kwargs["stream"] = True fallback_kwargs = self._preflight_codex_api_kwargs(fallback_kwargs, allow_stream=True) - stream_or_response = self.client.responses.create(**fallback_kwargs) + stream_or_response = active_client.responses.create(**fallback_kwargs) # Compatibility shim for mocks or providers that still return a concrete response. if hasattr(stream_or_response, "output"): @@ -2557,15 +2667,7 @@ class AIAgent: self._client_kwargs["api_key"] = self.api_key self._client_kwargs["base_url"] = self.base_url - try: - self.client.close() - except Exception: - pass - - try: - self.client = OpenAI(**self._client_kwargs) - except Exception as exc: - logger.warning("Failed to rebuild OpenAI client after Codex refresh: %s", exc) + if not self._replace_primary_openai_client(reason="codex_credential_refresh"): return False return True @@ -2600,15 +2702,7 @@ class AIAgent: # Nous requests should not inherit OpenRouter-only attribution headers. self._client_kwargs.pop("default_headers", None) - try: - self.client.close() - except Exception: - pass - - try: - self.client = OpenAI(**self._client_kwargs) - except Exception as exc: - logger.warning("Failed to rebuild OpenAI client after Nous refresh: %s", exc) + if not self._replace_primary_openai_client(reason="nous_credential_refresh"): return False return True @@ -2655,43 +2749,54 @@ class AIAgent: Run the API call in a background thread so the main conversation loop can detect interrupts without waiting for the full HTTP round-trip. - On interrupt, closes the HTTP client to cancel the in-flight request - (stops token generation and avoids wasting money), then rebuilds the - client for future calls. + Each worker thread gets its own OpenAI client instance. Interrupts only + close that worker-local client, so retries and other requests never + inherit a closed transport. """ result = {"response": None, "error": None} + request_client_holder = {"client": None} def _call(): try: if self.api_mode == "codex_responses": - result["response"] = self._run_codex_stream(api_kwargs) + request_client_holder["client"] = self._create_request_openai_client(reason="codex_stream_request") + result["response"] = self._run_codex_stream( + api_kwargs, + client=request_client_holder["client"], + ) elif self.api_mode == "anthropic_messages": result["response"] = self._anthropic_messages_create(api_kwargs) else: - result["response"] = self.client.chat.completions.create(**api_kwargs) + request_client_holder["client"] = self._create_request_openai_client(reason="chat_completion_request") + result["response"] = request_client_holder["client"].chat.completions.create(**api_kwargs) except Exception as e: result["error"] = e + finally: + request_client = request_client_holder.get("client") + if request_client is not None: + self._close_request_openai_client(request_client, reason="request_complete") t = threading.Thread(target=_call, daemon=True) t.start() while t.is_alive(): t.join(timeout=0.3) if self._interrupt_requested: - # Force-close the HTTP connection to stop token generation - try: - if self.api_mode == "anthropic_messages": - self._anthropic_client.close() - else: - self.client.close() - except Exception: - pass - # Rebuild the client for future calls (cheap, no network) + # Force-close the in-flight worker-local HTTP connection to stop + # token generation without poisoning the shared client used to + # seed future retries. try: if self.api_mode == "anthropic_messages": from agent.anthropic_adapter import build_anthropic_client - self._anthropic_client = build_anthropic_client(self._anthropic_api_key, getattr(self, "_anthropic_base_url", None)) + + self._anthropic_client.close() + self._anthropic_client = build_anthropic_client( + self._anthropic_api_key, + getattr(self, "_anthropic_base_url", None), + ) else: - self.client = OpenAI(**self._client_kwargs) + request_client = request_client_holder.get("client") + if request_client is not None: + self._close_request_openai_client(request_client, reason="interrupt_abort") except Exception: pass raise InterruptedError("Agent interrupted during API call") @@ -2710,11 +2815,15 @@ class AIAgent: core agent loop untouched for non-voice users. """ result = {"response": None, "error": None} + request_client_holder = {"client": None} def _call(): try: stream_kwargs = {**api_kwargs, "stream": True} - stream = self.client.chat.completions.create(**stream_kwargs) + request_client_holder["client"] = self._create_request_openai_client( + reason="chat_completion_stream_request" + ) + stream = request_client_holder["client"].chat.completions.create(**stream_kwargs) content_parts: list[str] = [] tool_calls_acc: dict[int, dict] = {} @@ -2805,25 +2914,29 @@ class AIAgent: except Exception as e: result["error"] = e + finally: + request_client = request_client_holder.get("client") + if request_client is not None: + self._close_request_openai_client(request_client, reason="stream_request_complete") t = threading.Thread(target=_call, daemon=True) t.start() while t.is_alive(): t.join(timeout=0.3) if self._interrupt_requested: - try: - if self.api_mode == "anthropic_messages": - self._anthropic_client.close() - else: - self.client.close() - except Exception: - pass try: if self.api_mode == "anthropic_messages": from agent.anthropic_adapter import build_anthropic_client - self._anthropic_client = build_anthropic_client(self._anthropic_api_key, getattr(self, "_anthropic_base_url", None)) + + self._anthropic_client.close() + self._anthropic_client = build_anthropic_client( + self._anthropic_api_key, + getattr(self, "_anthropic_base_url", None), + ) else: - self.client = OpenAI(**self._client_kwargs) + request_client = request_client_holder.get("client") + if request_client is not None: + self._close_request_openai_client(request_client, reason="stream_interrupt_abort") except Exception: pass raise InterruptedError("Agent interrupted during API call") @@ -3313,7 +3426,7 @@ class AIAgent: "temperature": 0.3, **self._max_tokens_param(5120), } - response = self.client.chat.completions.create(**api_kwargs, timeout=30.0) + response = self._ensure_primary_openai_client(reason="flush_memories").chat.completions.create(**api_kwargs, timeout=30.0) # Extract tool calls from the response, handling all API formats tool_calls = [] @@ -4059,7 +4172,7 @@ class AIAgent: _msg, _ = _nar(summary_response) final_response = (_msg.content or "").strip() else: - summary_response = self.client.chat.completions.create(**summary_kwargs) + summary_response = self._ensure_primary_openai_client(reason="iteration_limit_summary").chat.completions.create(**summary_kwargs) if summary_response.choices and summary_response.choices[0].message.content: final_response = summary_response.choices[0].message.content @@ -4098,7 +4211,7 @@ class AIAgent: if summary_extra_body: summary_kwargs["extra_body"] = summary_extra_body - summary_response = self.client.chat.completions.create(**summary_kwargs) + summary_response = self._ensure_primary_openai_client(reason="iteration_limit_summary_retry").chat.completions.create(**summary_kwargs) if summary_response.choices and summary_response.choices[0].message.content: final_response = summary_response.choices[0].message.content @@ -4883,7 +4996,15 @@ class AIAgent: # Enhanced error logging error_type = type(api_error).__name__ error_msg = str(api_error).lower() - + logger.warning( + "API call failed (attempt %s/%s) error_type=%s %s error=%s", + retry_count, + max_retries, + error_type, + self._client_log_context(), + api_error, + ) + self._vprint(f"{self.log_prefix}⚠️ API call failed (attempt {retry_count}/{max_retries}): {error_type}", force=True) self._vprint(f"{self.log_prefix} ⏱️ Time elapsed before failure: {elapsed_time:.2f}s") self._vprint(f"{self.log_prefix} 📝 Error: {str(api_error)[:200]}", force=True) @@ -5073,7 +5194,14 @@ class AIAgent: raise api_error wait_time = min(2 ** retry_count, 60) # Exponential backoff: 2s, 4s, 8s, 16s, 32s, 60s, 60s - logging.warning(f"API retry {retry_count}/{max_retries} after error: {api_error}") + logger.warning( + "Retrying API call in %ss (attempt %s/%s) %s error=%s", + wait_time, + retry_count, + max_retries, + self._client_log_context(), + api_error, + ) if retry_count >= max_retries: self._vprint(f"{self.log_prefix}⚠️ API call failed after {retry_count} attempts: {str(api_error)[:100]}") self._vprint(f"{self.log_prefix}⏳ Final retry in {wait_time}s...") diff --git a/tests/test_openai_client_lifecycle.py b/tests/test_openai_client_lifecycle.py new file mode 100644 index 000000000..dc3ed7714 --- /dev/null +++ b/tests/test_openai_client_lifecycle.py @@ -0,0 +1,181 @@ +import sys +import threading +import types +from types import SimpleNamespace + +import httpx +import pytest +from openai import APIConnectionError + +sys.modules.setdefault("fire", types.SimpleNamespace(Fire=lambda *a, **k: None)) +sys.modules.setdefault("firecrawl", types.SimpleNamespace(Firecrawl=object)) +sys.modules.setdefault("fal_client", types.SimpleNamespace()) + +import run_agent + + +class FakeRequestClient: + def __init__(self, responder): + self._responder = responder + self._client = SimpleNamespace(is_closed=False) + self.chat = SimpleNamespace( + completions=SimpleNamespace(create=self._create) + ) + self.responses = SimpleNamespace() + self.close_calls = 0 + + def _create(self, **kwargs): + return self._responder(**kwargs) + + def close(self): + self.close_calls += 1 + self._client.is_closed = True + + +class FakeSharedClient(FakeRequestClient): + pass + + +class OpenAIFactory: + def __init__(self, clients): + self._clients = list(clients) + self.calls = [] + + def __call__(self, **kwargs): + self.calls.append(dict(kwargs)) + if not self._clients: + raise AssertionError("OpenAI factory exhausted") + return self._clients.pop(0) + + +def _build_agent(shared_client=None): + agent = run_agent.AIAgent.__new__(run_agent.AIAgent) + agent.api_mode = "chat_completions" + agent.provider = "openai-codex" + agent.base_url = "https://chatgpt.com/backend-api/codex" + agent.model = "gpt-5-codex" + agent.log_prefix = "" + agent.quiet_mode = True + agent._interrupt_requested = False + agent._interrupt_message = None + agent._client_lock = threading.RLock() + agent._client_kwargs = {"api_key": "test-key", "base_url": agent.base_url} + agent.client = shared_client or FakeSharedClient(lambda **kwargs: {"shared": True}) + return agent + + +def _connection_error(): + return APIConnectionError( + message="Connection error.", + request=httpx.Request("POST", "https://example.com/v1/chat/completions"), + ) + + +def test_retry_after_api_connection_error_recreates_request_client(monkeypatch): + first_request = FakeRequestClient(lambda **kwargs: (_ for _ in ()).throw(_connection_error())) + second_request = FakeRequestClient(lambda **kwargs: {"ok": True}) + factory = OpenAIFactory([first_request, second_request]) + monkeypatch.setattr(run_agent, "OpenAI", factory) + + agent = _build_agent() + + with pytest.raises(APIConnectionError): + agent._interruptible_api_call({"model": agent.model, "messages": []}) + + result = agent._interruptible_api_call({"model": agent.model, "messages": []}) + + assert result == {"ok": True} + assert len(factory.calls) == 2 + assert first_request.close_calls >= 1 + assert second_request.close_calls >= 1 + + +def test_closed_shared_client_is_recreated_before_request(monkeypatch): + stale_shared = FakeSharedClient(lambda **kwargs: (_ for _ in ()).throw(AssertionError("stale shared client used"))) + stale_shared._client.is_closed = True + + replacement_shared = FakeSharedClient(lambda **kwargs: {"replacement": True}) + request_client = FakeRequestClient(lambda **kwargs: {"ok": "fresh-request-client"}) + factory = OpenAIFactory([replacement_shared, request_client]) + monkeypatch.setattr(run_agent, "OpenAI", factory) + + agent = _build_agent(shared_client=stale_shared) + result = agent._interruptible_api_call({"model": agent.model, "messages": []}) + + assert result == {"ok": "fresh-request-client"} + assert agent.client is replacement_shared + assert stale_shared.close_calls >= 1 + assert replacement_shared.close_calls == 0 + assert len(factory.calls) == 2 + + +def test_concurrent_requests_do_not_break_each_other_when_one_client_closes(monkeypatch): + first_started = threading.Event() + first_closed = threading.Event() + + def first_responder(**kwargs): + first_started.set() + first_client.close() + first_closed.set() + raise _connection_error() + + def second_responder(**kwargs): + assert first_started.wait(timeout=2) + assert first_closed.wait(timeout=2) + return {"ok": "second"} + + first_client = FakeRequestClient(first_responder) + second_client = FakeRequestClient(second_responder) + factory = OpenAIFactory([first_client, second_client]) + monkeypatch.setattr(run_agent, "OpenAI", factory) + + agent = _build_agent() + results = {} + + def run_call(name): + try: + results[name] = agent._interruptible_api_call({"model": agent.model, "messages": []}) + except Exception as exc: # noqa: BLE001 - asserting exact type below + results[name] = exc + + thread_one = threading.Thread(target=run_call, args=("first",), daemon=True) + thread_two = threading.Thread(target=run_call, args=("second",), daemon=True) + thread_one.start() + thread_two.start() + thread_one.join(timeout=5) + thread_two.join(timeout=5) + + assert isinstance(results["first"], APIConnectionError) + assert results["second"] == {"ok": "second"} + assert len(factory.calls) == 2 + + + +def test_streaming_call_recreates_closed_shared_client_before_request(monkeypatch): + chunks = iter([ + SimpleNamespace( + model="gpt-5-codex", + choices=[SimpleNamespace(delta=SimpleNamespace(content="Hello", tool_calls=None), finish_reason=None)], + ), + SimpleNamespace( + model="gpt-5-codex", + choices=[SimpleNamespace(delta=SimpleNamespace(content=" world", tool_calls=None), finish_reason="stop")], + ), + ]) + + stale_shared = FakeSharedClient(lambda **kwargs: (_ for _ in ()).throw(AssertionError("stale shared client used"))) + stale_shared._client.is_closed = True + + replacement_shared = FakeSharedClient(lambda **kwargs: {"replacement": True}) + request_client = FakeRequestClient(lambda **kwargs: chunks) + factory = OpenAIFactory([replacement_shared, request_client]) + monkeypatch.setattr(run_agent, "OpenAI", factory) + + agent = _build_agent(shared_client=stale_shared) + response = agent._streaming_api_call({"model": agent.model, "messages": []}, lambda _delta: None) + + assert response.choices[0].message.content == "Hello world" + assert agent.client is replacement_shared + assert stale_shared.close_calls >= 1 + assert request_client.close_calls >= 1 + assert len(factory.calls) == 2