Fix test_analysis_error_logs_exc_info: mock _aux_async_client so download path is reached

This commit is contained in:
SHL0MS
2026-03-10 16:03:19 -04:00
parent c358af7861
commit 0229e6b407

View File

@@ -25,6 +25,7 @@ from tools.vision_tools import (
# _validate_image_url — urlparse-based validation
# ---------------------------------------------------------------------------
class TestValidateImageUrl:
"""Tests for URL validation, including urlparse-based netloc check."""
@@ -95,6 +96,7 @@ class TestValidateImageUrl:
# _determine_mime_type
# ---------------------------------------------------------------------------
class TestDetermineMimeType:
def test_jpg(self):
assert _determine_mime_type(Path("photo.jpg")) == "image/jpeg"
@@ -119,6 +121,7 @@ class TestDetermineMimeType:
# _image_to_base64_data_url
# ---------------------------------------------------------------------------
class TestImageToBase64DataUrl:
def test_returns_data_url(self, tmp_path):
img = tmp_path / "test.png"
@@ -141,15 +144,21 @@ class TestImageToBase64DataUrl:
# _handle_vision_analyze — type signature & behavior
# ---------------------------------------------------------------------------
class TestHandleVisionAnalyze:
"""Verify _handle_vision_analyze returns an Awaitable and builds correct prompt."""
def test_returns_awaitable(self):
"""The handler must return an Awaitable (coroutine) since it's registered as async."""
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool:
with patch(
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
) as mock_tool:
mock_tool.return_value = json.dumps({"result": "ok"})
result = _handle_vision_analyze(
{"image_url": "https://example.com/img.png", "question": "What is this?"}
{
"image_url": "https://example.com/img.png",
"question": "What is this?",
}
)
# It should be an Awaitable (coroutine)
assert isinstance(result, Awaitable)
@@ -158,10 +167,15 @@ class TestHandleVisionAnalyze:
def test_prompt_contains_question(self):
"""The full prompt should incorporate the user's question."""
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool:
with patch(
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
) as mock_tool:
mock_tool.return_value = json.dumps({"result": "ok"})
coro = _handle_vision_analyze(
{"image_url": "https://example.com/img.png", "question": "Describe the cat"}
{
"image_url": "https://example.com/img.png",
"question": "Describe the cat",
}
)
# Clean up coroutine
coro.close()
@@ -172,8 +186,12 @@ class TestHandleVisionAnalyze:
def test_uses_auxiliary_vision_model_env(self):
"""AUXILIARY_VISION_MODEL env var should override DEFAULT_VISION_MODEL."""
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool, \
patch.dict(os.environ, {"AUXILIARY_VISION_MODEL": "custom/model-v1"}):
with (
patch(
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
) as mock_tool,
patch.dict(os.environ, {"AUXILIARY_VISION_MODEL": "custom/model-v1"}),
):
mock_tool.return_value = json.dumps({"result": "ok"})
coro = _handle_vision_analyze(
{"image_url": "https://example.com/img.png", "question": "test"}
@@ -185,8 +203,12 @@ class TestHandleVisionAnalyze:
def test_falls_back_to_default_model(self):
"""Without AUXILIARY_VISION_MODEL, should use DEFAULT_VISION_MODEL or fallback."""
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool, \
patch.dict(os.environ, {}, clear=False):
with (
patch(
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
) as mock_tool,
patch.dict(os.environ, {}, clear=False),
):
# Ensure AUXILIARY_VISION_MODEL is not set
os.environ.pop("AUXILIARY_VISION_MODEL", None)
mock_tool.return_value = json.dumps({"result": "ok"})
@@ -202,7 +224,9 @@ class TestHandleVisionAnalyze:
def test_empty_args_graceful(self):
"""Missing keys should default to empty strings, not raise."""
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool:
with patch(
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
) as mock_tool:
mock_tool.return_value = json.dumps({"result": "ok"})
result = _handle_vision_analyze({})
assert isinstance(result, Awaitable)
@@ -213,6 +237,7 @@ class TestHandleVisionAnalyze:
# Error logging with exc_info — verify tracebacks are logged
# ---------------------------------------------------------------------------
class TestErrorLoggingExcInfo:
"""Verify that exc_info=True is used in error/warning log calls."""
@@ -229,9 +254,13 @@ class TestErrorLoggingExcInfo:
mock_client_cls.return_value = mock_client
dest = tmp_path / "image.jpg"
with caplog.at_level(logging.ERROR, logger="tools.vision_tools"), \
pytest.raises(ConnectionError):
await _download_image("https://example.com/img.jpg", dest, max_retries=1)
with (
caplog.at_level(logging.ERROR, logger="tools.vision_tools"),
pytest.raises(ConnectionError),
):
await _download_image(
"https://example.com/img.jpg", dest, max_retries=1
)
# Should have logged with exc_info (traceback present)
error_records = [r for r in caplog.records if r.levelno >= logging.ERROR]
@@ -241,11 +270,17 @@ class TestErrorLoggingExcInfo:
@pytest.mark.asyncio
async def test_analysis_error_logs_exc_info(self, caplog):
"""When vision_analyze_tool encounters an error, it should log with exc_info."""
with patch("tools.vision_tools._validate_image_url", return_value=True), \
patch("tools.vision_tools._download_image", new_callable=AsyncMock,
side_effect=Exception("download boom")), \
caplog.at_level(logging.ERROR, logger="tools.vision_tools"):
with (
patch("tools.vision_tools._validate_image_url", return_value=True),
patch(
"tools.vision_tools._download_image",
new_callable=AsyncMock,
side_effect=Exception("download boom"),
),
patch("tools.vision_tools._aux_async_client", MagicMock()),
patch("tools.vision_tools.DEFAULT_VISION_MODEL", "test/model"),
caplog.at_level(logging.ERROR, logger="tools.vision_tools"),
):
result = await vision_analyze_tool(
"https://example.com/img.jpg", "describe this", "test/model"
)
@@ -269,14 +304,20 @@ class TestErrorLoggingExcInfo:
dest.write_bytes(b"\xff\xd8\xff" + b"\x00" * 16)
return dest
with patch("tools.vision_tools._validate_image_url", return_value=True), \
patch("tools.vision_tools._download_image", side_effect=fake_download), \
patch("tools.vision_tools._image_to_base64_data_url",
return_value="data:image/jpeg;base64,abc"), \
patch("agent.auxiliary_client.get_auxiliary_extra_body", return_value=None), \
patch("agent.auxiliary_client.auxiliary_max_tokens_param", return_value={"max_tokens": 2000}), \
caplog.at_level(logging.WARNING, logger="tools.vision_tools"):
with (
patch("tools.vision_tools._validate_image_url", return_value=True),
patch("tools.vision_tools._download_image", side_effect=fake_download),
patch(
"tools.vision_tools._image_to_base64_data_url",
return_value="data:image/jpeg;base64,abc",
),
patch("agent.auxiliary_client.get_auxiliary_extra_body", return_value=None),
patch(
"agent.auxiliary_client.auxiliary_max_tokens_param",
return_value={"max_tokens": 2000},
),
caplog.at_level(logging.WARNING, logger="tools.vision_tools"),
):
# Mock the vision client
mock_client = AsyncMock()
mock_response = MagicMock()
@@ -286,11 +327,13 @@ class TestErrorLoggingExcInfo:
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
# Patch module-level _aux_async_client so the tool doesn't bail early
with patch("tools.vision_tools._aux_async_client", mock_client), \
patch("tools.vision_tools.DEFAULT_VISION_MODEL", "test/model"):
with (
patch("tools.vision_tools._aux_async_client", mock_client),
patch("tools.vision_tools.DEFAULT_VISION_MODEL", "test/model"),
):
# Make unlink fail to trigger cleanup warning
original_unlink = Path.unlink
def failing_unlink(self, *args, **kwargs):
raise PermissionError("no permission")
@@ -299,8 +342,12 @@ class TestErrorLoggingExcInfo:
"https://example.com/tempimg.jpg", "describe", "test/model"
)
warning_records = [r for r in caplog.records if r.levelno == logging.WARNING
and "temporary file" in r.getMessage().lower()]
warning_records = [
r
for r in caplog.records
if r.levelno == logging.WARNING
and "temporary file" in r.getMessage().lower()
]
assert len(warning_records) >= 1
assert warning_records[0].exc_info is not None
@@ -309,6 +356,7 @@ class TestErrorLoggingExcInfo:
# check_vision_requirements & get_debug_session_info
# ---------------------------------------------------------------------------
class TestVisionRequirements:
def test_check_requirements_returns_bool(self):
result = check_vision_requirements()
@@ -327,9 +375,11 @@ class TestVisionRequirements:
# Integration: registry entry
# ---------------------------------------------------------------------------
class TestVisionRegistration:
def test_vision_analyze_registered(self):
from tools.registry import registry
entry = registry._tools.get("vision_analyze")
assert entry is not None
assert entry.toolset == "vision"
@@ -337,6 +387,7 @@ class TestVisionRegistration:
def test_schema_has_required_fields(self):
from tools.registry import registry
entry = registry._tools.get("vision_analyze")
schema = entry.schema
assert schema["name"] == "vision_analyze"
@@ -347,5 +398,6 @@ class TestVisionRegistration:
def test_handler_is_callable(self):
from tools.registry import registry
entry = registry._tools.get("vision_analyze")
assert callable(entry.handler)