From b2172c4b2e808860f3c46dacbb352d3f3347a33d Mon Sep 17 00:00:00 2001 From: tekelala Date: Fri, 27 Feb 2026 11:44:57 -0500 Subject: [PATCH 1/3] feat(telegram): add document file processing for PDF, text, and Office files Download, cache, and enrich document files sent via Telegram. Supports .pdf, .md, .txt, .docx, .xlsx, .pptx with size validation, unsupported type rejection, text content injection for .md/.txt, and hourly cache cleanup. Co-Authored-By: Claude Opus 4.6 --- gateway/platforms/base.py | 68 ++++++++++++++++++++++++++++++++++ gateway/platforms/telegram.py | 70 ++++++++++++++++++++++++++++++++++- gateway/run.py | 41 ++++++++++++++++++-- 3 files changed, 175 insertions(+), 4 deletions(-) diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index b28b78e7c..f854723a4 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -171,6 +171,74 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg") -> str: return cache_audio_from_bytes(response.content, ext) +# --------------------------------------------------------------------------- +# Document cache utilities +# +# Same pattern as image/audio cache -- documents from platforms are downloaded +# here so the agent can reference them by local file path. +# --------------------------------------------------------------------------- + +DOCUMENT_CACHE_DIR = Path(os.path.expanduser("~/.hermes/document_cache")) + +SUPPORTED_DOCUMENT_TYPES = { + ".pdf": "application/pdf", + ".md": "text/markdown", + ".txt": "text/plain", + ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", +} + + +def get_document_cache_dir() -> Path: + """Return the document cache directory, creating it if it doesn't exist.""" + DOCUMENT_CACHE_DIR.mkdir(parents=True, exist_ok=True) + return DOCUMENT_CACHE_DIR + + +def cache_document_from_bytes(data: bytes, filename: str) -> str: + """ + Save raw document bytes to the cache and return the absolute file path. + + The cached filename preserves the original human-readable name with a + unique prefix: ``doc_{uuid12}_{original_filename}``. + + Args: + data: Raw document bytes. + filename: Original filename (e.g. "report.pdf"). + + Returns: + Absolute path to the cached document file as a string. + """ + cache_dir = get_document_cache_dir() + safe_name = filename if filename else "document" + cached_name = f"doc_{uuid.uuid4().hex[:12]}_{safe_name}" + filepath = cache_dir / cached_name + filepath.write_bytes(data) + return str(filepath) + + +def cleanup_document_cache(max_age_hours: int = 24) -> int: + """ + Delete cached documents older than *max_age_hours*. + + Returns the number of files removed. + """ + import time + + cache_dir = get_document_cache_dir() + cutoff = time.time() - (max_age_hours * 3600) + removed = 0 + for f in cache_dir.iterdir(): + if f.is_file() and f.stat().st_mtime < cutoff: + try: + f.unlink() + removed += 1 + except OSError: + pass + return removed + + class MessageType(Enum): """Types of incoming messages.""" TEXT = "text" diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 73d749bd3..2bfd5085a 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -8,6 +8,7 @@ Uses python-telegram-bot library for: """ import asyncio +import os import re from typing import Dict, List, Optional, Any @@ -42,6 +43,8 @@ from gateway.platforms.base import ( SendResult, cache_image_from_bytes, cache_audio_from_bytes, + cache_document_from_bytes, + SUPPORTED_DOCUMENT_TYPES, ) @@ -419,6 +422,8 @@ class TelegramAdapter(BasePlatformAdapter): msg_type = MessageType.AUDIO elif msg.voice: msg_type = MessageType.VOICE + elif msg.document: + msg_type = MessageType.DOCUMENT else: msg_type = MessageType.DOCUMENT @@ -479,7 +484,70 @@ class TelegramAdapter(BasePlatformAdapter): print(f"[Telegram] Cached user audio: {cached_path}", flush=True) except Exception as e: print(f"[Telegram] Failed to cache audio: {e}", flush=True) - + + # Download document files to cache for agent processing + elif msg.document: + doc = msg.document + try: + # Determine file extension + ext = "" + original_filename = doc.file_name or "" + if original_filename: + _, ext = os.path.splitext(original_filename) + ext = ext.lower() + + # If no extension from filename, reverse-lookup from MIME type + if not ext and doc.mime_type: + mime_to_ext = {v: k for k, v in SUPPORTED_DOCUMENT_TYPES.items()} + ext = mime_to_ext.get(doc.mime_type, "") + + # Check if supported + if ext not in SUPPORTED_DOCUMENT_TYPES: + supported_list = ", ".join(sorted(SUPPORTED_DOCUMENT_TYPES.keys())) + event.text = ( + f"Unsupported document type '{ext or 'unknown'}'. " + f"Supported types: {supported_list}" + ) + print(f"[Telegram] Unsupported document type: {ext or 'unknown'}", flush=True) + await self.handle_message(event) + return + + # Check file size (Telegram Bot API limit: 20 MB) + if doc.file_size and doc.file_size > 20 * 1024 * 1024: + event.text = ( + "The document is too large (over 20 MB). " + "Please send a smaller file." + ) + print(f"[Telegram] Document too large: {doc.file_size} bytes", flush=True) + await self.handle_message(event) + return + + # Download and cache + file_obj = await doc.get_file() + doc_bytes = await file_obj.download_as_bytearray() + raw_bytes = bytes(doc_bytes) + cached_path = cache_document_from_bytes(raw_bytes, original_filename or f"document{ext}") + mime_type = SUPPORTED_DOCUMENT_TYPES[ext] + event.media_urls = [cached_path] + event.media_types = [mime_type] + print(f"[Telegram] Cached user document: {cached_path}", flush=True) + + # For text files, inject content into event.text + if ext in (".md", ".txt"): + try: + text_content = raw_bytes.decode("utf-8") + display_name = original_filename or f"document{ext}" + injection = f"[Content of {display_name}]:\n{text_content}" + if event.text: + event.text = f"{injection}\n\n{event.text}" + else: + event.text = injection + except UnicodeDecodeError: + print(f"[Telegram] Could not decode text file as UTF-8, skipping content injection", flush=True) + + except Exception as e: + print(f"[Telegram] Failed to cache document: {e}", flush=True) + await self.handle_message(event) async def _handle_sticker(self, msg: Message, event: "MessageEvent") -> None: diff --git a/gateway/run.py b/gateway/run.py index df882d8e6..48c4b3ce2 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -742,7 +742,36 @@ class GatewayRunner: message_text = await self._enrich_message_with_transcription( message_text, audio_paths ) - + + # ----------------------------------------------------------------- + # Enrich document messages with context notes for the agent + # ----------------------------------------------------------------- + if event.media_urls and event.message_type == MessageType.DOCUMENT: + for i, path in enumerate(event.media_urls): + mtype = event.media_types[i] if i < len(event.media_types) else "" + if not (mtype.startswith("application/") or mtype.startswith("text/")): + continue + # Extract display filename by stripping the doc_{uuid12}_ prefix + import os as _os + basename = _os.path.basename(path) + # Format: doc_<12hex>_ + parts = basename.split("_", 2) + display_name = parts[2] if len(parts) >= 3 else basename + + if mtype.startswith("text/"): + context_note = ( + f"[The user sent a text document: '{display_name}'. " + f"Its content has been included below. " + f"The file is also saved at: {path}]" + ) + else: + context_note = ( + f"[The user sent a document: '{display_name}'. " + f"The file is saved at: {path}. " + f"Ask the user what they'd like you to do with it.]" + ) + message_text = f"{context_note}\n\n{message_text}" + try: # Emit agent:start hook hook_ctx = { @@ -1754,10 +1783,10 @@ def _start_cron_ticker(stop_event: threading.Event, adapters=None, interval: int needing a separate `hermes cron daemon` or system cron entry. Also refreshes the channel directory every 5 minutes and prunes the - image/audio cache once per hour. + image/audio/document cache once per hour. """ from cron.scheduler import tick as cron_tick - from gateway.platforms.base import cleanup_image_cache + from gateway.platforms.base import cleanup_image_cache, cleanup_document_cache IMAGE_CACHE_EVERY = 60 # ticks — once per hour at default 60s interval CHANNEL_DIR_EVERY = 5 # ticks — every 5 minutes @@ -1786,6 +1815,12 @@ def _start_cron_ticker(stop_event: threading.Event, adapters=None, interval: int logger.info("Image cache cleanup: removed %d stale file(s)", removed) except Exception as e: logger.debug("Image cache cleanup error: %s", e) + try: + removed = cleanup_document_cache(max_age_hours=24) + if removed: + logger.info("Document cache cleanup: removed %d stale file(s)", removed) + except Exception as e: + logger.debug("Document cache cleanup error: %s", e) stop_event.wait(timeout=interval) logger.info("Cron ticker stopped") From fbb1923fad18eb3bba332c3bfbdcfd69dddae19e Mon Sep 17 00:00:00 2001 From: tekelala Date: Fri, 27 Feb 2026 11:53:46 -0500 Subject: [PATCH 2/3] fix(security): patch path traversal, size bypass, and prompt injection in document processing - Sanitize filenames in cache_document_from_bytes to prevent path traversal (strip directory components, null bytes, resolve check) - Reject documents with None file_size instead of silently allowing download - Cap text file injection at 100 KB to prevent oversized prompt payloads - Sanitize display_name in run.py context notes to block prompt injection via filenames - Add 35 unit tests covering document cache utilities and Telegram document handling Co-Authored-By: Claude Opus 4.6 --- gateway/platforms/base.py | 12 +- gateway/platforms/telegram.py | 12 +- gateway/run.py | 3 + tests/gateway/test_document_cache.py | 157 +++++++++++ tests/gateway/test_telegram_documents.py | 338 +++++++++++++++++++++++ 5 files changed, 516 insertions(+), 6 deletions(-) create mode 100644 tests/gateway/test_document_cache.py create mode 100644 tests/gateway/test_telegram_documents.py diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index f854723a4..2e818b4ea 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -209,11 +209,21 @@ def cache_document_from_bytes(data: bytes, filename: str) -> str: Returns: Absolute path to the cached document file as a string. + + Raises: + ValueError: If the sanitized path escapes the cache directory. """ cache_dir = get_document_cache_dir() - safe_name = filename if filename else "document" + # Sanitize: strip directory components, null bytes, and control characters + safe_name = Path(filename).name if filename else "document" + safe_name = safe_name.replace("\x00", "").strip() + if not safe_name or safe_name in (".", ".."): + safe_name = "document" cached_name = f"doc_{uuid.uuid4().hex[:12]}_{safe_name}" filepath = cache_dir / cached_name + # Final safety check: ensure path stays inside cache dir + if not filepath.resolve().is_relative_to(cache_dir.resolve()): + raise ValueError(f"Path traversal rejected: {filename!r}") filepath.write_bytes(data) return str(filepath) diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 2bfd5085a..e7c6062a1 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -513,10 +513,11 @@ class TelegramAdapter(BasePlatformAdapter): return # Check file size (Telegram Bot API limit: 20 MB) - if doc.file_size and doc.file_size > 20 * 1024 * 1024: + MAX_DOC_BYTES = 20 * 1024 * 1024 + if not doc.file_size or doc.file_size > MAX_DOC_BYTES: event.text = ( - "The document is too large (over 20 MB). " - "Please send a smaller file." + "The document is too large or its size could not be verified. " + "Maximum: 20 MB." ) print(f"[Telegram] Document too large: {doc.file_size} bytes", flush=True) await self.handle_message(event) @@ -532,8 +533,9 @@ class TelegramAdapter(BasePlatformAdapter): event.media_types = [mime_type] print(f"[Telegram] Cached user document: {cached_path}", flush=True) - # For text files, inject content into event.text - if ext in (".md", ".txt"): + # For text files, inject content into event.text (capped at 100 KB) + MAX_TEXT_INJECT_BYTES = 100 * 1024 + if ext in (".md", ".txt") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES: try: text_content = raw_bytes.decode("utf-8") display_name = original_filename or f"document{ext}" diff --git a/gateway/run.py b/gateway/run.py index 48c4b3ce2..83f781fb0 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -757,6 +757,9 @@ class GatewayRunner: # Format: doc_<12hex>_ parts = basename.split("_", 2) display_name = parts[2] if len(parts) >= 3 else basename + # Sanitize to prevent prompt injection via filenames + import re as _re + display_name = _re.sub(r'[^\w.\- ]', '_', display_name) if mtype.startswith("text/"): context_note = ( diff --git a/tests/gateway/test_document_cache.py b/tests/gateway/test_document_cache.py new file mode 100644 index 000000000..18440ed9c --- /dev/null +++ b/tests/gateway/test_document_cache.py @@ -0,0 +1,157 @@ +""" +Tests for document cache utilities in gateway/platforms/base.py. + +Covers: get_document_cache_dir, cache_document_from_bytes, + cleanup_document_cache, SUPPORTED_DOCUMENT_TYPES. +""" + +import os +import time +from pathlib import Path + +import pytest + +from gateway.platforms.base import ( + SUPPORTED_DOCUMENT_TYPES, + cache_document_from_bytes, + cleanup_document_cache, + get_document_cache_dir, +) + +# --------------------------------------------------------------------------- +# Fixture: redirect DOCUMENT_CACHE_DIR to a temp directory for every test +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def _redirect_cache(tmp_path, monkeypatch): + """Point the module-level DOCUMENT_CACHE_DIR to a fresh tmp_path.""" + monkeypatch.setattr( + "gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache" + ) + + +# --------------------------------------------------------------------------- +# TestGetDocumentCacheDir +# --------------------------------------------------------------------------- + +class TestGetDocumentCacheDir: + def test_creates_directory(self, tmp_path): + cache_dir = get_document_cache_dir() + assert cache_dir.exists() + assert cache_dir.is_dir() + + def test_returns_existing_directory(self): + first = get_document_cache_dir() + second = get_document_cache_dir() + assert first == second + assert first.exists() + + +# --------------------------------------------------------------------------- +# TestCacheDocumentFromBytes +# --------------------------------------------------------------------------- + +class TestCacheDocumentFromBytes: + def test_basic_caching(self): + data = b"hello world" + path = cache_document_from_bytes(data, "test.txt") + assert os.path.exists(path) + assert Path(path).read_bytes() == data + + def test_filename_preserved_in_path(self): + path = cache_document_from_bytes(b"data", "report.pdf") + assert "report.pdf" in os.path.basename(path) + + def test_empty_filename_uses_fallback(self): + path = cache_document_from_bytes(b"data", "") + assert "document" in os.path.basename(path) + + def test_unique_filenames(self): + p1 = cache_document_from_bytes(b"a", "same.txt") + p2 = cache_document_from_bytes(b"b", "same.txt") + assert p1 != p2 + + def test_path_traversal_blocked(self): + """Malicious directory components are stripped — only the leaf name survives.""" + path = cache_document_from_bytes(b"data", "../../etc/passwd") + basename = os.path.basename(path) + assert "passwd" in basename + # Must NOT contain directory separators + assert ".." not in basename + # File must reside inside the cache directory + cache_dir = get_document_cache_dir() + assert Path(path).resolve().is_relative_to(cache_dir.resolve()) + + def test_null_bytes_stripped(self): + path = cache_document_from_bytes(b"data", "file\x00.pdf") + basename = os.path.basename(path) + assert "\x00" not in basename + assert "file.pdf" in basename + + def test_dot_dot_filename_handled(self): + """A filename that is literally '..' falls back to 'document'.""" + path = cache_document_from_bytes(b"data", "..") + basename = os.path.basename(path) + assert "document" in basename + + def test_none_filename_uses_fallback(self): + path = cache_document_from_bytes(b"data", None) + assert "document" in os.path.basename(path) + + +# --------------------------------------------------------------------------- +# TestCleanupDocumentCache +# --------------------------------------------------------------------------- + +class TestCleanupDocumentCache: + def test_removes_old_files(self, tmp_path): + cache_dir = get_document_cache_dir() + old_file = cache_dir / "old.txt" + old_file.write_text("old") + # Set modification time to 48 hours ago + old_mtime = time.time() - 48 * 3600 + os.utime(old_file, (old_mtime, old_mtime)) + + removed = cleanup_document_cache(max_age_hours=24) + assert removed == 1 + assert not old_file.exists() + + def test_keeps_recent_files(self): + cache_dir = get_document_cache_dir() + recent = cache_dir / "recent.txt" + recent.write_text("fresh") + + removed = cleanup_document_cache(max_age_hours=24) + assert removed == 0 + assert recent.exists() + + def test_returns_removed_count(self): + cache_dir = get_document_cache_dir() + old_time = time.time() - 48 * 3600 + for i in range(3): + f = cache_dir / f"old_{i}.txt" + f.write_text("x") + os.utime(f, (old_time, old_time)) + + assert cleanup_document_cache(max_age_hours=24) == 3 + + def test_empty_cache_dir(self): + assert cleanup_document_cache(max_age_hours=24) == 0 + + +# --------------------------------------------------------------------------- +# TestSupportedDocumentTypes +# --------------------------------------------------------------------------- + +class TestSupportedDocumentTypes: + def test_all_extensions_have_mime_types(self): + for ext, mime in SUPPORTED_DOCUMENT_TYPES.items(): + assert ext.startswith("."), f"{ext} missing leading dot" + assert "/" in mime, f"{mime} is not a valid MIME type" + + @pytest.mark.parametrize( + "ext", + [".pdf", ".md", ".txt", ".docx", ".xlsx", ".pptx"], + ) + def test_expected_extensions_present(self, ext): + assert ext in SUPPORTED_DOCUMENT_TYPES diff --git a/tests/gateway/test_telegram_documents.py b/tests/gateway/test_telegram_documents.py new file mode 100644 index 000000000..4aceda842 --- /dev/null +++ b/tests/gateway/test_telegram_documents.py @@ -0,0 +1,338 @@ +""" +Tests for Telegram document handling in gateway/platforms/telegram.py. + +Covers: document type detection, download/cache flow, size limits, + text injection, error handling. + +Note: python-telegram-bot may not be installed in the test environment. +We mock the telegram module at import time to avoid collection errors. +""" + +import asyncio +import importlib +import os +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import ( + MessageEvent, + MessageType, + SUPPORTED_DOCUMENT_TYPES, +) + + +# --------------------------------------------------------------------------- +# Mock the telegram package if it's not installed +# --------------------------------------------------------------------------- + +def _ensure_telegram_mock(): + """Install mock telegram modules so TelegramAdapter can be imported.""" + if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"): + # Real library is installed — no mocking needed + return + + telegram_mod = MagicMock() + # ContextTypes needs DEFAULT_TYPE as an actual attribute for the annotation + telegram_mod.ext.ContextTypes.DEFAULT_TYPE = type(None) + telegram_mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2" + telegram_mod.constants.ChatType.GROUP = "group" + telegram_mod.constants.ChatType.SUPERGROUP = "supergroup" + telegram_mod.constants.ChatType.CHANNEL = "channel" + telegram_mod.constants.ChatType.PRIVATE = "private" + + for name in ("telegram", "telegram.ext", "telegram.constants"): + sys.modules.setdefault(name, telegram_mod) + + +_ensure_telegram_mock() + +# Now we can safely import +from gateway.platforms.telegram import TelegramAdapter # noqa: E402 + + +# --------------------------------------------------------------------------- +# Helpers to build mock Telegram objects +# --------------------------------------------------------------------------- + +def _make_file_obj(data: bytes = b"hello"): + """Create a mock Telegram File with download_as_bytearray.""" + f = AsyncMock() + f.download_as_bytearray = AsyncMock(return_value=bytearray(data)) + f.file_path = "documents/file.pdf" + return f + + +def _make_document( + file_name="report.pdf", + mime_type="application/pdf", + file_size=1024, + file_obj=None, +): + """Create a mock Telegram Document object.""" + doc = MagicMock() + doc.file_name = file_name + doc.mime_type = mime_type + doc.file_size = file_size + doc.get_file = AsyncMock(return_value=file_obj or _make_file_obj()) + return doc + + +def _make_message(document=None, caption=None): + """Build a mock Telegram Message with the given document.""" + msg = MagicMock() + msg.message_id = 42 + msg.text = caption or "" + msg.caption = caption + msg.date = None + # Media flags — all None except document + msg.photo = None + msg.video = None + msg.audio = None + msg.voice = None + msg.sticker = None + msg.document = document + # Chat / user + msg.chat = MagicMock() + msg.chat.id = 100 + msg.chat.type = "private" + msg.chat.title = None + msg.chat.full_name = "Test User" + msg.from_user = MagicMock() + msg.from_user.id = 1 + msg.from_user.full_name = "Test User" + msg.message_thread_id = None + return msg + + +def _make_update(msg): + """Wrap a message in a mock Update.""" + update = MagicMock() + update.message = msg + return update + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture() +def adapter(): + config = PlatformConfig(enabled=True, token="fake-token") + a = TelegramAdapter(config) + # Capture events instead of processing them + a.handle_message = AsyncMock() + return a + + +@pytest.fixture(autouse=True) +def _redirect_cache(tmp_path, monkeypatch): + """Point document cache to tmp_path so tests don't touch ~/.hermes.""" + monkeypatch.setattr( + "gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache" + ) + + +# --------------------------------------------------------------------------- +# TestDocumentTypeDetection +# --------------------------------------------------------------------------- + +class TestDocumentTypeDetection: + @pytest.mark.asyncio + async def test_document_detected_explicitly(self, adapter): + doc = _make_document() + msg = _make_message(document=doc) + update = _make_update(msg) + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert event.message_type == MessageType.DOCUMENT + + @pytest.mark.asyncio + async def test_fallback_is_document(self, adapter): + """When no specific media attr is set, message_type defaults to DOCUMENT.""" + msg = _make_message() + msg.document = None # no media at all + update = _make_update(msg) + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert event.message_type == MessageType.DOCUMENT + + +# --------------------------------------------------------------------------- +# TestDocumentDownloadBlock +# --------------------------------------------------------------------------- + +class TestDocumentDownloadBlock: + @pytest.mark.asyncio + async def test_supported_pdf_is_cached(self, adapter): + pdf_bytes = b"%PDF-1.4 fake" + file_obj = _make_file_obj(pdf_bytes) + doc = _make_document(file_name="report.pdf", file_size=1024, file_obj=file_obj) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert len(event.media_urls) == 1 + assert os.path.exists(event.media_urls[0]) + assert event.media_types == ["application/pdf"] + + @pytest.mark.asyncio + async def test_supported_txt_injects_content(self, adapter): + content = b"Hello from a text file" + file_obj = _make_file_obj(content) + doc = _make_document( + file_name="notes.txt", mime_type="text/plain", + file_size=len(content), file_obj=file_obj, + ) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert "Hello from a text file" in event.text + assert "[Content of notes.txt]" in event.text + + @pytest.mark.asyncio + async def test_supported_md_injects_content(self, adapter): + content = b"# Title\nSome markdown" + file_obj = _make_file_obj(content) + doc = _make_document( + file_name="readme.md", mime_type="text/markdown", + file_size=len(content), file_obj=file_obj, + ) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert "# Title" in event.text + + @pytest.mark.asyncio + async def test_caption_preserved_with_injection(self, adapter): + content = b"file text" + file_obj = _make_file_obj(content) + doc = _make_document( + file_name="doc.txt", mime_type="text/plain", + file_size=len(content), file_obj=file_obj, + ) + msg = _make_message(document=doc, caption="Please summarize") + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert "file text" in event.text + assert "Please summarize" in event.text + + @pytest.mark.asyncio + async def test_unsupported_type_rejected(self, adapter): + doc = _make_document(file_name="archive.zip", mime_type="application/zip", file_size=100) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert "Unsupported document type" in event.text + assert ".zip" in event.text + + @pytest.mark.asyncio + async def test_oversized_file_rejected(self, adapter): + doc = _make_document(file_name="huge.pdf", file_size=25 * 1024 * 1024) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert "too large" in event.text + + @pytest.mark.asyncio + async def test_none_file_size_rejected(self, adapter): + """Security fix: file_size=None must be rejected (not silently allowed).""" + doc = _make_document(file_name="tricky.pdf", file_size=None) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert "too large" in event.text or "could not be verified" in event.text + + @pytest.mark.asyncio + async def test_missing_filename_uses_mime_lookup(self, adapter): + """No file_name but valid mime_type should resolve to extension.""" + content = b"some pdf bytes" + file_obj = _make_file_obj(content) + doc = _make_document( + file_name=None, mime_type="application/pdf", + file_size=len(content), file_obj=file_obj, + ) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert len(event.media_urls) == 1 + assert event.media_types == ["application/pdf"] + + @pytest.mark.asyncio + async def test_missing_filename_and_mime_rejected(self, adapter): + doc = _make_document(file_name=None, mime_type=None, file_size=100) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + assert "Unsupported" in event.text + + @pytest.mark.asyncio + async def test_unicode_decode_error_handled(self, adapter): + """Binary bytes that aren't valid UTF-8 in a .txt — content not injected but file still cached.""" + binary = bytes(range(128, 256)) # not valid UTF-8 + file_obj = _make_file_obj(binary) + doc = _make_document( + file_name="binary.txt", mime_type="text/plain", + file_size=len(binary), file_obj=file_obj, + ) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + # File should still be cached + assert len(event.media_urls) == 1 + assert os.path.exists(event.media_urls[0]) + # Content NOT injected — text should be empty (no caption set) + assert "[Content of" not in (event.text or "") + + @pytest.mark.asyncio + async def test_text_injection_capped(self, adapter): + """A .txt file over 100 KB should NOT have its content injected.""" + large = b"x" * (200 * 1024) # 200 KB + file_obj = _make_file_obj(large) + doc = _make_document( + file_name="big.txt", mime_type="text/plain", + file_size=len(large), file_obj=file_obj, + ) + msg = _make_message(document=doc) + update = _make_update(msg) + + await adapter._handle_media_message(update, MagicMock()) + event = adapter.handle_message.call_args[0][0] + # File should be cached + assert len(event.media_urls) == 1 + # Content should NOT be injected + assert "[Content of" not in (event.text or "") + + @pytest.mark.asyncio + async def test_download_exception_handled(self, adapter): + """If get_file() raises, the handler logs the error without crashing.""" + doc = _make_document(file_name="crash.pdf", file_size=100) + doc.get_file = AsyncMock(side_effect=RuntimeError("Telegram API down")) + msg = _make_message(document=doc) + update = _make_update(msg) + + # Should not raise + await adapter._handle_media_message(update, MagicMock()) + # handle_message should still be called (the handler catches the exception) + adapter.handle_message.assert_called_once() From 79bd65034c9254bdb49d90d7177bc1fa5b706a45 Mon Sep 17 00:00:00 2001 From: tekelala Date: Fri, 27 Feb 2026 12:21:27 -0500 Subject: [PATCH 3/3] fix(agent): handle 413 payload-too-large via compression instead of aborting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The 413 "Request Entity Too Large" error from the LLM API was caught by the generic 4xx handler which aborts immediately. This is wrong for 413 — it's a payload-size issue that can be resolved by compressing conversation history. - Intercept 413 before the generic 4xx block and route to _compress_context - Exclude 413 from generic is_client_error detection - Add 'request entity too large' to context-length phrases as safety net - Add tests for 413 compression behavior Co-Authored-By: Claude Opus 4.6 --- run_agent.py | 44 ++++++++- tests/test_413_compression.py | 171 ++++++++++++++++++++++++++++++++++ 2 files changed, 210 insertions(+), 5 deletions(-) create mode 100644 tests/test_413_compression.py diff --git a/run_agent.py b/run_agent.py index 1cf3808e1..49131ff7a 100644 --- a/run_agent.py +++ b/run_agent.py @@ -2092,11 +2092,44 @@ class AIAgent: "interrupted": True, } + # Check for 413 payload-too-large BEFORE generic 4xx handler. + # A 413 is a payload-size error — the correct response is to + # compress history and retry, not abort immediately. + status_code = getattr(api_error, "status_code", None) + is_payload_too_large = ( + status_code == 413 + or 'request entity too large' in error_msg + or 'error code: 413' in error_msg + ) + + if is_payload_too_large: + print(f"{self.log_prefix}⚠️ Request payload too large (413) - attempting compression...") + + original_len = len(messages) + messages, active_system_prompt = self._compress_context( + messages, system_message, approx_tokens=approx_tokens + ) + + if len(messages) < original_len: + print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...") + continue # Retry with compressed messages + else: + print(f"{self.log_prefix}❌ Payload too large and cannot compress further.") + logging.error(f"{self.log_prefix}413 payload too large. Cannot compress further.") + self._persist_session(messages, conversation_history) + return { + "messages": messages, + "completed": False, + "api_calls": api_call_count, + "error": "Request payload too large (413). Cannot compress further.", + "partial": True + } + # Check for non-retryable client errors (4xx HTTP status codes). # These indicate a problem with the request itself (bad model ID, # invalid API key, forbidden, etc.) and will never succeed on retry. - status_code = getattr(api_error, "status_code", None) - is_client_status_error = isinstance(status_code, int) and 400 <= status_code < 500 + # Note: 413 is excluded — it's handled above via compression. + is_client_status_error = isinstance(status_code, int) and 400 <= status_code < 500 and status_code != 413 is_client_error = is_client_status_error or any(phrase in error_msg for phrase in [ 'error code: 400', 'error code: 401', 'error code: 403', 'error code: 404', 'error code: 422', @@ -2104,7 +2137,7 @@ class AIAgent: 'invalid api key', 'invalid_api_key', 'authentication', 'unauthorized', 'forbidden', 'not found', ]) - + if is_client_error: self._dump_api_request_debug( api_kwargs, reason="non_retryable_client_error", error=api_error, @@ -2124,8 +2157,9 @@ class AIAgent: # Check for non-retryable errors (context length exceeded) is_context_length_error = any(phrase in error_msg for phrase in [ - 'context length', 'maximum context', 'token limit', - 'too many tokens', 'reduce the length', 'exceeds the limit' + 'context length', 'maximum context', 'token limit', + 'too many tokens', 'reduce the length', 'exceeds the limit', + 'request entity too large', # OpenRouter/Nous 413 safety net ]) if is_context_length_error: diff --git a/tests/test_413_compression.py b/tests/test_413_compression.py new file mode 100644 index 000000000..f6274ebf1 --- /dev/null +++ b/tests/test_413_compression.py @@ -0,0 +1,171 @@ +"""Tests for 413 payload-too-large → compression retry logic in AIAgent. + +Verifies that HTTP 413 errors trigger history compression and retry, +rather than being treated as non-retryable generic 4xx errors. +""" + +import uuid +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from run_agent import AIAgent + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_tool_defs(*names: str) -> list: + return [ + { + "type": "function", + "function": { + "name": n, + "description": f"{n} tool", + "parameters": {"type": "object", "properties": {}}, + }, + } + for n in names + ] + + +def _mock_response(content="Hello", finish_reason="stop", tool_calls=None, usage=None): + msg = SimpleNamespace( + content=content, + tool_calls=tool_calls, + reasoning_content=None, + reasoning=None, + ) + choice = SimpleNamespace(message=msg, finish_reason=finish_reason) + resp = SimpleNamespace(choices=[choice], model="test/model") + resp.usage = SimpleNamespace(**usage) if usage else None + return resp + + +def _make_413_error(*, use_status_code=True, message="Request entity too large"): + """Create an exception that mimics a 413 HTTP error.""" + err = Exception(message) + if use_status_code: + err.status_code = 413 + return err + + +@pytest.fixture() +def agent(): + with ( + patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + ): + a = AIAgent( + api_key="test-key-1234567890", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + a.client = MagicMock() + a._cached_system_prompt = "You are helpful." + a._use_prompt_caching = False + a.tool_delay = 0 + a.compression_enabled = False + a.save_trajectories = False + return a + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestHTTP413Compression: + """413 errors should trigger compression, not abort as generic 4xx.""" + + def test_413_triggers_compression(self, agent): + """A 413 error should call _compress_context and retry, not abort.""" + # First call raises 413; second call succeeds after compression. + err_413 = _make_413_error() + ok_resp = _mock_response(content="Success after compression", finish_reason="stop") + agent.client.chat.completions.create.side_effect = [err_413, ok_resp] + + with ( + patch.object(agent, "_compress_context") as mock_compress, + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + # Compression removes messages, enabling retry + mock_compress.return_value = ( + [{"role": "user", "content": "hello"}], + "compressed prompt", + ) + result = agent.run_conversation("hello") + + mock_compress.assert_called_once() + assert result["completed"] is True + assert result["final_response"] == "Success after compression" + + def test_413_not_treated_as_generic_4xx(self, agent): + """413 must NOT hit the generic 4xx abort path; it should attempt compression.""" + err_413 = _make_413_error() + ok_resp = _mock_response(content="Recovered", finish_reason="stop") + agent.client.chat.completions.create.side_effect = [err_413, ok_resp] + + with ( + patch.object(agent, "_compress_context") as mock_compress, + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + mock_compress.return_value = ( + [{"role": "user", "content": "hello"}], + "compressed", + ) + result = agent.run_conversation("hello") + + # If 413 were treated as generic 4xx, result would have "failed": True + assert result.get("failed") is not True + assert result["completed"] is True + + def test_413_error_message_detection(self, agent): + """413 detected via error message string (no status_code attr).""" + err = _make_413_error(use_status_code=False, message="error code: 413") + ok_resp = _mock_response(content="OK", finish_reason="stop") + agent.client.chat.completions.create.side_effect = [err, ok_resp] + + with ( + patch.object(agent, "_compress_context") as mock_compress, + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + mock_compress.return_value = ( + [{"role": "user", "content": "hello"}], + "compressed", + ) + result = agent.run_conversation("hello") + + mock_compress.assert_called_once() + assert result["completed"] is True + + def test_413_cannot_compress_further(self, agent): + """When compression can't reduce messages, return partial result.""" + err_413 = _make_413_error() + agent.client.chat.completions.create.side_effect = [err_413] + + with ( + patch.object(agent, "_compress_context") as mock_compress, + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + # Compression returns same number of messages → can't compress further + mock_compress.return_value = ( + [{"role": "user", "content": "hello"}], + "same prompt", + ) + result = agent.run_conversation("hello") + + assert result["completed"] is False + assert result.get("partial") is True + assert "413" in result["error"]