forked from Rockachopa/Timmy-time-dashboard
@@ -119,75 +119,70 @@ class BaseAgent(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def run(self, message: str) -> str:
|
||||
"""Run the agent with a message.
|
||||
# Transient errors that indicate Ollama contention or temporary
|
||||
# unavailability — these deserve a retry with backoff.
|
||||
_TRANSIENT = (
|
||||
httpx.ConnectError,
|
||||
httpx.ReadError,
|
||||
httpx.ReadTimeout,
|
||||
httpx.ConnectTimeout,
|
||||
ConnectionError,
|
||||
TimeoutError,
|
||||
)
|
||||
|
||||
Retries on transient failures (connection errors, timeouts) with
|
||||
exponential backoff. GPU contention from concurrent Ollama
|
||||
requests causes ReadError / ReadTimeout — these are transient
|
||||
and should be retried, not raised immediately (#70).
|
||||
async def run(self, message: str, *, max_retries: int = 3) -> str:
|
||||
"""Run the agent with a message, retrying on transient failures.
|
||||
|
||||
Returns:
|
||||
Agent response
|
||||
GPU contention from concurrent Ollama requests causes ReadError /
|
||||
ReadTimeout — these are transient and retried with exponential
|
||||
backoff (#70).
|
||||
"""
|
||||
max_retries = 3
|
||||
last_exception = None
|
||||
# Transient errors that indicate Ollama contention or temporary
|
||||
# unavailability — these deserve a retry with backoff.
|
||||
_transient = (
|
||||
httpx.ConnectError,
|
||||
httpx.ReadError,
|
||||
httpx.ReadTimeout,
|
||||
httpx.ConnectTimeout,
|
||||
ConnectionError,
|
||||
TimeoutError,
|
||||
)
|
||||
response = await self._run_with_retries(message, max_retries)
|
||||
await self._emit_response_event(message, response)
|
||||
return response
|
||||
|
||||
async def _run_with_retries(self, message: str, max_retries: int) -> str:
|
||||
"""Execute agent.run() with retry logic for transient errors."""
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
result = self.agent.run(message, stream=False)
|
||||
response = result.content if hasattr(result, "content") else str(result)
|
||||
break # Success, exit the retry loop
|
||||
except _transient as exc:
|
||||
last_exception = exc
|
||||
if attempt < max_retries:
|
||||
# Contention backoff — longer waits because the GPU
|
||||
# needs time to finish the other request.
|
||||
wait = min(2**attempt, 16)
|
||||
logger.warning(
|
||||
"Ollama contention on attempt %d/%d: %s. Waiting %ds before retry...",
|
||||
attempt,
|
||||
max_retries,
|
||||
type(exc).__name__,
|
||||
wait,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
else:
|
||||
logger.error(
|
||||
"Ollama unreachable after %d attempts: %s",
|
||||
max_retries,
|
||||
exc,
|
||||
)
|
||||
raise last_exception from exc
|
||||
return result.content if hasattr(result, "content") else str(result)
|
||||
except self._TRANSIENT as exc:
|
||||
self._handle_retry_or_raise(
|
||||
exc, attempt, max_retries, transient=True,
|
||||
)
|
||||
await asyncio.sleep(min(2**attempt, 16))
|
||||
except Exception as exc:
|
||||
last_exception = exc
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
"Agent run failed on attempt %d/%d: %s. Retrying...",
|
||||
attempt,
|
||||
max_retries,
|
||||
exc,
|
||||
)
|
||||
await asyncio.sleep(min(2 ** (attempt - 1), 8))
|
||||
else:
|
||||
logger.error(
|
||||
"Agent run failed after %d attempts: %s",
|
||||
max_retries,
|
||||
exc,
|
||||
)
|
||||
raise last_exception from exc
|
||||
self._handle_retry_or_raise(
|
||||
exc, attempt, max_retries, transient=False,
|
||||
)
|
||||
await asyncio.sleep(min(2 ** (attempt - 1), 8))
|
||||
# Unreachable — _handle_retry_or_raise raises on last attempt.
|
||||
raise RuntimeError("retry loop exited unexpectedly") # pragma: no cover
|
||||
|
||||
# Emit completion event
|
||||
@staticmethod
|
||||
def _handle_retry_or_raise(
|
||||
exc: Exception, attempt: int, max_retries: int, *, transient: bool,
|
||||
) -> None:
|
||||
"""Log a retry warning or raise after exhausting attempts."""
|
||||
if attempt < max_retries:
|
||||
if transient:
|
||||
logger.warning(
|
||||
"Ollama contention on attempt %d/%d: %s. Waiting before retry...",
|
||||
attempt, max_retries, type(exc).__name__,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Agent run failed on attempt %d/%d: %s. Retrying...",
|
||||
attempt, max_retries, exc,
|
||||
)
|
||||
else:
|
||||
label = "Ollama unreachable" if transient else "Agent run failed"
|
||||
logger.error("%s after %d attempts: %s", label, max_retries, exc)
|
||||
raise exc
|
||||
|
||||
async def _emit_response_event(self, message: str, response: str) -> None:
|
||||
"""Publish a completion event to the event bus if connected."""
|
||||
if self.event_bus:
|
||||
await self.event_bus.publish(
|
||||
Event(
|
||||
@@ -197,8 +192,6 @@ class BaseAgent(ABC):
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def get_capabilities(self) -> list[str]:
|
||||
"""Get list of capabilities this agent provides."""
|
||||
return self.tools
|
||||
|
||||
@@ -361,6 +361,41 @@ class TestRun:
|
||||
assert response == "ok"
|
||||
|
||||
|
||||
# ── _handle_retry_or_raise ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestHandleRetryOrRaise:
|
||||
def test_raises_on_last_attempt(self):
|
||||
BaseAgent = _make_base_class()
|
||||
with pytest.raises(ValueError, match="boom"):
|
||||
BaseAgent._handle_retry_or_raise(
|
||||
ValueError("boom"), attempt=3, max_retries=3, transient=False,
|
||||
)
|
||||
|
||||
def test_raises_on_last_attempt_transient(self):
|
||||
BaseAgent = _make_base_class()
|
||||
exc = httpx.ConnectError("down")
|
||||
with pytest.raises(httpx.ConnectError):
|
||||
BaseAgent._handle_retry_or_raise(
|
||||
exc, attempt=3, max_retries=3, transient=True,
|
||||
)
|
||||
|
||||
def test_no_raise_on_early_attempt(self):
|
||||
BaseAgent = _make_base_class()
|
||||
# Should return None (no raise) on non-final attempt
|
||||
result = BaseAgent._handle_retry_or_raise(
|
||||
ValueError("retry me"), attempt=1, max_retries=3, transient=False,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_no_raise_on_early_transient(self):
|
||||
BaseAgent = _make_base_class()
|
||||
result = BaseAgent._handle_retry_or_raise(
|
||||
httpx.ReadTimeout("busy"), attempt=2, max_retries=3, transient=True,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ── get_capabilities / get_status ────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user