diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index ec14dd2d6..b610f5a2d 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -10,6 +10,7 @@ Uses discord.py library for: """ import asyncio +import json import logging import os import struct @@ -18,6 +19,7 @@ import tempfile import threading import time from collections import defaultdict +from pathlib import Path from typing import Callable, Dict, List, Optional, Any logger = logging.getLogger(__name__) @@ -434,8 +436,11 @@ class DiscordAdapter(BasePlatformAdapter): self._voice_input_callback: Optional[Callable] = None # set by run.py self._on_voice_disconnect: Optional[Callable] = None # set by run.py # Track threads where the bot has participated so follow-up messages - # in those threads don't require @mention. - self._bot_participated_threads: set = set() + # in those threads don't require @mention. Persisted to disk so the + # set survives gateway restarts. + self._bot_participated_threads: set = self._load_participated_threads() + # Cap to prevent unbounded growth (Discord threads get archived). + self._MAX_TRACKED_THREADS = 500 async def connect(self) -> bool: """Connect to Discord and start receiving events.""" @@ -1573,6 +1578,10 @@ class DiscordAdapter(BasePlatformAdapter): link = f"<#{thread_id}>" if thread_id else f"**{thread_name}**" await interaction.followup.send(f"Created thread {link}", ephemeral=True) + # Track thread participation so follow-ups don't require @mention + if thread_id: + self._track_thread(thread_id) + # If a message was provided, kick off a new Hermes session in the thread starter = (message or "").strip() if starter and thread_id: @@ -1798,6 +1807,49 @@ class DiscordAdapter(BasePlatformAdapter): return f"{parent_name} / {thread_name}" return thread_name + # ------------------------------------------------------------------ + # Thread participation persistence + # ------------------------------------------------------------------ + + @staticmethod + def _thread_state_path() -> Path: + """Path to the persisted thread participation set.""" + from hermes_cli.config import get_hermes_home + return get_hermes_home() / "discord_threads.json" + + @classmethod + def _load_participated_threads(cls) -> set: + """Load persisted thread IDs from disk.""" + path = cls._thread_state_path() + try: + if path.exists(): + data = json.loads(path.read_text(encoding="utf-8")) + if isinstance(data, list): + return set(data) + except Exception as e: + logger.debug("Could not load discord thread state: %s", e) + return set() + + def _save_participated_threads(self) -> None: + """Persist the current thread set to disk (best-effort).""" + path = self._thread_state_path() + try: + # Trim to most recent entries if over cap + thread_list = list(self._bot_participated_threads) + if len(thread_list) > self._MAX_TRACKED_THREADS: + thread_list = thread_list[-self._MAX_TRACKED_THREADS:] + self._bot_participated_threads = set(thread_list) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(thread_list), encoding="utf-8") + except Exception as e: + logger.debug("Could not save discord thread state: %s", e) + + def _track_thread(self, thread_id: str) -> None: + """Add a thread to the participation set and persist.""" + if thread_id not in self._bot_participated_threads: + self._bot_participated_threads.add(thread_id) + self._save_participated_threads() + async def _handle_message(self, message: DiscordMessage) -> None: """Handle incoming Discord messages.""" # In server channels (not DMs), require the bot to be @mentioned @@ -1850,7 +1902,7 @@ class DiscordAdapter(BasePlatformAdapter): is_thread = True thread_id = str(thread.id) auto_threaded_channel = thread - self._bot_participated_threads.add(thread_id) + self._track_thread(thread_id) # Determine message type msg_type = MessageType.TEXT @@ -1954,7 +2006,7 @@ class DiscordAdapter(BasePlatformAdapter): # Track thread participation so the bot won't require @mention for # follow-up messages in threads it has already engaged in. if thread_id: - self._bot_participated_threads.add(thread_id) + self._track_thread(thread_id) await self.handle_message(event) diff --git a/tests/gateway/test_discord_thread_persistence.py b/tests/gateway/test_discord_thread_persistence.py new file mode 100644 index 000000000..0288b620d --- /dev/null +++ b/tests/gateway/test_discord_thread_persistence.py @@ -0,0 +1,83 @@ +"""Tests for Discord thread participation persistence. + +Verifies that _bot_participated_threads survives adapter restarts by +being persisted to ~/.hermes/discord_threads.json. +""" + +import json +import os +from unittest.mock import patch + +import pytest + + +class TestDiscordThreadPersistence: + """Thread IDs are saved to disk and reloaded on init.""" + + def _make_adapter(self, tmp_path): + """Build a minimal DiscordAdapter with HERMES_HOME pointed at tmp_path.""" + from gateway.config import PlatformConfig + from gateway.platforms.discord import DiscordAdapter + + config = PlatformConfig(enabled=True, token="test-token") + with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): + return DiscordAdapter(config=config) + + def test_starts_empty_when_no_state_file(self, tmp_path): + adapter = self._make_adapter(tmp_path) + assert adapter._bot_participated_threads == set() + + def test_track_thread_persists_to_disk(self, tmp_path): + adapter = self._make_adapter(tmp_path) + with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): + adapter._track_thread("111") + adapter._track_thread("222") + + state_file = tmp_path / "discord_threads.json" + assert state_file.exists() + saved = json.loads(state_file.read_text()) + assert set(saved) == {"111", "222"} + + def test_threads_survive_restart(self, tmp_path): + """Threads tracked by one adapter instance are visible to the next.""" + adapter1 = self._make_adapter(tmp_path) + with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): + adapter1._track_thread("aaa") + adapter1._track_thread("bbb") + + adapter2 = self._make_adapter(tmp_path) + assert "aaa" in adapter2._bot_participated_threads + assert "bbb" in adapter2._bot_participated_threads + + def test_duplicate_track_does_not_double_save(self, tmp_path): + adapter = self._make_adapter(tmp_path) + with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): + adapter._track_thread("111") + adapter._track_thread("111") # no-op + + saved = json.loads((tmp_path / "discord_threads.json").read_text()) + assert saved.count("111") == 1 + + def test_caps_at_max_tracked_threads(self, tmp_path): + adapter = self._make_adapter(tmp_path) + adapter._MAX_TRACKED_THREADS = 5 + with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): + for i in range(10): + adapter._track_thread(str(i)) + + assert len(adapter._bot_participated_threads) == 5 + + def test_corrupted_state_file_falls_back_to_empty(self, tmp_path): + state_file = tmp_path / "discord_threads.json" + state_file.write_text("not valid json{{{") + adapter = self._make_adapter(tmp_path) + assert adapter._bot_participated_threads == set() + + def test_missing_hermes_home_does_not_crash(self, tmp_path): + """Load/save tolerate missing directories.""" + fake_home = tmp_path / "nonexistent" / "deep" + with patch.dict(os.environ, {"HERMES_HOME": str(fake_home)}): + from gateway.platforms.discord import DiscordAdapter + # _load should return empty set, not crash + threads = DiscordAdapter._load_participated_threads() + assert threads == set()