316 lines
13 KiB
Python
316 lines
13 KiB
Python
"""Test Ollama disconnection handling.
|
|
|
|
Verifies that:
|
|
1. BaseAgent.run() retries on transient errors (contention/disconnect) with backoff (#70)
|
|
2. BaseAgent.run() re-raises the error after retries are exhausted
|
|
3. session.chat() returns disconnect-specific message on connection errors
|
|
4. session.chat_with_tools() returns _ErrorRunOutput with disconnect message on connection errors
|
|
"""
|
|
|
|
import importlib
|
|
import logging
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
|
|
class TestBaseAgentDisconnect:
|
|
"""Test BaseAgent.run() disconnection handling."""
|
|
|
|
def test_base_agent_retries_and_logs_on_connect_error(self, caplog):
|
|
"""BaseAgent.run() retries on ConnectError with backoff, then logs 'Ollama unreachable' (#70)."""
|
|
caplog.set_level(logging.WARNING)
|
|
importlib.import_module("timmy.agents.base")
|
|
|
|
with (
|
|
patch("timmy.agents.base.Ollama") as mock_ollama,
|
|
patch("timmy.agents.base.Agent") as mock_agent_class,
|
|
patch("timmy.agents.base.asyncio.sleep"),
|
|
):
|
|
mock_ollama.return_value = MagicMock()
|
|
mock_agent = MagicMock()
|
|
mock_agent.run.side_effect = httpx.ConnectError("Connection refused")
|
|
mock_agent_class.return_value = mock_agent
|
|
|
|
from timmy.agents.base import BaseAgent
|
|
|
|
class ConcreteAgent(BaseAgent):
|
|
async def execute_task(self, task_id: str, description: str, context: dict):
|
|
return {"task_id": task_id, "status": "completed"}
|
|
|
|
agent = ConcreteAgent(
|
|
agent_id="test",
|
|
name="Test",
|
|
role="tester",
|
|
system_prompt="You are a test agent.",
|
|
tools=[],
|
|
)
|
|
|
|
with pytest.raises(httpx.ConnectError):
|
|
import asyncio
|
|
|
|
asyncio.run(agent.run("test message"))
|
|
|
|
# Should have retried 3 times total
|
|
assert mock_agent.run.call_count == 3
|
|
assert any("Ollama contention" in record.message for record in caplog.records), (
|
|
f"Expected 'Ollama contention' in logs, got: {[r.message for r in caplog.records]}"
|
|
)
|
|
assert any("Ollama unreachable" in record.message for record in caplog.records), (
|
|
f"Expected 'Ollama unreachable' in logs, got: {[r.message for r in caplog.records]}"
|
|
)
|
|
|
|
def test_base_agent_retries_and_logs_on_read_error(self, caplog):
|
|
"""BaseAgent.run() retries on ReadError with backoff, then logs 'Ollama unreachable' (#70)."""
|
|
caplog.set_level(logging.WARNING)
|
|
importlib.import_module("timmy.agents.base")
|
|
|
|
with (
|
|
patch("timmy.agents.base.Ollama") as mock_ollama,
|
|
patch("timmy.agents.base.Agent") as mock_agent_class,
|
|
patch("timmy.agents.base.asyncio.sleep"),
|
|
):
|
|
mock_ollama.return_value = MagicMock()
|
|
mock_agent = MagicMock()
|
|
mock_agent.run.side_effect = httpx.ReadError("Server closed connection")
|
|
mock_agent_class.return_value = mock_agent
|
|
|
|
from timmy.agents.base import BaseAgent
|
|
|
|
class ConcreteAgent(BaseAgent):
|
|
async def execute_task(self, task_id: str, description: str, context: dict):
|
|
return {"task_id": task_id, "status": "completed"}
|
|
|
|
agent = ConcreteAgent(
|
|
agent_id="test",
|
|
name="Test",
|
|
role="tester",
|
|
system_prompt="You are a test agent.",
|
|
tools=[],
|
|
)
|
|
|
|
with pytest.raises(httpx.ReadError):
|
|
import asyncio
|
|
|
|
asyncio.run(agent.run("test message"))
|
|
|
|
assert mock_agent.run.call_count == 3
|
|
assert any("Ollama contention" in record.message for record in caplog.records), (
|
|
f"Expected 'Ollama contention' in logs, got: {[r.message for r in caplog.records]}"
|
|
)
|
|
|
|
def test_base_agent_retries_and_logs_on_connection_error(self, caplog):
|
|
"""BaseAgent.run() retries on ConnectionError with backoff (#70)."""
|
|
caplog.set_level(logging.WARNING)
|
|
importlib.import_module("timmy.agents.base")
|
|
|
|
with (
|
|
patch("timmy.agents.base.Ollama") as mock_ollama,
|
|
patch("timmy.agents.base.Agent") as mock_agent_class,
|
|
patch("timmy.agents.base.asyncio.sleep"),
|
|
):
|
|
mock_ollama.return_value = MagicMock()
|
|
mock_agent = MagicMock()
|
|
mock_agent.run.side_effect = ConnectionError("Network unreachable")
|
|
mock_agent_class.return_value = mock_agent
|
|
|
|
from timmy.agents.base import BaseAgent
|
|
|
|
class ConcreteAgent(BaseAgent):
|
|
async def execute_task(self, task_id: str, description: str, context: dict):
|
|
return {"task_id": task_id, "status": "completed"}
|
|
|
|
agent = ConcreteAgent(
|
|
agent_id="test",
|
|
name="Test",
|
|
role="tester",
|
|
system_prompt="You are a test agent.",
|
|
tools=[],
|
|
)
|
|
|
|
with pytest.raises(ConnectionError):
|
|
import asyncio
|
|
|
|
asyncio.run(agent.run("test message"))
|
|
|
|
assert mock_agent.run.call_count == 3
|
|
assert any("Ollama unreachable" in record.message for record in caplog.records), (
|
|
f"Expected 'Ollama unreachable' in logs, got: {[r.message for r in caplog.records]}"
|
|
)
|
|
|
|
def test_base_agent_re_raises_connection_error_after_retries(self):
|
|
"""BaseAgent.run() re-raises the connection error after exhausting retries (#70)."""
|
|
importlib.import_module("timmy.agents.base")
|
|
|
|
with (
|
|
patch("timmy.agents.base.Ollama") as mock_ollama,
|
|
patch("timmy.agents.base.Agent") as mock_agent_class,
|
|
patch("timmy.agents.base.asyncio.sleep"),
|
|
):
|
|
mock_ollama.return_value = MagicMock()
|
|
mock_agent = MagicMock()
|
|
mock_agent.run.side_effect = httpx.ConnectError("Connection refused")
|
|
mock_agent_class.return_value = mock_agent
|
|
|
|
from timmy.agents.base import BaseAgent
|
|
|
|
class ConcreteAgent(BaseAgent):
|
|
async def execute_task(self, task_id: str, description: str, context: dict):
|
|
return {"task_id": task_id, "status": "completed"}
|
|
|
|
agent = ConcreteAgent(
|
|
agent_id="test",
|
|
name="Test",
|
|
role="tester",
|
|
system_prompt="You are a test agent.",
|
|
tools=[],
|
|
)
|
|
|
|
with pytest.raises(httpx.ConnectError, match="Connection refused"):
|
|
import asyncio
|
|
|
|
asyncio.run(agent.run("test message"))
|
|
|
|
|
|
class TestSessionDisconnect:
|
|
"""Test session.py disconnection handling."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_returns_disconnect_message_on_connect_error(self, caplog):
|
|
"""session.chat() returns disconnect-specific message on httpx.ConnectError."""
|
|
caplog.set_level(logging.ERROR)
|
|
|
|
with patch("timmy.session._get_agent") as mock_get_agent:
|
|
mock_agent = MagicMock()
|
|
mock_agent.arun.side_effect = httpx.ConnectError("Connection refused")
|
|
mock_get_agent.return_value = mock_agent
|
|
|
|
# Import after patching
|
|
from timmy import session
|
|
|
|
result = await session.chat("test message")
|
|
|
|
assert "Ollama appears to be disconnected" in result
|
|
assert any("Ollama disconnected" in record.message for record in caplog.records), (
|
|
f"Expected 'Ollama disconnected' in logs, got: {[r.message for r in caplog.records]}"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_returns_disconnect_message_on_read_error(self, caplog):
|
|
"""session.chat() returns disconnect-specific message on httpx.ReadError."""
|
|
caplog.set_level(logging.ERROR)
|
|
|
|
with patch("timmy.session._get_agent") as mock_get_agent:
|
|
mock_agent = MagicMock()
|
|
mock_agent.arun.side_effect = httpx.ReadError("Server closed connection")
|
|
mock_get_agent.return_value = mock_agent
|
|
|
|
from timmy import session
|
|
|
|
result = await session.chat("test message")
|
|
|
|
assert "Ollama appears to be disconnected" in result
|
|
assert any("Ollama disconnected" in record.message for record in caplog.records), (
|
|
f"Expected 'Ollama disconnected' in logs, got: {[r.message for r in caplog.records]}"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_returns_disconnect_message_on_connection_error(self, caplog):
|
|
"""session.chat() returns disconnect-specific message on ConnectionError."""
|
|
caplog.set_level(logging.ERROR)
|
|
|
|
with patch("timmy.session._get_agent") as mock_get_agent:
|
|
mock_agent = MagicMock()
|
|
mock_agent.arun.side_effect = ConnectionError("Network unreachable")
|
|
mock_get_agent.return_value = mock_agent
|
|
|
|
from timmy import session
|
|
|
|
result = await session.chat("test message")
|
|
|
|
assert "Ollama appears to be disconnected" in result
|
|
assert any("Ollama disconnected" in record.message for record in caplog.records), (
|
|
f"Expected 'Ollama disconnected' in logs, got: {[r.message for r in caplog.records]}"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_with_tools_returns_error_run_output_on_connect_error(self, caplog):
|
|
"""session.chat_with_tools() returns _ErrorRunOutput with disconnect message on ConnectError."""
|
|
caplog.set_level(logging.ERROR)
|
|
|
|
with patch("timmy.session._get_agent") as mock_get_agent:
|
|
mock_agent = MagicMock()
|
|
mock_agent.arun.side_effect = httpx.ConnectError("Connection refused")
|
|
mock_get_agent.return_value = mock_agent
|
|
|
|
from timmy import session
|
|
|
|
result = await session.chat_with_tools("test message")
|
|
|
|
assert hasattr(result, "content")
|
|
assert hasattr(result, "status")
|
|
assert "Ollama appears to be disconnected" in result.content
|
|
assert result.status == "ERROR"
|
|
assert any("Ollama disconnected" in record.message for record in caplog.records), (
|
|
f"Expected 'Ollama disconnected' in logs, got: {[r.message for r in caplog.records]}"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_with_tools_returns_error_run_output_on_read_error(self, caplog):
|
|
"""session.chat_with_tools() returns _ErrorRunOutput with disconnect message on ReadError."""
|
|
caplog.set_level(logging.ERROR)
|
|
|
|
with patch("timmy.session._get_agent") as mock_get_agent:
|
|
mock_agent = MagicMock()
|
|
mock_agent.arun.side_effect = httpx.ReadError("Server closed connection")
|
|
mock_get_agent.return_value = mock_agent
|
|
|
|
from timmy import session
|
|
|
|
result = await session.chat_with_tools("test message")
|
|
|
|
assert "Ollama appears to be disconnected" in result.content
|
|
assert any("Ollama disconnected" in record.message for record in caplog.records), (
|
|
f"Expected 'Ollama disconnected' in logs, got: {[r.message for r in caplog.records]}"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_continue_chat_returns_error_run_output_on_connect_error(self, caplog):
|
|
"""session.continue_chat() returns _ErrorRunOutput with disconnect message on ConnectError."""
|
|
caplog.set_level(logging.ERROR)
|
|
|
|
with patch("timmy.session._get_agent") as mock_get_agent:
|
|
mock_agent = MagicMock()
|
|
mock_agent.acontinue_run.side_effect = httpx.ConnectError("Connection refused")
|
|
mock_get_agent.return_value = mock_agent
|
|
|
|
from timmy import session
|
|
|
|
mock_run_output = MagicMock()
|
|
result = await session.continue_chat(mock_run_output)
|
|
|
|
assert hasattr(result, "content")
|
|
assert "Ollama appears to be disconnected" in result.content
|
|
assert any("Ollama disconnected" in record.message for record in caplog.records), (
|
|
f"Expected 'Ollama disconnected' in logs, got: {[r.message for r in caplog.records]}"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_other_errors_use_generic_message(self, caplog):
|
|
"""Non-connection errors still use the generic error message."""
|
|
caplog.set_level(logging.ERROR)
|
|
|
|
with patch("timmy.session._get_agent") as mock_get_agent:
|
|
mock_agent = MagicMock()
|
|
mock_agent.arun.side_effect = ValueError("Some other error")
|
|
mock_get_agent.return_value = mock_agent
|
|
|
|
from timmy import session
|
|
|
|
result = await session.chat("test message")
|
|
|
|
assert "I'm having trouble reaching my inference backend" in result
|
|
# Should NOT have Ollama disconnected message
|
|
assert "Ollama appears to be disconnected" not in result
|