feat(discord): only create threads and reactions for authorized users
This commit is contained in:
@@ -408,7 +408,7 @@ class VoiceReceiver:
|
||||
class DiscordAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Discord bot adapter.
|
||||
|
||||
|
||||
Handles:
|
||||
- Receiving messages from servers and DMs
|
||||
- Sending responses with Discord markdown
|
||||
@@ -418,10 +418,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
- Auto-threading for long conversations
|
||||
- Reaction-based feedback
|
||||
"""
|
||||
|
||||
|
||||
# Discord message limits
|
||||
MAX_MESSAGE_LENGTH = 2000
|
||||
|
||||
|
||||
# Auto-disconnect from voice channel after this many seconds of inactivity
|
||||
VOICE_TIMEOUT = 300
|
||||
|
||||
@@ -449,7 +449,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
self._bot_task: Optional[asyncio.Task] = None
|
||||
# Cap to prevent unbounded growth (Discord threads get archived).
|
||||
self._MAX_TRACKED_THREADS = 500
|
||||
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Discord and start receiving events."""
|
||||
if not DISCORD_AVAILABLE:
|
||||
@@ -480,11 +480,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
logger.warning("Opus codec found at %s but failed to load", opus_path)
|
||||
if not discord.opus.is_loaded():
|
||||
logger.warning("Opus codec not found — voice channel playback disabled")
|
||||
|
||||
|
||||
if not self.config.token:
|
||||
logger.error("[%s] No bot token configured", self.name)
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
# Acquire scoped lock to prevent duplicate bot token usage
|
||||
from gateway.status import acquire_scoped_lock
|
||||
@@ -504,13 +504,13 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
intents.guild_messages = True
|
||||
intents.members = True
|
||||
intents.voice_states = True
|
||||
|
||||
|
||||
# Create bot
|
||||
self._client = commands.Bot(
|
||||
command_prefix="!", # Not really used, we handle raw messages
|
||||
intents=intents,
|
||||
)
|
||||
|
||||
|
||||
# Parse allowed user entries (may contain usernames or IDs)
|
||||
allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "")
|
||||
if allowed_env:
|
||||
@@ -518,17 +518,17 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
_clean_discord_id(uid) for uid in allowed_env.split(",")
|
||||
if uid.strip()
|
||||
}
|
||||
|
||||
|
||||
adapter_self = self # capture for closure
|
||||
|
||||
|
||||
# Register event handlers
|
||||
@self._client.event
|
||||
async def on_ready():
|
||||
logger.info("[%s] Connected as %s", adapter_self.name, adapter_self._client.user)
|
||||
|
||||
|
||||
# Resolve any usernames in the allowed list to numeric IDs
|
||||
await adapter_self._resolve_allowed_usernames()
|
||||
|
||||
|
||||
# Sync slash commands with Discord
|
||||
try:
|
||||
synced = await adapter_self._client.tree.sync()
|
||||
@@ -536,18 +536,22 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.warning("[%s] Slash command sync failed: %s", adapter_self.name, e, exc_info=True)
|
||||
adapter_self._ready_event.set()
|
||||
|
||||
|
||||
@self._client.event
|
||||
async def on_message(message: DiscordMessage):
|
||||
# Always ignore our own messages
|
||||
if message.author == self._client.user:
|
||||
return
|
||||
|
||||
|
||||
# Ignore Discord system messages (thread renames, pins, member joins, etc.)
|
||||
# Allow both default and reply types — replies have a distinct MessageType.
|
||||
if message.type not in (discord.MessageType.default, discord.MessageType.reply):
|
||||
return
|
||||
|
||||
|
||||
# Check if the message author is in the allowed user list
|
||||
if not self._is_allowed_user(str(message.author.id)):
|
||||
return
|
||||
|
||||
# Bot message filtering (DISCORD_ALLOW_BOTS):
|
||||
# "none" — ignore all other bots (default)
|
||||
# "mentions" — accept bot messages only when they @mention us
|
||||
@@ -560,7 +564,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if not self._client.user or self._client.user not in message.mentions:
|
||||
return
|
||||
# "all" falls through to handle_message
|
||||
|
||||
|
||||
# If the message @mentions other users but NOT the bot, the
|
||||
# sender is talking to someone else — stay silent. Only
|
||||
# applies in server channels; in DMs the user is always
|
||||
@@ -614,23 +618,23 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
# Register slash commands
|
||||
self._register_slash_commands()
|
||||
|
||||
|
||||
# Start the bot in background
|
||||
self._bot_task = asyncio.create_task(self._client.start(self.config.token))
|
||||
|
||||
|
||||
# Wait for ready
|
||||
await asyncio.wait_for(self._ready_event.wait(), timeout=30)
|
||||
|
||||
|
||||
self._running = True
|
||||
return True
|
||||
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("[%s] Timeout waiting for connection to Discord", self.name, exc_info=True)
|
||||
return False
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.error("[%s] Failed to connect to Discord: %s", self.name, e, exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Discord."""
|
||||
# Clean up all active voice connections before closing the client
|
||||
@@ -703,7 +707,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if hasattr(message, "add_reaction"):
|
||||
await self._remove_reaction(message, "👀")
|
||||
await self._add_reaction(message, "✅" if success else "❌")
|
||||
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -720,24 +724,24 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
|
||||
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
|
||||
# Format and split message if needed
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
|
||||
|
||||
message_ids = []
|
||||
reference = None
|
||||
|
||||
|
||||
if reply_to:
|
||||
try:
|
||||
ref_msg = await channel.fetch_message(int(reply_to))
|
||||
reference = ref_msg
|
||||
except Exception as e:
|
||||
logger.debug("Could not fetch reply-to message: %s", e)
|
||||
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk_reference = reference if i == 0 else None
|
||||
try:
|
||||
@@ -764,13 +768,13 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
raise
|
||||
message_ids.append(str(msg.id))
|
||||
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=message_ids[0] if message_ids else None,
|
||||
raw_response={"message_ids": message_ids}
|
||||
)
|
||||
|
||||
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.error("[%s] Failed to send Discord message: %s", self.name, e, exc_info=True)
|
||||
return SendResult(success=False, error=str(e))
|
||||
@@ -1242,25 +1246,25 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
"""Send an image natively as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
|
||||
# Download the image and send as a Discord file attachment
|
||||
# (Discord renders attachments inline, unlike plain URLs)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"Failed to download image: HTTP {resp.status}")
|
||||
|
||||
|
||||
image_data = await resp.read()
|
||||
|
||||
|
||||
# Determine filename from URL or content type
|
||||
content_type = resp.headers.get("content-type", "image/png")
|
||||
ext = "png"
|
||||
@@ -1270,16 +1274,16 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
ext = "gif"
|
||||
elif "webp" in content_type:
|
||||
ext = "webp"
|
||||
|
||||
|
||||
import io
|
||||
file = discord.File(io.BytesIO(image_data), filename=f"image.{ext}")
|
||||
|
||||
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"[%s] aiohttp not installed, falling back to URL. Run: pip install aiohttp",
|
||||
@@ -1330,7 +1334,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.error("[%s] Failed to send document, falling back to base adapter: %s", self.name, e, exc_info=True)
|
||||
return await super().send_document(chat_id, file_path, caption, file_name, reply_to, metadata=metadata)
|
||||
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||||
"""Start a persistent typing indicator for a channel.
|
||||
|
||||
@@ -1374,20 +1378,20 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Get information about a Discord channel."""
|
||||
if not self._client:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
|
||||
|
||||
try:
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
|
||||
|
||||
if not channel:
|
||||
return {"name": str(chat_id), "type": "dm"}
|
||||
|
||||
|
||||
# Determine channel type
|
||||
if isinstance(channel, discord.DMChannel):
|
||||
chat_type = "dm"
|
||||
@@ -1403,7 +1407,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
chat_type = "channel"
|
||||
name = getattr(channel, "name", str(chat_id))
|
||||
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"type": chat_type,
|
||||
@@ -1413,7 +1417,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.error("[%s] Failed to get chat info for %s: %s", self.name, chat_id, e, exc_info=True)
|
||||
return {"name": str(chat_id), "type": "dm", "error": str(e)}
|
||||
|
||||
|
||||
async def _resolve_allowed_usernames(self) -> None:
|
||||
"""
|
||||
Resolve non-numeric entries in DISCORD_ALLOWED_USERS to Discord user IDs.
|
||||
@@ -1481,7 +1485,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
def format_message(self, content: str) -> str:
|
||||
"""
|
||||
Format message for Discord.
|
||||
|
||||
|
||||
Discord uses its own markdown variant.
|
||||
"""
|
||||
# Discord markdown is fairly standard, no special escaping needed
|
||||
@@ -1647,7 +1651,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
chat_name = interaction.channel.name
|
||||
if hasattr(interaction.channel, "guild") and interaction.channel.guild:
|
||||
chat_name = f"{interaction.channel.guild.name} / #{chat_name}"
|
||||
|
||||
|
||||
# Get channel topic (if available)
|
||||
chat_topic = getattr(interaction.channel, "topic", None)
|
||||
|
||||
@@ -2051,7 +2055,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if doc_ext in SUPPORTED_DOCUMENT_TYPES:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
break
|
||||
|
||||
|
||||
# When auto-threading kicked in, route responses to the new thread
|
||||
effective_channel = auto_threaded_channel or message.channel
|
||||
|
||||
@@ -2070,7 +2074,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
# Get channel topic (if available - TextChannels have topics, DMs/threads don't)
|
||||
chat_topic = getattr(message.channel, "topic", None)
|
||||
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=str(effective_channel.id),
|
||||
@@ -2081,7 +2085,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
thread_id=thread_id,
|
||||
chat_topic=chat_topic,
|
||||
)
|
||||
|
||||
|
||||
# Build media URLs -- download image attachments to local cache so the
|
||||
# vision tool can access them reliably (Discord CDN URLs can expire).
|
||||
media_urls = []
|
||||
@@ -2175,7 +2179,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
"[Discord] Failed to cache document %s: %s",
|
||||
att.filename, e, exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
event_text = message.content
|
||||
if pending_text_injection:
|
||||
event_text = f"{pending_text_injection}\n\n{event_text}" if event_text else pending_text_injection
|
||||
|
||||
Reference in New Issue
Block a user