diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index 168919b09..6146bb2bc 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -408,7 +408,7 @@ class VoiceReceiver: class DiscordAdapter(BasePlatformAdapter): """ Discord bot adapter. - + Handles: - Receiving messages from servers and DMs - Sending responses with Discord markdown @@ -418,10 +418,10 @@ class DiscordAdapter(BasePlatformAdapter): - Auto-threading for long conversations - Reaction-based feedback """ - + # Discord message limits MAX_MESSAGE_LENGTH = 2000 - + # Auto-disconnect from voice channel after this many seconds of inactivity VOICE_TIMEOUT = 300 @@ -449,7 +449,7 @@ class DiscordAdapter(BasePlatformAdapter): self._bot_task: Optional[asyncio.Task] = None # Cap to prevent unbounded growth (Discord threads get archived). self._MAX_TRACKED_THREADS = 500 - + async def connect(self) -> bool: """Connect to Discord and start receiving events.""" if not DISCORD_AVAILABLE: @@ -480,11 +480,11 @@ class DiscordAdapter(BasePlatformAdapter): logger.warning("Opus codec found at %s but failed to load", opus_path) if not discord.opus.is_loaded(): logger.warning("Opus codec not found — voice channel playback disabled") - + if not self.config.token: logger.error("[%s] No bot token configured", self.name) return False - + try: # Acquire scoped lock to prevent duplicate bot token usage from gateway.status import acquire_scoped_lock @@ -504,13 +504,13 @@ class DiscordAdapter(BasePlatformAdapter): intents.guild_messages = True intents.members = True intents.voice_states = True - + # Create bot self._client = commands.Bot( command_prefix="!", # Not really used, we handle raw messages intents=intents, ) - + # Parse allowed user entries (may contain usernames or IDs) allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "") if allowed_env: @@ -518,17 +518,17 @@ class DiscordAdapter(BasePlatformAdapter): _clean_discord_id(uid) for uid in allowed_env.split(",") if uid.strip() } - + adapter_self = self # capture for closure - + # Register event handlers @self._client.event async def on_ready(): logger.info("[%s] Connected as %s", adapter_self.name, adapter_self._client.user) - + # Resolve any usernames in the allowed list to numeric IDs await adapter_self._resolve_allowed_usernames() - + # Sync slash commands with Discord try: synced = await adapter_self._client.tree.sync() @@ -536,18 +536,22 @@ class DiscordAdapter(BasePlatformAdapter): except Exception as e: # pragma: no cover - defensive logging logger.warning("[%s] Slash command sync failed: %s", adapter_self.name, e, exc_info=True) adapter_self._ready_event.set() - + @self._client.event async def on_message(message: DiscordMessage): # Always ignore our own messages if message.author == self._client.user: return - + # Ignore Discord system messages (thread renames, pins, member joins, etc.) # Allow both default and reply types — replies have a distinct MessageType. if message.type not in (discord.MessageType.default, discord.MessageType.reply): return - + + # Check if the message author is in the allowed user list + if not self._is_allowed_user(str(message.author.id)): + return + # Bot message filtering (DISCORD_ALLOW_BOTS): # "none" — ignore all other bots (default) # "mentions" — accept bot messages only when they @mention us @@ -560,7 +564,7 @@ class DiscordAdapter(BasePlatformAdapter): if not self._client.user or self._client.user not in message.mentions: return # "all" falls through to handle_message - + # If the message @mentions other users but NOT the bot, the # sender is talking to someone else — stay silent. Only # applies in server channels; in DMs the user is always @@ -614,23 +618,23 @@ class DiscordAdapter(BasePlatformAdapter): # Register slash commands self._register_slash_commands() - + # Start the bot in background self._bot_task = asyncio.create_task(self._client.start(self.config.token)) - + # Wait for ready await asyncio.wait_for(self._ready_event.wait(), timeout=30) - + self._running = True return True - + except asyncio.TimeoutError: logger.error("[%s] Timeout waiting for connection to Discord", self.name, exc_info=True) return False except Exception as e: # pragma: no cover - defensive logging logger.error("[%s] Failed to connect to Discord: %s", self.name, e, exc_info=True) return False - + async def disconnect(self) -> None: """Disconnect from Discord.""" # Clean up all active voice connections before closing the client @@ -703,7 +707,7 @@ class DiscordAdapter(BasePlatformAdapter): if hasattr(message, "add_reaction"): await self._remove_reaction(message, "👀") await self._add_reaction(message, "✅" if success else "❌") - + async def send( self, chat_id: str, @@ -720,24 +724,24 @@ class DiscordAdapter(BasePlatformAdapter): channel = self._client.get_channel(int(chat_id)) if not channel: channel = await self._client.fetch_channel(int(chat_id)) - + if not channel: return SendResult(success=False, error=f"Channel {chat_id} not found") - + # Format and split message if needed formatted = self.format_message(content) chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH) - + message_ids = [] reference = None - + if reply_to: try: ref_msg = await channel.fetch_message(int(reply_to)) reference = ref_msg except Exception as e: logger.debug("Could not fetch reply-to message: %s", e) - + for i, chunk in enumerate(chunks): chunk_reference = reference if i == 0 else None try: @@ -764,13 +768,13 @@ class DiscordAdapter(BasePlatformAdapter): else: raise message_ids.append(str(msg.id)) - + return SendResult( success=True, message_id=message_ids[0] if message_ids else None, raw_response={"message_ids": message_ids} ) - + except Exception as e: # pragma: no cover - defensive logging logger.error("[%s] Failed to send Discord message: %s", self.name, e, exc_info=True) return SendResult(success=False, error=str(e)) @@ -1242,25 +1246,25 @@ class DiscordAdapter(BasePlatformAdapter): """Send an image natively as a Discord file attachment.""" if not self._client: return SendResult(success=False, error="Not connected") - + try: import aiohttp - + channel = self._client.get_channel(int(chat_id)) if not channel: channel = await self._client.fetch_channel(int(chat_id)) if not channel: return SendResult(success=False, error=f"Channel {chat_id} not found") - + # Download the image and send as a Discord file attachment # (Discord renders attachments inline, unlike plain URLs) async with aiohttp.ClientSession() as session: async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30)) as resp: if resp.status != 200: raise Exception(f"Failed to download image: HTTP {resp.status}") - + image_data = await resp.read() - + # Determine filename from URL or content type content_type = resp.headers.get("content-type", "image/png") ext = "png" @@ -1270,16 +1274,16 @@ class DiscordAdapter(BasePlatformAdapter): ext = "gif" elif "webp" in content_type: ext = "webp" - + import io file = discord.File(io.BytesIO(image_data), filename=f"image.{ext}") - + msg = await channel.send( content=caption if caption else None, file=file, ) return SendResult(success=True, message_id=str(msg.id)) - + except ImportError: logger.warning( "[%s] aiohttp not installed, falling back to URL. Run: pip install aiohttp", @@ -1330,7 +1334,7 @@ class DiscordAdapter(BasePlatformAdapter): except Exception as e: # pragma: no cover - defensive logging logger.error("[%s] Failed to send document, falling back to base adapter: %s", self.name, e, exc_info=True) return await super().send_document(chat_id, file_path, caption, file_name, reply_to, metadata=metadata) - + async def send_typing(self, chat_id: str, metadata=None) -> None: """Start a persistent typing indicator for a channel. @@ -1374,20 +1378,20 @@ class DiscordAdapter(BasePlatformAdapter): await task except (asyncio.CancelledError, Exception): pass - + async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: """Get information about a Discord channel.""" if not self._client: return {"name": "Unknown", "type": "dm"} - + try: channel = self._client.get_channel(int(chat_id)) if not channel: channel = await self._client.fetch_channel(int(chat_id)) - + if not channel: return {"name": str(chat_id), "type": "dm"} - + # Determine channel type if isinstance(channel, discord.DMChannel): chat_type = "dm" @@ -1403,7 +1407,7 @@ class DiscordAdapter(BasePlatformAdapter): else: chat_type = "channel" name = getattr(channel, "name", str(chat_id)) - + return { "name": name, "type": chat_type, @@ -1413,7 +1417,7 @@ class DiscordAdapter(BasePlatformAdapter): except Exception as e: # pragma: no cover - defensive logging logger.error("[%s] Failed to get chat info for %s: %s", self.name, chat_id, e, exc_info=True) return {"name": str(chat_id), "type": "dm", "error": str(e)} - + async def _resolve_allowed_usernames(self) -> None: """ Resolve non-numeric entries in DISCORD_ALLOWED_USERS to Discord user IDs. @@ -1481,7 +1485,7 @@ class DiscordAdapter(BasePlatformAdapter): def format_message(self, content: str) -> str: """ Format message for Discord. - + Discord uses its own markdown variant. """ # Discord markdown is fairly standard, no special escaping needed @@ -1647,7 +1651,7 @@ class DiscordAdapter(BasePlatformAdapter): chat_name = interaction.channel.name if hasattr(interaction.channel, "guild") and interaction.channel.guild: chat_name = f"{interaction.channel.guild.name} / #{chat_name}" - + # Get channel topic (if available) chat_topic = getattr(interaction.channel, "topic", None) @@ -2051,7 +2055,7 @@ class DiscordAdapter(BasePlatformAdapter): if doc_ext in SUPPORTED_DOCUMENT_TYPES: msg_type = MessageType.DOCUMENT break - + # When auto-threading kicked in, route responses to the new thread effective_channel = auto_threaded_channel or message.channel @@ -2070,7 +2074,7 @@ class DiscordAdapter(BasePlatformAdapter): # Get channel topic (if available - TextChannels have topics, DMs/threads don't) chat_topic = getattr(message.channel, "topic", None) - + # Build source source = self.build_source( chat_id=str(effective_channel.id), @@ -2081,7 +2085,7 @@ class DiscordAdapter(BasePlatformAdapter): thread_id=thread_id, chat_topic=chat_topic, ) - + # Build media URLs -- download image attachments to local cache so the # vision tool can access them reliably (Discord CDN URLs can expire). media_urls = [] @@ -2175,7 +2179,7 @@ class DiscordAdapter(BasePlatformAdapter): "[Discord] Failed to cache document %s: %s", att.filename, e, exc_info=True, ) - + event_text = message.content if pending_text_injection: event_text = f"{pending_text_injection}\n\n{event_text}" if event_text else pending_text_injection