Merge pull request #153 from tekelala/main

fix(agent): handle 413 payload-too-large via compression instead of aborting
This commit is contained in:
Teknium
2026-02-27 22:57:55 -08:00
committed by GitHub
7 changed files with 895 additions and 9 deletions

View File

@@ -171,6 +171,84 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg") -> str:
return cache_audio_from_bytes(response.content, ext)
# ---------------------------------------------------------------------------
# Document cache utilities
#
# Same pattern as image/audio cache -- documents from platforms are downloaded
# here so the agent can reference them by local file path.
# ---------------------------------------------------------------------------
DOCUMENT_CACHE_DIR = Path(os.path.expanduser("~/.hermes/document_cache"))
SUPPORTED_DOCUMENT_TYPES = {
".pdf": "application/pdf",
".md": "text/markdown",
".txt": "text/plain",
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
}
def get_document_cache_dir() -> Path:
"""Return the document cache directory, creating it if it doesn't exist."""
DOCUMENT_CACHE_DIR.mkdir(parents=True, exist_ok=True)
return DOCUMENT_CACHE_DIR
def cache_document_from_bytes(data: bytes, filename: str) -> str:
"""
Save raw document bytes to the cache and return the absolute file path.
The cached filename preserves the original human-readable name with a
unique prefix: ``doc_{uuid12}_{original_filename}``.
Args:
data: Raw document bytes.
filename: Original filename (e.g. "report.pdf").
Returns:
Absolute path to the cached document file as a string.
Raises:
ValueError: If the sanitized path escapes the cache directory.
"""
cache_dir = get_document_cache_dir()
# Sanitize: strip directory components, null bytes, and control characters
safe_name = Path(filename).name if filename else "document"
safe_name = safe_name.replace("\x00", "").strip()
if not safe_name or safe_name in (".", ".."):
safe_name = "document"
cached_name = f"doc_{uuid.uuid4().hex[:12]}_{safe_name}"
filepath = cache_dir / cached_name
# Final safety check: ensure path stays inside cache dir
if not filepath.resolve().is_relative_to(cache_dir.resolve()):
raise ValueError(f"Path traversal rejected: {filename!r}")
filepath.write_bytes(data)
return str(filepath)
def cleanup_document_cache(max_age_hours: int = 24) -> int:
"""
Delete cached documents older than *max_age_hours*.
Returns the number of files removed.
"""
import time
cache_dir = get_document_cache_dir()
cutoff = time.time() - (max_age_hours * 3600)
removed = 0
for f in cache_dir.iterdir():
if f.is_file() and f.stat().st_mtime < cutoff:
try:
f.unlink()
removed += 1
except OSError:
pass
return removed
class MessageType(Enum):
"""Types of incoming messages."""
TEXT = "text"

View File

