Merge pull request #1661 from NousResearch/fix/discord-thread-persistence

fix(discord): persist thread participation across gateway restarts
This commit is contained in:
Teknium
2026-03-17 02:27:09 -07:00
committed by GitHub
2 changed files with 139 additions and 4 deletions

View File

@@ -10,6 +10,7 @@ Uses discord.py library for:
"""
import asyncio
import json
import logging
import os
import struct
@@ -18,6 +19,7 @@ import tempfile
import threading
import time
from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, List, Optional, Any
logger = logging.getLogger(__name__)
@@ -434,8 +436,11 @@ class DiscordAdapter(BasePlatformAdapter):
self._voice_input_callback: Optional[Callable] = None # set by run.py
self._on_voice_disconnect: Optional[Callable] = None # set by run.py
# Track threads where the bot has participated so follow-up messages
# in those threads don't require @mention.
self._bot_participated_threads: set = set()
# in those threads don't require @mention. Persisted to disk so the
# set survives gateway restarts.
self._bot_participated_threads: set = self._load_participated_threads()
# Cap to prevent unbounded growth (Discord threads get archived).
self._MAX_TRACKED_THREADS = 500
async def connect(self) -> bool:
"""Connect to Discord and start receiving events."""
@@ -1573,6 +1578,10 @@ class DiscordAdapter(BasePlatformAdapter):
link = f"<#{thread_id}>" if thread_id else f"**{thread_name}**"
await interaction.followup.send(f"Created thread {link}", ephemeral=True)
# Track thread participation so follow-ups don't require @mention
if thread_id:
self._track_thread(thread_id)
# If a message was provided, kick off a new Hermes session in the thread
starter = (message or "").strip()
if starter and thread_id:
@@ -1801,6 +1810,49 @@ class DiscordAdapter(BasePlatformAdapter):
return f"{parent_name} / {thread_name}"
return thread_name
# ------------------------------------------------------------------
# Thread participation persistence
# ------------------------------------------------------------------
@staticmethod
def _thread_state_path() -> Path:
"""Path to the persisted thread participation set."""
from hermes_cli.config import get_hermes_home
return get_hermes_home() / "discord_threads.json"
@classmethod
def _load_participated_threads(cls) -> set:
"""Load persisted thread IDs from disk."""
path = cls._thread_state_path()
try:
if path.exists():
data = json.loads(path.read_text(encoding="utf-8"))
if isinstance(data, list):
return set(data)
except Exception as e:
logger.debug("Could not load discord thread state: %s", e)
return set()
def _save_participated_threads(self) -> None:
"""Persist the current thread set to disk (best-effort)."""
path = self._thread_state_path()
try:
# Trim to most recent entries if over cap
thread_list = list(self._bot_participated_threads)
if len(thread_list) > self._MAX_TRACKED_THREADS:
thread_list = thread_list[-self._MAX_TRACKED_THREADS:]
self._bot_participated_threads = set(thread_list)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(thread_list), encoding="utf-8")
except Exception as e:
logger.debug("Could not save discord thread state: %s", e)
def _track_thread(self, thread_id: str) -> None:
"""Add a thread to the participation set and persist."""
if thread_id not in self._bot_participated_threads:
self._bot_participated_threads.add(thread_id)
self._save_participated_threads()
async def _handle_message(self, message: DiscordMessage) -> None:
"""Handle incoming Discord messages."""
# In server channels (not DMs), require the bot to be @mentioned
@@ -1853,7 +1905,7 @@ class DiscordAdapter(BasePlatformAdapter):
is_thread = True
thread_id = str(thread.id)
auto_threaded_channel = thread
self._bot_participated_threads.add(thread_id)
self._track_thread(thread_id)
# Determine message type
msg_type = MessageType.TEXT
@@ -1957,7 +2009,7 @@ class DiscordAdapter(BasePlatformAdapter):
# Track thread participation so the bot won't require @mention for
# follow-up messages in threads it has already engaged in.
if thread_id:
self._bot_participated_threads.add(thread_id)
self._track_thread(thread_id)
await self.handle_message(event)