From afe2f0abe19940aff75f6ab31d4f74ebed34f586 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Sun, 22 Mar 2026 07:38:14 -0700 Subject: [PATCH] feat(discord): add document caching and text-file injection (#2503) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Download and cache .pdf, .docx, .xlsx, .pptx attachments locally instead of passing expiring CDN URLs to the agent - Inject .txt and .md content (≤100 KB) into event.text so the agent sees file content without needing to fetch the URL - Add 20 MB size guard and SUPPORTED_DOCUMENT_TYPES allowlist - Fix: unsupported types (.zip etc.) no longer get MessageType.DOCUMENT - Add 9 unit tests in test_discord_document_handling.py Mirrors the Slack implementation from PR #784. Discord CDN URLs are publicly accessible so no auth header is needed (unlike Slack). Co-authored-by: Dilee --- gateway/platforms/discord.py | 78 +++- .../gateway/test_discord_document_handling.py | 347 ++++++++++++++++++ 2 files changed, 420 insertions(+), 5 deletions(-) create mode 100644 tests/gateway/test_discord_document_handling.py diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index 9553906a8..c05feb2ed 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -43,6 +43,8 @@ from pathlib import Path as _Path sys.path.insert(0, str(_Path(__file__).resolve().parents[2])) from gateway.config import Platform, PlatformConfig +import re + from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -50,6 +52,8 @@ from gateway.platforms.base import ( SendResult, cache_image_from_url, cache_audio_from_url, + cache_document_from_bytes, + SUPPORTED_DOCUMENT_TYPES, ) @@ -1950,7 +1954,12 @@ class DiscordAdapter(BasePlatformAdapter): elif att.content_type.startswith("audio/"): msg_type = MessageType.AUDIO else: - msg_type = MessageType.DOCUMENT + doc_ext = "" + if att.filename: + _, doc_ext = os.path.splitext(att.filename) + doc_ext = doc_ext.lower() + if doc_ext in SUPPORTED_DOCUMENT_TYPES: + msg_type = MessageType.DOCUMENT break # When auto-threading kicked in, route responses to the new thread @@ -1987,6 +1996,7 @@ class DiscordAdapter(BasePlatformAdapter): # vision tool can access them reliably (Discord CDN URLs can expire). media_urls = [] media_types = [] + pending_text_injection: Optional[str] = None for att in message.attachments: content_type = att.content_type or "unknown" if content_type.startswith("image/"): @@ -2018,12 +2028,70 @@ class DiscordAdapter(BasePlatformAdapter): media_urls.append(att.url) media_types.append(content_type) else: - # Other attachments: keep the original URL - media_urls.append(att.url) - media_types.append(content_type) + # Document attachments: download, cache, and optionally inject text + ext = "" + if att.filename: + _, ext = os.path.splitext(att.filename) + ext = ext.lower() + if not ext and content_type: + mime_to_ext = {v: k for k, v in SUPPORTED_DOCUMENT_TYPES.items()} + ext = mime_to_ext.get(content_type, "") + if ext not in SUPPORTED_DOCUMENT_TYPES: + logger.warning( + "[Discord] Unsupported document type '%s' (%s), skipping", + ext or "unknown", content_type, + ) + else: + MAX_DOC_BYTES = 20 * 1024 * 1024 + if att.size and att.size > MAX_DOC_BYTES: + logger.warning( + "[Discord] Document too large (%s bytes), skipping: %s", + att.size, att.filename, + ) + else: + try: + import aiohttp + async with aiohttp.ClientSession() as session: + async with session.get( + att.url, + timeout=aiohttp.ClientTimeout(total=30), + ) as resp: + if resp.status != 200: + raise Exception(f"HTTP {resp.status}") + raw_bytes = await resp.read() + cached_path = cache_document_from_bytes( + raw_bytes, att.filename or f"document{ext}" + ) + doc_mime = SUPPORTED_DOCUMENT_TYPES[ext] + media_urls.append(cached_path) + media_types.append(doc_mime) + logger.info("[Discord] Cached user document: %s", cached_path) + # Inject text content for .txt/.md files (capped at 100 KB) + MAX_TEXT_INJECT_BYTES = 100 * 1024 + if ext in (".md", ".txt") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES: + try: + text_content = raw_bytes.decode("utf-8") + display_name = att.filename or f"document{ext}" + display_name = re.sub(r'[^\w.\- ]', '_', display_name) + injection = f"[Content of {display_name}]:\n{text_content}" + if pending_text_injection: + pending_text_injection = f"{pending_text_injection}\n\n{injection}" + else: + pending_text_injection = injection + except UnicodeDecodeError: + pass + except Exception as e: + logger.warning( + "[Discord] Failed to cache document %s: %s", + att.filename, e, exc_info=True, + ) + event_text = message.content + if pending_text_injection: + event_text = f"{pending_text_injection}\n\n{event_text}" if event_text else pending_text_injection + event = MessageEvent( - text=message.content, + text=event_text, message_type=msg_type, source=source, raw_message=message, diff --git a/tests/gateway/test_discord_document_handling.py b/tests/gateway/test_discord_document_handling.py new file mode 100644 index 000000000..b3ee5d00f --- /dev/null +++ b/tests/gateway/test_discord_document_handling.py @@ -0,0 +1,347 @@ +"""Tests for Discord incoming document/file attachment handling. + +Covers the document branch in DiscordAdapter._handle_message() — +the `else` clause of the attachment content-type loop that was added +to download, cache, and optionally inject text from non-image/audio files. +""" + +import os +import sys +from datetime import datetime, timezone +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from gateway.config import PlatformConfig +from gateway.platforms.base import MessageType + + +# --------------------------------------------------------------------------- +# Discord mock setup (copied from test_discord_free_response.py) +# --------------------------------------------------------------------------- + +def _ensure_discord_mock(): + """Install a mock discord module when discord.py isn't available.""" + if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"): + return + + discord_mod = MagicMock() + discord_mod.Intents.default.return_value = MagicMock() + discord_mod.Client = MagicMock + discord_mod.File = MagicMock + discord_mod.DMChannel = type("DMChannel", (), {}) + discord_mod.Thread = type("Thread", (), {}) + discord_mod.ForumChannel = type("ForumChannel", (), {}) + discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object) + discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3) + discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4) + discord_mod.Interaction = object + discord_mod.Embed = MagicMock + discord_mod.app_commands = SimpleNamespace( + describe=lambda **kwargs: (lambda fn: fn), + choices=lambda **kwargs: (lambda fn: fn), + Choice=lambda **kwargs: SimpleNamespace(**kwargs), + ) + + ext_mod = MagicMock() + commands_mod = MagicMock() + commands_mod.Bot = MagicMock + ext_mod.commands = commands_mod + + sys.modules.setdefault("discord", discord_mod) + sys.modules.setdefault("discord.ext", ext_mod) + sys.modules.setdefault("discord.ext.commands", commands_mod) + + +_ensure_discord_mock() + +import gateway.platforms.discord as discord_platform # noqa: E402 +from gateway.platforms.discord import DiscordAdapter # noqa: E402 + + +# --------------------------------------------------------------------------- +# Fake channel / thread types +# --------------------------------------------------------------------------- + +class FakeDMChannel: + def __init__(self, channel_id: int = 1): + self.id = channel_id + self.name = "dm" + + +class FakeThread: + def __init__(self, channel_id: int = 10): + self.id = channel_id + self.name = "thread" + self.parent = None + self.parent_id = None + self.guild = SimpleNamespace(name="TestServer") + self.topic = None + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def _redirect_cache(tmp_path, monkeypatch): + """Point document cache to tmp_path so tests never write to ~/.hermes.""" + monkeypatch.setattr( + "gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache" + ) + + +@pytest.fixture +def adapter(monkeypatch): + monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False) + monkeypatch.setattr(discord_platform.discord, "Thread", FakeThread, raising=False) + + config = PlatformConfig(enabled=True, token="fake-token") + a = DiscordAdapter(config) + a._client = SimpleNamespace(user=SimpleNamespace(id=999)) + a.handle_message = AsyncMock() + return a + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def make_attachment( + *, + filename: str, + content_type: str, + size: int = 1024, + url: str = "https://cdn.discordapp.com/attachments/fake/file", +) -> SimpleNamespace: + return SimpleNamespace( + filename=filename, + content_type=content_type, + size=size, + url=url, + ) + + +def make_message(attachments: list, content: str = "") -> SimpleNamespace: + return SimpleNamespace( + id=123, + content=content, + attachments=attachments, + mentions=[], + reference=None, + created_at=datetime.now(timezone.utc), + channel=FakeDMChannel(), + author=SimpleNamespace(id=42, display_name="Tester", name="Tester"), + ) + + +def _mock_aiohttp_download(raw_bytes: bytes): + """Return a patch context manager that makes aiohttp return raw_bytes.""" + resp = AsyncMock() + resp.status = 200 + resp.read = AsyncMock(return_value=raw_bytes) + resp.__aenter__ = AsyncMock(return_value=resp) + resp.__aexit__ = AsyncMock(return_value=False) + + session = AsyncMock() + session.get = MagicMock(return_value=resp) + session.__aenter__ = AsyncMock(return_value=session) + session.__aexit__ = AsyncMock(return_value=False) + + return patch("aiohttp.ClientSession", return_value=session) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestIncomingDocumentHandling: + + @pytest.mark.asyncio + async def test_pdf_document_cached(self, adapter): + """A PDF attachment should be downloaded, cached, typed as DOCUMENT.""" + pdf_bytes = b"%PDF-1.4 fake content" + + with _mock_aiohttp_download(pdf_bytes): + msg = make_message([make_attachment(filename="report.pdf", content_type="application/pdf")]) + await adapter._handle_message(msg) + + event = adapter.handle_message.call_args[0][0] + assert event.message_type == MessageType.DOCUMENT + assert len(event.media_urls) == 1 + assert os.path.exists(event.media_urls[0]) + assert event.media_types == ["application/pdf"] + assert "[Content of" not in (event.text or "") + + @pytest.mark.asyncio + async def test_txt_content_injected(self, adapter): + """.txt file under 100KB should have its content injected into event.text.""" + file_content = b"Hello from a text file" + + with _mock_aiohttp_download(file_content): + msg = make_message( + attachments=[make_attachment(filename="notes.txt", content_type="text/plain")], + content="summarize this", + ) + await adapter._handle_message(msg) + + event = adapter.handle_message.call_args[0][0] + assert "[Content of notes.txt]:" in event.text + assert "Hello from a text file" in event.text + assert "summarize this" in event.text + # injection prepended before caption + assert event.text.index("[Content of") < event.text.index("summarize this") + + @pytest.mark.asyncio + async def test_md_content_injected(self, adapter): + """.md file under 100KB should have its content injected.""" + file_content = b"# Title\nSome markdown content" + + with _mock_aiohttp_download(file_content): + msg = make_message( + attachments=[make_attachment(filename="readme.md", content_type="text/markdown")], + content="", + ) + await adapter._handle_message(msg) + + event = adapter.handle_message.call_args[0][0] + assert "[Content of readme.md]:" in event.text + assert "# Title" in event.text + + @pytest.mark.asyncio + async def test_oversized_document_skipped(self, adapter): + """A document over 20MB should be skipped — media_urls stays empty.""" + msg = make_message([ + make_attachment( + filename="huge.pdf", + content_type="application/pdf", + size=25 * 1024 * 1024, + ) + ]) + await adapter._handle_message(msg) + + event = adapter.handle_message.call_args[0][0] + assert event.media_urls == [] + # handler must still be called + adapter.handle_message.assert_called_once() + + @pytest.mark.asyncio + async def test_unsupported_type_skipped(self, adapter): + """An unsupported file type (.zip) should be skipped silently.""" + msg = make_message([ + make_attachment(filename="archive.zip", content_type="application/zip") + ]) + await adapter._handle_message(msg) + + event = adapter.handle_message.call_args[0][0] + assert event.media_urls == [] + assert event.message_type == MessageType.TEXT + + @pytest.mark.asyncio + async def test_download_error_handled(self, adapter): + """If the HTTP download raises, the handler should not crash.""" + resp = AsyncMock() + resp.__aenter__ = AsyncMock(side_effect=RuntimeError("connection reset")) + resp.__aexit__ = AsyncMock(return_value=False) + + session = AsyncMock() + session.get = MagicMock(return_value=resp) + session.__aenter__ = AsyncMock(return_value=session) + session.__aexit__ = AsyncMock(return_value=False) + + with patch("aiohttp.ClientSession", return_value=session): + msg = make_message([ + make_attachment(filename="report.pdf", content_type="application/pdf") + ]) + await adapter._handle_message(msg) + + # Must still deliver an event + adapter.handle_message.assert_called_once() + event = adapter.handle_message.call_args[0][0] + assert event.media_urls == [] + + @pytest.mark.asyncio + async def test_large_txt_cached_not_injected(self, adapter): + """.txt over 100KB should be cached but NOT injected into event.text.""" + large_content = b"x" * (200 * 1024) + + with _mock_aiohttp_download(large_content): + msg = make_message( + attachments=[make_attachment(filename="big.txt", content_type="text/plain", size=len(large_content))], + content="", + ) + await adapter._handle_message(msg) + + event = adapter.handle_message.call_args[0][0] + assert len(event.media_urls) == 1 + assert os.path.exists(event.media_urls[0]) + assert "[Content of" not in (event.text or "") + + @pytest.mark.asyncio + async def test_multiple_text_files_both_injected(self, adapter): + """Two text file attachments should both be injected into event.text in order.""" + content1 = b"First file content" + content2 = b"Second file content" + + call_count = 0 + responses = [content1, content2] + + def make_session(_responses): + idx = 0 + + class FakeSession: + async def __aenter__(self): + return self + + async def __aexit__(self, *_): + pass + + def get(self, url, **kwargs): + nonlocal idx + data = _responses[idx % len(_responses)] + idx += 1 + + resp = AsyncMock() + resp.status = 200 + resp.read = AsyncMock(return_value=data) + resp.__aenter__ = AsyncMock(return_value=resp) + resp.__aexit__ = AsyncMock(return_value=False) + return resp + + return FakeSession() + + with patch("aiohttp.ClientSession", return_value=make_session([content1, content2])): + msg = make_message( + attachments=[ + make_attachment(filename="file1.txt", content_type="text/plain"), + make_attachment(filename="file2.txt", content_type="text/plain"), + ], + content="", + ) + await adapter._handle_message(msg) + + event = adapter.handle_message.call_args[0][0] + assert "[Content of file1.txt]:" in event.text + assert "First file content" in event.text + assert "[Content of file2.txt]:" in event.text + assert "Second file content" in event.text + assert event.text.index("file1") < event.text.index("file2") + + @pytest.mark.asyncio + async def test_image_attachment_unaffected(self, adapter): + """Image attachments should still go through the image path, not the document path.""" + with patch( + "gateway.platforms.discord.cache_image_from_url", + new_callable=AsyncMock, + return_value="/tmp/cached_image.png", + ): + msg = make_message([ + make_attachment(filename="photo.png", content_type="image/png") + ]) + await adapter._handle_message(msg) + + event = adapter.handle_message.call_args[0][0] + assert event.message_type == MessageType.PHOTO + assert event.media_urls == ["/tmp/cached_image.png"] + assert event.media_types == ["image/png"]