Merge pull request #153 from tekelala/main
fix(agent): handle 413 payload-too-large via compression instead of aborting
This commit is contained in:
@@ -171,6 +171,84 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg") -> str:
|
||||
return cache_audio_from_bytes(response.content, ext)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Document cache utilities
|
||||
#
|
||||
# Same pattern as image/audio cache -- documents from platforms are downloaded
|
||||
# here so the agent can reference them by local file path.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DOCUMENT_CACHE_DIR = Path(os.path.expanduser("~/.hermes/document_cache"))
|
||||
|
||||
SUPPORTED_DOCUMENT_TYPES = {
|
||||
".pdf": "application/pdf",
|
||||
".md": "text/markdown",
|
||||
".txt": "text/plain",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
}
|
||||
|
||||
|
||||
def get_document_cache_dir() -> Path:
|
||||
"""Return the document cache directory, creating it if it doesn't exist."""
|
||||
DOCUMENT_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return DOCUMENT_CACHE_DIR
|
||||
|
||||
|
||||
def cache_document_from_bytes(data: bytes, filename: str) -> str:
|
||||
"""
|
||||
Save raw document bytes to the cache and return the absolute file path.
|
||||
|
||||
The cached filename preserves the original human-readable name with a
|
||||
unique prefix: ``doc_{uuid12}_{original_filename}``.
|
||||
|
||||
Args:
|
||||
data: Raw document bytes.
|
||||
filename: Original filename (e.g. "report.pdf").
|
||||
|
||||
Returns:
|
||||
Absolute path to the cached document file as a string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the sanitized path escapes the cache directory.
|
||||
"""
|
||||
cache_dir = get_document_cache_dir()
|
||||
# Sanitize: strip directory components, null bytes, and control characters
|
||||
safe_name = Path(filename).name if filename else "document"
|
||||
safe_name = safe_name.replace("\x00", "").strip()
|
||||
if not safe_name or safe_name in (".", ".."):
|
||||
safe_name = "document"
|
||||
cached_name = f"doc_{uuid.uuid4().hex[:12]}_{safe_name}"
|
||||
filepath = cache_dir / cached_name
|
||||
# Final safety check: ensure path stays inside cache dir
|
||||
if not filepath.resolve().is_relative_to(cache_dir.resolve()):
|
||||
raise ValueError(f"Path traversal rejected: {filename!r}")
|
||||
filepath.write_bytes(data)
|
||||
return str(filepath)
|
||||
|
||||
|
||||
def cleanup_document_cache(max_age_hours: int = 24) -> int:
|
||||
"""
|
||||
Delete cached documents older than *max_age_hours*.
|
||||
|
||||
Returns the number of files removed.
|
||||
"""
|
||||
import time
|
||||
|
||||
cache_dir = get_document_cache_dir()
|
||||
cutoff = time.time() - (max_age_hours * 3600)
|
||||
removed = 0
|
||||
for f in cache_dir.iterdir():
|
||||
if f.is_file() and f.stat().st_mtime < cutoff:
|
||||
try:
|
||||
f.unlink()
|
||||
removed += 1
|
||||
except OSError:
|
||||
pass
|
||||
return removed
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
"""Types of incoming messages."""
|
||||
TEXT = "text"
|
||||
|
||||
@@ -8,6 +8,7 @@ Uses python-telegram-bot library for:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
@@ -42,6 +43,8 @@ from gateway.platforms.base import (
|
||||
SendResult,
|
||||
cache_image_from_bytes,
|
||||
cache_audio_from_bytes,
|
||||
cache_document_from_bytes,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
|
||||
|
||||
@@ -419,6 +422,8 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
msg_type = MessageType.AUDIO
|
||||
elif msg.voice:
|
||||
msg_type = MessageType.VOICE
|
||||
elif msg.document:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
|
||||
@@ -479,7 +484,72 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
print(f"[Telegram] Cached user audio: {cached_path}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Telegram] Failed to cache audio: {e}", flush=True)
|
||||
|
||||
|
||||
# Download document files to cache for agent processing
|
||||
elif msg.document:
|
||||
doc = msg.document
|
||||
try:
|
||||
# Determine file extension
|
||||
ext = ""
|
||||
original_filename = doc.file_name or ""
|
||||
if original_filename:
|
||||
_, ext = os.path.splitext(original_filename)
|
||||
ext = ext.lower()
|
||||
|
||||
# If no extension from filename, reverse-lookup from MIME type
|
||||
if not ext and doc.mime_type:
|
||||
mime_to_ext = {v: k for k, v in SUPPORTED_DOCUMENT_TYPES.items()}
|
||||
ext = mime_to_ext.get(doc.mime_type, "")
|
||||
|
||||
# Check if supported
|
||||
if ext not in SUPPORTED_DOCUMENT_TYPES:
|
||||
supported_list = ", ".join(sorted(SUPPORTED_DOCUMENT_TYPES.keys()))
|
||||
event.text = (
|
||||
f"Unsupported document type '{ext or 'unknown'}'. "
|
||||
f"Supported types: {supported_list}"
|
||||
)
|
||||
print(f"[Telegram] Unsupported document type: {ext or 'unknown'}", flush=True)
|
||||
await self.handle_message(event)
|
||||
return
|
||||
|
||||
# Check file size (Telegram Bot API limit: 20 MB)
|
||||
MAX_DOC_BYTES = 20 * 1024 * 1024
|
||||
if not doc.file_size or doc.file_size > MAX_DOC_BYTES:
|
||||
event.text = (
|
||||
"The document is too large or its size could not be verified. "
|
||||
"Maximum: 20 MB."
|
||||
)
|
||||
print(f"[Telegram] Document too large: {doc.file_size} bytes", flush=True)
|
||||
await self.handle_message(event)
|
||||
return
|
||||
|
||||
# Download and cache
|
||||
file_obj = await doc.get_file()
|
||||
doc_bytes = await file_obj.download_as_bytearray()
|
||||
raw_bytes = bytes(doc_bytes)
|
||||
cached_path = cache_document_from_bytes(raw_bytes, original_filename or f"document{ext}")
|
||||
mime_type = SUPPORTED_DOCUMENT_TYPES[ext]
|
||||
event.media_urls = [cached_path]
|
||||
event.media_types = [mime_type]
|
||||
print(f"[Telegram] Cached user document: {cached_path}", flush=True)
|
||||
|
||||
# For text files, inject content into event.text (capped at 100 KB)
|
||||
MAX_TEXT_INJECT_BYTES = 100 * 1024
|
||||
if ext in (".md", ".txt") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES:
|
||||
try:
|
||||
text_content = raw_bytes.decode("utf-8")
|
||||
display_name = original_filename or f"document{ext}"
|
||||
injection = f"[Content of {display_name}]:\n{text_content}"
|
||||
if event.text:
|
||||
event.text = f"{injection}\n\n{event.text}"
|
||||
else:
|
||||
event.text = injection
|
||||
except UnicodeDecodeError:
|
||||
print(f"[Telegram] Could not decode text file as UTF-8, skipping content injection", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Telegram] Failed to cache document: {e}", flush=True)
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
async def _handle_sticker(self, msg: Message, event: "MessageEvent") -> None:
|
||||
|
||||
@@ -742,7 +742,39 @@ class GatewayRunner:
|
||||
message_text = await self._enrich_message_with_transcription(
|
||||
message_text, audio_paths
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Enrich document messages with context notes for the agent
|
||||
# -----------------------------------------------------------------
|
||||
if event.media_urls and event.message_type == MessageType.DOCUMENT:
|
||||
for i, path in enumerate(event.media_urls):
|
||||
mtype = event.media_types[i] if i < len(event.media_types) else ""
|
||||
if not (mtype.startswith("application/") or mtype.startswith("text/")):
|
||||
continue
|
||||
# Extract display filename by stripping the doc_{uuid12}_ prefix
|
||||
import os as _os
|
||||
basename = _os.path.basename(path)
|
||||
# Format: doc_<12hex>_<original_filename>
|
||||
parts = basename.split("_", 2)
|
||||
display_name = parts[2] if len(parts) >= 3 else basename
|
||||
# Sanitize to prevent prompt injection via filenames
|
||||
import re as _re
|
||||
display_name = _re.sub(r'[^\w.\- ]', '_', display_name)
|
||||
|
||||
if mtype.startswith("text/"):
|
||||
context_note = (
|
||||
f"[The user sent a text document: '{display_name}'. "
|
||||
f"Its content has been included below. "
|
||||
f"The file is also saved at: {path}]"
|
||||
)
|
||||
else:
|
||||
context_note = (
|
||||
f"[The user sent a document: '{display_name}'. "
|
||||
f"The file is saved at: {path}. "
|
||||
f"Ask the user what they'd like you to do with it.]"
|
||||
)
|
||||
message_text = f"{context_note}\n\n{message_text}"
|
||||
|
||||
try:
|
||||
# Emit agent:start hook
|
||||
hook_ctx = {
|
||||
@@ -1813,10 +1845,10 @@ def _start_cron_ticker(stop_event: threading.Event, adapters=None, interval: int
|
||||
needing a separate `hermes cron daemon` or system cron entry.
|
||||
|
||||
Also refreshes the channel directory every 5 minutes and prunes the
|
||||
image/audio cache once per hour.
|
||||
image/audio/document cache once per hour.
|
||||
"""
|
||||
from cron.scheduler import tick as cron_tick
|
||||
from gateway.platforms.base import cleanup_image_cache
|
||||
from gateway.platforms.base import cleanup_image_cache, cleanup_document_cache
|
||||
|
||||
IMAGE_CACHE_EVERY = 60 # ticks — once per hour at default 60s interval
|
||||
CHANNEL_DIR_EVERY = 5 # ticks — every 5 minutes
|
||||
@@ -1845,6 +1877,12 @@ def _start_cron_ticker(stop_event: threading.Event, adapters=None, interval: int
|
||||
logger.info("Image cache cleanup: removed %d stale file(s)", removed)
|
||||
except Exception as e:
|
||||
logger.debug("Image cache cleanup error: %s", e)
|
||||
try:
|
||||
removed = cleanup_document_cache(max_age_hours=24)
|
||||
if removed:
|
||||
logger.info("Document cache cleanup: removed %d stale file(s)", removed)
|
||||
except Exception as e:
|
||||
logger.debug("Document cache cleanup error: %s", e)
|
||||
|
||||
stop_event.wait(timeout=interval)
|
||||
logger.info("Cron ticker stopped")
|
||||
|
||||
44
run_agent.py
44
run_agent.py
@@ -2092,11 +2092,44 @@ class AIAgent:
|
||||
"interrupted": True,
|
||||
}
|
||||
|
||||
# Check for 413 payload-too-large BEFORE generic 4xx handler.
|
||||
# A 413 is a payload-size error — the correct response is to
|
||||
# compress history and retry, not abort immediately.
|
||||
status_code = getattr(api_error, "status_code", None)
|
||||
is_payload_too_large = (
|
||||
status_code == 413
|
||||
or 'request entity too large' in error_msg
|
||||
or 'error code: 413' in error_msg
|
||||
)
|
||||
|
||||
if is_payload_too_large:
|
||||
print(f"{self.log_prefix}⚠️ Request payload too large (413) - attempting compression...")
|
||||
|
||||
original_len = len(messages)
|
||||
messages, active_system_prompt = self._compress_context(
|
||||
messages, system_message, approx_tokens=approx_tokens
|
||||
)
|
||||
|
||||
if len(messages) < original_len:
|
||||
print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...")
|
||||
continue # Retry with compressed messages
|
||||
else:
|
||||
print(f"{self.log_prefix}❌ Payload too large and cannot compress further.")
|
||||
logging.error(f"{self.log_prefix}413 payload too large. Cannot compress further.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
"messages": messages,
|
||||
"completed": False,
|
||||
"api_calls": api_call_count,
|
||||
"error": "Request payload too large (413). Cannot compress further.",
|
||||
"partial": True
|
||||
}
|
||||
|
||||
# Check for non-retryable client errors (4xx HTTP status codes).
|
||||
# These indicate a problem with the request itself (bad model ID,
|
||||
# invalid API key, forbidden, etc.) and will never succeed on retry.
|
||||
status_code = getattr(api_error, "status_code", None)
|
||||
is_client_status_error = isinstance(status_code, int) and 400 <= status_code < 500
|
||||
# Note: 413 is excluded — it's handled above via compression.
|
||||
is_client_status_error = isinstance(status_code, int) and 400 <= status_code < 500 and status_code != 413
|
||||
is_client_error = is_client_status_error or any(phrase in error_msg for phrase in [
|
||||
'error code: 400', 'error code: 401', 'error code: 403',
|
||||
'error code: 404', 'error code: 422',
|
||||
@@ -2104,7 +2137,7 @@ class AIAgent:
|
||||
'invalid api key', 'invalid_api_key', 'authentication',
|
||||
'unauthorized', 'forbidden', 'not found',
|
||||
])
|
||||
|
||||
|
||||
if is_client_error:
|
||||
self._dump_api_request_debug(
|
||||
api_kwargs, reason="non_retryable_client_error", error=api_error,
|
||||
@@ -2124,8 +2157,9 @@ class AIAgent:
|
||||
|
||||
# Check for non-retryable errors (context length exceeded)
|
||||
is_context_length_error = any(phrase in error_msg for phrase in [
|
||||
'context length', 'maximum context', 'token limit',
|
||||
'too many tokens', 'reduce the length', 'exceeds the limit'
|
||||
'context length', 'maximum context', 'token limit',
|
||||
'too many tokens', 'reduce the length', 'exceeds the limit',
|
||||
'request entity too large', # OpenRouter/Nous 413 safety net
|
||||
])
|
||||
|
||||
if is_context_length_error:
|
||||
|
||||
157
tests/gateway/test_document_cache.py
Normal file
157
tests/gateway/test_document_cache.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Tests for document cache utilities in gateway/platforms/base.py.
|
||||
|
||||
Covers: get_document_cache_dir, cache_document_from_bytes,
|
||||
cleanup_document_cache, SUPPORTED_DOCUMENT_TYPES.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.platforms.base import (
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
cache_document_from_bytes,
|
||||
cleanup_document_cache,
|
||||
get_document_cache_dir,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixture: redirect DOCUMENT_CACHE_DIR to a temp directory for every test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _redirect_cache(tmp_path, monkeypatch):
|
||||
"""Point the module-level DOCUMENT_CACHE_DIR to a fresh tmp_path."""
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestGetDocumentCacheDir
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetDocumentCacheDir:
|
||||
def test_creates_directory(self, tmp_path):
|
||||
cache_dir = get_document_cache_dir()
|
||||
assert cache_dir.exists()
|
||||
assert cache_dir.is_dir()
|
||||
|
||||
def test_returns_existing_directory(self):
|
||||
first = get_document_cache_dir()
|
||||
second = get_document_cache_dir()
|
||||
assert first == second
|
||||
assert first.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestCacheDocumentFromBytes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCacheDocumentFromBytes:
|
||||
def test_basic_caching(self):
|
||||
data = b"hello world"
|
||||
path = cache_document_from_bytes(data, "test.txt")
|
||||
assert os.path.exists(path)
|
||||
assert Path(path).read_bytes() == data
|
||||
|
||||
def test_filename_preserved_in_path(self):
|
||||
path = cache_document_from_bytes(b"data", "report.pdf")
|
||||
assert "report.pdf" in os.path.basename(path)
|
||||
|
||||
def test_empty_filename_uses_fallback(self):
|
||||
path = cache_document_from_bytes(b"data", "")
|
||||
assert "document" in os.path.basename(path)
|
||||
|
||||
def test_unique_filenames(self):
|
||||
p1 = cache_document_from_bytes(b"a", "same.txt")
|
||||
p2 = cache_document_from_bytes(b"b", "same.txt")
|
||||
assert p1 != p2
|
||||
|
||||
def test_path_traversal_blocked(self):
|
||||
"""Malicious directory components are stripped — only the leaf name survives."""
|
||||
path = cache_document_from_bytes(b"data", "../../etc/passwd")
|
||||
basename = os.path.basename(path)
|
||||
assert "passwd" in basename
|
||||
# Must NOT contain directory separators
|
||||
assert ".." not in basename
|
||||
# File must reside inside the cache directory
|
||||
cache_dir = get_document_cache_dir()
|
||||
assert Path(path).resolve().is_relative_to(cache_dir.resolve())
|
||||
|
||||
def test_null_bytes_stripped(self):
|
||||
path = cache_document_from_bytes(b"data", "file\x00.pdf")
|
||||
basename = os.path.basename(path)
|
||||
assert "\x00" not in basename
|
||||
assert "file.pdf" in basename
|
||||
|
||||
def test_dot_dot_filename_handled(self):
|
||||
"""A filename that is literally '..' falls back to 'document'."""
|
||||
path = cache_document_from_bytes(b"data", "..")
|
||||
basename = os.path.basename(path)
|
||||
assert "document" in basename
|
||||
|
||||
def test_none_filename_uses_fallback(self):
|
||||
path = cache_document_from_bytes(b"data", None)
|
||||
assert "document" in os.path.basename(path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestCleanupDocumentCache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCleanupDocumentCache:
|
||||
def test_removes_old_files(self, tmp_path):
|
||||
cache_dir = get_document_cache_dir()
|
||||
old_file = cache_dir / "old.txt"
|
||||
old_file.write_text("old")
|
||||
# Set modification time to 48 hours ago
|
||||
old_mtime = time.time() - 48 * 3600
|
||||
os.utime(old_file, (old_mtime, old_mtime))
|
||||
|
||||
removed = cleanup_document_cache(max_age_hours=24)
|
||||
assert removed == 1
|
||||
assert not old_file.exists()
|
||||
|
||||
def test_keeps_recent_files(self):
|
||||
cache_dir = get_document_cache_dir()
|
||||
recent = cache_dir / "recent.txt"
|
||||
recent.write_text("fresh")
|
||||
|
||||
removed = cleanup_document_cache(max_age_hours=24)
|
||||
assert removed == 0
|
||||
assert recent.exists()
|
||||
|
||||
def test_returns_removed_count(self):
|
||||
cache_dir = get_document_cache_dir()
|
||||
old_time = time.time() - 48 * 3600
|
||||
for i in range(3):
|
||||
f = cache_dir / f"old_{i}.txt"
|
||||
f.write_text("x")
|
||||
os.utime(f, (old_time, old_time))
|
||||
|
||||
assert cleanup_document_cache(max_age_hours=24) == 3
|
||||
|
||||
def test_empty_cache_dir(self):
|
||||
assert cleanup_document_cache(max_age_hours=24) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSupportedDocumentTypes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSupportedDocumentTypes:
|
||||
def test_all_extensions_have_mime_types(self):
|
||||
for ext, mime in SUPPORTED_DOCUMENT_TYPES.items():
|
||||
assert ext.startswith("."), f"{ext} missing leading dot"
|
||||
assert "/" in mime, f"{mime} is not a valid MIME type"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ext",
|
||||
[".pdf", ".md", ".txt", ".docx", ".xlsx", ".pptx"],
|
||||
)
|
||||
def test_expected_extensions_present(self, ext):
|
||||
assert ext in SUPPORTED_DOCUMENT_TYPES
|
||||
338
tests/gateway/test_telegram_documents.py
Normal file
338
tests/gateway/test_telegram_documents.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
Tests for Telegram document handling in gateway/platforms/telegram.py.
|
||||
|
||||
Covers: document type detection, download/cache flow, size limits,
|
||||
text injection, error handling.
|
||||
|
||||
Note: python-telegram-bot may not be installed in the test environment.
|
||||
We mock the telegram module at import time to avoid collection errors.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock the telegram package if it's not installed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ensure_telegram_mock():
|
||||
"""Install mock telegram modules so TelegramAdapter can be imported."""
|
||||
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
|
||||
# Real library is installed — no mocking needed
|
||||
return
|
||||
|
||||
telegram_mod = MagicMock()
|
||||
# ContextTypes needs DEFAULT_TYPE as an actual attribute for the annotation
|
||||
telegram_mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
|
||||
telegram_mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
|
||||
telegram_mod.constants.ChatType.GROUP = "group"
|
||||
telegram_mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
|
||||
# Now we can safely import
|
||||
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers to build mock Telegram objects
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_file_obj(data: bytes = b"hello"):
|
||||
"""Create a mock Telegram File with download_as_bytearray."""
|
||||
f = AsyncMock()
|
||||
f.download_as_bytearray = AsyncMock(return_value=bytearray(data))
|
||||
f.file_path = "documents/file.pdf"
|
||||
return f
|
||||
|
||||
|
||||
def _make_document(
|
||||
file_name="report.pdf",
|
||||
mime_type="application/pdf",
|
||||
file_size=1024,
|
||||
file_obj=None,
|
||||
):
|
||||
"""Create a mock Telegram Document object."""
|
||||
doc = MagicMock()
|
||||
doc.file_name = file_name
|
||||
doc.mime_type = mime_type
|
||||
doc.file_size = file_size
|
||||
doc.get_file = AsyncMock(return_value=file_obj or _make_file_obj())
|
||||
return doc
|
||||
|
||||
|
||||
def _make_message(document=None, caption=None):
|
||||
"""Build a mock Telegram Message with the given document."""
|
||||
msg = MagicMock()
|
||||
msg.message_id = 42
|
||||
msg.text = caption or ""
|
||||
msg.caption = caption
|
||||
msg.date = None
|
||||
# Media flags — all None except document
|
||||
msg.photo = None
|
||||
msg.video = None
|
||||
msg.audio = None
|
||||
msg.voice = None
|
||||
msg.sticker = None
|
||||
msg.document = document
|
||||
# Chat / user
|
||||
msg.chat = MagicMock()
|
||||
msg.chat.id = 100
|
||||
msg.chat.type = "private"
|
||||
msg.chat.title = None
|
||||
msg.chat.full_name = "Test User"
|
||||
msg.from_user = MagicMock()
|
||||
msg.from_user.id = 1
|
||||
msg.from_user.full_name = "Test User"
|
||||
msg.message_thread_id = None
|
||||
return msg
|
||||
|
||||
|
||||
def _make_update(msg):
|
||||
"""Wrap a message in a mock Update."""
|
||||
update = MagicMock()
|
||||
update.message = msg
|
||||
return update
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture()
|
||||
def adapter():
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
a = TelegramAdapter(config)
|
||||
# Capture events instead of processing them
|
||||
a.handle_message = AsyncMock()
|
||||
return a
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _redirect_cache(tmp_path, monkeypatch):
|
||||
"""Point document cache to tmp_path so tests don't touch ~/.hermes."""
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestDocumentTypeDetection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDocumentTypeDetection:
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_detected_explicitly(self, adapter):
|
||||
doc = _make_document()
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.message_type == MessageType.DOCUMENT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_is_document(self, adapter):
|
||||
"""When no specific media attr is set, message_type defaults to DOCUMENT."""
|
||||
msg = _make_message()
|
||||
msg.document = None # no media at all
|
||||
update = _make_update(msg)
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.message_type == MessageType.DOCUMENT
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestDocumentDownloadBlock
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDocumentDownloadBlock:
|
||||
@pytest.mark.asyncio
|
||||
async def test_supported_pdf_is_cached(self, adapter):
|
||||
pdf_bytes = b"%PDF-1.4 fake"
|
||||
file_obj = _make_file_obj(pdf_bytes)
|
||||
doc = _make_document(file_name="report.pdf", file_size=1024, file_obj=file_obj)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert len(event.media_urls) == 1
|
||||
assert os.path.exists(event.media_urls[0])
|
||||
assert event.media_types == ["application/pdf"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supported_txt_injects_content(self, adapter):
|
||||
content = b"Hello from a text file"
|
||||
file_obj = _make_file_obj(content)
|
||||
doc = _make_document(
|
||||
file_name="notes.txt", mime_type="text/plain",
|
||||
file_size=len(content), file_obj=file_obj,
|
||||
)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "Hello from a text file" in event.text
|
||||
assert "[Content of notes.txt]" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supported_md_injects_content(self, adapter):
|
||||
content = b"# Title\nSome markdown"
|
||||
file_obj = _make_file_obj(content)
|
||||
doc = _make_document(
|
||||
file_name="readme.md", mime_type="text/markdown",
|
||||
file_size=len(content), file_obj=file_obj,
|
||||
)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "# Title" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caption_preserved_with_injection(self, adapter):
|
||||
content = b"file text"
|
||||
file_obj = _make_file_obj(content)
|
||||
doc = _make_document(
|
||||
file_name="doc.txt", mime_type="text/plain",
|
||||
file_size=len(content), file_obj=file_obj,
|
||||
)
|
||||
msg = _make_message(document=doc, caption="Please summarize")
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "file text" in event.text
|
||||
assert "Please summarize" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_type_rejected(self, adapter):
|
||||
doc = _make_document(file_name="archive.zip", mime_type="application/zip", file_size=100)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "Unsupported document type" in event.text
|
||||
assert ".zip" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_file_rejected(self, adapter):
|
||||
doc = _make_document(file_name="huge.pdf", file_size=25 * 1024 * 1024)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "too large" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_file_size_rejected(self, adapter):
|
||||
"""Security fix: file_size=None must be rejected (not silently allowed)."""
|
||||
doc = _make_document(file_name="tricky.pdf", file_size=None)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "too large" in event.text or "could not be verified" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_filename_uses_mime_lookup(self, adapter):
|
||||
"""No file_name but valid mime_type should resolve to extension."""
|
||||
content = b"some pdf bytes"
|
||||
file_obj = _make_file_obj(content)
|
||||
doc = _make_document(
|
||||
file_name=None, mime_type="application/pdf",
|
||||
file_size=len(content), file_obj=file_obj,
|
||||
)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert len(event.media_urls) == 1
|
||||
assert event.media_types == ["application/pdf"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_filename_and_mime_rejected(self, adapter):
|
||||
doc = _make_document(file_name=None, mime_type=None, file_size=100)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "Unsupported" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unicode_decode_error_handled(self, adapter):
|
||||
"""Binary bytes that aren't valid UTF-8 in a .txt — content not injected but file still cached."""
|
||||
binary = bytes(range(128, 256)) # not valid UTF-8
|
||||
file_obj = _make_file_obj(binary)
|
||||
doc = _make_document(
|
||||
file_name="binary.txt", mime_type="text/plain",
|
||||
file_size=len(binary), file_obj=file_obj,
|
||||
)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
# File should still be cached
|
||||
assert len(event.media_urls) == 1
|
||||
assert os.path.exists(event.media_urls[0])
|
||||
# Content NOT injected — text should be empty (no caption set)
|
||||
assert "[Content of" not in (event.text or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_injection_capped(self, adapter):
|
||||
"""A .txt file over 100 KB should NOT have its content injected."""
|
||||
large = b"x" * (200 * 1024) # 200 KB
|
||||
file_obj = _make_file_obj(large)
|
||||
doc = _make_document(
|
||||
file_name="big.txt", mime_type="text/plain",
|
||||
file_size=len(large), file_obj=file_obj,
|
||||
)
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
# File should be cached
|
||||
assert len(event.media_urls) == 1
|
||||
# Content should NOT be injected
|
||||
assert "[Content of" not in (event.text or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_exception_handled(self, adapter):
|
||||
"""If get_file() raises, the handler logs the error without crashing."""
|
||||
doc = _make_document(file_name="crash.pdf", file_size=100)
|
||||
doc.get_file = AsyncMock(side_effect=RuntimeError("Telegram API down"))
|
||||
msg = _make_message(document=doc)
|
||||
update = _make_update(msg)
|
||||
|
||||
# Should not raise
|
||||
await adapter._handle_media_message(update, MagicMock())
|
||||
# handle_message should still be called (the handler catches the exception)
|
||||
adapter.handle_message.assert_called_once()
|
||||
171
tests/test_413_compression.py
Normal file
171
tests/test_413_compression.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Tests for 413 payload-too-large → compression retry logic in AIAgent.
|
||||
|
||||
Verifies that HTTP 413 errors trigger history compression and retry,
|
||||
rather than being treated as non-retryable generic 4xx errors.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_tool_defs(*names: str) -> list:
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": n,
|
||||
"description": f"{n} tool",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
for n in names
|
||||
]
|
||||
|
||||
|
||||
def _mock_response(content="Hello", finish_reason="stop", tool_calls=None, usage=None):
|
||||
msg = SimpleNamespace(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
reasoning_content=None,
|
||||
reasoning=None,
|
||||
)
|
||||
choice = SimpleNamespace(message=msg, finish_reason=finish_reason)
|
||||
resp = SimpleNamespace(choices=[choice], model="test/model")
|
||||
resp.usage = SimpleNamespace(**usage) if usage else None
|
||||
return resp
|
||||
|
||||
|
||||
def _make_413_error(*, use_status_code=True, message="Request entity too large"):
|
||||
"""Create an exception that mimics a 413 HTTP error."""
|
||||
err = Exception(message)
|
||||
if use_status_code:
|
||||
err.status_code = 413
|
||||
return err
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def agent():
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
a.client = MagicMock()
|
||||
a._cached_system_prompt = "You are helpful."
|
||||
a._use_prompt_caching = False
|
||||
a.tool_delay = 0
|
||||
a.compression_enabled = False
|
||||
a.save_trajectories = False
|
||||
return a
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHTTP413Compression:
|
||||
"""413 errors should trigger compression, not abort as generic 4xx."""
|
||||
|
||||
def test_413_triggers_compression(self, agent):
|
||||
"""A 413 error should call _compress_context and retry, not abort."""
|
||||
# First call raises 413; second call succeeds after compression.
|
||||
err_413 = _make_413_error()
|
||||
ok_resp = _mock_response(content="Success after compression", finish_reason="stop")
|
||||
agent.client.chat.completions.create.side_effect = [err_413, ok_resp]
|
||||
|
||||
with (
|
||||
patch.object(agent, "_compress_context") as mock_compress,
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
# Compression removes messages, enabling retry
|
||||
mock_compress.return_value = (
|
||||
[{"role": "user", "content": "hello"}],
|
||||
"compressed prompt",
|
||||
)
|
||||
result = agent.run_conversation("hello")
|
||||
|
||||
mock_compress.assert_called_once()
|
||||
assert result["completed"] is True
|
||||
assert result["final_response"] == "Success after compression"
|
||||
|
||||
def test_413_not_treated_as_generic_4xx(self, agent):
|
||||
"""413 must NOT hit the generic 4xx abort path; it should attempt compression."""
|
||||
err_413 = _make_413_error()
|
||||
ok_resp = _mock_response(content="Recovered", finish_reason="stop")
|
||||
agent.client.chat.completions.create.side_effect = [err_413, ok_resp]
|
||||
|
||||
with (
|
||||
patch.object(agent, "_compress_context") as mock_compress,
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
mock_compress.return_value = (
|
||||
[{"role": "user", "content": "hello"}],
|
||||
"compressed",
|
||||
)
|
||||
result = agent.run_conversation("hello")
|
||||
|
||||
# If 413 were treated as generic 4xx, result would have "failed": True
|
||||
assert result.get("failed") is not True
|
||||
assert result["completed"] is True
|
||||
|
||||
def test_413_error_message_detection(self, agent):
|
||||
"""413 detected via error message string (no status_code attr)."""
|
||||
err = _make_413_error(use_status_code=False, message="error code: 413")
|
||||
ok_resp = _mock_response(content="OK", finish_reason="stop")
|
||||
agent.client.chat.completions.create.side_effect = [err, ok_resp]
|
||||
|
||||
with (
|
||||
patch.object(agent, "_compress_context") as mock_compress,
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
mock_compress.return_value = (
|
||||
[{"role": "user", "content": "hello"}],
|
||||
"compressed",
|
||||
)
|
||||
result = agent.run_conversation("hello")
|
||||
|
||||
mock_compress.assert_called_once()
|
||||
assert result["completed"] is True
|
||||
|
||||
def test_413_cannot_compress_further(self, agent):
|
||||
"""When compression can't reduce messages, return partial result."""
|
||||
err_413 = _make_413_error()
|
||||
agent.client.chat.completions.create.side_effect = [err_413]
|
||||
|
||||
with (
|
||||
patch.object(agent, "_compress_context") as mock_compress,
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
# Compression returns same number of messages → can't compress further
|
||||
mock_compress.return_value = (
|
||||
[{"role": "user", "content": "hello"}],
|
||||
"same prompt",
|
||||
)
|
||||
result = agent.run_conversation("hello")
|
||||
|
||||
assert result["completed"] is False
|
||||
assert result.get("partial") is True
|
||||
assert "413" in result["error"]
|
||||
Reference in New Issue
Block a user