From 0791efe2c340370e2bd734e12cf94221f7d3ec5b Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Mon, 23 Mar 2026 15:40:42 -0700 Subject: [PATCH] fix(security): add SSRF protection to vision_tools and web_tools (hardened) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(security): add SSRF protection to vision_tools and web_tools Both vision_analyze and web_extract/web_crawl accept arbitrary URLs without checking if they target private/internal network addresses. A prompt-injected or malicious skill could use this to access cloud metadata endpoints (169.254.169.254), localhost services, or private network hosts. Adds a shared url_safety.is_safe_url() that resolves hostnames and blocks private, loopback, link-local, and reserved IP ranges. Also blocks known internal hostnames (metadata.google.internal). Integrated at the URL validation layer in vision_tools and before each website_policy check in web_tools (extract, crawl). * test(vision): update localhost test to reflect SSRF protection The existing test_valid_url_with_port asserted localhost URLs pass validation. With SSRF protection, localhost is now correctly blocked. Update the test to verify the block, and add a separate test for valid URLs with ports using a public hostname. * fix(security): harden SSRF protection — fail-closed, CGNAT, multicast, redirect guard Follow-up hardening on top of dieutx's SSRF protection (PR #2630): - Change fail-open to fail-closed: DNS errors and unexpected exceptions now block the request instead of allowing it (OWASP best practice) - Block CGNAT range (100.64.0.0/10): Python's ipaddress.is_private does NOT cover this range (returns False for both is_private and is_global). Used by Tailscale/WireGuard and carrier infrastructure. - Add is_multicast and is_unspecified checks: multicast (224.0.0.0/4) and unspecified (0.0.0.0) addresses were not caught by the original four-check chain - Add redirect guard for vision_tools: httpx event hook re-validates each redirect target against SSRF checks, preventing the classic redirect-based SSRF bypass (302 to internal IP) - Move SSRF filtering before backend dispatch in web_extract: now covers Parallel and Tavily backends, not just Firecrawl - Extract _is_blocked_ip() helper for cleaner IP range checking - Add 24 new tests (CGNAT, multicast, IPv4-mapped IPv6, fail-closed behavior, parametrized blocked/allowed IP lists) - Fix existing tests to mock DNS resolution for test hostnames --------- Co-authored-by: dieutx --- tests/tools/test_url_safety.py | 176 +++++++++++++++++++ tests/tools/test_vision_tools.py | 21 ++- tests/tools/test_website_policy.py | 9 + tools/url_safety.py | 96 +++++++++++ tools/vision_tools.py | 28 ++- tools/web_tools.py | 266 ++++++++++++++++------------- 6 files changed, 472 insertions(+), 124 deletions(-) create mode 100644 tests/tools/test_url_safety.py create mode 100644 tools/url_safety.py diff --git a/tests/tools/test_url_safety.py b/tests/tools/test_url_safety.py new file mode 100644 index 000000000..6a2de78f6 --- /dev/null +++ b/tests/tools/test_url_safety.py @@ -0,0 +1,176 @@ +"""Tests for SSRF protection in url_safety module.""" + +import socket +from unittest.mock import patch + +from tools.url_safety import is_safe_url, _is_blocked_ip + +import ipaddress +import pytest + + +class TestIsSafeUrl: + def test_public_url_allowed(self): + with patch("socket.getaddrinfo", return_value=[ + (2, 1, 6, "", ("93.184.216.34", 0)), + ]): + assert is_safe_url("https://example.com/image.png") is True + + def test_localhost_blocked(self): + with patch("socket.getaddrinfo", return_value=[ + (2, 1, 6, "", ("127.0.0.1", 0)), + ]): + assert is_safe_url("http://localhost:8080/secret") is False + + def test_loopback_ip_blocked(self): + with patch("socket.getaddrinfo", return_value=[ + (2, 1, 6, "", ("127.0.0.1", 0)), + ]): + assert is_safe_url("http://127.0.0.1/admin") is False + + def test_private_10_blocked(self): + with patch("socket.getaddrinfo", return_value=[ + (2, 1, 6, "", ("10.0.0.1", 0)), + ]): + assert is_safe_url("http://internal-service.local/api") is False + + def test_private_172_blocked(self): + with patch("socket.getaddrinfo", return_value=[ + (2, 1, 6, "", ("172.16.0.1", 0)), + ]): + assert is_safe_url("http://private.corp/data") is False + + def test_private_192_blocked(self): + with patch("socket.getaddrinfo", return_value=[ + (2, 1, 6, "", ("192.168.1.1", 0)), + ]): + assert is_safe_url("http://router.local") is False + + def test_link_local_169_254_blocked(self): + with patch("socket.getaddrinfo", return_value=[ + (2, 1, 6, "", ("169.254.169.254", 0)), + ]): + assert is_safe_url("http://169.254.169.254/latest/meta-data/") is False + + def test_metadata_google_internal_blocked(self): + assert is_safe_url("http://metadata.google.internal/computeMetadata/v1/") is False + + def test_ipv6_loopback_blocked(self): + with patch("socket.getaddrinfo", return_value=[ + (10, 1, 6, "", ("::1", 0, 0, 0)), + ]): + assert is_safe_url("http://[::1]:8080/") is False + + def test_dns_failure_blocked(self): + """DNS failures now fail closed — block the request.""" + with patch("socket.getaddrinfo", side_effect=socket.gaierror("Name resolution failed")): + assert is_safe_url("https://nonexistent.example.com") is False + + def test_empty_url_blocked(self): + assert is_safe_url("") is False + + def test_no_hostname_blocked(self): + assert is_safe_url("http://") is False + + def test_public_ip_allowed(self): + with patch("socket.getaddrinfo", return_value=[ + (2, 1, 6, "", ("93.184.216.34", 0)), + ]): + assert is_safe_url("https://example.com") is True + + # ── New tests for hardened SSRF protection ── + + def test_cgnat_100_64_blocked(self): + """100.64.0.0/10 (CGNAT/Shared Address Space) is NOT covered by + ipaddress.is_private — must be blocked explicitly.""" + with patch("socket.getaddrinfo", return_value=[ + (2, 1, 6, "", ("100.64.0.1", 0)), + ]): + assert is_safe_url("http://some-cgnat-host.example/") is False + + def test_cgnat_100_127_blocked(self): + """Upper end of CGNAT range (100.127.255.255).""" + with patch("socket.getaddrinfo", return_value=[ + (2, 1, 6, "", ("100.127.255.254", 0)), + ]): + assert is_safe_url("http://tailscale-peer.example/") is False + + def test_multicast_blocked(self): + """Multicast addresses (224.0.0.0/4) not caught by is_private.""" + with patch("socket.getaddrinfo", return_value=[ + (2, 1, 6, "", ("224.0.0.251", 0)), + ]): + assert is_safe_url("http://mdns-host.local/") is False + + def test_multicast_ipv6_blocked(self): + with patch("socket.getaddrinfo", return_value=[ + (10, 1, 6, "", ("ff02::1", 0, 0, 0)), + ]): + assert is_safe_url("http://[ff02::1]/") is False + + def test_ipv4_mapped_ipv6_loopback_blocked(self): + """::ffff:127.0.0.1 — IPv4-mapped IPv6 loopback.""" + with patch("socket.getaddrinfo", return_value=[ + (10, 1, 6, "", ("::ffff:127.0.0.1", 0, 0, 0)), + ]): + assert is_safe_url("http://[::ffff:127.0.0.1]/") is False + + def test_ipv4_mapped_ipv6_metadata_blocked(self): + """::ffff:169.254.169.254 — IPv4-mapped IPv6 cloud metadata.""" + with patch("socket.getaddrinfo", return_value=[ + (10, 1, 6, "", ("::ffff:169.254.169.254", 0, 0, 0)), + ]): + assert is_safe_url("http://[::ffff:169.254.169.254]/") is False + + def test_unspecified_address_blocked(self): + """0.0.0.0 — unspecified address, can bind to all interfaces.""" + with patch("socket.getaddrinfo", return_value=[ + (2, 1, 6, "", ("0.0.0.0", 0)), + ]): + assert is_safe_url("http://0.0.0.0/") is False + + def test_unexpected_error_fails_closed(self): + """Unexpected exceptions should block, not allow.""" + with patch("tools.url_safety.urlparse", side_effect=ValueError("bad url")): + assert is_safe_url("http://evil.com/") is False + + def test_metadata_goog_blocked(self): + assert is_safe_url("http://metadata.goog/computeMetadata/v1/") is False + + def test_ipv6_unique_local_blocked(self): + """fc00::/7 — IPv6 unique local addresses.""" + with patch("socket.getaddrinfo", return_value=[ + (10, 1, 6, "", ("fd12::1", 0, 0, 0)), + ]): + assert is_safe_url("http://[fd12::1]/internal") is False + + def test_non_cgnat_100_allowed(self): + """100.0.0.1 is NOT in CGNAT range (100.64.0.0/10), should be allowed.""" + with patch("socket.getaddrinfo", return_value=[ + (2, 1, 6, "", ("100.0.0.1", 0)), + ]): + # 100.0.0.1 is a global IP, not in CGNAT range + assert is_safe_url("http://legit-host.example/") is True + + +class TestIsBlockedIp: + """Direct tests for the _is_blocked_ip helper.""" + + @pytest.mark.parametrize("ip_str", [ + "127.0.0.1", "10.0.0.1", "172.16.0.1", "192.168.1.1", + "169.254.169.254", "0.0.0.0", "224.0.0.1", "255.255.255.255", + "100.64.0.1", "100.100.100.100", "100.127.255.254", + "::1", "fe80::1", "fc00::1", "fd12::1", "ff02::1", + "::ffff:127.0.0.1", "::ffff:169.254.169.254", + ]) + def test_blocked_ips(self, ip_str): + ip = ipaddress.ip_address(ip_str) + assert _is_blocked_ip(ip) is True, f"{ip_str} should be blocked" + + @pytest.mark.parametrize("ip_str", [ + "8.8.8.8", "93.184.216.34", "1.1.1.1", "100.0.0.1", + "2606:4700::1", "2001:4860:4860::8888", + ]) + def test_allowed_ips(self, ip_str): + ip = ipaddress.ip_address(ip_str) + assert _is_blocked_ip(ip) is False, f"{ip_str} should be allowed" diff --git a/tests/tools/test_vision_tools.py b/tests/tools/test_vision_tools.py index 14febac0b..4f152cebd 100644 --- a/tests/tools/test_vision_tools.py +++ b/tests/tools/test_vision_tools.py @@ -33,17 +33,30 @@ class TestValidateImageUrl: assert _validate_image_url("https://example.com/image.jpg") is True def test_valid_http_url(self): - assert _validate_image_url("http://cdn.example.org/photo.png") is True + with patch("tools.url_safety.socket.getaddrinfo", return_value=[ + (2, 1, 6, "", ("93.184.216.34", 0)), + ]): + assert _validate_image_url("http://cdn.example.org/photo.png") is True def test_valid_url_without_extension(self): """CDN endpoints that redirect to images should still pass.""" - assert _validate_image_url("https://cdn.example.com/abcdef123") is True + with patch("tools.url_safety.socket.getaddrinfo", return_value=[ + (2, 1, 6, "", ("93.184.216.34", 0)), + ]): + assert _validate_image_url("https://cdn.example.com/abcdef123") is True def test_valid_url_with_query_params(self): - assert _validate_image_url("https://img.example.com/pic?w=200&h=200") is True + with patch("tools.url_safety.socket.getaddrinfo", return_value=[ + (2, 1, 6, "", ("93.184.216.34", 0)), + ]): + assert _validate_image_url("https://img.example.com/pic?w=200&h=200") is True + + def test_localhost_url_blocked_by_ssrf(self): + """localhost URLs are now blocked by SSRF protection.""" + assert _validate_image_url("http://localhost:8080/image.png") is False def test_valid_url_with_port(self): - assert _validate_image_url("http://localhost:8080/image.png") is True + assert _validate_image_url("http://example.com:8080/image.png") is True def test_valid_url_with_path_only(self): assert _validate_image_url("https://example.com/") is True diff --git a/tests/tools/test_website_policy.py b/tests/tools/test_website_policy.py index 9d620b59a..52618a1d6 100644 --- a/tests/tools/test_website_policy.py +++ b/tests/tools/test_website_policy.py @@ -343,6 +343,8 @@ def test_browser_navigate_allows_when_shared_file_missing(monkeypatch, tmp_path) async def test_web_extract_short_circuits_blocked_url(monkeypatch): from tools import web_tools + # Allow test URLs past SSRF check so website policy is what gets tested + monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True) monkeypatch.setattr( web_tools, "check_website_access", @@ -389,6 +391,9 @@ def test_check_website_access_fails_open_on_malformed_config(tmp_path, monkeypat async def test_web_extract_blocks_redirected_final_url(monkeypatch): from tools import web_tools + # Allow test URLs past SSRF check so website policy is what gets tested + monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True) + def fake_check(url): if url == "https://allowed.test": return None @@ -428,6 +433,8 @@ async def test_web_crawl_short_circuits_blocked_url(monkeypatch): # web_crawl_tool checks for Firecrawl env before website policy monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key") + # Allow test URLs past SSRF check so website policy is what gets tested + monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True) monkeypatch.setattr( web_tools, "check_website_access", @@ -457,6 +464,8 @@ async def test_web_crawl_blocks_redirected_final_url(monkeypatch): # web_crawl_tool checks for Firecrawl env before website policy monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key") + # Allow test URLs past SSRF check so website policy is what gets tested + monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True) def fake_check(url): if url == "https://allowed.test": diff --git a/tools/url_safety.py b/tools/url_safety.py new file mode 100644 index 000000000..ae610d0f7 --- /dev/null +++ b/tools/url_safety.py @@ -0,0 +1,96 @@ +"""URL safety checks — blocks requests to private/internal network addresses. + +Prevents SSRF (Server-Side Request Forgery) where a malicious prompt or +skill could trick the agent into fetching internal resources like cloud +metadata endpoints (169.254.169.254), localhost services, or private +network hosts. + +Limitations (documented, not fixable at pre-flight level): + - DNS rebinding (TOCTOU): an attacker-controlled DNS server with TTL=0 + can return a public IP for the check, then a private IP for the actual + connection. Fixing this requires connection-level validation (e.g. + Python's Champion library or an egress proxy like Stripe's Smokescreen). + - Redirect-based bypass in vision_tools is mitigated by an httpx event + hook that re-validates each redirect target. Web tools use third-party + SDKs (Firecrawl/Tavily) where redirect handling is on their servers. +""" + +import ipaddress +import logging +import socket +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + +# Hostnames that should always be blocked regardless of IP resolution +_BLOCKED_HOSTNAMES = frozenset({ + "metadata.google.internal", + "metadata.goog", +}) + +# 100.64.0.0/10 (CGNAT / Shared Address Space, RFC 6598) is NOT covered by +# ipaddress.is_private — it returns False for both is_private and is_global. +# Must be blocked explicitly. Used by carrier-grade NAT, Tailscale/WireGuard +# VPNs, and some cloud internal networks. +_CGNAT_NETWORK = ipaddress.ip_network("100.64.0.0/10") + + +def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool: + """Return True if the IP should be blocked for SSRF protection.""" + if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved: + return True + if ip.is_multicast or ip.is_unspecified: + return True + # CGNAT range not covered by is_private + if ip in _CGNAT_NETWORK: + return True + return False + + +def is_safe_url(url: str) -> bool: + """Return True if the URL target is not a private/internal address. + + Resolves the hostname to an IP and checks against private ranges. + Fails closed: DNS errors and unexpected exceptions block the request. + """ + try: + parsed = urlparse(url) + hostname = (parsed.hostname or "").strip().lower() + if not hostname: + return False + + # Block known internal hostnames + if hostname in _BLOCKED_HOSTNAMES: + logger.warning("Blocked request to internal hostname: %s", hostname) + return False + + # Try to resolve and check IP + try: + addr_info = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM) + except socket.gaierror: + # DNS resolution failed — fail closed. If DNS can't resolve it, + # the HTTP client will also fail, so blocking loses nothing. + logger.warning("Blocked request — DNS resolution failed for: %s", hostname) + return False + + for family, _, _, _, sockaddr in addr_info: + ip_str = sockaddr[0] + try: + ip = ipaddress.ip_address(ip_str) + except ValueError: + continue + + if _is_blocked_ip(ip): + logger.warning( + "Blocked request to private/internal address: %s -> %s", + hostname, ip_str, + ) + return False + + return True + + except Exception as exc: + # Fail closed on unexpected errors — don't let parsing edge cases + # become SSRF bypass vectors + logger.warning("Blocked request — URL safety check error for %s: %s", url, exc) + return False diff --git a/tools/vision_tools.py b/tools/vision_tools.py index 867d9ef39..1b64c4eb2 100644 --- a/tools/vision_tools.py +++ b/tools/vision_tools.py @@ -69,7 +69,12 @@ def _validate_image_url(url: str) -> bool: if not parsed.netloc: return False - return True # Allow all well-formed HTTP/HTTPS URLs for flexibility + # Block private/internal addresses to prevent SSRF + from tools.url_safety import is_safe_url + if not is_safe_url(url): + return False + + return True async def _download_image(image_url: str, destination: Path, max_retries: int = 3) -> Path: @@ -92,12 +97,31 @@ async def _download_image(image_url: str, destination: Path, max_retries: int = # Create parent directories if they don't exist destination.parent.mkdir(parents=True, exist_ok=True) + def _ssrf_redirect_guard(response): + """Re-validate each redirect target to prevent redirect-based SSRF. + + Without this, an attacker can host a public URL that 302-redirects + to http://169.254.169.254/ and bypass the pre-flight is_safe_url check. + """ + if response.is_redirect and response.next_request: + redirect_url = str(response.next_request.url) + from tools.url_safety import is_safe_url + if not is_safe_url(redirect_url): + raise ValueError( + f"Blocked redirect to private/internal address: {redirect_url}" + ) + last_error = None for attempt in range(max_retries): try: # Download the image with appropriate headers using async httpx # Enable follow_redirects to handle image CDNs that redirect (e.g., Imgur, Picsum) - async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: + # SSRF: event_hooks validates each redirect target against private IP ranges + async with httpx.AsyncClient( + timeout=30.0, + follow_redirects=True, + event_hooks={"response": [_ssrf_redirect_guard]}, + ) as client: response = await client.get( image_url, headers={ diff --git a/tools/web_tools.py b/tools/web_tools.py index fad0e021e..fc089cb75 100644 --- a/tools/web_tools.py +++ b/tools/web_tools.py @@ -46,6 +46,7 @@ import httpx from firecrawl import Firecrawl from agent.auxiliary_client import async_call_llm from tools.debug_helpers import DebugSession +from tools.url_safety import is_safe_url from tools.website_policy import check_website_access logger = logging.getLogger(__name__) @@ -861,136 +862,155 @@ async def web_extract_tool( try: logger.info("Extracting content from %d URL(s)", len(urls)) - # Dispatch to the configured backend - backend = _get_backend() - - if backend == "parallel": - results = await _parallel_extract(urls) - elif backend == "tavily": - logger.info("Tavily extract: %d URL(s)", len(urls)) - raw = _tavily_request("extract", { - "urls": urls, - "include_images": False, - }) - results = _normalize_tavily_documents(raw, fallback_url=urls[0] if urls else "") - else: - # ── Firecrawl extraction ── - # Determine requested formats for Firecrawl v2 - formats: List[str] = [] - if format == "markdown": - formats = ["markdown"] - elif format == "html": - formats = ["html"] + # ── SSRF protection — filter out private/internal URLs before any backend ── + safe_urls = [] + ssrf_blocked: List[Dict[str, Any]] = [] + for url in urls: + if not is_safe_url(url): + ssrf_blocked.append({ + "url": url, "title": "", "content": "", + "error": "Blocked: URL targets a private or internal network address", + }) else: - # Default: request markdown for LLM-readiness and include html as backup - formats = ["markdown", "html"] + safe_urls.append(url) - # Always use individual scraping for simplicity and reliability - # Batch scraping adds complexity without much benefit for small numbers of URLs - results: List[Dict[str, Any]] = [] + # Dispatch only safe URLs to the configured backend + if not safe_urls: + results = [] + else: + backend = _get_backend() - from tools.interrupt import is_interrupted as _is_interrupted - for url in urls: - if _is_interrupted(): - results.append({"url": url, "error": "Interrupted", "title": ""}) - continue + if backend == "parallel": + results = await _parallel_extract(safe_urls) + elif backend == "tavily": + logger.info("Tavily extract: %d URL(s)", len(safe_urls)) + raw = _tavily_request("extract", { + "urls": safe_urls, + "include_images": False, + }) + results = _normalize_tavily_documents(raw, fallback_url=safe_urls[0] if safe_urls else "") + else: + # ── Firecrawl extraction ── + # Determine requested formats for Firecrawl v2 + formats: List[str] = [] + if format == "markdown": + formats = ["markdown"] + elif format == "html": + formats = ["html"] + else: + # Default: request markdown for LLM-readiness and include html as backup + formats = ["markdown", "html"] - # Website policy check — block before fetching - blocked = check_website_access(url) - if blocked: - logger.info("Blocked web_extract for %s by rule %s", blocked["host"], blocked["rule"]) - results.append({ - "url": url, "title": "", "content": "", - "error": blocked["message"], - "blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]}, - }) - continue + # Always use individual scraping for simplicity and reliability + # Batch scraping adds complexity without much benefit for small numbers of URLs + results: List[Dict[str, Any]] = [] - try: - logger.info("Scraping: %s", url) - scrape_result = _get_firecrawl_client().scrape( - url=url, - formats=formats - ) + from tools.interrupt import is_interrupted as _is_interrupted + for url in safe_urls: + if _is_interrupted(): + results.append({"url": url, "error": "Interrupted", "title": ""}) + continue - # Process the result - properly handle object serialization - metadata = {} - title = "" - content_markdown = None - content_html = None - - # Extract data from the scrape result - if hasattr(scrape_result, 'model_dump'): - # Pydantic model - use model_dump to get dict - result_dict = scrape_result.model_dump() - content_markdown = result_dict.get('markdown') - content_html = result_dict.get('html') - metadata = result_dict.get('metadata', {}) - elif hasattr(scrape_result, '__dict__'): - # Regular object with attributes - content_markdown = getattr(scrape_result, 'markdown', None) - content_html = getattr(scrape_result, 'html', None) - - # Handle metadata - convert to dict if it's an object - metadata_obj = getattr(scrape_result, 'metadata', {}) - if hasattr(metadata_obj, 'model_dump'): - metadata = metadata_obj.model_dump() - elif hasattr(metadata_obj, '__dict__'): - metadata = metadata_obj.__dict__ - elif isinstance(metadata_obj, dict): - metadata = metadata_obj - else: - metadata = {} - elif isinstance(scrape_result, dict): - # Already a dictionary - content_markdown = scrape_result.get('markdown') - content_html = scrape_result.get('html') - metadata = scrape_result.get('metadata', {}) - - # Ensure metadata is a dict (not an object) - if not isinstance(metadata, dict): - if hasattr(metadata, 'model_dump'): - metadata = metadata.model_dump() - elif hasattr(metadata, '__dict__'): - metadata = metadata.__dict__ - else: - metadata = {} - - # Get title from metadata - title = metadata.get("title", "") - - # Re-check final URL after redirect - final_url = metadata.get("sourceURL", url) - final_blocked = check_website_access(final_url) - if final_blocked: - logger.info("Blocked redirected web_extract for %s by rule %s", final_blocked["host"], final_blocked["rule"]) + # Website policy check — block before fetching + blocked = check_website_access(url) + if blocked: + logger.info("Blocked web_extract for %s by rule %s", blocked["host"], blocked["rule"]) results.append({ - "url": final_url, "title": title, "content": "", "raw_content": "", - "error": final_blocked["message"], - "blocked_by_policy": {"host": final_blocked["host"], "rule": final_blocked["rule"], "source": final_blocked["source"]}, + "url": url, "title": "", "content": "", + "error": blocked["message"], + "blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]}, }) continue - # Choose content based on requested format - chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or "" + try: + logger.info("Scraping: %s", url) + scrape_result = _get_firecrawl_client().scrape( + url=url, + formats=formats + ) - results.append({ - "url": final_url, - "title": title, - "content": chosen_content, - "raw_content": chosen_content, - "metadata": metadata # Now guaranteed to be a dict - }) + # Process the result - properly handle object serialization + metadata = {} + title = "" + content_markdown = None + content_html = None - except Exception as scrape_err: - logger.debug("Scrape failed for %s: %s", url, scrape_err) - results.append({ - "url": url, - "title": "", - "content": "", - "raw_content": "", - "error": str(scrape_err) - }) + # Extract data from the scrape result + if hasattr(scrape_result, 'model_dump'): + # Pydantic model - use model_dump to get dict + result_dict = scrape_result.model_dump() + content_markdown = result_dict.get('markdown') + content_html = result_dict.get('html') + metadata = result_dict.get('metadata', {}) + elif hasattr(scrape_result, '__dict__'): + # Regular object with attributes + content_markdown = getattr(scrape_result, 'markdown', None) + content_html = getattr(scrape_result, 'html', None) + + # Handle metadata - convert to dict if it's an object + metadata_obj = getattr(scrape_result, 'metadata', {}) + if hasattr(metadata_obj, 'model_dump'): + metadata = metadata_obj.model_dump() + elif hasattr(metadata_obj, '__dict__'): + metadata = metadata_obj.__dict__ + elif isinstance(metadata_obj, dict): + metadata = metadata_obj + else: + metadata = {} + elif isinstance(scrape_result, dict): + # Already a dictionary + content_markdown = scrape_result.get('markdown') + content_html = scrape_result.get('html') + metadata = scrape_result.get('metadata', {}) + + # Ensure metadata is a dict (not an object) + if not isinstance(metadata, dict): + if hasattr(metadata, 'model_dump'): + metadata = metadata.model_dump() + elif hasattr(metadata, '__dict__'): + metadata = metadata.__dict__ + else: + metadata = {} + + # Get title from metadata + title = metadata.get("title", "") + + # Re-check final URL after redirect + final_url = metadata.get("sourceURL", url) + final_blocked = check_website_access(final_url) + if final_blocked: + logger.info("Blocked redirected web_extract for %s by rule %s", final_blocked["host"], final_blocked["rule"]) + results.append({ + "url": final_url, "title": title, "content": "", "raw_content": "", + "error": final_blocked["message"], + "blocked_by_policy": {"host": final_blocked["host"], "rule": final_blocked["rule"], "source": final_blocked["source"]}, + }) + continue + + # Choose content based on requested format + chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or "" + + results.append({ + "url": final_url, + "title": title, + "content": chosen_content, + "raw_content": chosen_content, + "metadata": metadata # Now guaranteed to be a dict + }) + + except Exception as scrape_err: + logger.debug("Scrape failed for %s: %s", url, scrape_err) + results.append({ + "url": url, + "title": "", + "content": "", + "raw_content": "", + "error": str(scrape_err) + }) + + # Merge any SSRF-blocked results back in + if ssrf_blocked: + results = ssrf_blocked + results response = {"results": results} @@ -1173,6 +1193,11 @@ async def web_crawl_tool( if not url.startswith(('http://', 'https://')): url = f'https://{url}' + # SSRF protection — block private/internal addresses + if not is_safe_url(url): + return json.dumps({"results": [{"url": url, "title": "", "content": "", + "error": "Blocked: URL targets a private or internal network address"}]}, ensure_ascii=False) + # Website policy check blocked = check_website_access(url) if blocked: @@ -1258,6 +1283,11 @@ async def web_crawl_tool( instructions_text = f" with instructions: '{instructions}'" if instructions else "" logger.info("Crawling %s%s", url, instructions_text) + # SSRF protection — block private/internal addresses + if not is_safe_url(url): + return json.dumps({"results": [{"url": url, "title": "", "content": "", + "error": "Blocked: URL targets a private or internal network address"}]}, ensure_ascii=False) + # Website policy check — block before crawling blocked = check_website_access(url) if blocked: