diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index efa5ed318..9a821727e 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -898,6 +898,26 @@ class BasePlatformAdapter(ABC): except Exception: pass + # ── Processing lifecycle hooks ────────────────────────────────────────── + # Subclasses override these to react to message processing events + # (e.g. Discord adds 👀/✅/❌ reactions). + + async def on_processing_start(self, event: MessageEvent) -> None: + """Hook called when background processing begins.""" + + async def on_processing_complete(self, event: MessageEvent, success: bool) -> None: + """Hook called when background processing completes.""" + + async def _run_processing_hook(self, hook_name: str, *args: Any, **kwargs: Any) -> None: + """Run a lifecycle hook without letting failures break message flow.""" + hook = getattr(self, hook_name, None) + if not callable(hook): + return + try: + await hook(*args, **kwargs) + except Exception as e: + logger.warning("[%s] %s hook failed: %s", self.name, hook_name, e) + @staticmethod def _is_retryable_error(error: Optional[str]) -> bool: """Return True if the error string looks like a transient network failure.""" @@ -1060,6 +1080,18 @@ class BasePlatformAdapter(ABC): async def _process_message_background(self, event: MessageEvent, session_key: str) -> None: """Background task that actually processes the message.""" + # Track delivery outcomes for the processing-complete hook + delivery_attempted = False + delivery_succeeded = False + + def _record_delivery(result): + nonlocal delivery_attempted, delivery_succeeded + if result is None: + return + delivery_attempted = True + if getattr(result, "success", False): + delivery_succeeded = True + # Create interrupt event for this session interrupt_event = asyncio.Event() self._active_sessions[session_key] = interrupt_event @@ -1069,6 +1101,8 @@ class BasePlatformAdapter(ABC): typing_task = asyncio.create_task(self._keep_typing(event.source.chat_id, metadata=_thread_metadata)) try: + await self._run_processing_hook("on_processing_start", event) + # Call the handler (this can take a while with tool calls) response = await self._message_handler(event) @@ -1138,6 +1172,7 @@ class BasePlatformAdapter(ABC): reply_to=event.message_id, metadata=_thread_metadata, ) + _record_delivery(result) # Human-like pacing delay between text and media human_delay = self._get_human_delay() @@ -1237,6 +1272,10 @@ class BasePlatformAdapter(ABC): except Exception as file_err: logger.error("[%s] Error sending local file %s: %s", self.name, file_path, file_err) + # Determine overall success for the processing hook + processing_ok = delivery_succeeded if delivery_attempted else not bool(response) + await self._run_processing_hook("on_processing_complete", event, processing_ok) + # Check if there's a pending message that was queued during our processing if session_key in self._pending_messages: pending_event = self._pending_messages.pop(session_key) @@ -1253,7 +1292,11 @@ class BasePlatformAdapter(ABC): await self._process_message_background(pending_event, session_key) return # Already cleaned up + except asyncio.CancelledError: + await self._run_processing_hook("on_processing_complete", event, False) + raise except Exception as e: + await self._run_processing_hook("on_processing_complete", event, False) logger.error("[%s] Error handling message: %s", self.name, e, exc_info=True) # Send the error to the user so they aren't left with radio silence try: diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index 1da9925cd..9e0c9c123 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -660,6 +660,41 @@ class DiscordAdapter(BasePlatformAdapter): pass logger.info("[%s] Disconnected", self.name) + + async def _add_reaction(self, message: Any, emoji: str) -> bool: + """Add an emoji reaction to a Discord message.""" + if not message or not hasattr(message, "add_reaction"): + return False + try: + await message.add_reaction(emoji) + return True + except Exception as e: + logger.debug("[%s] add_reaction failed (%s): %s", self.name, emoji, e) + return False + + async def _remove_reaction(self, message: Any, emoji: str) -> bool: + """Remove the bot's own emoji reaction from a Discord message.""" + if not message or not hasattr(message, "remove_reaction") or not self._client or not self._client.user: + return False + try: + await message.remove_reaction(emoji, self._client.user) + return True + except Exception as e: + logger.debug("[%s] remove_reaction failed (%s): %s", self.name, emoji, e) + return False + + async def on_processing_start(self, event: MessageEvent) -> None: + """Add an in-progress reaction for normal Discord message events.""" + message = event.raw_message + if hasattr(message, "add_reaction"): + await self._add_reaction(message, "👀") + + async def on_processing_complete(self, event: MessageEvent, success: bool) -> None: + """Swap the in-progress reaction for a final success/failure reaction.""" + message = event.raw_message + if hasattr(message, "add_reaction"): + await self._remove_reaction(message, "👀") + await self._add_reaction(message, "✅" if success else "❌") async def send( self, diff --git a/tests/gateway/test_base_topic_sessions.py b/tests/gateway/test_base_topic_sessions.py index e3ca7ae72..37e00b279 100644 --- a/tests/gateway/test_base_topic_sessions.py +++ b/tests/gateway/test_base_topic_sessions.py @@ -15,6 +15,7 @@ class DummyTelegramAdapter(BasePlatformAdapter): super().__init__(PlatformConfig(enabled=True, token="fake-token"), Platform.TELEGRAM) self.sent = [] self.typing = [] + self.processing_hooks = [] async def connect(self) -> bool: return True @@ -40,6 +41,12 @@ class DummyTelegramAdapter(BasePlatformAdapter): async def get_chat_info(self, chat_id: str): return {"id": chat_id} + async def on_processing_start(self, event: MessageEvent) -> None: + self.processing_hooks.append(("start", event.message_id)) + + async def on_processing_complete(self, event: MessageEvent, success: bool) -> None: + self.processing_hooks.append(("complete", event.message_id, success)) + def _make_event(chat_id: str, thread_id: str, message_id: str = "1") -> MessageEvent: return MessageEvent( @@ -133,3 +140,83 @@ class TestBasePlatformTopicSessions: "metadata": {"thread_id": "17585"}, } ] + assert adapter.processing_hooks == [ + ("start", "1"), + ("complete", "1", True), + ] + + @pytest.mark.asyncio + async def test_process_message_background_marks_total_send_failure_unsuccessful(self): + adapter = DummyTelegramAdapter() + + async def handler(_event): + await asyncio.sleep(0) + return "ack" + + async def failing_send(*_args, **_kwargs): + return SendResult(success=False, error="send failed") + + async def hold_typing(_chat_id, interval=2.0, metadata=None): + await asyncio.Event().wait() + + adapter.set_message_handler(handler) + adapter.send = failing_send + adapter._keep_typing = hold_typing + + event = _make_event("-1001", "17585") + await adapter._process_message_background(event, build_session_key(event.source)) + + assert adapter.processing_hooks == [ + ("start", "1"), + ("complete", "1", False), + ] + + @pytest.mark.asyncio + async def test_process_message_background_marks_exception_unsuccessful(self): + adapter = DummyTelegramAdapter() + + async def handler(_event): + await asyncio.sleep(0) + raise RuntimeError("boom") + + async def hold_typing(_chat_id, interval=2.0, metadata=None): + await asyncio.Event().wait() + + adapter.set_message_handler(handler) + adapter._keep_typing = hold_typing + + event = _make_event("-1001", "17585") + await adapter._process_message_background(event, build_session_key(event.source)) + + assert adapter.processing_hooks == [ + ("start", "1"), + ("complete", "1", False), + ] + + @pytest.mark.asyncio + async def test_process_message_background_marks_cancellation_unsuccessful(self): + adapter = DummyTelegramAdapter() + release = asyncio.Event() + + async def handler(_event): + await release.wait() + return "ack" + + async def hold_typing(_chat_id, interval=2.0, metadata=None): + await asyncio.Event().wait() + + adapter.set_message_handler(handler) + adapter._keep_typing = hold_typing + + event = _make_event("-1001", "17585") + task = asyncio.create_task(adapter._process_message_background(event, build_session_key(event.source))) + await asyncio.sleep(0) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + assert adapter.processing_hooks == [ + ("start", "1"), + ("complete", "1", False), + ] diff --git a/tests/gateway/test_discord_reactions.py b/tests/gateway/test_discord_reactions.py new file mode 100644 index 000000000..c19913a4c --- /dev/null +++ b/tests/gateway/test_discord_reactions.py @@ -0,0 +1,170 @@ +"""Tests for Discord message reactions tied to processing lifecycle hooks.""" + +import asyncio +import sys +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import MessageEvent, MessageType, SendResult +from gateway.session import SessionSource, build_session_key + + +def _ensure_discord_mock(): + if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"): + return + + discord_mod = MagicMock() + discord_mod.Intents.default.return_value = MagicMock() + discord_mod.DMChannel = type("DMChannel", (), {}) + discord_mod.Thread = type("Thread", (), {}) + discord_mod.ForumChannel = type("ForumChannel", (), {}) + discord_mod.Interaction = object + discord_mod.app_commands = SimpleNamespace( + describe=lambda **kwargs: (lambda fn: fn), + choices=lambda **kwargs: (lambda fn: fn), + Choice=lambda **kwargs: SimpleNamespace(**kwargs), + ) + + ext_mod = MagicMock() + commands_mod = MagicMock() + commands_mod.Bot = MagicMock + ext_mod.commands = commands_mod + + sys.modules.setdefault("discord", discord_mod) + sys.modules.setdefault("discord.ext", ext_mod) + sys.modules.setdefault("discord.ext.commands", commands_mod) + + +_ensure_discord_mock() + +from gateway.platforms.discord import DiscordAdapter # noqa: E402 + + +class FakeTree: + def __init__(self): + self.commands = {} + + def command(self, *, name, description): + def decorator(fn): + self.commands[name] = fn + return fn + + return decorator + + +@pytest.fixture +def adapter(): + config = PlatformConfig(enabled=True, token="***") + adapter = DiscordAdapter(config) + adapter._client = SimpleNamespace( + tree=FakeTree(), + get_channel=lambda _id: None, + fetch_channel=AsyncMock(), + user=SimpleNamespace(id=99999, name="HermesBot"), + ) + return adapter + + +def _make_event(message_id: str, raw_message) -> MessageEvent: + return MessageEvent( + text="hello", + message_type=MessageType.TEXT, + source=SessionSource( + platform=Platform.DISCORD, + chat_id="123", + chat_type="dm", + user_id="42", + user_name="Jezza", + ), + raw_message=raw_message, + message_id=message_id, + ) + + +@pytest.mark.asyncio +async def test_process_message_background_adds_and_swaps_reactions(adapter): + raw_message = SimpleNamespace( + add_reaction=AsyncMock(), + remove_reaction=AsyncMock(), + ) + + async def handler(_event): + await asyncio.sleep(0) + return "ack" + + async def hold_typing(_chat_id, interval=2.0, metadata=None): + await asyncio.Event().wait() + + adapter.set_message_handler(handler) + adapter.send = AsyncMock(return_value=SendResult(success=True, message_id="999")) + adapter._keep_typing = hold_typing + + event = _make_event("1", raw_message) + await adapter._process_message_background(event, build_session_key(event.source)) + + assert raw_message.add_reaction.await_args_list[0].args == ("👀",) + assert raw_message.remove_reaction.await_args_list[0].args == ("👀", adapter._client.user) + assert raw_message.add_reaction.await_args_list[1].args == ("✅",) + + +@pytest.mark.asyncio +async def test_interaction_backed_events_do_not_attempt_reactions(adapter): + interaction = SimpleNamespace(guild_id=123456789) + + async def handler(_event): + await asyncio.sleep(0) + return None + + async def hold_typing(_chat_id, interval=2.0, metadata=None): + await asyncio.Event().wait() + + adapter.set_message_handler(handler) + adapter._add_reaction = AsyncMock() + adapter._remove_reaction = AsyncMock() + adapter._keep_typing = hold_typing + + event = MessageEvent( + text="/status", + message_type=MessageType.COMMAND, + source=SessionSource( + platform=Platform.DISCORD, + chat_id="123", + chat_type="dm", + user_id="42", + user_name="Jezza", + ), + raw_message=interaction, + message_id="2", + ) + + await adapter._process_message_background(event, build_session_key(event.source)) + + adapter._add_reaction.assert_not_awaited() + adapter._remove_reaction.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_reaction_helper_failures_do_not_break_message_flow(adapter): + raw_message = SimpleNamespace( + add_reaction=AsyncMock(side_effect=[RuntimeError("no perms"), RuntimeError("no perms")]), + remove_reaction=AsyncMock(side_effect=RuntimeError("no perms")), + ) + + async def handler(_event): + await asyncio.sleep(0) + return "ack" + + async def hold_typing(_chat_id, interval=2.0, metadata=None): + await asyncio.Event().wait() + + adapter.set_message_handler(handler) + adapter.send = AsyncMock(return_value=SendResult(success=True, message_id="999")) + adapter._keep_typing = hold_typing + + event = _make_event("3", raw_message) + await adapter._process_message_background(event, build_session_key(event.source)) + + adapter.send.assert_awaited_once()