feat(discord): only create threads and reactions for authorized users

This commit is contained in:
Laura Batalha
2026-03-31 23:39:40 +01:00
committed by Teknium
parent 0a6d366327
commit f4d44c777b

View File

@@ -408,7 +408,7 @@ class VoiceReceiver:
class DiscordAdapter(BasePlatformAdapter):
"""
Discord bot adapter.
Handles:
- Receiving messages from servers and DMs
- Sending responses with Discord markdown
@@ -418,10 +418,10 @@ class DiscordAdapter(BasePlatformAdapter):
- Auto-threading for long conversations
- Reaction-based feedback
"""
# Discord message limits
MAX_MESSAGE_LENGTH = 2000
# Auto-disconnect from voice channel after this many seconds of inactivity
VOICE_TIMEOUT = 300
@@ -449,7 +449,7 @@ class DiscordAdapter(BasePlatformAdapter):
self._bot_task: Optional[asyncio.Task] = None
# Cap to prevent unbounded growth (Discord threads get archived).
self._MAX_TRACKED_THREADS = 500
async def connect(self) -> bool:
"""Connect to Discord and start receiving events."""
if not DISCORD_AVAILABLE:
@@ -480,11 +480,11 @@ class DiscordAdapter(BasePlatformAdapter):
logger.warning("Opus codec found at %s but failed to load", opus_path)
if not discord.opus.is_loaded():
logger.warning("Opus codec not found — voice channel playback disabled")
if not self.config.token:
logger.error("[%s] No bot token configured", self.name)
return False
try:
# Acquire scoped lock to prevent duplicate bot token usage
from gateway.status import acquire_scoped_lock
@@ -504,13 +504,13 @@ class DiscordAdapter(BasePlatformAdapter):
intents.guild_messages = True
intents.members = True
intents.voice_states = True
# Create bot
self._client = commands.Bot(
command_prefix="!", # Not really used, we handle raw messages
intents=intents,
)
# Parse allowed user entries (may contain usernames or IDs)
allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "")
if allowed_env:
@@ -518,17 +518,17 @@ class DiscordAdapter(BasePlatformAdapter):
_clean_discord_id(uid) for uid in allowed_env.split(",")
if uid.strip()
}
adapter_self = self # capture for closure
# Register event handlers
@self._client.event
async def on_ready():
logger.info("[%s] Connected as %s", adapter_self.name, adapter_self._client.user)
# Resolve any usernames in the allowed list to numeric IDs
await adapter_self._resolve_allowed_usernames()
# Sync slash commands with Discord
try:
synced = await adapter_self._client.tree.sync()
@@ -536,18 +536,22 @@ class DiscordAdapter(BasePlatformAdapter):
except Exception as e: # pragma: no cover - defensive logging
logger.warning("[%s] Slash command sync failed: %s", adapter_self.name, e, exc_info=True)
adapter_self._ready_event.set()
@self._client.event
async def on_message(message: DiscordMessage):
# Always ignore our own messages
if message.author == self._client.user:
return
# Ignore Discord system messages (thread renames, pins, member joins, etc.)
# Allow both default and reply types — replies have a distinct MessageType.
if message.type not in (discord.MessageType.default, discord.MessageType.reply):
return
# Check if the message author is in the allowed user list
if not self._is_allowed_user(str(message.author.id)):
return
# Bot message filtering (DISCORD_ALLOW_BOTS):
# "none" — ignore all other bots (default)
# "mentions" — accept bot messages only when they @mention us
@@ -560,7 +564,7 @@ class DiscordAdapter(BasePlatformAdapter):
if not self._client.user or self._client.user not in message.mentions:
return
# "all" falls through to handle_message
# If the message @mentions other users but NOT the bot, the
# sender is talking to someone else — stay silent. Only
# applies in server channels; in DMs the user is always
@@ -614,23 +618,23 @@ class DiscordAdapter(BasePlatformAdapter):
# Register slash commands
self._register_slash_commands()
# Start the bot in background
self._bot_task = asyncio.create_task(self._client.start(self.config.token))
# Wait for ready
await asyncio.wait_for(self._ready_event.wait(), timeout=30)
self._running = True
return True
except asyncio.TimeoutError:
logger.error("[%s] Timeout waiting for connection to Discord", self.name, exc_info=True)
return False
except Exception as e: # pragma: no cover - defensive logging
logger.error("[%s] Failed to connect to Discord: %s", self.name, e, exc_info=True)
return False
async def disconnect(self) -> None:
"""Disconnect from Discord."""
# Clean up all active voice connections before closing the client
@@ -703,7 +707,7 @@ class DiscordAdapter(BasePlatformAdapter):
if hasattr(message, "add_reaction"):
await self._remove_reaction(message, "👀")
await self._add_reaction(message, "" if success else "")
async def send(
self,
chat_id: str,
@@ -720,24 +724,24 @@ class DiscordAdapter(BasePlatformAdapter):
channel = self._client.get_channel(int(chat_id))
if not channel:
channel = await self._client.fetch_channel(int(chat_id))
if not channel:
return SendResult(success=False, error=f"Channel {chat_id} not found")
# Format and split message if needed
formatted = self.format_message(content)
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
message_ids = []
reference = None
if reply_to:
try:
ref_msg = await channel.fetch_message(int(reply_to))
reference = ref_msg
except Exception as e:
logger.debug("Could not fetch reply-to message: %s", e)
for i, chunk in enumerate(chunks):
chunk_reference = reference if i == 0 else None
try:
@@ -764,13 +768,13 @@ class DiscordAdapter(BasePlatformAdapter):
else:
raise
message_ids.append(str(msg.id))
return SendResult(
success=True,
message_id=message_ids[0] if message_ids else None,
raw_response={"message_ids": message_ids}
)
except Exception as e: # pragma: no cover - defensive logging
logger.error("[%s] Failed to send Discord message: %s", self.name, e, exc_info=True)
return SendResult(success=False, error=str(e))
@@ -1242,25 +1246,25 @@ class DiscordAdapter(BasePlatformAdapter):
"""Send an image natively as a Discord file attachment."""
if not self._client:
return SendResult(success=False, error="Not connected")
try:
import aiohttp
channel = self._client.get_channel(int(chat_id))
if not channel:
channel = await self._client.fetch_channel(int(chat_id))
if not channel:
return SendResult(success=False, error=f"Channel {chat_id} not found")
# Download the image and send as a Discord file attachment
# (Discord renders attachments inline, unlike plain URLs)
async with aiohttp.ClientSession() as session:
async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
if resp.status != 200:
raise Exception(f"Failed to download image: HTTP {resp.status}")
image_data = await resp.read()
# Determine filename from URL or content type
content_type = resp.headers.get("content-type", "image/png")
ext = "png"
@@ -1270,16 +1274,16 @@ class DiscordAdapter(BasePlatformAdapter):
ext = "gif"
elif "webp" in content_type:
ext = "webp"
import io
file = discord.File(io.BytesIO(image_data), filename=f"image.{ext}")
msg = await channel.send(
content=caption if caption else None,
file=file,
)
return SendResult(success=True, message_id=str(msg.id))
except ImportError:
logger.warning(
"[%s] aiohttp not installed, falling back to URL. Run: pip install aiohttp",
@@ -1330,7 +1334,7 @@ class DiscordAdapter(BasePlatformAdapter):
except Exception as e: # pragma: no cover - defensive logging
logger.error("[%s] Failed to send document, falling back to base adapter: %s", self.name, e, exc_info=True)
return await super().send_document(chat_id, file_path, caption, file_name, reply_to, metadata=metadata)
async def send_typing(self, chat_id: str, metadata=None) -> None:
"""Start a persistent typing indicator for a channel.
@@ -1374,20 +1378,20 @@ class DiscordAdapter(BasePlatformAdapter):
await task
except (asyncio.CancelledError, Exception):
pass
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
"""Get information about a Discord channel."""
if not self._client:
return {"name": "Unknown", "type": "dm"}
try:
channel = self._client.get_channel(int(chat_id))
if not channel:
channel = await self._client.fetch_channel(int(chat_id))
if not channel:
return {"name": str(chat_id), "type": "dm"}
# Determine channel type
if isinstance(channel, discord.DMChannel):
chat_type = "dm"
@@ -1403,7 +1407,7 @@ class DiscordAdapter(BasePlatformAdapter):
else:
chat_type = "channel"
name = getattr(channel, "name", str(chat_id))
return {
"name": name,
"type": chat_type,
@@ -1413,7 +1417,7 @@ class DiscordAdapter(BasePlatformAdapter):
except Exception as e: # pragma: no cover - defensive logging
logger.error("[%s] Failed to get chat info for %s: %s", self.name, chat_id, e, exc_info=True)
return {"name": str(chat_id), "type": "dm", "error": str(e)}
async def _resolve_allowed_usernames(self) -> None:
"""
Resolve non-numeric entries in DISCORD_ALLOWED_USERS to Discord user IDs.
@@ -1481,7 +1485,7 @@ class DiscordAdapter(BasePlatformAdapter):
def format_message(self, content: str) -> str:
"""
Format message for Discord.
Discord uses its own markdown variant.
"""
# Discord markdown is fairly standard, no special escaping needed
@@ -1647,7 +1651,7 @@ class DiscordAdapter(BasePlatformAdapter):
chat_name = interaction.channel.name
if hasattr(interaction.channel, "guild") and interaction.channel.guild:
chat_name = f"{interaction.channel.guild.name} / #{chat_name}"
# Get channel topic (if available)
chat_topic = getattr(interaction.channel, "topic", None)
@@ -2051,7 +2055,7 @@ class DiscordAdapter(BasePlatformAdapter):
if doc_ext in SUPPORTED_DOCUMENT_TYPES:
msg_type = MessageType.DOCUMENT
break
# When auto-threading kicked in, route responses to the new thread
effective_channel = auto_threaded_channel or message.channel
@@ -2070,7 +2074,7 @@ class DiscordAdapter(BasePlatformAdapter):
# Get channel topic (if available - TextChannels have topics, DMs/threads don't)
chat_topic = getattr(message.channel, "topic", None)
# Build source
source = self.build_source(
chat_id=str(effective_channel.id),
@@ -2081,7 +2085,7 @@ class DiscordAdapter(BasePlatformAdapter):
thread_id=thread_id,
chat_topic=chat_topic,
)
# Build media URLs -- download image attachments to local cache so the
# vision tool can access them reliably (Discord CDN URLs can expire).
media_urls = []
@@ -2175,7 +2179,7 @@ class DiscordAdapter(BasePlatformAdapter):
"[Discord] Failed to cache document %s: %s",
att.filename, e, exc_info=True,
)
event_text = message.content
if pending_text_injection:
event_text = f"{pending_text_injection}\n\n{event_text}" if event_text else pending_text_injection