diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index b28b78e7c..2e818b4ea 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -171,6 +171,84 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg") -> str: return cache_audio_from_bytes(response.content, ext) +# --------------------------------------------------------------------------- +# Document cache utilities +# +# Same pattern as image/audio cache -- documents from platforms are downloaded +# here so the agent can reference them by local file path. +# --------------------------------------------------------------------------- + +DOCUMENT_CACHE_DIR = Path(os.path.expanduser("~/.hermes/document_cache")) + +SUPPORTED_DOCUMENT_TYPES = { + ".pdf": "application/pdf", + ".md": "text/markdown", + ".txt": "text/plain", + ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", +} + + +def get_document_cache_dir() -> Path: + """Return the document cache directory, creating it if it doesn't exist.""" + DOCUMENT_CACHE_DIR.mkdir(parents=True, exist_ok=True) + return DOCUMENT_CACHE_DIR + + +def cache_document_from_bytes(data: bytes, filename: str) -> str: + """ + Save raw document bytes to the cache and return the absolute file path. + + The cached filename preserves the original human-readable name with a + unique prefix: ``doc_{uuid12}_{original_filename}``. + + Args: + data: Raw document bytes. + filename: Original filename (e.g. "report.pdf"). + + Returns: + Absolute path to the cached document file as a string. + + Raises: + ValueError: If the sanitized path escapes the cache directory. + """ + cache_dir = get_document_cache_dir() + # Sanitize: strip directory components, null bytes, and control characters + safe_name = Path(filename).name if filename else "document" + safe_name = safe_name.replace("\x00", "").strip() + if not safe_name or safe_name in (".", ".."): + safe_name = "document" + cached_name = f"doc_{uuid.uuid4().hex[:12]}_{safe_name}" + filepath = cache_dir / cached_name + # Final safety check: ensure path stays inside cache dir + if not filepath.resolve().is_relative_to(cache_dir.resolve()): + raise ValueError(f"Path traversal rejected: {filename!r}") + filepath.write_bytes(data) + return str(filepath) + + +def cleanup_document_cache(max_age_hours: int = 24) -> int: + """ + Delete cached documents older than *max_age_hours*. + + Returns the number of files removed. + """ + import time + + cache_dir = get_document_cache_dir() + cutoff = time.time() - (max_age_hours * 3600) + removed = 0 + for f in cache_dir.iterdir(): + if f.is_file() and f.stat().st_mtime < cutoff: + try: + f.unlink() + removed += 1 + except OSError: + pass + return removed + + class MessageType(Enum): """Types of incoming messages.""" TEXT = "text" diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 73d749bd3..e7c6062a1 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -8,6 +8,7 @@ Uses python-telegram-bot library for: """ import asyncio +import os import re from typing import Dict, List, Optional, Any @@ -42,6 +43,8 @@ from gateway.platforms.base import ( SendResult, cache_image_from_bytes, cache_audio_from_bytes, + cache_document_from_bytes, + SUPPORTED_DOCUMENT_TYPES, ) @@ -419,6 +422,8 @@ class TelegramAdapter(BasePlatformAdapter): msg_type = MessageType.AUDIO elif msg.voice: msg_type = MessageType.VOICE + elif msg.document: + msg_type = MessageType.DOCUMENT else: msg_type = MessageType.DOCUMENT @@ -479,7 +484,72 @@ class TelegramAdapter(BasePlatformAdapter): print(f"[Telegram] Cached user audio: {cached_path}", flush=True) except Exception as e: print(f"[Telegram] Failed to cache audio: {e}", flush=True) - + + # Download document files to cache for agent processing + elif msg.document: + doc = msg.document + try: + # Determine file extension + ext = "" + original_filename = doc.file_name or "" + if original_filename: + _, ext = os.path.splitext(original_filename) + ext = ext.lower() + + # If no extension from filename, reverse-lookup from MIME type + if not ext and doc.mime_type: + mime_to_ext = {v: k for k, v in SUPPORTED_DOCUMENT_TYPES.items()} + ext = mime_to_ext.get(doc.mime_type, "") + + # Check if supported + if ext not in SUPPORTED_DOCUMENT_TYPES: + supported_list = ", ".join(sorted(SUPPORTED_DOCUMENT_TYPES.keys())) + event.text = ( + f"Unsupported document type '{ext or 'unknown'}'. " + f"Supported types: {supported_list}" + ) + print(f"[Telegram] Unsupported document type: {ext or 'unknown'}", flush=True) + await self.handle_message(event) + return + + # Check file size (Telegram Bot API limit: 20 MB) + MAX_DOC_BYTES = 20 * 1024 * 1024 + if not doc.file_size or doc.file_size > MAX_DOC_BYTES: + event.text = ( + "The document is too large or its size could not be verified. " + "Maximum: 20 MB." + ) + print(f"[Telegram] Document too large: {doc.file_size} bytes", flush=True) + await self.handle_message(event) + return + + # Download and cache + file_obj = await doc.get_file() + doc_bytes = await file_obj.download_as_bytearray() + raw_bytes = bytes(doc_bytes) + cached_path = cache_document_from_bytes(raw_bytes, original_filename or f"document{ext}") + mime_type = SUPPORTED_DOCUMENT_TYPES[ext] + event.media_urls = [cached_path] + event.media_types = [mime_type] + print(f"[Telegram] Cached user document: {cached_path}", flush=True) + + # For text files, inject content into event.text (capped at 100 KB) + MAX_TEXT_INJECT_BYTES = 100 * 1024 + if ext in (".md", ".txt") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES: + try: + text_content = raw_bytes.decode("utf-8") + display_name = original_filename or f"document{ext}" + injection = f"[Content of {display_name}]:\n{text_content}" + if event.text: + event.text = f"{injection}\n\n{event.text}" + else: + event.text = injection + except UnicodeDecodeError: + print(f"[Telegram] Could not decode text file as UTF-8, skipping content injection", flush=True) + + except Exception as e: + print(f"[Telegram] Failed to cache document: {e}", flush=True) + await self.handle_message(event) async def _handle_sticker(self, msg: Message, event: "MessageEvent") -> None: diff --git a/gateway/run.py b/gateway/run.py index fd005270a..36c7ceb38 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -742,7 +742,39 @@ class GatewayRunner: message_text = await self._enrich_message_with_transcription( message_text, audio_paths ) - + + # ----------------------------------------------------------------- + # Enrich document messages with context notes for the agent + # ----------------------------------------------------------------- + if event.media_urls and event.message_type == MessageType.DOCUMENT: + for i, path in enumerate(event.media_urls): + mtype = event.media_types[i] if i < len(event.media_types) else "" + if not (mtype.startswith("application/") or mtype.startswith("text/")): + continue + # Extract display filename by stripping the doc_{uuid12}_ prefix + import os as _os + basename = _os.path.basename(path) + # Format: doc_<12hex>_ + parts = basename.split("_", 2) + display_name = parts[2] if len(parts) >= 3 else basename + # Sanitize to prevent prompt injection via filenames + import re as _re + display_name = _re.sub(r'[^\w.\- ]', '_', display_name) + + if mtype.startswith("text/"): + context_note = ( + f"[The user sent a text document: '{display_name}'. " + f"Its content has been included below. " + f"The file is also saved at: {path}]" + ) + else: + context_note = ( + f"[The user sent a document: '{display_name}'. " + f"The file is saved at: {path}. " + f"Ask the user what they'd like you to do with it.]" + ) + message_text = f"{context_note}\n\n{message_text}" + try: # Emit agent:start hook hook_ctx = { @@ -1813,10 +1845,10 @@ def _start_cron_ticker(stop_event: threading.Event, adapters=None, interval: int needing a separate `hermes cron daemon` or system cron entry. Also refreshes the channel directory every 5 minutes and prunes the - image/audio cache once per hour. + image/audio/document cache once per hour. """ from cron.scheduler import tick as cron_tick - from gateway.platforms.base import cleanup_image_cache + from gateway.platforms.base import cleanup_image_cache, cleanup_document_cache IMAGE_CACHE_EVERY = 60 # ticks — once per hour at default 60s interval CHANNEL_DIR_EVERY = 5 # ticks — every 5 minutes @@ -1845,6 +1877,12 @@ def _start_cron_ticker(stop_event: threading.Event, adapters=None, interval: int logger.info("Image cache cleanup: removed %d stale file(s)", removed) except Exception as e: logger.debug("Image cache cleanup error: %s", e) + try: + removed = cleanup_document_cache(max_age_hours=24) + if removed: + logger.info("Document cache cleanup: removed %d stale file(s)", removed) + except Exception as e: + logger.debug("Document cache cleanup error: %s", e) stop_event.wait(timeout=interval) logger.info("Cron ticker stopped") diff --git a/run_agent.py b/run_agent.py index 8958353f5..5d687d0e4 100644 --- a/run_agent.py +++ b/run_agent.py @@ -2092,11 +2092,44 @@ class AIAgent: "interrupted": True, } + # Check for 413 payload-too-large BEFORE generic 4xx handler. + # A 413 is a payload-size error — the correct response is to + # compress history and retry, not abort immediately. + status_code = getattr(api_error, "status_code", None) + is_payload_too_large = ( + status_code == 413 + or 'request entity too large' in error_msg + or 'error code: 413' in error_msg + ) + + if is_payload_too_large: + print(f"{self.log_prefix}⚠️ Request payload too large (413) - attempting compression...") + + original_len = len(messages) + messages, active_system_prompt = self._compress_context( + messages, system_message, approx_tokens=approx_tokens + ) + + if len(messages) < original_len: + print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...") + continue # Retry with compressed messages + else: + print(f"{self.log_prefix}❌ Payload too large and cannot compress further.") + logging.error(f"{self.log_prefix}413 payload too large. Cannot compress further.") + self._persist_session(messages, conversation_history) + return { + "messages": messages, + "completed": False, + "api_calls": api_call_count, + "error": "Request payload too large (413). Cannot compress further.", + "partial": True + } + # Check for non-retryable client errors (4xx HTTP status codes). # These indicate a problem with the request itself (bad model ID, # invalid API key, forbidden, etc.) and will never succeed on retry. - status_code = getattr(api_error, "status_code", None) - is_client_status_error = isinstance(status_code, int) and 400 <= status_code < 500 + # Note: 413 is excluded — it's handled above via compression. + is_client_status_error = isinstance(status_code, int) and 400 <= status_code < 500 and status_code != 413 is_client_error = is_client_status_error or any(phrase in error_msg for phrase in [ 'error code: 400', 'error code: 401', 'error code: 403', 'error code: 404', 'error code: 422', @@ -2104,7 +2137,7 @@ class AIAgent: 'invalid api key', 'invalid_api_key', 'authentication', 'unauthorized', 'forbidden', 'not found', ]) - + if is_client_error: self._dump_api_request_debug( api_kwargs, reason="non_retryable_client_error", error=api_error, @@ -2124,8 +2157,9 @@ class AIAgent: # Check for non-retryable errors (context length exceeded) is_context_length_error = any(phrase in error_msg for phrase in [ - 'context length', 'maximum context', 'token limit', - 'too many tokens', 'reduce the length', 'exceeds the limit' + 'context length', 'maximum context', 'token limit', + 'too many tokens', 'reduce the length', 'exceeds the limit', + 'request entity too large', # OpenRouter/Nous 413 safety net ]) if is_context_length_error: diff --git a/tests/gateway/test_document_cache.py b/tests/gateway/test_document_cache.py new file mode 100644 index 000000000..18440ed9c --- /dev/null +++ b/tests/gateway/test_document_cache.py @@ -0,0 +1,157 @@ +""" +Tests for document cache utilities in gateway/platforms/base.py. + +Covers: get_document_cache_dir, cache_document_from_bytes, + cleanup_document_cache, SUPPORTED_DOCUMENT_TYPES. +""" + +import os +import time +from pathlib import Path + +import pytest + +from gateway.platforms.base import ( + SUPPORTED_DOCUMENT_TYPES, + cache_document_from_bytes, + cleanup_document_cache, + get_document_cache_dir, +) + +# --------------------------------------------------------------------------- +# Fixture: redirect DOCUMENT_CACHE_DIR to a temp directory for every test +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def _redirect_cache(tmp_path, monkeypatch): + """Point the module-level DOCUMENT_CACHE_DIR to a fresh tmp_path.""" + monkeypatch.setattr( + "gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache" + ) + + +# --------------------------------------------------------------------------- +# TestGetDocumentCacheDir +# --------------------------------------------------------------------------- + +class TestGetDocumentCacheDir: + def test_creates_directory(self, tmp_path): + cache_dir = get_document_cache_dir() + assert cache_dir.exists() + assert cache_dir.is_dir() + + def test_returns_existing_directory(self): + first = get_document_cache_dir() + second = get_document_cache_dir() + assert first == second + assert first.exists() + + +# --------------------------------------------------------------------------- +# TestCacheDocumentFromBytes +# --------------------------------------------------------------------------- + +class TestCacheDocumentFromBytes: + def test_basic_caching(self): + data = b"hello world" + path = cache_document_from_bytes(data, "test.txt") + assert os.path.exists(path) + assert Path(path).read_bytes() == data + + def test_filename_preserved_in_path(self): + path = cache_document_from_bytes(b"data", "report.pdf") + assert "report.pdf" in os.path.basename(path) + + def test_empty_filename_uses_fallback(self): + path = cache_document_from_bytes(b"data", "") + assert "document" in os.path.basename(path) + + def test_unique_filenames(self): + p1 = cache_document_from_bytes(b"a", "same.txt") + p2 = cache_document_from_bytes(b"b", "same.txt") + assert p1 != p2 + + def test_path_traversal_blocked(self): + """Malicious directory components are stripped — only the leaf name survives.""" + path = cache_document_from_bytes(b"data", "../../etc/passwd") + basename = os.path.basename(path) + assert "passwd" in basename + # Must NOT contain directory separators + assert ".." not in basename + # File must reside inside the cache directory + cache_dir = get_document_cache_dir() + assert Path(path).resolve().is_relative_to(cache_dir.resolve()) + + def test_null_bytes_stripped(self): + path = cache_document_from_bytes(b"data", "file\x00.pdf") + basename = os.path.basename(path) + assert "\x00" not in basename + assert "file.pdf" in basename + + def test_dot_dot_filename_handled(self): + """A filename that is literally '..' falls back to 'document'.""" + path = cache_document_from_bytes(b"data", "..") + basename = os.path.basename(path) + assert "document" in basename + + def test_none_filename_uses_fallback(self): + path = cache_document_from_bytes(b"data", None) + assert "document" in os.path.basename(path) + + +# --------------------------------------------------------------------------- +# TestCleanupDocumentCache +# --------------------------------------------------------------------------- + +class TestCleanupDocumentCache: + def test_removes_old_files(self, tmp_path): + cache_dir = get_document_cache_dir() + old_file = cache_dir / "old.txt" + old_file.write_text("old") + # Set modification time to 48 hours ago + old_mtime = time.time() - 48 * 3600 + os.utime(old_file, (old_mtime, old_mtime)) + + removed = cleanup_document_cache(max_age_hours=24) + assert removed == 1 + assert not old_file.exists() + + def test_keeps_recent_files(self): + cache_dir = get_document_cache_dir() + recent = cache_dir / "recent.txt" + recent.write_text("fresh") + + removed = cleanup_document_cache(max_age_hours=24) + assert removed == 0 + assert recent.exists() + + def test_returns_removed_count(self): + cache_dir = get_document_cache_dir() + old_time = time.time() - 48 * 3600 + for i in range(3): + f = cache_dir / f"old_{i}.txt" + f.write_text("x") + os.utime(f, (old_time, old_time)) + + assert cleanup_document_cache(max_age_hours=24) == 3 + + def test_empty_cache_dir(self): + assert cleanup_document_cache(max_age_hours=24) == 0 + + +# --------------------------------------------------------------------------- +# TestSupportedDocumentTypes +# --------------------------------------------------------------------------- + +class TestSupportedDocumentTypes: + def test_all_extensions_have_mime_types(self): + for ext, mime in SUPPORTED_DOCUMENT_TYPES.items(): + assert ext.startswith("."), f"{ext} missing leading dot" + assert "/" in mime, f"{mime} is not a valid MIME type" + + @pytest.mark.parametrize( + "ext", + [".pdf", ".md", ".txt", ".docx", ".xlsx", ".pptx"], + ) + def test_expected_extensions_present(self, ext): + assert ext in SUPPORTED_DOCUMENT_TYPES diff --git a/tests/gateway/test_telegram_documents.py b/tests/gateway/test_telegram_documents.py new file mode 100644 index 000000000..4aceda842 --- /dev/null +++ b/tests/gateway/test_telegram_documents.py @@ -0,0 +1,338 @@ +""" +Tests for Telegram document handling in gateway/platforms/telegram.py. + +Covers: document type detection, download/cache flow, size limits, + text injection, error handling. + +Note: python-telegram-bot may not be installed in the test environment. +We mock the telegram module at import time to avoid collection errors. +""" + +import asyncio +import importlib +import os +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import ( + MessageEvent, + MessageType, + SUPPORTED_DOCUMENT_TYPES, +) + + +# --------------------------------------------------------------------------- +# Mock the telegram package if it's not installed +# --------------------------------------------------------------------------- + +def _ensure_telegram_mock(): + """Install mock telegram modules so TelegramAdapter can be imported.""" + if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"): + # Real library is installed — no mocking needed + return + + telegram_mod = MagicMock() + # ContextTypes needs DEFAULT_TYPE as an actual attribute for the annotation + telegram_mod.ext.ContextTypes.DEFAULT_TYPE = type(None) + telegram_mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2" + telegram_mod.constants.ChatType.GROUP = "group" + telegram_mod.constants.ChatType.SUPERGROUP = "supergroup" + telegram_mod.constants.ChatType.CHANNEL = "channel" + telegram_mod.constants.ChatType.PRIVATE = "private" + + for name in ("telegram", "telegram.ext", "telegram.constants"): + sys.modules.setdefault(name, telegram_mod) + + +_ensure_telegram_mock() + +# Now we can safely import +from gateway.platforms.telegram import TelegramAdapter # noqa: E402 + + +# --------------------------------------------------------------------------- +# Helpers to build mock Telegram objects +# --------------------------------------------------------------------------- + +def _make_file_obj(data: bytes = b"hello"): + """Create a mock Telegram File with download_as_bytearray.""" + f = AsyncMock() + f.download_as_bytearray = AsyncMock(return_value=bytearray(data)) + f.file_path = "documents/file.pdf" + return f + + +def _make_document( + file_name="report.pdf", + mime_type="application/pdf", + file_size=1024, + file_obj=None, +): + """Create a mock Telegram Document object.""" + doc = MagicMock() + doc.file_name = file_name + doc.mime_type = mime_type + doc.file_size = file_size + doc.get_file = AsyncMock(return_value=file_obj or _make_file_obj()) + return doc + + +def _make_message(document=None, caption=None): + """Build a mock Telegram Message with the given document.""" + msg = MagicMock() + msg.message_id = 42 + msg.text = caption or "" + msg.caption = caption + msg.date = None + # Media flags — all None except document + msg.photo = None + msg.video = None + msg.audio = None + msg.voice = None + msg.sticker = None + msg.document = document + # Chat / user + msg.chat = MagicMock() + msg.chat.id = 100 + msg.chat.type = "private" + msg.chat.title = None + msg.chat.full_name = "Test User" + msg.from_user = MagicMock() + msg.from_user.id = 1 + msg.from_user.full_name = "Test User" + msg.message_thread_id = None + return msg + + +def _make_update(msg): + """Wrap a message in a mock Update.""" + update = MagicMock() + update.message = msg + return update + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture() +def adapter(): + config = PlatformConfig(enabled=True, token="fake-token") + a = TelegramAdapter(config) + # Capture events instead of processing them + a.handle_message = AsyncMock() + return a + + +@pytest.fixture(autouse=True) +def _redirect_cache(tmp_path, monkeypatch): + """Point document cache to tmp_path so tests don't touch ~/.hermes.""" + monkeypatch.setattr( + "gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache" + ) + + +# --------------------------------------------------------------------------- +# TestDocumentTypeDetection +# --------------------------------------------------------------------------- + +class TestDocumentTypeDetection: + @pytest.mark.asyncio + async def test_document_detected_explicitly(self, adapter): + doc = _make_document() + msg = _make_message(document=doc) + update = _make_update(msg) + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert event.message_type == MessageType.DOCUMENT + + @pytest.mark.asyncio + async def test_fallback_is_document(self, adapter): + """When no specific media attr is set, message_type defaults to DOCUMENT.""" + msg = _make_message() + msg.document = None # no media at all + update = _make_update(msg) + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert event.message_type == MessageType.DOCUMENT + + +# --------------------------------------------------------------------------- +# TestDocumentDownloadBlock +# --------------------------------------------------------------------------- + +class TestDocumentDownloadBlock: + @pytest.mark.asyncio + async def test_supported_pdf_is_cached(self, adapter): + pdf_bytes = b"%PDF-1.4 fake" + file_obj = _make_file_obj(pdf_bytes) + doc = _make_document(file_name="report.pdf", file_size=1024, file_obj=file_obj) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert len(event.media_urls) == 1 + assert os.path.exists(event.media_urls[0]) + assert event.media_types == ["application/pdf"] + + @pytest.mark.asyncio + async def test_supported_txt_injects_content(self, adapter): + content = b"Hello from a text file" + file_obj = _make_file_obj(content) + doc = _make_document( + file_name="notes.txt", mime_type="text/plain", + file_size=len(content), file_obj=file_obj, + ) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert "Hello from a text file" in event.text + assert "[Content of notes.txt]" in event.text + + @pytest.mark.asyncio + async def test_supported_md_injects_content(self, adapter): + content = b"# Title\nSome markdown" + file_obj = _make_file_obj(content) + doc = _make_document( + file_name="readme.md", mime_type="text/markdown", + file_size=len(content), file_obj=file_obj, + ) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert "# Title" in event.text + + @pytest.mark.asyncio + async def test_caption_preserved_with_injection(self, adapter): + content = b"file text" + file_obj = _make_file_obj(content) + doc = _make_document( + file_name="doc.txt", mime_type="text/plain", + file_size=len(content), file_obj=file_obj, + ) + msg = _make_message(document=doc, caption="Please summarize") + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert "file text" in event.text + assert "Please summarize" in event.text + + @pytest.mark.asyncio + async def test_unsupported_type_rejected(self, adapter): + doc = _make_document(file_name="archive.zip", mime_type="application/zip", file_size=100) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert "Unsupported document type" in event.text + assert ".zip" in event.text + + @pytest.mark.asyncio + async def test_oversized_file_rejected(self, adapter): + doc = _make_document(file_name="huge.pdf", file_size=25 * 1024 * 1024) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert "too large" in event.text + + @pytest.mark.asyncio + async def test_none_file_size_rejected(self, adapter): + """Security fix: file_size=None must be rejected (not silently allowed).""" + doc = _make_document(file_name="tricky.pdf", file_size=None) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert "too large" in event.text or "could not be verified" in event.text + + @pytest.mark.asyncio + async def test_missing_filename_uses_mime_lookup(self, adapter): + """No file_name but valid mime_type should resolve to extension.""" + content = b"some pdf bytes" + file_obj = _make_file_obj(content) + doc = _make_document( + file_name=None, mime_type="application/pdf", + file_size=len(content), file_obj=file_obj, + ) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert len(event.media_urls) == 1 + assert event.media_types == ["application/pdf"] + + @pytest.mark.asyncio + async def test_missing_filename_and_mime_rejected(self, adapter): + doc = _make_document(file_name=None, mime_type=None, file_size=100) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert "Unsupported" in event.text + + @pytest.mark.asyncio + async def test_unicode_decode_error_handled(self, adapter): + """Binary bytes that aren't valid UTF-8 in a .txt — content not injected but file still cached.""" + binary = bytes(range(128, 256)) # not valid UTF-8 + file_obj = _make_file_obj(binary) + doc = _make_document( + file_name="binary.txt", mime_type="text/plain", + file_size=len(binary), file_obj=file_obj, + ) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + # File should still be cached + assert len(event.media_urls) == 1 + assert os.path.exists(event.media_urls[0]) + # Content NOT injected — text should be empty (no caption set) + assert "[Content of" not in (event.text or "") + + @pytest.mark.asyncio + async def test_text_injection_capped(self, adapter): + """A .txt file over 100 KB should NOT have its content injected.""" + large = b"x" * (200 * 1024) # 200 KB + file_obj = _make_file_obj(large) + doc = _make_document( + file_name="big.txt", mime_type="text/plain", + file_size=len(large), file_obj=file_obj, + ) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + # File should be cached + assert len(event.media_urls) == 1 + # Content should NOT be injected + assert "[Content of" not in (event.text or "") + + @pytest.mark.asyncio + async def test_download_exception_handled(self, adapter): + """If get_file() raises, the handler logs the error without crashing.""" + doc = _make_document(file_name="crash.pdf", file_size=100) + doc.get_file = AsyncMock(side_effect=RuntimeError("Telegram API down")) + msg = _make_message(document=doc) + update = _make_update(msg) + + # Should not raise + await adapter._handle_media_message(update, MagicMock()) + # handle_message should still be called (the handler catches the exception) + adapter.handle_message.assert_called_once() diff --git a/tests/test_413_compression.py b/tests/test_413_compression.py new file mode 100644 index 000000000..f6274ebf1 --- /dev/null +++ b/tests/test_413_compression.py @@ -0,0 +1,171 @@ +"""Tests for 413 payload-too-large → compression retry logic in AIAgent. + +Verifies that HTTP 413 errors trigger history compression and retry, +rather than being treated as non-retryable generic 4xx errors. +""" + +import uuid +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from run_agent import AIAgent + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_tool_defs(*names: str) -> list: + return [ + { + "type": "function", + "function": { + "name": n, + "description": f"{n} tool", + "parameters": {"type": "object", "properties": {}}, + }, + } + for n in names + ] + + +def _mock_response(content="Hello", finish_reason="stop", tool_calls=None, usage=None): + msg = SimpleNamespace( + content=content, + tool_calls=tool_calls, + reasoning_content=None, + reasoning=None, + ) + choice = SimpleNamespace(message=msg, finish_reason=finish_reason) + resp = SimpleNamespace(choices=[choice], model="test/model") + resp.usage = SimpleNamespace(**usage) if usage else None + return resp + + +def _make_413_error(*, use_status_code=True, message="Request entity too large"): + """Create an exception that mimics a 413 HTTP error.""" + err = Exception(message) + if use_status_code: + err.status_code = 413 + return err + + +@pytest.fixture() +def agent(): + with ( + patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + ): + a = AIAgent( + api_key="test-key-1234567890", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + a.client = MagicMock() + a._cached_system_prompt = "You are helpful." + a._use_prompt_caching = False + a.tool_delay = 0 + a.compression_enabled = False + a.save_trajectories = False + return a + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestHTTP413Compression: + """413 errors should trigger compression, not abort as generic 4xx.""" + + def test_413_triggers_compression(self, agent): + """A 413 error should call _compress_context and retry, not abort.""" + # First call raises 413; second call succeeds after compression. + err_413 = _make_413_error() + ok_resp = _mock_response(content="Success after compression", finish_reason="stop") + agent.client.chat.completions.create.side_effect = [err_413, ok_resp] + + with ( + patch.object(agent, "_compress_context") as mock_compress, + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + # Compression removes messages, enabling retry + mock_compress.return_value = ( + [{"role": "user", "content": "hello"}], + "compressed prompt", + ) + result = agent.run_conversation("hello") + + mock_compress.assert_called_once() + assert result["completed"] is True + assert result["final_response"] == "Success after compression" + + def test_413_not_treated_as_generic_4xx(self, agent): + """413 must NOT hit the generic 4xx abort path; it should attempt compression.""" + err_413 = _make_413_error() + ok_resp = _mock_response(content="Recovered", finish_reason="stop") + agent.client.chat.completions.create.side_effect = [err_413, ok_resp] + + with ( + patch.object(agent, "_compress_context") as mock_compress, + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + mock_compress.return_value = ( + [{"role": "user", "content": "hello"}], + "compressed", + ) + result = agent.run_conversation("hello") + + # If 413 were treated as generic 4xx, result would have "failed": True + assert result.get("failed") is not True + assert result["completed"] is True + + def test_413_error_message_detection(self, agent): + """413 detected via error message string (no status_code attr).""" + err = _make_413_error(use_status_code=False, message="error code: 413") + ok_resp = _mock_response(content="OK", finish_reason="stop") + agent.client.chat.completions.create.side_effect = [err, ok_resp] + + with ( + patch.object(agent, "_compress_context") as mock_compress, + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + mock_compress.return_value = ( + [{"role": "user", "content": "hello"}], + "compressed", + ) + result = agent.run_conversation("hello") + + mock_compress.assert_called_once() + assert result["completed"] is True + + def test_413_cannot_compress_further(self, agent): + """When compression can't reduce messages, return partial result.""" + err_413 = _make_413_error() + agent.client.chat.completions.create.side_effect = [err_413] + + with ( + patch.object(agent, "_compress_context") as mock_compress, + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + # Compression returns same number of messages → can't compress further + mock_compress.return_value = ( + [{"role": "user", "content": "hello"}], + "same prompt", + ) + result = agent.run_conversation("hello") + + assert result["completed"] is False + assert result.get("partial") is True + assert "413" in result["error"]