Fix test_analysis_error_logs_exc_info: mock _aux_async_client so download path is reached
This commit is contained in:
@@ -25,6 +25,7 @@ from tools.vision_tools import (
|
||||
# _validate_image_url — urlparse-based validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateImageUrl:
|
||||
"""Tests for URL validation, including urlparse-based netloc check."""
|
||||
|
||||
@@ -95,6 +96,7 @@ class TestValidateImageUrl:
|
||||
# _determine_mime_type
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDetermineMimeType:
|
||||
def test_jpg(self):
|
||||
assert _determine_mime_type(Path("photo.jpg")) == "image/jpeg"
|
||||
@@ -119,6 +121,7 @@ class TestDetermineMimeType:
|
||||
# _image_to_base64_data_url
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestImageToBase64DataUrl:
|
||||
def test_returns_data_url(self, tmp_path):
|
||||
img = tmp_path / "test.png"
|
||||
@@ -141,15 +144,21 @@ class TestImageToBase64DataUrl:
|
||||
# _handle_vision_analyze — type signature & behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHandleVisionAnalyze:
|
||||
"""Verify _handle_vision_analyze returns an Awaitable and builds correct prompt."""
|
||||
|
||||
def test_returns_awaitable(self):
|
||||
"""The handler must return an Awaitable (coroutine) since it's registered as async."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool:
|
||||
with patch(
|
||||
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
|
||||
) as mock_tool:
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
result = _handle_vision_analyze(
|
||||
{"image_url": "https://example.com/img.png", "question": "What is this?"}
|
||||
{
|
||||
"image_url": "https://example.com/img.png",
|
||||
"question": "What is this?",
|
||||
}
|
||||
)
|
||||
# It should be an Awaitable (coroutine)
|
||||
assert isinstance(result, Awaitable)
|
||||
@@ -158,10 +167,15 @@ class TestHandleVisionAnalyze:
|
||||
|
||||
def test_prompt_contains_question(self):
|
||||
"""The full prompt should incorporate the user's question."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool:
|
||||
with patch(
|
||||
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
|
||||
) as mock_tool:
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
coro = _handle_vision_analyze(
|
||||
{"image_url": "https://example.com/img.png", "question": "Describe the cat"}
|
||||
{
|
||||
"image_url": "https://example.com/img.png",
|
||||
"question": "Describe the cat",
|
||||
}
|
||||
)
|
||||
# Clean up coroutine
|
||||
coro.close()
|
||||
@@ -172,8 +186,12 @@ class TestHandleVisionAnalyze:
|
||||
|
||||
def test_uses_auxiliary_vision_model_env(self):
|
||||
"""AUXILIARY_VISION_MODEL env var should override DEFAULT_VISION_MODEL."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool, \
|
||||
patch.dict(os.environ, {"AUXILIARY_VISION_MODEL": "custom/model-v1"}):
|
||||
with (
|
||||
patch(
|
||||
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
|
||||
) as mock_tool,
|
||||
patch.dict(os.environ, {"AUXILIARY_VISION_MODEL": "custom/model-v1"}),
|
||||
):
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
coro = _handle_vision_analyze(
|
||||
{"image_url": "https://example.com/img.png", "question": "test"}
|
||||
@@ -185,8 +203,12 @@ class TestHandleVisionAnalyze:
|
||||
|
||||
def test_falls_back_to_default_model(self):
|
||||
"""Without AUXILIARY_VISION_MODEL, should use DEFAULT_VISION_MODEL or fallback."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool, \
|
||||
patch.dict(os.environ, {}, clear=False):
|
||||
with (
|
||||
patch(
|
||||
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
|
||||
) as mock_tool,
|
||||
patch.dict(os.environ, {}, clear=False),
|
||||
):
|
||||
# Ensure AUXILIARY_VISION_MODEL is not set
|
||||
os.environ.pop("AUXILIARY_VISION_MODEL", None)
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
@@ -202,7 +224,9 @@ class TestHandleVisionAnalyze:
|
||||
|
||||
def test_empty_args_graceful(self):
|
||||
"""Missing keys should default to empty strings, not raise."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool:
|
||||
with patch(
|
||||
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
|
||||
) as mock_tool:
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
result = _handle_vision_analyze({})
|
||||
assert isinstance(result, Awaitable)
|
||||
@@ -213,6 +237,7 @@ class TestHandleVisionAnalyze:
|
||||
# Error logging with exc_info — verify tracebacks are logged
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestErrorLoggingExcInfo:
|
||||
"""Verify that exc_info=True is used in error/warning log calls."""
|
||||
|
||||
@@ -229,9 +254,13 @@ class TestErrorLoggingExcInfo:
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
dest = tmp_path / "image.jpg"
|
||||
with caplog.at_level(logging.ERROR, logger="tools.vision_tools"), \
|
||||
pytest.raises(ConnectionError):
|
||||
await _download_image("https://example.com/img.jpg", dest, max_retries=1)
|
||||
with (
|
||||
caplog.at_level(logging.ERROR, logger="tools.vision_tools"),
|
||||
pytest.raises(ConnectionError),
|
||||
):
|
||||
await _download_image(
|
||||
"https://example.com/img.jpg", dest, max_retries=1
|
||||
)
|
||||
|
||||
# Should have logged with exc_info (traceback present)
|
||||
error_records = [r for r in caplog.records if r.levelno >= logging.ERROR]
|
||||
@@ -241,11 +270,17 @@ class TestErrorLoggingExcInfo:
|
||||
@pytest.mark.asyncio
|
||||
async def test_analysis_error_logs_exc_info(self, caplog):
|
||||
"""When vision_analyze_tool encounters an error, it should log with exc_info."""
|
||||
with patch("tools.vision_tools._validate_image_url", return_value=True), \
|
||||
patch("tools.vision_tools._download_image", new_callable=AsyncMock,
|
||||
side_effect=Exception("download boom")), \
|
||||
caplog.at_level(logging.ERROR, logger="tools.vision_tools"):
|
||||
|
||||
with (
|
||||
patch("tools.vision_tools._validate_image_url", return_value=True),
|
||||
patch(
|
||||
"tools.vision_tools._download_image",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("download boom"),
|
||||
),
|
||||
patch("tools.vision_tools._aux_async_client", MagicMock()),
|
||||
patch("tools.vision_tools.DEFAULT_VISION_MODEL", "test/model"),
|
||||
caplog.at_level(logging.ERROR, logger="tools.vision_tools"),
|
||||
):
|
||||
result = await vision_analyze_tool(
|
||||
"https://example.com/img.jpg", "describe this", "test/model"
|
||||
)
|
||||
@@ -269,14 +304,20 @@ class TestErrorLoggingExcInfo:
|
||||
dest.write_bytes(b"\xff\xd8\xff" + b"\x00" * 16)
|
||||
return dest
|
||||
|
||||
with patch("tools.vision_tools._validate_image_url", return_value=True), \
|
||||
patch("tools.vision_tools._download_image", side_effect=fake_download), \
|
||||
patch("tools.vision_tools._image_to_base64_data_url",
|
||||
return_value="data:image/jpeg;base64,abc"), \
|
||||
patch("agent.auxiliary_client.get_auxiliary_extra_body", return_value=None), \
|
||||
patch("agent.auxiliary_client.auxiliary_max_tokens_param", return_value={"max_tokens": 2000}), \
|
||||
caplog.at_level(logging.WARNING, logger="tools.vision_tools"):
|
||||
|
||||
with (
|
||||
patch("tools.vision_tools._validate_image_url", return_value=True),
|
||||
patch("tools.vision_tools._download_image", side_effect=fake_download),
|
||||
patch(
|
||||
"tools.vision_tools._image_to_base64_data_url",
|
||||
return_value="data:image/jpeg;base64,abc",
|
||||
),
|
||||
patch("agent.auxiliary_client.get_auxiliary_extra_body", return_value=None),
|
||||
patch(
|
||||
"agent.auxiliary_client.auxiliary_max_tokens_param",
|
||||
return_value={"max_tokens": 2000},
|
||||
),
|
||||
caplog.at_level(logging.WARNING, logger="tools.vision_tools"),
|
||||
):
|
||||
# Mock the vision client
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
@@ -286,11 +327,13 @@ class TestErrorLoggingExcInfo:
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Patch module-level _aux_async_client so the tool doesn't bail early
|
||||
with patch("tools.vision_tools._aux_async_client", mock_client), \
|
||||
patch("tools.vision_tools.DEFAULT_VISION_MODEL", "test/model"):
|
||||
|
||||
with (
|
||||
patch("tools.vision_tools._aux_async_client", mock_client),
|
||||
patch("tools.vision_tools.DEFAULT_VISION_MODEL", "test/model"),
|
||||
):
|
||||
# Make unlink fail to trigger cleanup warning
|
||||
original_unlink = Path.unlink
|
||||
|
||||
def failing_unlink(self, *args, **kwargs):
|
||||
raise PermissionError("no permission")
|
||||
|
||||
@@ -299,8 +342,12 @@ class TestErrorLoggingExcInfo:
|
||||
"https://example.com/tempimg.jpg", "describe", "test/model"
|
||||
)
|
||||
|
||||
warning_records = [r for r in caplog.records if r.levelno == logging.WARNING
|
||||
and "temporary file" in r.getMessage().lower()]
|
||||
warning_records = [
|
||||
r
|
||||
for r in caplog.records
|
||||
if r.levelno == logging.WARNING
|
||||
and "temporary file" in r.getMessage().lower()
|
||||
]
|
||||
assert len(warning_records) >= 1
|
||||
assert warning_records[0].exc_info is not None
|
||||
|
||||
@@ -309,6 +356,7 @@ class TestErrorLoggingExcInfo:
|
||||
# check_vision_requirements & get_debug_session_info
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVisionRequirements:
|
||||
def test_check_requirements_returns_bool(self):
|
||||
result = check_vision_requirements()
|
||||
@@ -327,9 +375,11 @@ class TestVisionRequirements:
|
||||
# Integration: registry entry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVisionRegistration:
|
||||
def test_vision_analyze_registered(self):
|
||||
from tools.registry import registry
|
||||
|
||||
entry = registry._tools.get("vision_analyze")
|
||||
assert entry is not None
|
||||
assert entry.toolset == "vision"
|
||||
@@ -337,6 +387,7 @@ class TestVisionRegistration:
|
||||
|
||||
def test_schema_has_required_fields(self):
|
||||
from tools.registry import registry
|
||||
|
||||
entry = registry._tools.get("vision_analyze")
|
||||
schema = entry.schema
|
||||
assert schema["name"] == "vision_analyze"
|
||||
@@ -347,5 +398,6 @@ class TestVisionRegistration:
|
||||
|
||||
def test_handler_is_callable(self):
|
||||
from tools.registry import registry
|
||||
|
||||
entry = registry._tools.get("vision_analyze")
|
||||
assert callable(entry.handler)
|
||||
|
||||
Reference in New Issue
Block a user