Merge pull request #792 from NousResearch/hermes/hermes-d2f5523a
Merge PR #428: Improve type hints and error diagnostics in vision_tools + add 42 tests
This commit is contained in:
351
tests/tools/test_vision_tools.py
Normal file
351
tests/tools/test_vision_tools.py
Normal file
@@ -0,0 +1,351 @@
|
||||
"""Tests for tools/vision_tools.py — URL validation, type hints, error logging."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Awaitable
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.vision_tools import (
|
||||
_validate_image_url,
|
||||
_handle_vision_analyze,
|
||||
_determine_mime_type,
|
||||
_image_to_base64_data_url,
|
||||
vision_analyze_tool,
|
||||
check_vision_requirements,
|
||||
get_debug_session_info,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_image_url — urlparse-based validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidateImageUrl:
|
||||
"""Tests for URL validation, including urlparse-based netloc check."""
|
||||
|
||||
def test_valid_https_url(self):
|
||||
assert _validate_image_url("https://example.com/image.jpg") is True
|
||||
|
||||
def test_valid_http_url(self):
|
||||
assert _validate_image_url("http://cdn.example.org/photo.png") is True
|
||||
|
||||
def test_valid_url_without_extension(self):
|
||||
"""CDN endpoints that redirect to images should still pass."""
|
||||
assert _validate_image_url("https://cdn.example.com/abcdef123") is True
|
||||
|
||||
def test_valid_url_with_query_params(self):
|
||||
assert _validate_image_url("https://img.example.com/pic?w=200&h=200") is True
|
||||
|
||||
def test_valid_url_with_port(self):
|
||||
assert _validate_image_url("http://localhost:8080/image.png") is True
|
||||
|
||||
def test_valid_url_with_path_only(self):
|
||||
assert _validate_image_url("https://example.com/") is True
|
||||
|
||||
def test_rejects_empty_string(self):
|
||||
assert _validate_image_url("") is False
|
||||
|
||||
def test_rejects_none(self):
|
||||
assert _validate_image_url(None) is False
|
||||
|
||||
def test_rejects_non_string(self):
|
||||
assert _validate_image_url(12345) is False
|
||||
|
||||
def test_rejects_ftp_scheme(self):
|
||||
assert _validate_image_url("ftp://files.example.com/image.jpg") is False
|
||||
|
||||
def test_rejects_file_scheme(self):
|
||||
assert _validate_image_url("file:///etc/passwd") is False
|
||||
|
||||
def test_rejects_no_scheme(self):
|
||||
assert _validate_image_url("example.com/image.jpg") is False
|
||||
|
||||
def test_rejects_javascript_scheme(self):
|
||||
assert _validate_image_url("javascript:alert(1)") is False
|
||||
|
||||
def test_rejects_http_without_netloc(self):
|
||||
"""http:// alone has no network location — urlparse catches this."""
|
||||
assert _validate_image_url("http://") is False
|
||||
|
||||
def test_rejects_https_without_netloc(self):
|
||||
assert _validate_image_url("https://") is False
|
||||
|
||||
def test_rejects_http_colon_only(self):
|
||||
assert _validate_image_url("http:") is False
|
||||
|
||||
def test_rejects_data_url(self):
|
||||
assert _validate_image_url("data:image/png;base64,iVBOR") is False
|
||||
|
||||
def test_rejects_whitespace_only(self):
|
||||
assert _validate_image_url(" ") is False
|
||||
|
||||
def test_rejects_boolean(self):
|
||||
assert _validate_image_url(True) is False
|
||||
|
||||
def test_rejects_list(self):
|
||||
assert _validate_image_url(["https://example.com"]) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _determine_mime_type
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDetermineMimeType:
|
||||
def test_jpg(self):
|
||||
assert _determine_mime_type(Path("photo.jpg")) == "image/jpeg"
|
||||
|
||||
def test_jpeg(self):
|
||||
assert _determine_mime_type(Path("photo.jpeg")) == "image/jpeg"
|
||||
|
||||
def test_png(self):
|
||||
assert _determine_mime_type(Path("screenshot.png")) == "image/png"
|
||||
|
||||
def test_gif(self):
|
||||
assert _determine_mime_type(Path("anim.gif")) == "image/gif"
|
||||
|
||||
def test_webp(self):
|
||||
assert _determine_mime_type(Path("modern.webp")) == "image/webp"
|
||||
|
||||
def test_unknown_extension_defaults_to_jpeg(self):
|
||||
assert _determine_mime_type(Path("file.xyz")) == "image/jpeg"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _image_to_base64_data_url
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestImageToBase64DataUrl:
|
||||
def test_returns_data_url(self, tmp_path):
|
||||
img = tmp_path / "test.png"
|
||||
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8)
|
||||
result = _image_to_base64_data_url(img)
|
||||
assert result.startswith("data:image/png;base64,")
|
||||
|
||||
def test_custom_mime_type(self, tmp_path):
|
||||
img = tmp_path / "test.bin"
|
||||
img.write_bytes(b"\x00" * 16)
|
||||
result = _image_to_base64_data_url(img, mime_type="image/webp")
|
||||
assert result.startswith("data:image/webp;base64,")
|
||||
|
||||
def test_file_not_found_raises(self, tmp_path):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
_image_to_base64_data_url(tmp_path / "nonexistent.png")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_vision_analyze — type signature & behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHandleVisionAnalyze:
|
||||
"""Verify _handle_vision_analyze returns an Awaitable and builds correct prompt."""
|
||||
|
||||
def test_returns_awaitable(self):
|
||||
"""The handler must return an Awaitable (coroutine) since it's registered as async."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool:
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
result = _handle_vision_analyze(
|
||||
{"image_url": "https://example.com/img.png", "question": "What is this?"}
|
||||
)
|
||||
# It should be an Awaitable (coroutine)
|
||||
assert isinstance(result, Awaitable)
|
||||
# Clean up the coroutine to avoid RuntimeWarning
|
||||
result.close()
|
||||
|
||||
def test_prompt_contains_question(self):
|
||||
"""The full prompt should incorporate the user's question."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool:
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
coro = _handle_vision_analyze(
|
||||
{"image_url": "https://example.com/img.png", "question": "Describe the cat"}
|
||||
)
|
||||
# Clean up coroutine
|
||||
coro.close()
|
||||
call_args = mock_tool.call_args
|
||||
full_prompt = call_args[0][1] # second positional arg
|
||||
assert "Describe the cat" in full_prompt
|
||||
assert "Fully describe and explain" in full_prompt
|
||||
|
||||
def test_uses_auxiliary_vision_model_env(self):
|
||||
"""AUXILIARY_VISION_MODEL env var should override DEFAULT_VISION_MODEL."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool, \
|
||||
patch.dict(os.environ, {"AUXILIARY_VISION_MODEL": "custom/model-v1"}):
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
coro = _handle_vision_analyze(
|
||||
{"image_url": "https://example.com/img.png", "question": "test"}
|
||||
)
|
||||
coro.close()
|
||||
call_args = mock_tool.call_args
|
||||
model = call_args[0][2] # third positional arg
|
||||
assert model == "custom/model-v1"
|
||||
|
||||
def test_falls_back_to_default_model(self):
|
||||
"""Without AUXILIARY_VISION_MODEL, should use DEFAULT_VISION_MODEL or fallback."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool, \
|
||||
patch.dict(os.environ, {}, clear=False):
|
||||
# Ensure AUXILIARY_VISION_MODEL is not set
|
||||
os.environ.pop("AUXILIARY_VISION_MODEL", None)
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
coro = _handle_vision_analyze(
|
||||
{"image_url": "https://example.com/img.png", "question": "test"}
|
||||
)
|
||||
coro.close()
|
||||
call_args = mock_tool.call_args
|
||||
model = call_args[0][2]
|
||||
# Should be DEFAULT_VISION_MODEL or the hardcoded fallback
|
||||
assert model is not None
|
||||
assert len(model) > 0
|
||||
|
||||
def test_empty_args_graceful(self):
|
||||
"""Missing keys should default to empty strings, not raise."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool:
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
result = _handle_vision_analyze({})
|
||||
assert isinstance(result, Awaitable)
|
||||
result.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error logging with exc_info — verify tracebacks are logged
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestErrorLoggingExcInfo:
|
||||
"""Verify that exc_info=True is used in error/warning log calls."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_failure_logs_exc_info(self, tmp_path, caplog):
|
||||
"""After max retries, the download error should include exc_info."""
|
||||
from tools.vision_tools import _download_image
|
||||
|
||||
with patch("tools.vision_tools.httpx.AsyncClient") as mock_client_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(side_effect=ConnectionError("network down"))
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
dest = tmp_path / "image.jpg"
|
||||
with caplog.at_level(logging.ERROR, logger="tools.vision_tools"), \
|
||||
pytest.raises(ConnectionError):
|
||||
await _download_image("https://example.com/img.jpg", dest, max_retries=1)
|
||||
|
||||
# Should have logged with exc_info (traceback present)
|
||||
error_records = [r for r in caplog.records if r.levelno >= logging.ERROR]
|
||||
assert len(error_records) >= 1
|
||||
assert error_records[0].exc_info is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analysis_error_logs_exc_info(self, caplog):
|
||||
"""When vision_analyze_tool encounters an error, it should log with exc_info."""
|
||||
with patch("tools.vision_tools._validate_image_url", return_value=True), \
|
||||
patch("tools.vision_tools._download_image", new_callable=AsyncMock,
|
||||
side_effect=Exception("download boom")), \
|
||||
caplog.at_level(logging.ERROR, logger="tools.vision_tools"):
|
||||
|
||||
result = await vision_analyze_tool(
|
||||
"https://example.com/img.jpg", "describe this", "test/model"
|
||||
)
|
||||
result_data = json.loads(result)
|
||||
# Error response uses "success": False, not an "error" key
|
||||
assert result_data["success"] is False
|
||||
|
||||
error_records = [r for r in caplog.records if r.levelno >= logging.ERROR]
|
||||
assert any(r.exc_info is not None for r in error_records)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_error_logs_exc_info(self, tmp_path, caplog):
|
||||
"""Temp file cleanup failure should log warning with exc_info."""
|
||||
# Create a real temp file that will be "downloaded"
|
||||
temp_dir = tmp_path / "temp_vision_images"
|
||||
temp_dir.mkdir()
|
||||
|
||||
async def fake_download(url, dest, max_retries=3):
|
||||
"""Simulate download by writing file to the expected destination."""
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_bytes(b"\xff\xd8\xff" + b"\x00" * 16)
|
||||
return dest
|
||||
|
||||
with patch("tools.vision_tools._validate_image_url", return_value=True), \
|
||||
patch("tools.vision_tools._download_image", side_effect=fake_download), \
|
||||
patch("tools.vision_tools._image_to_base64_data_url",
|
||||
return_value="data:image/jpeg;base64,abc"), \
|
||||
patch("agent.auxiliary_client.get_auxiliary_extra_body", return_value=None), \
|
||||
patch("agent.auxiliary_client.auxiliary_max_tokens_param", return_value={"max_tokens": 2000}), \
|
||||
caplog.at_level(logging.WARNING, logger="tools.vision_tools"):
|
||||
|
||||
# Mock the vision client
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = "A test image description"
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Patch module-level _aux_async_client so the tool doesn't bail early
|
||||
with patch("tools.vision_tools._aux_async_client", mock_client), \
|
||||
patch("tools.vision_tools.DEFAULT_VISION_MODEL", "test/model"):
|
||||
|
||||
# Make unlink fail to trigger cleanup warning
|
||||
original_unlink = Path.unlink
|
||||
def failing_unlink(self, *args, **kwargs):
|
||||
raise PermissionError("no permission")
|
||||
|
||||
with patch.object(Path, "unlink", failing_unlink):
|
||||
result = await vision_analyze_tool(
|
||||
"https://example.com/tempimg.jpg", "describe", "test/model"
|
||||
)
|
||||
|
||||
warning_records = [r for r in caplog.records if r.levelno == logging.WARNING
|
||||
and "temporary file" in r.getMessage().lower()]
|
||||
assert len(warning_records) >= 1
|
||||
assert warning_records[0].exc_info is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_vision_requirements & get_debug_session_info
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestVisionRequirements:
|
||||
def test_check_requirements_returns_bool(self):
|
||||
result = check_vision_requirements()
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_debug_session_info_returns_dict(self):
|
||||
info = get_debug_session_info()
|
||||
assert isinstance(info, dict)
|
||||
# DebugSession.get_session_info() returns these keys
|
||||
assert "enabled" in info
|
||||
assert "session_id" in info
|
||||
assert "total_calls" in info
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: registry entry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestVisionRegistration:
|
||||
def test_vision_analyze_registered(self):
|
||||
from tools.registry import registry
|
||||
entry = registry._tools.get("vision_analyze")
|
||||
assert entry is not None
|
||||
assert entry.toolset == "vision"
|
||||
assert entry.is_async is True
|
||||
|
||||
def test_schema_has_required_fields(self):
|
||||
from tools.registry import registry
|
||||
entry = registry._tools.get("vision_analyze")
|
||||
schema = entry.schema
|
||||
assert schema["name"] == "vision_analyze"
|
||||
params = schema.get("parameters", {})
|
||||
props = params.get("properties", {})
|
||||
assert "image_url" in props
|
||||
assert "question" in props
|
||||
|
||||
def test_handler_is_callable(self):
|
||||
from tools.registry import registry
|
||||
entry = registry._tools.get("vision_analyze")
|
||||
assert callable(entry.handler)
|
||||
@@ -27,14 +27,15 @@ Usage:
|
||||
)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import asyncio
|
||||
import uuid
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Any, Awaitable, Dict, Optional
|
||||
from urllib.parse import urlparse
|
||||
import httpx
|
||||
from openai import AsyncOpenAI
|
||||
from agent.auxiliary_client import get_vision_auxiliary_client
|
||||
@@ -73,15 +74,18 @@ def _validate_image_url(url: str) -> bool:
|
||||
"""
|
||||
if not url or not isinstance(url, str):
|
||||
return False
|
||||
|
||||
# Check if it's a valid URL format
|
||||
if not (url.startswith('http://') or url.startswith('https://')):
|
||||
|
||||
# Basic HTTP/HTTPS URL check
|
||||
if not (url.startswith("http://") or url.startswith("https://")):
|
||||
return False
|
||||
|
||||
# Check for common image extensions (optional, as URLs may not have extensions)
|
||||
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp', '.svg']
|
||||
|
||||
return True # Allow all HTTP/HTTPS URLs for flexibility
|
||||
|
||||
# Parse to ensure we at least have a network location; still allow URLs
|
||||
# without file extensions (e.g. CDN endpoints that redirect to images).
|
||||
parsed = urlparse(url)
|
||||
if not parsed.netloc:
|
||||
return False
|
||||
|
||||
return True # Allow all well-formed HTTP/HTTPS URLs for flexibility
|
||||
|
||||
|
||||
async def _download_image(image_url: str, destination: Path, max_retries: int = 3) -> Path:
|
||||
@@ -131,7 +135,12 @@ async def _download_image(image_url: str, destination: Path, max_retries: int =
|
||||
logger.warning("Retrying in %ss...", wait_time)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error("Image download failed after %s attempts: %s", max_retries, str(e)[:100])
|
||||
logger.error(
|
||||
"Image download failed after %s attempts: %s",
|
||||
max_retries,
|
||||
str(e)[:100],
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
raise last_error
|
||||
|
||||
@@ -188,7 +197,7 @@ def _image_to_base64_data_url(image_path: Path, mime_type: Optional[str] = None)
|
||||
async def vision_analyze_tool(
|
||||
image_url: str,
|
||||
user_prompt: str,
|
||||
model: str = DEFAULT_VISION_MODEL
|
||||
model: str = DEFAULT_VISION_MODEL,
|
||||
) -> str:
|
||||
"""
|
||||
Analyze an image from a URL or local file path using vision AI.
|
||||
@@ -347,7 +356,7 @@ async def vision_analyze_tool(
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error analyzing image: {str(e)}"
|
||||
logger.error("%s", error_msg)
|
||||
logger.error("%s", error_msg, exc_info=True)
|
||||
|
||||
# Prepare error response
|
||||
result = {
|
||||
@@ -368,7 +377,9 @@ async def vision_analyze_tool(
|
||||
temp_image_path.unlink()
|
||||
logger.debug("Cleaned up temporary image file")
|
||||
except Exception as cleanup_error:
|
||||
logger.warning("Could not delete temporary file: %s", cleanup_error)
|
||||
logger.warning(
|
||||
"Could not delete temporary file: %s", cleanup_error, exc_info=True
|
||||
)
|
||||
|
||||
|
||||
def check_vision_requirements() -> bool:
|
||||
@@ -464,10 +475,13 @@ VISION_ANALYZE_SCHEMA = {
|
||||
}
|
||||
|
||||
|
||||
def _handle_vision_analyze(args, **kw):
|
||||
def _handle_vision_analyze(args: Dict[str, Any], **kw: Any) -> Awaitable[str]:
|
||||
image_url = args.get("image_url", "")
|
||||
question = args.get("question", "")
|
||||
full_prompt = f"Fully describe and explain everything about this image, then answer the following question:\n\n{question}"
|
||||
full_prompt = (
|
||||
"Fully describe and explain everything about this image, then answer the "
|
||||
f"following question:\n\n{question}"
|
||||
)
|
||||
model = (os.getenv("AUXILIARY_VISION_MODEL", "").strip()
|
||||
or DEFAULT_VISION_MODEL
|
||||
or "google/gemini-3-flash-preview")
|
||||
|
||||
Reference in New Issue
Block a user