Files
hermes-agent/tests/tools/test_vision_tools.py
Teknium 5e67fc8c40 fix(vision): reject non-image files and enforce website policy (salvage #1940) (#3845)
Three safety gaps in vision_analyze_tool:

1. Local files accepted without checking if they're actually images —
   a renamed text file would get base64-encoded and sent to the model.
   Now validates magic bytes (PNG, JPEG, GIF, BMP, WebP, SVG).

2. No website policy enforcement on image URLs — blocked domains could
   be fetched via the vision tool. Now checks before download.

3. No redirect check — if an allowed URL redirected to a blocked domain,
   the download would proceed. Now re-checks the final URL.

Fixed one test that needed _validate_image_url mocked to bypass DNS
resolution on the fake blocked.test domain (is_safe_url does DNS
checks that were added after the original PR).

Co-authored-by: GutSlabs <GutSlabs@users.noreply.github.com>
2026-03-29 20:55:04 -07:00

547 lines
21 KiB
Python

"""Tests for tools/vision_tools.py — URL validation, type hints, error logging."""
import asyncio
import json
import logging
import os
from pathlib import Path
from typing import Awaitable
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from tools.vision_tools import (
_validate_image_url,
_handle_vision_analyze,
_determine_mime_type,
_image_to_base64_data_url,
vision_analyze_tool,
check_vision_requirements,
get_debug_session_info,
)
# ---------------------------------------------------------------------------
# _validate_image_url — urlparse-based validation
# ---------------------------------------------------------------------------
class TestValidateImageUrl:
"""Tests for URL validation, including urlparse-based netloc check."""
def test_valid_https_url(self):
assert _validate_image_url("https://example.com/image.jpg") is True
def test_valid_http_url(self):
with patch("tools.url_safety.socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("93.184.216.34", 0)),
]):
assert _validate_image_url("http://cdn.example.org/photo.png") is True
def test_valid_url_without_extension(self):
"""CDN endpoints that redirect to images should still pass."""
with patch("tools.url_safety.socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("93.184.216.34", 0)),
]):
assert _validate_image_url("https://cdn.example.com/abcdef123") is True
def test_valid_url_with_query_params(self):
with patch("tools.url_safety.socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("93.184.216.34", 0)),
]):
assert _validate_image_url("https://img.example.com/pic?w=200&h=200") is True
def test_localhost_url_blocked_by_ssrf(self):
"""localhost URLs are now blocked by SSRF protection."""
assert _validate_image_url("http://localhost:8080/image.png") is False
def test_valid_url_with_port(self):
assert _validate_image_url("http://example.com:8080/image.png") is True
def test_valid_url_with_path_only(self):
assert _validate_image_url("https://example.com/") is True
def test_rejects_empty_string(self):
assert _validate_image_url("") is False
def test_rejects_none(self):
assert _validate_image_url(None) is False
def test_rejects_non_string(self):
assert _validate_image_url(12345) is False
def test_rejects_ftp_scheme(self):
assert _validate_image_url("ftp://files.example.com/image.jpg") is False
def test_rejects_file_scheme(self):
assert _validate_image_url("file:///etc/passwd") is False
def test_rejects_no_scheme(self):
assert _validate_image_url("example.com/image.jpg") is False
def test_rejects_javascript_scheme(self):
assert _validate_image_url("javascript:alert(1)") is False
def test_rejects_http_without_netloc(self):
"""http:// alone has no network location — urlparse catches this."""
assert _validate_image_url("http://") is False
def test_rejects_https_without_netloc(self):
assert _validate_image_url("https://") is False
def test_rejects_http_colon_only(self):
assert _validate_image_url("http:") is False
def test_rejects_data_url(self):
assert _validate_image_url("data:image/png;base64,iVBOR") is False
def test_rejects_whitespace_only(self):
assert _validate_image_url(" ") is False
def test_rejects_boolean(self):
assert _validate_image_url(True) is False
def test_rejects_list(self):
assert _validate_image_url(["https://example.com"]) is False
# ---------------------------------------------------------------------------
# _determine_mime_type
# ---------------------------------------------------------------------------
class TestDetermineMimeType:
def test_jpg(self):
assert _determine_mime_type(Path("photo.jpg")) == "image/jpeg"
def test_jpeg(self):
assert _determine_mime_type(Path("photo.jpeg")) == "image/jpeg"
def test_png(self):
assert _determine_mime_type(Path("screenshot.png")) == "image/png"
def test_gif(self):
assert _determine_mime_type(Path("anim.gif")) == "image/gif"
def test_webp(self):
assert _determine_mime_type(Path("modern.webp")) == "image/webp"
def test_unknown_extension_defaults_to_jpeg(self):
assert _determine_mime_type(Path("file.xyz")) == "image/jpeg"
# ---------------------------------------------------------------------------
# _image_to_base64_data_url
# ---------------------------------------------------------------------------
class TestImageToBase64DataUrl:
def test_returns_data_url(self, tmp_path):
img = tmp_path / "test.png"
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8)
result = _image_to_base64_data_url(img)
assert result.startswith("data:image/png;base64,")
def test_custom_mime_type(self, tmp_path):
img = tmp_path / "test.bin"
img.write_bytes(b"\x00" * 16)
result = _image_to_base64_data_url(img, mime_type="image/webp")
assert result.startswith("data:image/webp;base64,")
def test_file_not_found_raises(self, tmp_path):
with pytest.raises(FileNotFoundError):
_image_to_base64_data_url(tmp_path / "nonexistent.png")
# ---------------------------------------------------------------------------
# _handle_vision_analyze — type signature & behavior
# ---------------------------------------------------------------------------
class TestHandleVisionAnalyze:
"""Verify _handle_vision_analyze returns an Awaitable and builds correct prompt."""
def test_returns_awaitable(self):
"""The handler must return an Awaitable (coroutine) since it's registered as async."""
with patch(
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
) as mock_tool:
mock_tool.return_value = json.dumps({"result": "ok"})
result = _handle_vision_analyze(
{
"image_url": "https://example.com/img.png",
"question": "What is this?",
}
)
# It should be an Awaitable (coroutine)
assert isinstance(result, Awaitable)
# Clean up the coroutine to avoid RuntimeWarning
result.close()
def test_prompt_contains_question(self):
"""The full prompt should incorporate the user's question."""
with patch(
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
) as mock_tool:
mock_tool.return_value = json.dumps({"result": "ok"})
coro = _handle_vision_analyze(
{
"image_url": "https://example.com/img.png",
"question": "Describe the cat",
}
)
# Clean up coroutine
coro.close()
call_args = mock_tool.call_args
full_prompt = call_args[0][1] # second positional arg
assert "Describe the cat" in full_prompt
assert "Fully describe and explain" in full_prompt
def test_uses_auxiliary_vision_model_env(self):
"""AUXILIARY_VISION_MODEL env var should override DEFAULT_VISION_MODEL."""
with (
patch(
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
) as mock_tool,
patch.dict(os.environ, {"AUXILIARY_VISION_MODEL": "custom/model-v1"}),
):
mock_tool.return_value = json.dumps({"result": "ok"})
coro = _handle_vision_analyze(
{"image_url": "https://example.com/img.png", "question": "test"}
)
coro.close()
call_args = mock_tool.call_args
model = call_args[0][2] # third positional arg
assert model == "custom/model-v1"
def test_falls_back_to_default_model(self):
"""Without AUXILIARY_VISION_MODEL, model should be None (let call_llm resolve default)."""
with (
patch(
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
) as mock_tool,
patch.dict(os.environ, {}, clear=False),
):
# Ensure AUXILIARY_VISION_MODEL is not set
os.environ.pop("AUXILIARY_VISION_MODEL", None)
mock_tool.return_value = json.dumps({"result": "ok"})
coro = _handle_vision_analyze(
{"image_url": "https://example.com/img.png", "question": "test"}
)
coro.close()
call_args = mock_tool.call_args
model = call_args[0][2]
# With no AUXILIARY_VISION_MODEL set, model should be None
# (the centralized call_llm router picks the default)
assert model is None
def test_empty_args_graceful(self):
"""Missing keys should default to empty strings, not raise."""
with patch(
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
) as mock_tool:
mock_tool.return_value = json.dumps({"result": "ok"})
result = _handle_vision_analyze({})
assert isinstance(result, Awaitable)
result.close()
# ---------------------------------------------------------------------------
# Error logging with exc_info — verify tracebacks are logged
# ---------------------------------------------------------------------------
class TestErrorLoggingExcInfo:
"""Verify that exc_info=True is used in error/warning log calls."""
@pytest.mark.asyncio
async def test_download_failure_logs_exc_info(self, tmp_path, caplog):
"""After max retries, the download error should include exc_info."""
from tools.vision_tools import _download_image
with patch("tools.vision_tools.httpx.AsyncClient") as mock_client_cls:
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.get = AsyncMock(side_effect=ConnectionError("network down"))
mock_client_cls.return_value = mock_client
dest = tmp_path / "image.jpg"
with (
caplog.at_level(logging.ERROR, logger="tools.vision_tools"),
pytest.raises(ConnectionError),
):
await _download_image(
"https://example.com/img.jpg", dest, max_retries=1
)
# Should have logged with exc_info (traceback present)
error_records = [r for r in caplog.records if r.levelno >= logging.ERROR]
assert len(error_records) >= 1
assert error_records[0].exc_info is not None
@pytest.mark.asyncio
async def test_analysis_error_logs_exc_info(self, caplog):
"""When vision_analyze_tool encounters an error, it should log with exc_info."""
with (
patch("tools.vision_tools._validate_image_url", return_value=True),
patch(
"tools.vision_tools._download_image",
new_callable=AsyncMock,
side_effect=Exception("download boom"),
),
caplog.at_level(logging.ERROR, logger="tools.vision_tools"),
):
result = await vision_analyze_tool(
"https://example.com/img.jpg", "describe this", "test/model"
)
result_data = json.loads(result)
# Error response uses "success": False, not an "error" key
assert result_data["success"] is False
error_records = [r for r in caplog.records if r.levelno >= logging.ERROR]
assert any(r.exc_info and r.exc_info[0] is not None for r in error_records)
@pytest.mark.asyncio
async def test_cleanup_error_logs_exc_info(self, tmp_path, caplog):
"""Temp file cleanup failure should log warning with exc_info."""
# Create a real temp file that will be "downloaded"
temp_dir = tmp_path / "temp_vision_images"
temp_dir.mkdir()
async def fake_download(url, dest, max_retries=3):
"""Simulate download by writing file to the expected destination."""
dest.parent.mkdir(parents=True, exist_ok=True)
dest.write_bytes(b"\xff\xd8\xff" + b"\x00" * 16)
return dest
with (
patch("tools.vision_tools._validate_image_url", return_value=True),
patch("tools.vision_tools._download_image", side_effect=fake_download),
patch(
"tools.vision_tools._image_to_base64_data_url",
return_value="data:image/jpeg;base64,abc",
),
caplog.at_level(logging.WARNING, logger="tools.vision_tools"),
):
# Mock the async_call_llm function to return a mock response
mock_response = MagicMock()
mock_choice = MagicMock()
mock_choice.message.content = "A test image description"
mock_response.choices = [mock_choice]
with (
patch("tools.vision_tools.async_call_llm", new_callable=AsyncMock, return_value=mock_response),
):
# Make unlink fail to trigger cleanup warning
original_unlink = Path.unlink
def failing_unlink(self, *args, **kwargs):
raise PermissionError("no permission")
with patch.object(Path, "unlink", failing_unlink):
result = await vision_analyze_tool(
"https://example.com/tempimg.jpg", "describe", "test/model"
)
warning_records = [
r
for r in caplog.records
if r.levelno == logging.WARNING
and "temporary file" in r.getMessage().lower()
]
assert len(warning_records) >= 1
assert warning_records[0].exc_info is not None
class TestVisionSafetyGuards:
@pytest.mark.asyncio
async def test_local_non_image_file_rejected_before_llm_call(self, tmp_path):
secret = tmp_path / "secret.txt"
secret.write_text("TOP-SECRET=1\n", encoding="utf-8")
with patch("tools.vision_tools.async_call_llm", new_callable=AsyncMock) as mock_llm:
result = json.loads(await vision_analyze_tool(str(secret), "extract text"))
assert result["success"] is False
assert "Only real image files are supported" in result["error"]
mock_llm.assert_not_awaited()
@pytest.mark.asyncio
async def test_blocked_remote_url_short_circuits_before_download(self):
blocked = {
"host": "blocked.test",
"rule": "blocked.test",
"source": "config",
"message": "Blocked by website policy",
}
with (
patch("tools.vision_tools.check_website_access", return_value=blocked),
patch("tools.vision_tools._validate_image_url", return_value=True),
patch("tools.vision_tools._download_image", new_callable=AsyncMock) as mock_download,
):
result = json.loads(await vision_analyze_tool("https://blocked.test/cat.png", "describe"))
assert result["success"] is False
assert "Blocked by website policy" in result["error"]
mock_download.assert_not_awaited()
@pytest.mark.asyncio
async def test_download_blocks_redirected_final_url(self, tmp_path):
from tools.vision_tools import _download_image
def fake_check(url):
if url == "https://allowed.test/cat.png":
return None
if url == "https://blocked.test/final.png":
return {
"host": "blocked.test",
"rule": "blocked.test",
"source": "config",
"message": "Blocked by website policy",
}
raise AssertionError(f"unexpected URL checked: {url}")
class FakeResponse:
url = "https://blocked.test/final.png"
content = b"\x89PNG\r\n\x1a\n" + b"\x00" * 16
def raise_for_status(self):
return None
with (
patch("tools.vision_tools.check_website_access", side_effect=fake_check),
patch("tools.vision_tools.httpx.AsyncClient") as mock_client_cls,
pytest.raises(PermissionError, match="Blocked by website policy"),
):
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.get = AsyncMock(return_value=FakeResponse())
mock_client_cls.return_value = mock_client
await _download_image("https://allowed.test/cat.png", tmp_path / "cat.png", max_retries=1)
assert not (tmp_path / "cat.png").exists()
# ---------------------------------------------------------------------------
# check_vision_requirements & get_debug_session_info
# ---------------------------------------------------------------------------
class TestVisionRequirements:
def test_check_requirements_returns_bool(self):
result = check_vision_requirements()
assert isinstance(result, bool)
def test_check_requirements_accepts_codex_auth(self, monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
(tmp_path / "auth.json").write_text(
'{"active_provider":"openai-codex","providers":{"openai-codex":{"tokens":{"access_token":"codex-access-token","refresh_token":"codex-refresh-token"}}}}'
)
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("AUXILIARY_VISION_PROVIDER", raising=False)
monkeypatch.delenv("CONTEXT_VISION_PROVIDER", raising=False)
assert check_vision_requirements() is True
def test_debug_session_info_returns_dict(self):
info = get_debug_session_info()
assert isinstance(info, dict)
# DebugSession.get_session_info() returns these keys
assert "enabled" in info
assert "session_id" in info
assert "total_calls" in info
# ---------------------------------------------------------------------------
# Integration: registry entry
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Tilde expansion in local file paths
# ---------------------------------------------------------------------------
class TestTildeExpansion:
"""Verify that ~/path style paths are expanded correctly."""
@pytest.mark.asyncio
async def test_tilde_path_expanded_to_local_file(self, tmp_path, monkeypatch):
"""vision_analyze_tool should expand ~ in file paths."""
# Create a fake image file under a fake home directory
fake_home = tmp_path / "fakehome"
fake_home.mkdir()
img = fake_home / "test_image.png"
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8)
monkeypatch.setenv("HOME", str(fake_home))
mock_response = MagicMock()
mock_choice = MagicMock()
mock_choice.message.content = "A test image"
mock_response.choices = [mock_choice]
with (
patch(
"tools.vision_tools._image_to_base64_data_url",
return_value="data:image/png;base64,abc",
),
patch(
"tools.vision_tools.async_call_llm",
new_callable=AsyncMock,
return_value=mock_response,
),
):
result = await vision_analyze_tool(
"~/test_image.png", "describe this", "test/model"
)
data = json.loads(result)
assert data["success"] is True
assert data["analysis"] == "A test image"
@pytest.mark.asyncio
async def test_tilde_path_nonexistent_file_gives_error(self, tmp_path, monkeypatch):
"""A tilde path that doesn't resolve to a real file should fail gracefully."""
fake_home = tmp_path / "fakehome"
fake_home.mkdir()
monkeypatch.setenv("HOME", str(fake_home))
result = await vision_analyze_tool(
"~/nonexistent.png", "describe this", "test/model"
)
data = json.loads(result)
assert data["success"] is False
class TestVisionRegistration:
def test_vision_analyze_registered(self):
from tools.registry import registry
entry = registry._tools.get("vision_analyze")
assert entry is not None
assert entry.toolset == "vision"
assert entry.is_async is True
def test_schema_has_required_fields(self):
from tools.registry import registry
entry = registry._tools.get("vision_analyze")
schema = entry.schema
assert schema["name"] == "vision_analyze"
params = schema.get("parameters", {})
props = params.get("properties", {})
assert "image_url" in props
assert "question" in props
def test_handler_is_callable(self):
from tools.registry import registry
entry = registry._tools.get("vision_analyze")
assert callable(entry.handler)