@@ -8,6 +8,7 @@ Uses python-telegram-bot library for:
"""
import asyncio
import os
import re
from typing import Dict, List, Optional, Any
@@ -42,6 +43,8 @@ from gateway.platforms.base import (
SendResult,
cache_image_from_bytes,
cache_audio_from_bytes,
cache_document_from_bytes,
SUPPORTED_DOCUMENT_TYPES,
)
@@ -419,6 +422,8 @@ class TelegramAdapter(BasePlatformAdapter):
msg_type = MessageType.AUDIO
elif msg.voice:
msg_type = MessageType.VOICE
elif msg.document:
msg_type = MessageType.DOCUMENT
else:
msg_type = MessageType.DOCUMENT
@@ -479,7 +484,72 @@ class TelegramAdapter(BasePlatformAdapter):
print(f"[Telegram] Cached user audio: {cached_path}", flush=True)
except Exception as e:
print(f"[Telegram] Failed to cache audio: {e}", flush=True)
# Download document files to cache for agent processing
elif msg.document:
doc = msg.document
try:
# Determine file extension
ext = ""
original_filename = doc.file_name or ""
if original_filename:
_, ext = os.path.splitext(original_filename)
ext = ext.lower()
# If no extension from filename, reverse-lookup from MIME type
if not ext and doc.mime_type:
mime_to_ext = {v: k for k, v in SUPPORTED_DOCUMENT_TYPES.items()}
ext = mime_to_ext.get(doc.mime_type, "")
# Check if supported
if ext not in SUPPORTED_DOCUMENT_TYPES:
supported_list = ", ".join(sorted(SUPPORTED_DOCUMENT_TYPES.keys()))
event.text = (
f"Unsupported document type '{ext or 'unknown'}'. "
f"Supported types: {supported_list}"
)
print(f"[Telegram] Unsupported document type: {ext or 'unknown'}", flush=True)
await self.handle_message(event)
return
# Check file size (Telegram Bot API limit: 20 MB)
MAX_DOC_BYTES = 20 * 1024 * 1024
if not doc.file_size or doc.file_size > MAX_DOC_BYTES:
event.text = (
"The document is too large or its size could not be verified. "
"Maximum: 20 MB."
)
print(f"[Telegram] Document too large: {doc.file_size} bytes", flush=True)
await self.handle_message(event)
return
# Download and cache
file_obj = await doc.get_file()
doc_bytes = await file_obj.download_as_bytearray()
raw_bytes = bytes(doc_bytes)
cached_path = cache_document_from_bytes(raw_bytes, original_filename or f"document{ext}")
mime_type = SUPPORTED_DOCUMENT_TYPES[ext]
event.media_urls = [cached_path]
event.media_types = [mime_type]
print(f"[Telegram] Cached user document: {cached_path}", flush=True)
# For text files, inject content into event.text (capped at 100 KB)
MAX_TEXT_INJECT_BYTES = 100 * 1024
if ext in (".md", ".txt") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES:
try:
text_content = raw_bytes.decode("utf-8")
display_name = original_filename or f"document{ext}"
injection = f"[Content of {display_name}]:\n{text_content}"
if event.text:
event.text = f"{injection}\n\n{event.text}"
else:
event.text = injection
except UnicodeDecodeError:
print(f"[Telegram] Could not decode text file as UTF-8, skipping content injection", flush=True)
except Exception as e:
print(f"[Telegram] Failed to cache document: {e}", flush=True)
await self.handle_message(event)
async def _handle_sticker(self, msg: Message, event: "MessageEvent") -> None:

View File

@@ -742,7 +742,39 @@ class GatewayRunner:
message_text = await self._enrich_message_with_transcription(
message_text, audio_paths
)
# -----------------------------------------------------------------
# Enrich document messages with context notes for the agent
# -----------------------------------------------------------------
if event.media_urls and event.message_type == MessageType.DOCUMENT:
for i, path in enumerate(event.media_urls):
mtype = event.media_types[i] if i < len(event.media_types) else ""
if not (mtype.startswith("application/") or mtype.startswith("text/")):
continue
# Extract display filename by stripping the doc_{uuid12}_ prefix
import os as _os
basename = _os.path.basename(path)
# Format: doc_<12hex>_<original_filename>
parts = basename.split("_", 2)
display_name = parts[2] if len(parts) >= 3 else basename
# Sanitize to prevent prompt injection via filenames
import re as _re
display_name = _re.sub(r'[^\w.\- ]', '_', display_name)
if mtype.startswith("text/"):
context_note = (
f"[The user sent a text document: '{display_name}'. "
f"Its content has been included below. "
f"The file is also saved at: {path}]"
)
else:
context_note = (
f"[The user sent a document: '{display_name}'. "
f"The file is saved at: {path}. "
f"Ask the user what they'd like you to do with it.]"
)
message_text = f"{context_note}\n\n{message_text}"
try:
# Emit agent:start hook
hook_ctx = {
@@ -1813,10 +1845,10 @@ def _start_cron_ticker(stop_event: threading.Event, adapters=None, interval: int
needing a separate `hermes cron daemon` or system cron entry.
Also refreshes the channel directory every 5 minutes and prunes the
image/audio cache once per hour.
image/audio/document cache once per hour.
"""
from cron.scheduler import tick as cron_tick
from gateway.platforms.base import cleanup_image_cache
from gateway.platforms.base import cleanup_image_cache, cleanup_document_cache
IMAGE_CACHE_EVERY = 60 # ticks — once per hour at default 60s interval
CHANNEL_DIR_EVERY = 5 # ticks — every 5 minutes
@@ -1845,6 +1877,12 @@ def _start_cron_ticker(stop_event: threading.Event, adapters=None, interval: int
logger.info("Image cache cleanup: removed %d stale file(s)", removed)
except Exception as e:
logger.debug("Image cache cleanup error: %s", e)
try:
removed = cleanup_document_cache(max_age_hours=24)
if removed:
logger.info("Document cache cleanup: removed %d stale file(s)", removed)
except Exception as e:
logger.debug("Document cache cleanup error: %s", e)
stop_event.wait(timeout=interval)
logger.info("Cron ticker stopped")

View File

@@ -2092,11 +2092,44 @@ class AIAgent:
"interrupted": True,
}
# Check for 413 payload-too-large BEFORE generic 4xx handler.
# A 413 is a payload-size error — the correct response is to
# compress history and retry, not abort immediately.
status_code = getattr(api_error, "status_code", None)
is_payload_too_large = (
status_code == 413
or 'request entity too large' in error_msg
or 'error code: 413' in error_msg
)
if is_payload_too_large:
print(f"{self.log_prefix}⚠️ Request payload too large (413) - attempting compression...")
original_len = len(messages)
messages, active_system_prompt = self._compress_context(
messages, system_message, approx_tokens=approx_tokens
)
if len(messages) < original_len:
print(f"{self.log_prefix} 🗜️ Compressed {original_len}{len(messages)} messages, retrying...")
continue # Retry with compressed messages
else:
print(f"{self.log_prefix}❌ Payload too large and cannot compress further.")
logging.error(f"{self.log_prefix}413 payload too large. Cannot compress further.")
self._persist_session(messages, conversation_history)
return {
"messages": messages,
"completed": False,
"api_calls": api_call_count,
"error": "Request payload too large (413). Cannot compress further.",
"partial": True
}
# Check for non-retryable client errors (4xx HTTP status codes).
# These indicate a problem with the request itself (bad model ID,
# invalid API key, forbidden, etc.) and will never succeed on retry.
status_code = getattr(api_error, "status_code", None)
is_client_status_error = isinstance(status_code, int) and 400 <= status_code < 500
# Note: 413 is excluded — it's handled above via compression.
is_client_status_error = isinstance(status_code, int) and 400 <= status_code < 500 and status_code != 413
is_client_error = is_client_status_error or any(phrase in error_msg for phrase in [
'error code: 400', 'error code: 401', 'error code: 403',
'error code: 404', 'error code: 422',
@@ -2104,7 +2137,7 @@ class AIAgent:
'invalid api key', 'invalid_api_key', 'authentication',
'unauthorized', 'forbidden', 'not found',
])
if is_client_error:
self._dump_api_request_debug(
api_kwargs, reason="non_retryable_client_error", error=api_error,
@@ -2124,8 +2157,9 @@ class AIAgent:
# Check for non-retryable errors (context length exceeded)
is_context_length_error = any(phrase in error_msg for phrase in [
'context length', 'maximum context', 'token limit',
'too many tokens', 'reduce the length', 'exceeds the limit'
'context length', 'maximum context', 'token limit',
'too many tokens', 'reduce the length', 'exceeds the limit',
'request entity too large', # OpenRouter/Nous 413 safety net
])
if is_context_length_error:

View File

@@ -0,0 +1,157 @@
"""
Tests for document cache utilities in gateway/platforms/base.py.
Covers: get_document_cache_dir, cache_document_from_bytes,
cleanup_document_cache, SUPPORTED_DOCUMENT_TYPES.
"""
import os
import time
from pathlib import Path
import pytest
from gateway.platforms.base import (
SUPPORTED_DOCUMENT_TYPES,
cache_document_from_bytes,
cleanup_document_cache,
get_document_cache_dir,
)
# ---------------------------------------------------------------------------
# Fixture: redirect DOCUMENT_CACHE_DIR to a temp directory for every test
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _redirect_cache(tmp_path, monkeypatch):
"""Point the module-level DOCUMENT_CACHE_DIR to a fresh tmp_path."""
monkeypatch.setattr(
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
)
# ---------------------------------------------------------------------------
# TestGetDocumentCacheDir
# ---------------------------------------------------------------------------
class TestGetDocumentCacheDir:
def test_creates_directory(self, tmp_path):
cache_dir = get_document_cache_dir()
assert cache_dir.exists()
assert cache_dir.is_dir()
def test_returns_existing_directory(self):
first = get_document_cache_dir()
second = get_document_cache_dir()
assert first == second
assert first.exists()
# ---------------------------------------------------------------------------
# TestCacheDocumentFromBytes
# ---------------------------------------------------------------------------
class TestCacheDocumentFromBytes:
def test_basic_caching(self):
data = b"hello world"
path = cache_document_from_bytes(data, "test.txt")
assert os.path.exists(path)
assert Path(path).read_bytes() == data
def test_filename_preserved_in_path(self):
path = cache_document_from_bytes(b"data", "report.pdf")
assert "report.pdf" in os.path.basename(path)
def test_empty_filename_uses_fallback(self):
path = cache_document_from_bytes(b"data", "")
assert "document" in os.path.basename(path)
def test_unique_filenames(self):
p1 = cache_document_from_bytes(b"a", "same.txt")
p2 = cache_document_from_bytes(b"b", "same.txt")
assert p1 != p2
def test_path_traversal_blocked(self):
"""Malicious directory components are stripped — only the leaf name survives."""
path = cache_document_from_bytes(b"data", "../../etc/passwd")
basename = os.path.basename(path)
assert "passwd" in basename
# Must NOT contain directory separators
assert ".." not in basename
# File must reside inside the cache directory
cache_dir = get_document_cache_dir()
assert Path(path).resolve().is_relative_to(cache_dir.resolve())
def test_null_bytes_stripped(self):
path = cache_document_from_bytes(b"data", "file\x00.pdf")
basename = os.path.basename(path)
assert "\x00" not in basename
assert "file.pdf" in basename
def test_dot_dot_filename_handled(self):
"""A filename that is literally '..' falls back to 'document'."""
path = cache_document_from_bytes(b"data", "..")
basename = os.path.basename(path)
assert "document" in basename
def test_none_filename_uses_fallback(self):
path = cache_document_from_bytes(b"data", None)
assert "document" in os.path.basename(path)
# ---------------------------------------------------------------------------
# TestCleanupDocumentCache
# ---------------------------------------------------------------------------
class TestCleanupDocumentCache:
def test_removes_old_files(self, tmp_path):
cache_dir = get_document_cache_dir()
old_file = cache_dir / "old.txt"
old_file.write_text("old")
# Set modification time to 48 hours ago
old_mtime = time.time() - 48 * 3600
os.utime(old_file, (old_mtime, old_mtime))
removed = cleanup_document_cache(max_age_hours=24)
assert removed == 1
assert not old_file.exists()
def test_keeps_recent_files(self):
cache_dir = get_document_cache_dir()
recent = cache_dir / "recent.txt"
recent.write_text("fresh")
removed = cleanup_document_cache(max_age_hours=24)
assert removed == 0
assert recent.exists()
def test_returns_removed_count(self):
cache_dir = get_document_cache_dir()
old_time = time.time() - 48 * 3600
for i in range(3):
f = cache_dir / f"old_{i}.txt"
f.write_text("x")
os.utime(f, (old_time, old_time))
assert cleanup_document_cache(max_age_hours=24) == 3
def test_empty_cache_dir(self):
assert cleanup_document_cache(max_age_hours=24) == 0
# ---------------------------------------------------------------------------
# TestSupportedDocumentTypes
# ---------------------------------------------------------------------------
class TestSupportedDocumentTypes:
def test_all_extensions_have_mime_types(self):
for ext, mime in SUPPORTED_DOCUMENT_TYPES.items():
assert ext.startswith("."), f"{ext} missing leading dot"
assert "/" in mime, f"{mime} is not a valid MIME type"
@pytest.mark.parametrize(
"ext",
[".pdf", ".md", ".txt", ".docx", ".xlsx", ".pptx"],
)
def test_expected_extensions_present(self, ext):
assert ext in SUPPORTED_DOCUMENT_TYPES

View File

@@ -0,0 +1,338 @@
"""
Tests for Telegram document handling in gateway/platforms/telegram.py.
Covers: document type detection, download/cache flow, size limits,
text injection, error handling.
Note: python-telegram-bot may not be installed in the test environment.
We mock the telegram module at import time to avoid collection errors.
"""
import asyncio
import importlib
import os
import sys
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import (
MessageEvent,
MessageType,
SUPPORTED_DOCUMENT_TYPES,
)
# ---------------------------------------------------------------------------
# Mock the telegram package if it's not installed
# ---------------------------------------------------------------------------
def _ensure_telegram_mock():
"""Install mock telegram modules so TelegramAdapter can be imported."""
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
# Real library is installed — no mocking needed
return
telegram_mod = MagicMock()
# ContextTypes needs DEFAULT_TYPE as an actual attribute for the annotation
telegram_mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
telegram_mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
telegram_mod.constants.ChatType.GROUP = "group"
telegram_mod.constants.ChatType.SUPERGROUP = "supergroup"
telegram_mod.constants.ChatType.CHANNEL = "channel"
telegram_mod.constants.ChatType.PRIVATE = "private"
for name in ("telegram", "telegram.ext", "telegram.constants"):
sys.modules.setdefault(name, telegram_mod)
_ensure_telegram_mock()
# Now we can safely import
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
# ---------------------------------------------------------------------------
# Helpers to build mock Telegram objects
# ---------------------------------------------------------------------------
def _make_file_obj(data: bytes = b"hello"):
"""Create a mock Telegram File with download_as_bytearray."""
f = AsyncMock()
f.download_as_bytearray = AsyncMock(return_value=bytearray(data))
f.file_path = "documents/file.pdf"
return f
def _make_document(
file_name="report.pdf",
mime_type="application/pdf",
file_size=1024,
file_obj=None,
):
"""Create a mock Telegram Document object."""
doc = MagicMock()
doc.file_name = file_name
doc.mime_type = mime_type
doc.file_size = file_size
doc.get_file = AsyncMock(return_value=file_obj or _make_file_obj())
return doc
def _make_message(document=None, caption=None):
"""Build a mock Telegram Message with the given document."""
msg = MagicMock()
msg.message_id = 42
msg.text = caption or ""
msg.caption = caption
msg.date = None
# Media flags — all None except document
msg.photo = None
msg.video = None
msg.audio = None
msg.voice = None
msg.sticker = None
msg.document = document
# Chat / user
msg.chat = MagicMock()
msg.chat.id = 100
msg.chat.type = "private"
msg.chat.title = None
msg.chat.full_name = "Test User"
msg.from_user = MagicMock()
msg.from_user.id = 1
msg.from_user.full_name = "Test User"
msg.message_thread_id = None
return msg
def _make_update(msg):
"""Wrap a message in a mock Update."""
update = MagicMock()
update.message = msg
return update
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def adapter():
config = PlatformConfig(enabled=True, token="fake-token")
a = TelegramAdapter(config)
# Capture events instead of processing them
a.handle_message = AsyncMock()
return a
@pytest.fixture(autouse=True)
def _redirect_cache(tmp_path, monkeypatch):
"""Point document cache to tmp_path so tests don't touch ~/.hermes."""
monkeypatch.setattr(
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
)
# ---------------------------------------------------------------------------
# TestDocumentTypeDetection
# ---------------------------------------------------------------------------
class TestDocumentTypeDetection:
@pytest.mark.asyncio
async def test_document_detected_explicitly(self, adapter):
doc = _make_document()
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert event.message_type == MessageType.DOCUMENT
@pytest.mark.asyncio
async def test_fallback_is_document(self, adapter):
"""When no specific media attr is set, message_type defaults to DOCUMENT."""
msg = _make_message()
msg.document = None # no media at all
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert event.message_type == MessageType.DOCUMENT
# ---------------------------------------------------------------------------
# TestDocumentDownloadBlock
# ---------------------------------------------------------------------------
class TestDocumentDownloadBlock:
@pytest.mark.asyncio
async def test_supported_pdf_is_cached(self, adapter):
pdf_bytes = b"%PDF-1.4 fake"
file_obj = _make_file_obj(pdf_bytes)
doc = _make_document(file_name="report.pdf", file_size=1024, file_obj=file_obj)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert len(event.media_urls) == 1
assert os.path.exists(event.media_urls[0])
assert event.media_types == ["application/pdf"]
@pytest.mark.asyncio
async def test_supported_txt_injects_content(self, adapter):
content = b"Hello from a text file"
file_obj = _make_file_obj(content)
doc = _make_document(
file_name="notes.txt", mime_type="text/plain",
file_size=len(content), file_obj=file_obj,
)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert "Hello from a text file" in event.text
assert "[Content of notes.txt]" in event.text
@pytest.mark.asyncio
async def test_supported_md_injects_content(self, adapter):
content = b"# Title\nSome markdown"
file_obj = _make_file_obj(content)
doc = _make_document(
file_name="readme.md", mime_type="text/markdown",
file_size=len(content), file_obj=file_obj,
)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert "# Title" in event.text
@pytest.mark.asyncio
async def test_caption_preserved_with_injection(self, adapter):
content = b"file text"
file_obj = _make_file_obj(content)
doc = _make_document(
file_name="doc.txt", mime_type="text/plain",
file_size=len(content), file_obj=file_obj,
)
msg = _make_message(document=doc, caption="Please summarize")
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert "file text" in event.text
assert "Please summarize" in event.text
@pytest.mark.asyncio
async def test_unsupported_type_rejected(self, adapter):
doc = _make_document(file_name="archive.zip", mime_type="application/zip", file_size=100)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert "Unsupported document type" in event.text
assert ".zip" in event.text
@pytest.mark.asyncio
async def test_oversized_file_rejected(self, adapter):
doc = _make_document(file_name="huge.pdf", file_size=25 * 1024 * 1024)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert "too large" in event.text
@pytest.mark.asyncio
async def test_none_file_size_rejected(self, adapter):
"""Security fix: file_size=None must be rejected (not silently allowed)."""
doc = _make_document(file_name="tricky.pdf", file_size=None)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert "too large" in event.text or "could not be verified" in event.text
@pytest.mark.asyncio
async def test_missing_filename_uses_mime_lookup(self, adapter):
"""No file_name but valid mime_type should resolve to extension."""
content = b"some pdf bytes"
file_obj = _make_file_obj(content)
doc = _make_document(
file_name=None, mime_type="application/pdf",
file_size=len(content), file_obj=file_obj,
)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert len(event.media_urls) == 1
assert event.media_types == ["application/pdf"]
@pytest.mark.asyncio
async def test_missing_filename_and_mime_rejected(self, adapter):
doc = _make_document(file_name=None, mime_type=None, file_size=100)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
assert "Unsupported" in event.text
@pytest.mark.asyncio
async def test_unicode_decode_error_handled(self, adapter):
"""Binary bytes that aren't valid UTF-8 in a .txt — content not injected but file still cached."""
binary = bytes(range(128, 256)) # not valid UTF-8
file_obj = _make_file_obj(binary)
doc = _make_document(
file_name="binary.txt", mime_type="text/plain",
file_size=len(binary), file_obj=file_obj,
)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
# File should still be cached
assert len(event.media_urls) == 1
assert os.path.exists(event.media_urls[0])
# Content NOT injected — text should be empty (no caption set)
assert "[Content of" not in (event.text or "")
@pytest.mark.asyncio
async def test_text_injection_capped(self, adapter):
"""A .txt file over 100 KB should NOT have its content injected."""
large = b"x" * (200 * 1024) # 200 KB
file_obj = _make_file_obj(large)
doc = _make_document(
file_name="big.txt", mime_type="text/plain",
file_size=len(large), file_obj=file_obj,
)
msg = _make_message(document=doc)
update = _make_update(msg)
await adapter._handle_media_message(update, MagicMock())
event = adapter.handle_message.call_args[0][0]
# File should be cached
assert len(event.media_urls) == 1
# Content should NOT be injected
assert "[Content of" not in (event.text or "")
@pytest.mark.asyncio
async def test_download_exception_handled(self, adapter):
"""If get_file() raises, the handler logs the error without crashing."""
doc = _make_document(file_name="crash.pdf", file_size=100)
doc.get_file = AsyncMock(side_effect=RuntimeError("Telegram API down"))
msg = _make_message(document=doc)
update = _make_update(msg)
# Should not raise
await adapter._handle_media_message(update, MagicMock())
# handle_message should still be called (the handler catches the exception)
adapter.handle_message.assert_called_once()

View File

@@ -0,0 +1,171 @@
"""Tests for 413 payload-too-large → compression retry logic in AIAgent.
Verifies that HTTP 413 errors trigger history compression and retry,
rather than being treated as non-retryable generic 4xx errors.
"""
import uuid
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from run_agent import AIAgent
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_tool_defs(*names: str) -> list:
return [
{
"type": "function",
"function": {
"name": n,
"description": f"{n} tool",
"parameters": {"type": "object", "properties": {}},
},
}
for n in names
]
def _mock_response(content="Hello", finish_reason="stop", tool_calls=None, usage=None):
msg = SimpleNamespace(
content=content,
tool_calls=tool_calls,
reasoning_content=None,
reasoning=None,
)
choice = SimpleNamespace(message=msg, finish_reason=finish_reason)
resp = SimpleNamespace(choices=[choice], model="test/model")
resp.usage = SimpleNamespace(**usage) if usage else None
return resp
def _make_413_error(*, use_status_code=True, message="Request entity too large"):
"""Create an exception that mimics a 413 HTTP error."""
err = Exception(message)
if use_status_code:
err.status_code = 413
return err
@pytest.fixture()
def agent():
with (
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
):
a = AIAgent(
api_key="test-key-1234567890",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
a.client = MagicMock()
a._cached_system_prompt = "You are helpful."
a._use_prompt_caching = False
a.tool_delay = 0
a.compression_enabled = False
a.save_trajectories = False
return a
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestHTTP413Compression:
"""413 errors should trigger compression, not abort as generic 4xx."""
def test_413_triggers_compression(self, agent):
"""A 413 error should call _compress_context and retry, not abort."""
# First call raises 413; second call succeeds after compression.
err_413 = _make_413_error()
ok_resp = _mock_response(content="Success after compression", finish_reason="stop")
agent.client.chat.completions.create.side_effect = [err_413, ok_resp]
with (
patch.object(agent, "_compress_context") as mock_compress,
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
):
# Compression removes messages, enabling retry
mock_compress.return_value = (
[{"role": "user", "content": "hello"}],
"compressed prompt",
)
result = agent.run_conversation("hello")
mock_compress.assert_called_once()
assert result["completed"] is True
assert result["final_response"] == "Success after compression"
def test_413_not_treated_as_generic_4xx(self, agent):
"""413 must NOT hit the generic 4xx abort path; it should attempt compression."""
err_413 = _make_413_error()
ok_resp = _mock_response(content="Recovered", finish_reason="stop")
agent.client.chat.completions.create.side_effect = [err_413, ok_resp]
with (
patch.object(agent, "_compress_context") as mock_compress,
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
):
mock_compress.return_value = (
[{"role": "user", "content": "hello"}],
"compressed",
)
result = agent.run_conversation("hello")
# If 413 were treated as generic 4xx, result would have "failed": True
assert result.get("failed") is not True
assert result["completed"] is True
def test_413_error_message_detection(self, agent):
"""413 detected via error message string (no status_code attr)."""
err = _make_413_error(use_status_code=False, message="error code: 413")
ok_resp = _mock_response(content="OK", finish_reason="stop")
agent.client.chat.completions.create.side_effect = [err, ok_resp]
with (
patch.object(agent, "_compress_context") as mock_compress,
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
):
mock_compress.return_value = (
[{"role": "user", "content": "hello"}],
"compressed",
)
result = agent.run_conversation("hello")
mock_compress.assert_called_once()
assert result["completed"] is True
def test_413_cannot_compress_further(self, agent):
"""When compression can't reduce messages, return partial result."""
err_413 = _make_413_error()
agent.client.chat.completions.create.side_effect = [err_413]
with (
patch.object(agent, "_compress_context") as mock_compress,
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
):
# Compression returns same number of messages → can't compress further
mock_compress.return_value = (
[{"role": "user", "content": "hello"}],
"same prompt",
)
result = agent.run_conversation("hello")
assert result["completed"] is False
assert result.get("partial") is True
assert "413" in result["error"]