159 lines
5.4 KiB
Python
159 lines
5.4 KiB
Python
"""Unit tests for the web_fetch tool in timmy.tools."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from timmy.tools import web_fetch
|
|
|
|
|
|
class TestWebFetch:
|
|
"""Tests for web_fetch function."""
|
|
|
|
def test_invalid_url_no_scheme(self):
|
|
"""URLs without http(s) scheme are rejected."""
|
|
result = web_fetch("example.com")
|
|
assert "Error: invalid URL" in result
|
|
|
|
def test_invalid_url_empty(self):
|
|
"""Empty URL is rejected."""
|
|
result = web_fetch("")
|
|
assert "Error: invalid URL" in result
|
|
|
|
def test_invalid_url_ftp(self):
|
|
"""Non-HTTP schemes are rejected."""
|
|
result = web_fetch("ftp://example.com")
|
|
assert "Error: invalid URL" in result
|
|
|
|
@patch("timmy.tools.trafilatura", create=True)
|
|
@patch("timmy.tools._requests", create=True)
|
|
def test_successful_fetch(self, mock_requests, mock_trafilatura):
|
|
"""Happy path: fetch + extract returns text."""
|
|
# We need to patch at import level inside the function
|
|
mock_resp = MagicMock()
|
|
mock_resp.text = "<html><body><p>Hello world</p></body></html>"
|
|
|
|
with patch.dict(
|
|
"sys.modules", {"requests": mock_requests, "trafilatura": mock_trafilatura}
|
|
):
|
|
mock_requests.get.return_value = mock_resp
|
|
mock_requests.exceptions = _make_exceptions()
|
|
mock_trafilatura.extract.return_value = "Hello world"
|
|
|
|
result = web_fetch("https://example.com")
|
|
|
|
assert result == "Hello world"
|
|
|
|
@patch.dict("sys.modules", {"requests": MagicMock(), "trafilatura": MagicMock()})
|
|
def test_truncation(self):
|
|
"""Long text is truncated to max_tokens * 4 chars."""
|
|
import sys
|
|
|
|
mock_trafilatura = sys.modules["trafilatura"]
|
|
mock_requests = sys.modules["requests"]
|
|
|
|
long_text = "a" * 20000
|
|
mock_resp = MagicMock()
|
|
mock_resp.text = "<html><body>" + long_text + "</body></html>"
|
|
mock_requests.get.return_value = mock_resp
|
|
mock_requests.exceptions = _make_exceptions()
|
|
mock_trafilatura.extract.return_value = long_text
|
|
|
|
result = web_fetch("https://example.com", max_tokens=100)
|
|
|
|
# 100 tokens * 4 chars = 400 chars max
|
|
assert len(result) < 500
|
|
assert "[…truncated" in result
|
|
|
|
@patch.dict("sys.modules", {"requests": MagicMock(), "trafilatura": MagicMock()})
|
|
def test_extraction_failure(self):
|
|
"""Returns error when trafilatura can't extract text."""
|
|
import sys
|
|
|
|
mock_trafilatura = sys.modules["trafilatura"]
|
|
mock_requests = sys.modules["requests"]
|
|
|
|
mock_resp = MagicMock()
|
|
mock_resp.text = "<html></html>"
|
|
mock_requests.get.return_value = mock_resp
|
|
mock_requests.exceptions = _make_exceptions()
|
|
mock_trafilatura.extract.return_value = None
|
|
|
|
result = web_fetch("https://example.com")
|
|
assert "Error: could not extract" in result
|
|
|
|
@patch.dict("sys.modules", {"trafilatura": MagicMock()})
|
|
def test_timeout(self):
|
|
"""Timeout errors are handled gracefully."""
|
|
|
|
mock_requests = MagicMock()
|
|
exc_mod = _make_exceptions()
|
|
mock_requests.exceptions = exc_mod
|
|
mock_requests.get.side_effect = exc_mod.Timeout("timed out")
|
|
|
|
with patch.dict("sys.modules", {"requests": mock_requests}):
|
|
result = web_fetch("https://example.com")
|
|
|
|
assert "timed out" in result
|
|
|
|
@patch.dict("sys.modules", {"trafilatura": MagicMock()})
|
|
def test_http_error(self):
|
|
"""HTTP errors (404, 500, etc.) are handled gracefully."""
|
|
|
|
mock_requests = MagicMock()
|
|
exc_mod = _make_exceptions()
|
|
mock_requests.exceptions = exc_mod
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 404
|
|
mock_requests.get.return_value.raise_for_status.side_effect = exc_mod.HTTPError(
|
|
response=mock_response
|
|
)
|
|
|
|
with patch.dict("sys.modules", {"requests": mock_requests}):
|
|
result = web_fetch("https://example.com/nope")
|
|
|
|
assert "404" in result
|
|
|
|
def test_missing_requests(self):
|
|
"""Graceful error when requests not installed."""
|
|
with patch.dict("sys.modules", {"requests": None}):
|
|
result = web_fetch("https://example.com")
|
|
assert "requests" in result and "not installed" in result
|
|
|
|
def test_missing_trafilatura(self):
|
|
"""Graceful error when trafilatura not installed."""
|
|
mock_requests = MagicMock()
|
|
with patch.dict("sys.modules", {"requests": mock_requests, "trafilatura": None}):
|
|
result = web_fetch("https://example.com")
|
|
assert "trafilatura" in result and "not installed" in result
|
|
|
|
def test_catalog_entry_exists(self):
|
|
"""web_fetch should appear in the tool catalog."""
|
|
from timmy.tools import get_all_available_tools
|
|
|
|
catalog = get_all_available_tools()
|
|
assert "web_fetch" in catalog
|
|
assert "orchestrator" in catalog["web_fetch"]["available_in"]
|
|
|
|
|
|
def _make_exceptions():
|
|
"""Create a mock exceptions module with real exception classes."""
|
|
|
|
class Timeout(Exception):
|
|
pass
|
|
|
|
class HTTPError(Exception):
|
|
def __init__(self, *args, response=None, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.response = response
|
|
|
|
class RequestException(Exception):
|
|
pass
|
|
|
|
mod = MagicMock()
|
|
mod.Timeout = Timeout
|
|
mod.HTTPError = HTTPError
|
|
mod.RequestException = RequestException
|
|
return mod
|