feat(discord): add document caching and text-file injection (#2503)
- Download and cache .pdf, .docx, .xlsx, .pptx attachments locally instead of passing expiring CDN URLs to the agent - Inject .txt and .md content (≤100 KB) into event.text so the agent sees file content without needing to fetch the URL - Add 20 MB size guard and SUPPORTED_DOCUMENT_TYPES allowlist - Fix: unsupported types (.zip etc.) no longer get MessageType.DOCUMENT - Add 9 unit tests in test_discord_document_handling.py Mirrors the Slack implementation from PR #784. Discord CDN URLs are publicly accessible so no auth header is needed (unlike Slack). Co-authored-by: Dilee <uzmpsk.dilekakbas@gmail.com>
This commit is contained in:
@@ -43,6 +43,8 @@ from pathlib import Path as _Path
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
import re
|
||||
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
@@ -50,6 +52,8 @@ from gateway.platforms.base import (
|
||||
SendResult,
|
||||
cache_image_from_url,
|
||||
cache_audio_from_url,
|
||||
cache_document_from_bytes,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
|
||||
|
||||
@@ -1950,7 +1954,12 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
elif att.content_type.startswith("audio/"):
|
||||
msg_type = MessageType.AUDIO
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
doc_ext = ""
|
||||
if att.filename:
|
||||
_, doc_ext = os.path.splitext(att.filename)
|
||||
doc_ext = doc_ext.lower()
|
||||
if doc_ext in SUPPORTED_DOCUMENT_TYPES:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
break
|
||||
|
||||
# When auto-threading kicked in, route responses to the new thread
|
||||
@@ -1987,6 +1996,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# vision tool can access them reliably (Discord CDN URLs can expire).
|
||||
media_urls = []
|
||||
media_types = []
|
||||
pending_text_injection: Optional[str] = None
|
||||
for att in message.attachments:
|
||||
content_type = att.content_type or "unknown"
|
||||
if content_type.startswith("image/"):
|
||||
@@ -2018,12 +2028,70 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
media_urls.append(att.url)
|
||||
media_types.append(content_type)
|
||||
else:
|
||||
# Other attachments: keep the original URL
|
||||
media_urls.append(att.url)
|
||||
media_types.append(content_type)
|
||||
# Document attachments: download, cache, and optionally inject text
|
||||
ext = ""
|
||||
if att.filename:
|
||||
_, ext = os.path.splitext(att.filename)
|
||||
ext = ext.lower()
|
||||
if not ext and content_type:
|
||||
mime_to_ext = {v: k for k, v in SUPPORTED_DOCUMENT_TYPES.items()}
|
||||
ext = mime_to_ext.get(content_type, "")
|
||||
if ext not in SUPPORTED_DOCUMENT_TYPES:
|
||||
logger.warning(
|
||||
"[Discord] Unsupported document type '%s' (%s), skipping",
|
||||
ext or "unknown", content_type,
|
||||
)
|
||||
else:
|
||||
MAX_DOC_BYTES = 20 * 1024 * 1024
|
||||
if att.size and att.size > MAX_DOC_BYTES:
|
||||
logger.warning(
|
||||
"[Discord] Document too large (%s bytes), skipping: %s",
|
||||
att.size, att.filename,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
att.url,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"HTTP {resp.status}")
|
||||
raw_bytes = await resp.read()
|
||||
cached_path = cache_document_from_bytes(
|
||||
raw_bytes, att.filename or f"document{ext}"
|
||||
)
|
||||
doc_mime = SUPPORTED_DOCUMENT_TYPES[ext]
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(doc_mime)
|
||||
logger.info("[Discord] Cached user document: %s", cached_path)
|
||||
# Inject text content for .txt/.md files (capped at 100 KB)
|
||||
MAX_TEXT_INJECT_BYTES = 100 * 1024
|
||||
if ext in (".md", ".txt") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES:
|
||||
try:
|
||||
text_content = raw_bytes.decode("utf-8")
|
||||
display_name = att.filename or f"document{ext}"
|
||||
display_name = re.sub(r'[^\w.\- ]', '_', display_name)
|
||||
injection = f"[Content of {display_name}]:\n{text_content}"
|
||||
if pending_text_injection:
|
||||
pending_text_injection = f"{pending_text_injection}\n\n{injection}"
|
||||
else:
|
||||
pending_text_injection = injection
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"[Discord] Failed to cache document %s: %s",
|
||||
att.filename, e, exc_info=True,
|
||||
)
|
||||
|
||||
event_text = message.content
|
||||
if pending_text_injection:
|
||||
event_text = f"{pending_text_injection}\n\n{event_text}" if event_text else pending_text_injection
|
||||
|
||||
event = MessageEvent(
|
||||
text=message.content,
|
||||
text=event_text,
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=message,
|
||||
|
||||
347
tests/gateway/test_discord_document_handling.py
Normal file
347
tests/gateway/test_discord_document_handling.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""Tests for Discord incoming document/file attachment handling.
|
||||
|
||||
Covers the document branch in DiscordAdapter._handle_message() —
|
||||
the `else` clause of the attachment content-type loop that was added
|
||||
to download, cache, and optionally inject text from non-image/audio files.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.base import MessageType
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discord mock setup (copied from test_discord_free_response.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ensure_discord_mock():
|
||||
"""Install a mock discord module when discord.py isn't available."""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3)
|
||||
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
import gateway.platforms.discord as discord_platform # noqa: E402
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake channel / thread types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class FakeDMChannel:
|
||||
def __init__(self, channel_id: int = 1):
|
||||
self.id = channel_id
|
||||
self.name = "dm"
|
||||
|
||||
|
||||
class FakeThread:
|
||||
def __init__(self, channel_id: int = 10):
|
||||
self.id = channel_id
|
||||
self.name = "thread"
|
||||
self.parent = None
|
||||
self.parent_id = None
|
||||
self.guild = SimpleNamespace(name="TestServer")
|
||||
self.topic = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _redirect_cache(tmp_path, monkeypatch):
|
||||
"""Point document cache to tmp_path so tests never write to ~/.hermes."""
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(monkeypatch):
|
||||
monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False)
|
||||
monkeypatch.setattr(discord_platform.discord, "Thread", FakeThread, raising=False)
|
||||
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
a = DiscordAdapter(config)
|
||||
a._client = SimpleNamespace(user=SimpleNamespace(id=999))
|
||||
a.handle_message = AsyncMock()
|
||||
return a
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_attachment(
|
||||
*,
|
||||
filename: str,
|
||||
content_type: str,
|
||||
size: int = 1024,
|
||||
url: str = "https://cdn.discordapp.com/attachments/fake/file",
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
filename=filename,
|
||||
content_type=content_type,
|
||||
size=size,
|
||||
url=url,
|
||||
)
|
||||
|
||||
|
||||
def make_message(attachments: list, content: str = "") -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=123,
|
||||
content=content,
|
||||
attachments=attachments,
|
||||
mentions=[],
|
||||
reference=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
channel=FakeDMChannel(),
|
||||
author=SimpleNamespace(id=42, display_name="Tester", name="Tester"),
|
||||
)
|
||||
|
||||
|
||||
def _mock_aiohttp_download(raw_bytes: bytes):
|
||||
"""Return a patch context manager that makes aiohttp return raw_bytes."""
|
||||
resp = AsyncMock()
|
||||
resp.status = 200
|
||||
resp.read = AsyncMock(return_value=raw_bytes)
|
||||
resp.__aenter__ = AsyncMock(return_value=resp)
|
||||
resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
session = AsyncMock()
|
||||
session.get = MagicMock(return_value=resp)
|
||||
session.__aenter__ = AsyncMock(return_value=session)
|
||||
session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
return patch("aiohttp.ClientSession", return_value=session)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIncomingDocumentHandling:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_document_cached(self, adapter):
|
||||
"""A PDF attachment should be downloaded, cached, typed as DOCUMENT."""
|
||||
pdf_bytes = b"%PDF-1.4 fake content"
|
||||
|
||||
with _mock_aiohttp_download(pdf_bytes):
|
||||
msg = make_message([make_attachment(filename="report.pdf", content_type="application/pdf")])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.message_type == MessageType.DOCUMENT
|
||||
assert len(event.media_urls) == 1
|
||||
assert os.path.exists(event.media_urls[0])
|
||||
assert event.media_types == ["application/pdf"]
|
||||
assert "[Content of" not in (event.text or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_txt_content_injected(self, adapter):
|
||||
""".txt file under 100KB should have its content injected into event.text."""
|
||||
file_content = b"Hello from a text file"
|
||||
|
||||
with _mock_aiohttp_download(file_content):
|
||||
msg = make_message(
|
||||
attachments=[make_attachment(filename="notes.txt", content_type="text/plain")],
|
||||
content="summarize this",
|
||||
)
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "[Content of notes.txt]:" in event.text
|
||||
assert "Hello from a text file" in event.text
|
||||
assert "summarize this" in event.text
|
||||
# injection prepended before caption
|
||||
assert event.text.index("[Content of") < event.text.index("summarize this")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_md_content_injected(self, adapter):
|
||||
""".md file under 100KB should have its content injected."""
|
||||
file_content = b"# Title\nSome markdown content"
|
||||
|
||||
with _mock_aiohttp_download(file_content):
|
||||
msg = make_message(
|
||||
attachments=[make_attachment(filename="readme.md", content_type="text/markdown")],
|
||||
content="",
|
||||
)
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "[Content of readme.md]:" in event.text
|
||||
assert "# Title" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_document_skipped(self, adapter):
|
||||
"""A document over 20MB should be skipped — media_urls stays empty."""
|
||||
msg = make_message([
|
||||
make_attachment(
|
||||
filename="huge.pdf",
|
||||
content_type="application/pdf",
|
||||
size=25 * 1024 * 1024,
|
||||
)
|
||||
])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.media_urls == []
|
||||
# handler must still be called
|
||||
adapter.handle_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_type_skipped(self, adapter):
|
||||
"""An unsupported file type (.zip) should be skipped silently."""
|
||||
msg = make_message([
|
||||
make_attachment(filename="archive.zip", content_type="application/zip")
|
||||
])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.media_urls == []
|
||||
assert event.message_type == MessageType.TEXT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_error_handled(self, adapter):
|
||||
"""If the HTTP download raises, the handler should not crash."""
|
||||
resp = AsyncMock()
|
||||
resp.__aenter__ = AsyncMock(side_effect=RuntimeError("connection reset"))
|
||||
resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
session = AsyncMock()
|
||||
session.get = MagicMock(return_value=resp)
|
||||
session.__aenter__ = AsyncMock(return_value=session)
|
||||
session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=session):
|
||||
msg = make_message([
|
||||
make_attachment(filename="report.pdf", content_type="application/pdf")
|
||||
])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
# Must still deliver an event
|
||||
adapter.handle_message.assert_called_once()
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.media_urls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_txt_cached_not_injected(self, adapter):
|
||||
""".txt over 100KB should be cached but NOT injected into event.text."""
|
||||
large_content = b"x" * (200 * 1024)
|
||||
|
||||
with _mock_aiohttp_download(large_content):
|
||||
msg = make_message(
|
||||
attachments=[make_attachment(filename="big.txt", content_type="text/plain", size=len(large_content))],
|
||||
content="",
|
||||
)
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert len(event.media_urls) == 1
|
||||
assert os.path.exists(event.media_urls[0])
|
||||
assert "[Content of" not in (event.text or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_text_files_both_injected(self, adapter):
|
||||
"""Two text file attachments should both be injected into event.text in order."""
|
||||
content1 = b"First file content"
|
||||
content2 = b"Second file content"
|
||||
|
||||
call_count = 0
|
||||
responses = [content1, content2]
|
||||
|
||||
def make_session(_responses):
|
||||
idx = 0
|
||||
|
||||
class FakeSession:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *_):
|
||||
pass
|
||||
|
||||
def get(self, url, **kwargs):
|
||||
nonlocal idx
|
||||
data = _responses[idx % len(_responses)]
|
||||
idx += 1
|
||||
|
||||
resp = AsyncMock()
|
||||
resp.status = 200
|
||||
resp.read = AsyncMock(return_value=data)
|
||||
resp.__aenter__ = AsyncMock(return_value=resp)
|
||||
resp.__aexit__ = AsyncMock(return_value=False)
|
||||
return resp
|
||||
|
||||
return FakeSession()
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=make_session([content1, content2])):
|
||||
msg = make_message(
|
||||
attachments=[
|
||||
make_attachment(filename="file1.txt", content_type="text/plain"),
|
||||
make_attachment(filename="file2.txt", content_type="text/plain"),
|
||||
],
|
||||
content="",
|
||||
)
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "[Content of file1.txt]:" in event.text
|
||||
assert "First file content" in event.text
|
||||
assert "[Content of file2.txt]:" in event.text
|
||||
assert "Second file content" in event.text
|
||||
assert event.text.index("file1") < event.text.index("file2")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_attachment_unaffected(self, adapter):
|
||||
"""Image attachments should still go through the image path, not the document path."""
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_image_from_url",
|
||||
new_callable=AsyncMock,
|
||||
return_value="/tmp/cached_image.png",
|
||||
):
|
||||
msg = make_message([
|
||||
make_attachment(filename="photo.png", content_type="image/png")
|
||||
])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.message_type == MessageType.PHOTO
|
||||
assert event.media_urls == ["/tmp/cached_image.png"]
|
||||
assert event.media_types == ["image/png"]
|
||||
Reference in New Issue
Block a user