diff --git a/src/integrations/chat_bridge/vendors/discord.py b/src/integrations/chat_bridge/vendors/discord.py index 9abf4d2..857cea7 100644 --- a/src/integrations/chat_bridge/vendors/discord.py +++ b/src/integrations/chat_bridge/vendors/discord.py @@ -515,25 +515,36 @@ class DiscordVendor(ChatPlatform): async def _handle_message(self, message) -> None: """Process an incoming message and respond via a thread.""" - # Strip the bot mention from the message content - content = message.content - if self._client.user: - content = content.replace(f"<@{self._client.user.id}>", "").strip() - + content = self._extract_content(message) if not content: return - # Create or reuse a thread for this conversation thread = await self._get_or_create_thread(message) target = thread or message.channel + session_id = f"discord_{thread.id}" if thread else f"discord_{message.channel.id}" - # Derive session_id for per-conversation history via Agno's SQLite - if thread: - session_id = f"discord_{thread.id}" - else: - session_id = f"discord_{message.channel.id}" + run_output, response = await self._invoke_agent(content, session_id, target) - # Run Timmy agent with typing indicator and timeout + if run_output is not None: + await self._handle_paused_run(run_output, target, session_id) + raw_content = run_output.content if hasattr(run_output, "content") else "" + response = _clean_response(raw_content or "") + + await self._send_response(response, target) + + def _extract_content(self, message) -> str: + """Strip the bot mention and return clean message text.""" + content = message.content + if self._client.user: + content = content.replace(f"<@{self._client.user.id}>", "").strip() + return content + + async def _invoke_agent(self, content: str, session_id: str, target): + """Run chat_with_tools with a typing indicator and timeout. + + Returns a (run_output, error_response) tuple. On success the + error_response is ``None``; on failure run_output is ``None``. + """ run_output = None response = None try: @@ -548,51 +559,57 @@ class DiscordVendor(ChatPlatform): except Exception as exc: logger.error("Discord: chat_with_tools() failed: %s", exc) response = "I'm having trouble reaching my inference backend right now. Please try again shortly." + return run_output, response - # Check if Agno paused the run for tool confirmation - if run_output is not None: - status = getattr(run_output, "status", None) - is_paused = status == "PAUSED" or str(status) == "RunStatus.paused" + async def _handle_paused_run(self, run_output, target, session_id: str) -> None: + """If Agno paused the run for tool confirmation, enqueue approvals.""" + status = getattr(run_output, "status", None) + is_paused = status == "PAUSED" or str(status) == "RunStatus.paused" - if is_paused and getattr(run_output, "active_requirements", None): - from config import settings + if not (is_paused and getattr(run_output, "active_requirements", None)): + return - if settings.discord_confirm_actions: - for req in run_output.active_requirements: - if getattr(req, "needs_confirmation", False): - te = req.tool_execution - tool_name = getattr(te, "tool_name", "unknown") - tool_args = getattr(te, "tool_args", {}) or {} + from config import settings - from timmy.approvals import create_item + if not settings.discord_confirm_actions: + return - item = create_item( - title=f"Discord: {tool_name}", - description=_format_action_description(tool_name, tool_args), - proposed_action=json.dumps({"tool": tool_name, "args": tool_args}), - impact=_get_impact_level(tool_name), - ) - self._pending_actions[item.id] = { - "run_output": run_output, - "requirement": req, - "tool_name": tool_name, - "tool_args": tool_args, - "target": target, - "session_id": session_id, - } - await self._send_confirmation(target, tool_name, tool_args, item.id) + for req in run_output.active_requirements: + if not getattr(req, "needs_confirmation", False): + continue + te = req.tool_execution + tool_name = getattr(te, "tool_name", "unknown") + tool_args = getattr(te, "tool_args", {}) or {} - raw_content = run_output.content if hasattr(run_output, "content") else "" - response = _clean_response(raw_content or "") + from timmy.approvals import create_item - # Discord has a 2000 character limit — send with error handling - if response and response.strip(): - for chunk in _chunk_message(response, 2000): - try: - await target.send(chunk) - except Exception as exc: - logger.error("Discord: failed to send message chunk: %s", exc) - break + item = create_item( + title=f"Discord: {tool_name}", + description=_format_action_description(tool_name, tool_args), + proposed_action=json.dumps({"tool": tool_name, "args": tool_args}), + impact=_get_impact_level(tool_name), + ) + self._pending_actions[item.id] = { + "run_output": run_output, + "requirement": req, + "tool_name": tool_name, + "tool_args": tool_args, + "target": target, + "session_id": session_id, + } + await self._send_confirmation(target, tool_name, tool_args, item.id) + + @staticmethod + async def _send_response(response: str | None, target) -> None: + """Send a response to Discord, chunked to the 2000-char limit.""" + if not response or not response.strip(): + return + for chunk in _chunk_message(response, 2000): + try: + await target.send(chunk) + except Exception as exc: + logger.error("Discord: failed to send message chunk: %s", exc) + break async def _get_or_create_thread(self, message): """Get the active thread for a channel, or create one. diff --git a/tests/integrations/test_discord_vendor.py b/tests/integrations/test_discord_vendor.py index 6954717..acd8824 100644 --- a/tests/integrations/test_discord_vendor.py +++ b/tests/integrations/test_discord_vendor.py @@ -174,6 +174,103 @@ class TestDiscordVendor: assert result is False +class TestExtractContent: + def test_strips_bot_mention(self): + from integrations.chat_bridge.vendors.discord import DiscordVendor + + vendor = DiscordVendor() + vendor._client = MagicMock() + vendor._client.user.id = 12345 + msg = MagicMock() + msg.content = "<@12345> hello there" + assert vendor._extract_content(msg) == "hello there" + + def test_no_client_user(self): + from integrations.chat_bridge.vendors.discord import DiscordVendor + + vendor = DiscordVendor() + vendor._client = MagicMock() + vendor._client.user = None + msg = MagicMock() + msg.content = "hello" + assert vendor._extract_content(msg) == "hello" + + def test_empty_after_strip(self): + from integrations.chat_bridge.vendors.discord import DiscordVendor + + vendor = DiscordVendor() + vendor._client = MagicMock() + vendor._client.user.id = 99 + msg = MagicMock() + msg.content = "<@99>" + assert vendor._extract_content(msg) == "" + + +class TestInvokeAgent: + @staticmethod + def _make_typing_target(): + """Build a mock target whose .typing() is an async context manager.""" + from contextlib import asynccontextmanager + + target = AsyncMock() + + @asynccontextmanager + async def _typing(): + yield + + target.typing = _typing + return target + + @pytest.mark.asyncio + async def test_timeout_returns_error(self): + from integrations.chat_bridge.vendors.discord import DiscordVendor + + vendor = DiscordVendor() + target = self._make_typing_target() + + with patch( + "integrations.chat_bridge.vendors.discord.chat_with_tools", side_effect=TimeoutError + ): + run_output, response = await vendor._invoke_agent("hi", "sess", target) + assert run_output is None + assert "too long" in response + + @pytest.mark.asyncio + async def test_exception_returns_error(self): + from integrations.chat_bridge.vendors.discord import DiscordVendor + + vendor = DiscordVendor() + target = self._make_typing_target() + + with patch( + "integrations.chat_bridge.vendors.discord.chat_with_tools", + side_effect=RuntimeError("boom"), + ): + run_output, response = await vendor._invoke_agent("hi", "sess", target) + assert run_output is None + assert "trouble" in response + + +class TestSendResponse: + @pytest.mark.asyncio + async def test_skips_empty(self): + from integrations.chat_bridge.vendors.discord import DiscordVendor + + target = AsyncMock() + await DiscordVendor._send_response(None, target) + target.send.assert_not_called() + await DiscordVendor._send_response("", target) + target.send.assert_not_called() + + @pytest.mark.asyncio + async def test_sends_short_message(self): + from integrations.chat_bridge.vendors.discord import DiscordVendor + + target = AsyncMock() + await DiscordVendor._send_response("hello", target) + target.send.assert_called_once_with("hello") + + class TestChunkMessage: def test_short_message(self): from integrations.chat_bridge.vendors.discord import _chunk_message