From 5e67fc8c40d8f867504bf168f21564eb6903a00e Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Sun, 29 Mar 2026 20:55:04 -0700 Subject: [PATCH] fix(vision): reject non-image files and enforce website policy (salvage #1940) (#3845) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three safety gaps in vision_analyze_tool: 1. Local files accepted without checking if they're actually images — a renamed text file would get base64-encoded and sent to the model. Now validates magic bytes (PNG, JPEG, GIF, BMP, WebP, SVG). 2. No website policy enforcement on image URLs — blocked domains could be fetched via the vision tool. Now checks before download. 3. No redirect check — if an allowed URL redirected to a blocked domain, the download would proceed. Now re-checks the final URL. Fixed one test that needed _validate_image_url mocked to bypass DNS resolution on the fake blocked.test domain (is_safe_url does DNS checks that were added after the original PR). Co-authored-by: GutSlabs --- tests/tools/test_vision_tools.py | 72 ++++++++++++++++++++++++++++++++ tools/vision_tools.py | 42 ++++++++++++++++++- 2 files changed, 113 insertions(+), 1 deletion(-) diff --git a/tests/tools/test_vision_tools.py b/tests/tools/test_vision_tools.py index 4f152cebd..97ee57a11 100644 --- a/tests/tools/test_vision_tools.py +++ b/tests/tools/test_vision_tools.py @@ -354,6 +354,78 @@ class TestErrorLoggingExcInfo: assert warning_records[0].exc_info is not None +class TestVisionSafetyGuards: + @pytest.mark.asyncio + async def test_local_non_image_file_rejected_before_llm_call(self, tmp_path): + secret = tmp_path / "secret.txt" + secret.write_text("TOP-SECRET=1\n", encoding="utf-8") + + with patch("tools.vision_tools.async_call_llm", new_callable=AsyncMock) as mock_llm: + result = json.loads(await vision_analyze_tool(str(secret), "extract text")) + + assert result["success"] is False + assert "Only real image files are supported" in result["error"] + mock_llm.assert_not_awaited() + + @pytest.mark.asyncio + async def test_blocked_remote_url_short_circuits_before_download(self): + blocked = { + "host": "blocked.test", + "rule": "blocked.test", + "source": "config", + "message": "Blocked by website policy", + } + + with ( + patch("tools.vision_tools.check_website_access", return_value=blocked), + patch("tools.vision_tools._validate_image_url", return_value=True), + patch("tools.vision_tools._download_image", new_callable=AsyncMock) as mock_download, + ): + result = json.loads(await vision_analyze_tool("https://blocked.test/cat.png", "describe")) + + assert result["success"] is False + assert "Blocked by website policy" in result["error"] + mock_download.assert_not_awaited() + + @pytest.mark.asyncio + async def test_download_blocks_redirected_final_url(self, tmp_path): + from tools.vision_tools import _download_image + + def fake_check(url): + if url == "https://allowed.test/cat.png": + return None + if url == "https://blocked.test/final.png": + return { + "host": "blocked.test", + "rule": "blocked.test", + "source": "config", + "message": "Blocked by website policy", + } + raise AssertionError(f"unexpected URL checked: {url}") + + class FakeResponse: + url = "https://blocked.test/final.png" + content = b"\x89PNG\r\n\x1a\n" + b"\x00" * 16 + + def raise_for_status(self): + return None + + with ( + patch("tools.vision_tools.check_website_access", side_effect=fake_check), + patch("tools.vision_tools.httpx.AsyncClient") as mock_client_cls, + pytest.raises(PermissionError, match="Blocked by website policy"), + ): + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.get = AsyncMock(return_value=FakeResponse()) + mock_client_cls.return_value = mock_client + + await _download_image("https://allowed.test/cat.png", tmp_path / "cat.png", max_retries=1) + + assert not (tmp_path / "cat.png").exists() + + # --------------------------------------------------------------------------- # check_vision_requirements & get_debug_session_info # --------------------------------------------------------------------------- diff --git a/tools/vision_tools.py b/tools/vision_tools.py index d8b96bc4e..47b406846 100644 --- a/tools/vision_tools.py +++ b/tools/vision_tools.py @@ -39,6 +39,7 @@ from urllib.parse import urlparse import httpx from agent.auxiliary_client import async_call_llm, extract_content_or_reasoning from tools.debug_helpers import DebugSession +from tools.website_policy import check_website_access logger = logging.getLogger(__name__) @@ -76,6 +77,28 @@ def _validate_image_url(url: str) -> bool: return True +def _detect_image_mime_type(image_path: Path) -> Optional[str]: + """Return a MIME type when the file looks like a supported image.""" + with image_path.open("rb") as f: + header = f.read(64) + + if header.startswith(b"\x89PNG\r\n\x1a\n"): + return "image/png" + if header.startswith(b"\xff\xd8\xff"): + return "image/jpeg" + if header.startswith((b"GIF87a", b"GIF89a")): + return "image/gif" + if header.startswith(b"BM"): + return "image/bmp" + if len(header) >= 12 and header[:4] == b"RIFF" and header[8:12] == b"WEBP": + return "image/webp" + if image_path.suffix.lower() == ".svg": + head = image_path.read_text(encoding="utf-8", errors="ignore")[:4096].lower() + if " Path: """ Download an image from a URL to a local destination (async) with retry logic. @@ -115,6 +138,10 @@ async def _download_image(image_url: str, destination: Path, max_retries: int = last_error = None for attempt in range(max_retries): try: + blocked = check_website_access(image_url) + if blocked: + raise PermissionError(blocked["message"]) + # Download the image with appropriate headers using async httpx # Enable follow_redirects to handle image CDNs that redirect (e.g., Imgur, Picsum) # SSRF: event_hooks validates each redirect target against private IP ranges @@ -131,6 +158,11 @@ async def _download_image(image_url: str, destination: Path, max_retries: int = }, ) response.raise_for_status() + + final_url = str(response.url) + blocked = check_website_access(final_url) + if blocked: + raise PermissionError(blocked["message"]) # Save the image content destination.write_bytes(response.content) @@ -257,6 +289,7 @@ async def vision_analyze_tool( # Track whether we should clean up the file after processing. # Local files (e.g. from the image cache) should NOT be deleted. should_cleanup = True + detected_mime_type = None try: from tools.interrupt import is_interrupted @@ -275,6 +308,9 @@ async def vision_analyze_tool( should_cleanup = False # Don't delete cached/local files elif _validate_image_url(image_url): # Remote URL -- download to a temporary location + blocked = check_website_access(image_url) + if blocked: + raise PermissionError(blocked["message"]) logger.info("Downloading image from URL...") temp_dir = Path("./temp_vision_images") temp_image_path = temp_dir / f"temp_image_{uuid.uuid4()}.jpg" @@ -289,10 +325,14 @@ async def vision_analyze_tool( image_size_bytes = temp_image_path.stat().st_size image_size_kb = image_size_bytes / 1024 logger.info("Image ready (%.1f KB)", image_size_kb) + + detected_mime_type = _detect_image_mime_type(temp_image_path) + if not detected_mime_type: + raise ValueError("Only real image files are supported for vision analysis.") # Convert image to base64 data URL logger.info("Converting image to base64...") - image_data_url = _image_to_base64_data_url(temp_image_path) + image_data_url = _image_to_base64_data_url(temp_image_path, mime_type=detected_mime_type) # Calculate size in KB for better readability data_size_kb = len(image_data_url) / 1024 logger.info("Image converted to base64 (%.1f KB)", data_size_kb)