[loop-cycle-1233] refactor: break up BaseAgent.run() (#561) (#584)

This commit is contained in:
2026-03-20 11:24:36 -04:00
parent 48103bb076
commit faf6c1a5f1
2 changed files with 90 additions and 62 deletions

View File

@@ -119,75 +119,70 @@ class BaseAgent(ABC):
"""
pass
async def run(self, message: str) -> str:
"""Run the agent with a message.
# Transient errors that indicate Ollama contention or temporary
# unavailability — these deserve a retry with backoff.
_TRANSIENT = (
httpx.ConnectError,
httpx.ReadError,
httpx.ReadTimeout,
httpx.ConnectTimeout,
ConnectionError,
TimeoutError,
)
Retries on transient failures (connection errors, timeouts) with
exponential backoff. GPU contention from concurrent Ollama
requests causes ReadError / ReadTimeout — these are transient
and should be retried, not raised immediately (#70).
async def run(self, message: str, *, max_retries: int = 3) -> str:
"""Run the agent with a message, retrying on transient failures.
Returns:
Agent response
GPU contention from concurrent Ollama requests causes ReadError /
ReadTimeout — these are transient and retried with exponential
backoff (#70).
"""
max_retries = 3
last_exception = None
# Transient errors that indicate Ollama contention or temporary
# unavailability — these deserve a retry with backoff.
_transient = (
httpx.ConnectError,
httpx.ReadError,
httpx.ReadTimeout,
httpx.ConnectTimeout,
ConnectionError,
TimeoutError,
)
response = await self._run_with_retries(message, max_retries)
await self._emit_response_event(message, response)
return response
async def _run_with_retries(self, message: str, max_retries: int) -> str:
"""Execute agent.run() with retry logic for transient errors."""
for attempt in range(1, max_retries + 1):
try:
result = self.agent.run(message, stream=False)
response = result.content if hasattr(result, "content") else str(result)
break # Success, exit the retry loop
except _transient as exc:
last_exception = exc
if attempt < max_retries:
# Contention backoff — longer waits because the GPU
# needs time to finish the other request.
wait = min(2**attempt, 16)
logger.warning(
"Ollama contention on attempt %d/%d: %s. Waiting %ds before retry...",
attempt,
max_retries,
type(exc).__name__,
wait,
)
await asyncio.sleep(wait)
else:
logger.error(
"Ollama unreachable after %d attempts: %s",
max_retries,
exc,
)
raise last_exception from exc
return result.content if hasattr(result, "content") else str(result)
except self._TRANSIENT as exc:
self._handle_retry_or_raise(
exc, attempt, max_retries, transient=True,
)
await asyncio.sleep(min(2**attempt, 16))
except Exception as exc:
last_exception = exc
if attempt < max_retries:
logger.warning(
"Agent run failed on attempt %d/%d: %s. Retrying...",
attempt,
max_retries,
exc,
)
await asyncio.sleep(min(2 ** (attempt - 1), 8))
else:
logger.error(
"Agent run failed after %d attempts: %s",
max_retries,
exc,
)
raise last_exception from exc
self._handle_retry_or_raise(
exc, attempt, max_retries, transient=False,
)
await asyncio.sleep(min(2 ** (attempt - 1), 8))
# Unreachable — _handle_retry_or_raise raises on last attempt.
raise RuntimeError("retry loop exited unexpectedly") # pragma: no cover
# Emit completion event
@staticmethod
def _handle_retry_or_raise(
exc: Exception, attempt: int, max_retries: int, *, transient: bool,
) -> None:
"""Log a retry warning or raise after exhausting attempts."""
if attempt < max_retries:
if transient:
logger.warning(
"Ollama contention on attempt %d/%d: %s. Waiting before retry...",
attempt, max_retries, type(exc).__name__,
)
else:
logger.warning(
"Agent run failed on attempt %d/%d: %s. Retrying...",
attempt, max_retries, exc,
)
else:
label = "Ollama unreachable" if transient else "Agent run failed"
logger.error("%s after %d attempts: %s", label, max_retries, exc)
raise exc
async def _emit_response_event(self, message: str, response: str) -> None:
"""Publish a completion event to the event bus if connected."""
if self.event_bus:
await self.event_bus.publish(
Event(
@@ -197,8 +192,6 @@ class BaseAgent(ABC):
)
)
return response
def get_capabilities(self) -> list[str]:
"""Get list of capabilities this agent provides."""
return self.tools

View File

@@ -361,6 +361,41 @@ class TestRun:
assert response == "ok"
# ── _handle_retry_or_raise ────────────────────────────────────────────────
class TestHandleRetryOrRaise:
def test_raises_on_last_attempt(self):
BaseAgent = _make_base_class()
with pytest.raises(ValueError, match="boom"):
BaseAgent._handle_retry_or_raise(
ValueError("boom"), attempt=3, max_retries=3, transient=False,
)
def test_raises_on_last_attempt_transient(self):
BaseAgent = _make_base_class()
exc = httpx.ConnectError("down")
with pytest.raises(httpx.ConnectError):
BaseAgent._handle_retry_or_raise(
exc, attempt=3, max_retries=3, transient=True,
)
def test_no_raise_on_early_attempt(self):
BaseAgent = _make_base_class()
# Should return None (no raise) on non-final attempt
result = BaseAgent._handle_retry_or_raise(
ValueError("retry me"), attempt=1, max_retries=3, transient=False,
)
assert result is None
def test_no_raise_on_early_transient(self):
BaseAgent = _make_base_class()
result = BaseAgent._handle_retry_or_raise(
httpx.ReadTimeout("busy"), attempt=2, max_retries=3, transient=True,
)
assert result is None
# ── get_capabilities / get_status ────────────────────────────────────────────