diff --git a/tests/tools/test_vision_tools.py b/tests/tools/test_vision_tools.py new file mode 100644 index 000000000..3bdd30178 --- /dev/null +++ b/tests/tools/test_vision_tools.py @@ -0,0 +1,351 @@ +"""Tests for tools/vision_tools.py — URL validation, type hints, error logging.""" + +import asyncio +import json +import logging +import os +from pathlib import Path +from typing import Awaitable +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from tools.vision_tools import ( + _validate_image_url, + _handle_vision_analyze, + _determine_mime_type, + _image_to_base64_data_url, + vision_analyze_tool, + check_vision_requirements, + get_debug_session_info, +) + + +# --------------------------------------------------------------------------- +# _validate_image_url — urlparse-based validation +# --------------------------------------------------------------------------- + +class TestValidateImageUrl: + """Tests for URL validation, including urlparse-based netloc check.""" + + def test_valid_https_url(self): + assert _validate_image_url("https://example.com/image.jpg") is True + + def test_valid_http_url(self): + assert _validate_image_url("http://cdn.example.org/photo.png") is True + + def test_valid_url_without_extension(self): + """CDN endpoints that redirect to images should still pass.""" + assert _validate_image_url("https://cdn.example.com/abcdef123") is True + + def test_valid_url_with_query_params(self): + assert _validate_image_url("https://img.example.com/pic?w=200&h=200") is True + + def test_valid_url_with_port(self): + assert _validate_image_url("http://localhost:8080/image.png") is True + + def test_valid_url_with_path_only(self): + assert _validate_image_url("https://example.com/") is True + + def test_rejects_empty_string(self): + assert _validate_image_url("") is False + + def test_rejects_none(self): + assert _validate_image_url(None) is False + + def test_rejects_non_string(self): + assert _validate_image_url(12345) is False + + def test_rejects_ftp_scheme(self): + assert _validate_image_url("ftp://files.example.com/image.jpg") is False + + def test_rejects_file_scheme(self): + assert _validate_image_url("file:///etc/passwd") is False + + def test_rejects_no_scheme(self): + assert _validate_image_url("example.com/image.jpg") is False + + def test_rejects_javascript_scheme(self): + assert _validate_image_url("javascript:alert(1)") is False + + def test_rejects_http_without_netloc(self): + """http:// alone has no network location — urlparse catches this.""" + assert _validate_image_url("http://") is False + + def test_rejects_https_without_netloc(self): + assert _validate_image_url("https://") is False + + def test_rejects_http_colon_only(self): + assert _validate_image_url("http:") is False + + def test_rejects_data_url(self): + assert _validate_image_url("data:image/png;base64,iVBOR") is False + + def test_rejects_whitespace_only(self): + assert _validate_image_url(" ") is False + + def test_rejects_boolean(self): + assert _validate_image_url(True) is False + + def test_rejects_list(self): + assert _validate_image_url(["https://example.com"]) is False + + +# --------------------------------------------------------------------------- +# _determine_mime_type +# --------------------------------------------------------------------------- + +class TestDetermineMimeType: + def test_jpg(self): + assert _determine_mime_type(Path("photo.jpg")) == "image/jpeg" + + def test_jpeg(self): + assert _determine_mime_type(Path("photo.jpeg")) == "image/jpeg" + + def test_png(self): + assert _determine_mime_type(Path("screenshot.png")) == "image/png" + + def test_gif(self): + assert _determine_mime_type(Path("anim.gif")) == "image/gif" + + def test_webp(self): + assert _determine_mime_type(Path("modern.webp")) == "image/webp" + + def test_unknown_extension_defaults_to_jpeg(self): + assert _determine_mime_type(Path("file.xyz")) == "image/jpeg" + + +# --------------------------------------------------------------------------- +# _image_to_base64_data_url +# --------------------------------------------------------------------------- + +class TestImageToBase64DataUrl: + def test_returns_data_url(self, tmp_path): + img = tmp_path / "test.png" + img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8) + result = _image_to_base64_data_url(img) + assert result.startswith("data:image/png;base64,") + + def test_custom_mime_type(self, tmp_path): + img = tmp_path / "test.bin" + img.write_bytes(b"\x00" * 16) + result = _image_to_base64_data_url(img, mime_type="image/webp") + assert result.startswith("data:image/webp;base64,") + + def test_file_not_found_raises(self, tmp_path): + with pytest.raises(FileNotFoundError): + _image_to_base64_data_url(tmp_path / "nonexistent.png") + + +# --------------------------------------------------------------------------- +# _handle_vision_analyze — type signature & behavior +# --------------------------------------------------------------------------- + +class TestHandleVisionAnalyze: + """Verify _handle_vision_analyze returns an Awaitable and builds correct prompt.""" + + def test_returns_awaitable(self): + """The handler must return an Awaitable (coroutine) since it's registered as async.""" + with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool: + mock_tool.return_value = json.dumps({"result": "ok"}) + result = _handle_vision_analyze( + {"image_url": "https://example.com/img.png", "question": "What is this?"} + ) + # It should be an Awaitable (coroutine) + assert isinstance(result, Awaitable) + # Clean up the coroutine to avoid RuntimeWarning + result.close() + + def test_prompt_contains_question(self): + """The full prompt should incorporate the user's question.""" + with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool: + mock_tool.return_value = json.dumps({"result": "ok"}) + coro = _handle_vision_analyze( + {"image_url": "https://example.com/img.png", "question": "Describe the cat"} + ) + # Clean up coroutine + coro.close() + call_args = mock_tool.call_args + full_prompt = call_args[0][1] # second positional arg + assert "Describe the cat" in full_prompt + assert "Fully describe and explain" in full_prompt + + def test_uses_auxiliary_vision_model_env(self): + """AUXILIARY_VISION_MODEL env var should override DEFAULT_VISION_MODEL.""" + with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool, \ + patch.dict(os.environ, {"AUXILIARY_VISION_MODEL": "custom/model-v1"}): + mock_tool.return_value = json.dumps({"result": "ok"}) + coro = _handle_vision_analyze( + {"image_url": "https://example.com/img.png", "question": "test"} + ) + coro.close() + call_args = mock_tool.call_args + model = call_args[0][2] # third positional arg + assert model == "custom/model-v1" + + def test_falls_back_to_default_model(self): + """Without AUXILIARY_VISION_MODEL, should use DEFAULT_VISION_MODEL or fallback.""" + with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool, \ + patch.dict(os.environ, {}, clear=False): + # Ensure AUXILIARY_VISION_MODEL is not set + os.environ.pop("AUXILIARY_VISION_MODEL", None) + mock_tool.return_value = json.dumps({"result": "ok"}) + coro = _handle_vision_analyze( + {"image_url": "https://example.com/img.png", "question": "test"} + ) + coro.close() + call_args = mock_tool.call_args + model = call_args[0][2] + # Should be DEFAULT_VISION_MODEL or the hardcoded fallback + assert model is not None + assert len(model) > 0 + + def test_empty_args_graceful(self): + """Missing keys should default to empty strings, not raise.""" + with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool: + mock_tool.return_value = json.dumps({"result": "ok"}) + result = _handle_vision_analyze({}) + assert isinstance(result, Awaitable) + result.close() + + +# --------------------------------------------------------------------------- +# Error logging with exc_info — verify tracebacks are logged +# --------------------------------------------------------------------------- + +class TestErrorLoggingExcInfo: + """Verify that exc_info=True is used in error/warning log calls.""" + + @pytest.mark.asyncio + async def test_download_failure_logs_exc_info(self, tmp_path, caplog): + """After max retries, the download error should include exc_info.""" + from tools.vision_tools import _download_image + + with patch("tools.vision_tools.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.get = AsyncMock(side_effect=ConnectionError("network down")) + mock_client_cls.return_value = mock_client + + dest = tmp_path / "image.jpg" + with caplog.at_level(logging.ERROR, logger="tools.vision_tools"), \ + pytest.raises(ConnectionError): + await _download_image("https://example.com/img.jpg", dest, max_retries=1) + + # Should have logged with exc_info (traceback present) + error_records = [r for r in caplog.records if r.levelno >= logging.ERROR] + assert len(error_records) >= 1 + assert error_records[0].exc_info is not None + + @pytest.mark.asyncio + async def test_analysis_error_logs_exc_info(self, caplog): + """When vision_analyze_tool encounters an error, it should log with exc_info.""" + with patch("tools.vision_tools._validate_image_url", return_value=True), \ + patch("tools.vision_tools._download_image", new_callable=AsyncMock, + side_effect=Exception("download boom")), \ + caplog.at_level(logging.ERROR, logger="tools.vision_tools"): + + result = await vision_analyze_tool( + "https://example.com/img.jpg", "describe this", "test/model" + ) + result_data = json.loads(result) + # Error response uses "success": False, not an "error" key + assert result_data["success"] is False + + error_records = [r for r in caplog.records if r.levelno >= logging.ERROR] + assert any(r.exc_info is not None for r in error_records) + + @pytest.mark.asyncio + async def test_cleanup_error_logs_exc_info(self, tmp_path, caplog): + """Temp file cleanup failure should log warning with exc_info.""" + # Create a real temp file that will be "downloaded" + temp_dir = tmp_path / "temp_vision_images" + temp_dir.mkdir() + + async def fake_download(url, dest, max_retries=3): + """Simulate download by writing file to the expected destination.""" + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_bytes(b"\xff\xd8\xff" + b"\x00" * 16) + return dest + + with patch("tools.vision_tools._validate_image_url", return_value=True), \ + patch("tools.vision_tools._download_image", side_effect=fake_download), \ + patch("tools.vision_tools._image_to_base64_data_url", + return_value="data:image/jpeg;base64,abc"), \ + patch("agent.auxiliary_client.get_auxiliary_extra_body", return_value=None), \ + patch("agent.auxiliary_client.auxiliary_max_tokens_param", return_value={"max_tokens": 2000}), \ + caplog.at_level(logging.WARNING, logger="tools.vision_tools"): + + # Mock the vision client + mock_client = AsyncMock() + mock_response = MagicMock() + mock_choice = MagicMock() + mock_choice.message.content = "A test image description" + mock_response.choices = [mock_choice] + mock_client.chat.completions.create = AsyncMock(return_value=mock_response) + + # Patch module-level _aux_async_client so the tool doesn't bail early + with patch("tools.vision_tools._aux_async_client", mock_client), \ + patch("tools.vision_tools.DEFAULT_VISION_MODEL", "test/model"): + + # Make unlink fail to trigger cleanup warning + original_unlink = Path.unlink + def failing_unlink(self, *args, **kwargs): + raise PermissionError("no permission") + + with patch.object(Path, "unlink", failing_unlink): + result = await vision_analyze_tool( + "https://example.com/tempimg.jpg", "describe", "test/model" + ) + + warning_records = [r for r in caplog.records if r.levelno == logging.WARNING + and "temporary file" in r.getMessage().lower()] + assert len(warning_records) >= 1 + assert warning_records[0].exc_info is not None + + +# --------------------------------------------------------------------------- +# check_vision_requirements & get_debug_session_info +# --------------------------------------------------------------------------- + +class TestVisionRequirements: + def test_check_requirements_returns_bool(self): + result = check_vision_requirements() + assert isinstance(result, bool) + + def test_debug_session_info_returns_dict(self): + info = get_debug_session_info() + assert isinstance(info, dict) + # DebugSession.get_session_info() returns these keys + assert "enabled" in info + assert "session_id" in info + assert "total_calls" in info + + +# --------------------------------------------------------------------------- +# Integration: registry entry +# --------------------------------------------------------------------------- + +class TestVisionRegistration: + def test_vision_analyze_registered(self): + from tools.registry import registry + entry = registry._tools.get("vision_analyze") + assert entry is not None + assert entry.toolset == "vision" + assert entry.is_async is True + + def test_schema_has_required_fields(self): + from tools.registry import registry + entry = registry._tools.get("vision_analyze") + schema = entry.schema + assert schema["name"] == "vision_analyze" + params = schema.get("parameters", {}) + props = params.get("properties", {}) + assert "image_url" in props + assert "question" in props + + def test_handler_is_callable(self): + from tools.registry import registry + entry = registry._tools.get("vision_analyze") + assert callable(entry.handler) diff --git a/tools/vision_tools.py b/tools/vision_tools.py index 718e17363..d91051175 100644 --- a/tools/vision_tools.py +++ b/tools/vision_tools.py @@ -27,14 +27,15 @@ Usage: ) """ +import asyncio +import base64 import json import logging import os -import asyncio import uuid -import base64 from pathlib import Path -from typing import Dict, Any, Optional +from typing import Any, Awaitable, Dict, Optional +from urllib.parse import urlparse import httpx from openai import AsyncOpenAI from agent.auxiliary_client import get_vision_auxiliary_client @@ -73,15 +74,18 @@ def _validate_image_url(url: str) -> bool: """ if not url or not isinstance(url, str): return False - - # Check if it's a valid URL format - if not (url.startswith('http://') or url.startswith('https://')): + + # Basic HTTP/HTTPS URL check + if not (url.startswith("http://") or url.startswith("https://")): return False - - # Check for common image extensions (optional, as URLs may not have extensions) - image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp', '.svg'] - - return True # Allow all HTTP/HTTPS URLs for flexibility + + # Parse to ensure we at least have a network location; still allow URLs + # without file extensions (e.g. CDN endpoints that redirect to images). + parsed = urlparse(url) + if not parsed.netloc: + return False + + return True # Allow all well-formed HTTP/HTTPS URLs for flexibility async def _download_image(image_url: str, destination: Path, max_retries: int = 3) -> Path: @@ -131,7 +135,12 @@ async def _download_image(image_url: str, destination: Path, max_retries: int = logger.warning("Retrying in %ss...", wait_time) await asyncio.sleep(wait_time) else: - logger.error("Image download failed after %s attempts: %s", max_retries, str(e)[:100]) + logger.error( + "Image download failed after %s attempts: %s", + max_retries, + str(e)[:100], + exc_info=True, + ) raise last_error @@ -188,7 +197,7 @@ def _image_to_base64_data_url(image_path: Path, mime_type: Optional[str] = None) async def vision_analyze_tool( image_url: str, user_prompt: str, - model: str = DEFAULT_VISION_MODEL + model: str = DEFAULT_VISION_MODEL, ) -> str: """ Analyze an image from a URL or local file path using vision AI. @@ -347,7 +356,7 @@ async def vision_analyze_tool( except Exception as e: error_msg = f"Error analyzing image: {str(e)}" - logger.error("%s", error_msg) + logger.error("%s", error_msg, exc_info=True) # Prepare error response result = { @@ -368,7 +377,9 @@ async def vision_analyze_tool( temp_image_path.unlink() logger.debug("Cleaned up temporary image file") except Exception as cleanup_error: - logger.warning("Could not delete temporary file: %s", cleanup_error) + logger.warning( + "Could not delete temporary file: %s", cleanup_error, exc_info=True + ) def check_vision_requirements() -> bool: @@ -464,10 +475,13 @@ VISION_ANALYZE_SCHEMA = { } -def _handle_vision_analyze(args, **kw): +def _handle_vision_analyze(args: Dict[str, Any], **kw: Any) -> Awaitable[str]: image_url = args.get("image_url", "") question = args.get("question", "") - full_prompt = f"Fully describe and explain everything about this image, then answer the following question:\n\n{question}" + full_prompt = ( + "Fully describe and explain everything about this image, then answer the " + f"following question:\n\n{question}" + ) model = (os.getenv("AUXILIARY_VISION_MODEL", "").strip() or DEFAULT_VISION_MODEL or "google/gemini-3-flash-preview")