From 86ddaaee9c24b80920221b64e0a6064b10f883af Mon Sep 17 00:00:00 2001 From: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com> Date: Wed, 11 Mar 2026 23:18:49 +0300 Subject: [PATCH] fix: extract voice reply logic and add comprehensive tests - Fix tempfile.mktemp() TOCTOU race in Discord voice input (use NamedTemporaryFile) - Extract voice reply decision from _handle_message into _should_send_voice_reply() - Rewrite TestAutoVoiceReply to call real method instead of testing a copy - Add 59 new tests: VoiceReceiver, VC commands, adapter methods, streaming TTS --- gateway/platforms/discord.py | 4 +- gateway/run.py | 96 ++-- tests/gateway/test_voice_command.py | 842 ++++++++++++++++++++++++++-- 3 files changed, 845 insertions(+), 97 deletions(-) diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index d10375e3c..cf57b37d9 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -851,7 +851,9 @@ class DiscordAdapter(BasePlatformAdapter): """Convert PCM -> WAV -> STT -> callback.""" from tools.voice_mode import is_whisper_hallucination - wav_path = tempfile.mktemp(suffix=".wav", prefix="vc_listen_") + tmp_f = tempfile.NamedTemporaryFile(suffix=".wav", prefix="vc_listen_", delete=False) + wav_path = tmp_f.name + tmp_f.close() try: await asyncio.to_thread(VoiceReceiver.pcm_to_wav, pcm_data, wav_path) diff --git a/gateway/run.py b/gateway/run.py index b672ac48c..5deb093af 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -1616,42 +1616,8 @@ class GatewayRunner: ) # Auto voice reply: send TTS audio before the text response - chat_id = source.chat_id - voice_mode = self._voice_mode.get(chat_id, "off") - is_voice_input = (event.message_type == MessageType.VOICE) - should_voice_reply = ( - (voice_mode == "all") - or (voice_mode == "voice_only" and is_voice_input) - ) - logger.info("Voice reply check: chat_id=%s, voice_mode=%s, is_voice=%s, should_reply=%s, has_response=%s", - chat_id, voice_mode, is_voice_input, should_voice_reply, bool(response)) - if should_voice_reply and response and not response.startswith("Error:"): - # Skip if agent already called TTS tool (avoid double voice) - has_agent_tts = any( - msg.get("role") == "assistant" - and any( - tc.get("function", {}).get("name") == "text_to_speech" - for tc in (msg.get("tool_calls") or []) - ) - for msg in agent_messages - ) - # Skip if voice input — base adapter auto-TTS in - # _process_message_background already sent audio for voice - # messages, so sending another would be double. - # Exception: Discord voice channel — the Discord play_tts - # override also skips (no-op), so the runner MUST handle it - # via play_in_voice_channel. - skip_double = is_voice_input - if skip_double: - adapter = self.adapters.get(source.platform) - guild_id = self._get_guild_id(event) - if (guild_id and adapter - and hasattr(adapter, "is_in_voice_channel") - and adapter.is_in_voice_channel(guild_id)): - skip_double = False - logger.info("Voice reply: has_agent_tts=%s, skip_double=%s, calling _send_voice_reply", has_agent_tts, skip_double) - if not has_agent_tts and not skip_double: - await self._send_voice_reply(event, response) + if self._should_send_voice_reply(event, response, agent_messages): + await self._send_voice_reply(event, response) return response @@ -2302,6 +2268,64 @@ class GatewayRunner: await adapter.handle_message(event) + def _should_send_voice_reply( + self, + event: MessageEvent, + response: str, + agent_messages: list, + ) -> bool: + """Decide whether the runner should send a TTS voice reply. + + Returns False when: + - voice_mode is off for this chat + - response is empty or an error + - agent already called text_to_speech tool (dedup) + - voice input and base adapter auto-TTS already handled it (skip_double) + Exception: Discord voice channel — base play_tts is a no-op there, + so the runner must handle VC playback. + """ + if not response or response.startswith("Error:"): + return False + + chat_id = event.source.chat_id + voice_mode = self._voice_mode.get(chat_id, "off") + is_voice_input = (event.message_type == MessageType.VOICE) + + should = ( + (voice_mode == "all") + or (voice_mode == "voice_only" and is_voice_input) + ) + if not should: + return False + + # Dedup: agent already called TTS tool + has_agent_tts = any( + msg.get("role") == "assistant" + and any( + tc.get("function", {}).get("name") == "text_to_speech" + for tc in (msg.get("tool_calls") or []) + ) + for msg in agent_messages + ) + if has_agent_tts: + return False + + # Dedup: base adapter auto-TTS already handles voice input. + # Exception: Discord voice channel — play_tts override is a no-op, + # so the runner must handle VC playback. + skip_double = is_voice_input + if skip_double: + adapter = self.adapters.get(event.source.platform) + guild_id = self._get_guild_id(event) + if (guild_id and adapter + and hasattr(adapter, "is_in_voice_channel") + and adapter.is_in_voice_channel(guild_id)): + skip_double = False + if skip_double: + return False + + return True + async def _send_voice_reply(self, event: MessageEvent, text: str) -> None: """Generate TTS audio and send as a voice message before the text reply.""" try: diff --git a/tests/gateway/test_voice_command.py b/tests/gateway/test_voice_command.py index 1e4250960..f8ec52423 100644 --- a/tests/gateway/test_voice_command.py +++ b/tests/gateway/test_voice_command.py @@ -2,7 +2,11 @@ import json import os +import queue +import threading +import time import pytest +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch from gateway.platforms.base import MessageEvent, MessageType, SessionSource @@ -126,7 +130,7 @@ class TestHandleVoiceCommand: # ===================================================================== class TestAutoVoiceReply: - """Test the should_voice_reply decision logic (extracted from _handle_message). + """Test the real _should_send_voice_reply method on GatewayRunner. The gateway has two TTS paths: 1. base adapter auto-TTS: fires for voice input in _process_message_background @@ -138,43 +142,30 @@ class TestAutoVoiceReply: override skip, so the runner must handle it via play_in_voice_channel. """ - def _should_reply(self, voice_mode, message_type, agent_messages=None, - response="Hello!", in_voice_channel=False): - """Replicate the auto voice reply decision from _handle_message.""" - if not response or response.startswith("Error:"): - return False + @pytest.fixture + def runner(self, tmp_path): + return _make_runner(tmp_path) - is_voice_input = (message_type == MessageType.VOICE) - should = ( - (voice_mode == "all") - or (voice_mode == "voice_only" and is_voice_input) + def _call(self, runner, voice_mode, message_type, agent_messages=None, + response="Hello!", in_voice_channel=False): + """Call real _should_send_voice_reply on a GatewayRunner instance.""" + chat_id = "123" + if voice_mode != "off": + runner._voice_mode[chat_id] = voice_mode + else: + runner._voice_mode.pop(chat_id, None) + + event = _make_event(message_type=message_type) + + if in_voice_channel: + mock_adapter = MagicMock() + mock_adapter.is_in_voice_channel = MagicMock(return_value=True) + event.raw_message = SimpleNamespace(guild_id=111, guild=None) + runner.adapters[event.source.platform] = mock_adapter + + return runner._should_send_voice_reply( + event, response, agent_messages or [] ) - if not should: - return False - - # Dedup: agent already called TTS tool - if agent_messages: - has_agent_tts = any( - msg.get("role") == "assistant" - and any( - tc.get("function", {}).get("name") == "text_to_speech" - for tc in (msg.get("tool_calls") or []) - ) - for msg in agent_messages - ) - if has_agent_tts: - return False - - # Dedup: base adapter auto-TTS already handles voice input. - # Exception: in voice channel, Discord play_tts also skips, - # so the runner must handle VC playback. - skip_double = is_voice_input - if skip_double and in_voice_channel: - skip_double = False - if skip_double: - return False - - return True # -- Full platform x input x mode matrix -------------------------------- # @@ -204,52 +195,52 @@ class TestAutoVoiceReply: # -- Telegram/Slack/Web: voice input, base handles --------------------- - def test_voice_input_voice_only_skipped(self): + def test_voice_input_voice_only_skipped(self, runner): """voice_only + voice input: base auto-TTS handles it, runner skips.""" - assert self._should_reply("voice_only", MessageType.VOICE) is False + assert self._call(runner, "voice_only", MessageType.VOICE) is False - def test_voice_input_all_mode_skipped(self): + def test_voice_input_all_mode_skipped(self, runner): """all + voice input: base auto-TTS handles it, runner skips.""" - assert self._should_reply("all", MessageType.VOICE) is False + assert self._call(runner, "all", MessageType.VOICE) is False # -- Text input: only runner handles ----------------------------------- - def test_text_input_all_mode_runner_fires(self): + def test_text_input_all_mode_runner_fires(self, runner): """all + text input: only runner fires (base auto-TTS only for voice).""" - assert self._should_reply("all", MessageType.TEXT) is True + assert self._call(runner, "all", MessageType.TEXT) is True - def test_text_input_voice_only_no_reply(self): + def test_text_input_voice_only_no_reply(self, runner): """voice_only + text input: neither fires.""" - assert self._should_reply("voice_only", MessageType.TEXT) is False + assert self._call(runner, "voice_only", MessageType.TEXT) is False # -- Mode off: nothing fires ------------------------------------------- - def test_off_mode_voice(self): - assert self._should_reply("off", MessageType.VOICE) is False + def test_off_mode_voice(self, runner): + assert self._call(runner, "off", MessageType.VOICE) is False - def test_off_mode_text(self): - assert self._should_reply("off", MessageType.TEXT) is False + def test_off_mode_text(self, runner): + assert self._call(runner, "off", MessageType.TEXT) is False # -- Discord VC exception: runner must handle -------------------------- - def test_discord_vc_voice_input_runner_fires(self): + def test_discord_vc_voice_input_runner_fires(self, runner): """Discord VC + voice input: base play_tts skips (VC override), so runner must handle via play_in_voice_channel.""" - assert self._should_reply("all", MessageType.VOICE, in_voice_channel=True) is True + assert self._call(runner, "all", MessageType.VOICE, in_voice_channel=True) is True - def test_discord_vc_voice_only_runner_fires(self): + def test_discord_vc_voice_only_runner_fires(self, runner): """Discord VC + voice_only + voice: runner must handle.""" - assert self._should_reply("voice_only", MessageType.VOICE, in_voice_channel=True) is True + assert self._call(runner, "voice_only", MessageType.VOICE, in_voice_channel=True) is True # -- Edge cases -------------------------------------------------------- - def test_error_response_skipped(self): - assert self._should_reply("all", MessageType.TEXT, response="Error: boom") is False + def test_error_response_skipped(self, runner): + assert self._call(runner, "all", MessageType.TEXT, response="Error: boom") is False - def test_empty_response_skipped(self): - assert self._should_reply("all", MessageType.TEXT, response="") is False + def test_empty_response_skipped(self, runner): + assert self._call(runner, "all", MessageType.TEXT, response="") is False - def test_dedup_skips_when_agent_called_tts(self): + def test_dedup_skips_when_agent_called_tts(self, runner): messages = [{ "role": "assistant", "tool_calls": [{ @@ -258,9 +249,9 @@ class TestAutoVoiceReply: "function": {"name": "text_to_speech", "arguments": "{}"}, }], }] - assert self._should_reply("all", MessageType.TEXT, agent_messages=messages) is False + assert self._call(runner, "all", MessageType.TEXT, agent_messages=messages) is False - def test_no_dedup_for_other_tools(self): + def test_no_dedup_for_other_tools(self, runner): messages = [{ "role": "assistant", "tool_calls": [{ @@ -269,7 +260,7 @@ class TestAutoVoiceReply: "function": {"name": "web_search", "arguments": "{}"}, }], }] - assert self._should_reply("all", MessageType.TEXT, agent_messages=messages) is True + assert self._call(runner, "all", MessageType.TEXT, agent_messages=messages) is True # ===================================================================== @@ -443,3 +434,734 @@ class TestVoiceInHelp: import inspect source = inspect.getsource(GatewayRunner._handle_message) assert '"voice"' in source + + +# ===================================================================== +# VoiceReceiver unit tests +# ===================================================================== + +class TestVoiceReceiver: + """Test VoiceReceiver silence detection, SSRC mapping, and lifecycle.""" + + def _make_receiver(self): + from gateway.platforms.discord import VoiceReceiver + mock_vc = MagicMock() + mock_vc._connection.secret_key = [0] * 32 + mock_vc._connection.dave_session = None + mock_vc._connection.ssrc = 9999 + mock_vc._connection.add_socket_listener = MagicMock() + mock_vc._connection.remove_socket_listener = MagicMock() + mock_vc._connection.hook = None + receiver = VoiceReceiver(mock_vc) + return receiver + + def test_initial_state(self): + receiver = self._make_receiver() + assert receiver._running is False + assert receiver._paused is False + assert len(receiver._buffers) == 0 + assert len(receiver._ssrc_to_user) == 0 + + def test_start_sets_running(self): + receiver = self._make_receiver() + receiver.start() + assert receiver._running is True + + def test_stop_clears_state(self): + receiver = self._make_receiver() + receiver.start() + receiver.map_ssrc(100, 42) + receiver._buffers[100] = bytearray(b"\x00" * 1000) + receiver._last_packet_time[100] = time.monotonic() + receiver.stop() + assert receiver._running is False + assert len(receiver._buffers) == 0 + assert len(receiver._ssrc_to_user) == 0 + assert len(receiver._last_packet_time) == 0 + + def test_map_ssrc(self): + receiver = self._make_receiver() + receiver.map_ssrc(100, 42) + assert receiver._ssrc_to_user[100] == 42 + + def test_map_ssrc_overwrites(self): + receiver = self._make_receiver() + receiver.map_ssrc(100, 42) + receiver.map_ssrc(100, 99) + assert receiver._ssrc_to_user[100] == 99 + + def test_pause_resume(self): + receiver = self._make_receiver() + assert receiver._paused is False + receiver.pause() + assert receiver._paused is True + receiver.resume() + assert receiver._paused is False + + def test_check_silence_empty(self): + receiver = self._make_receiver() + assert receiver.check_silence() == [] + + def test_check_silence_returns_completed_utterance(self): + receiver = self._make_receiver() + receiver.map_ssrc(100, 42) + # 48kHz, stereo, 16-bit = 192000 bytes/sec + # MIN_SPEECH_DURATION = 0.5s → need 96000 bytes + pcm_data = bytearray(b"\x00" * 96000) + receiver._buffers[100] = pcm_data + # Set last_packet_time far enough in the past to exceed SILENCE_THRESHOLD + receiver._last_packet_time[100] = time.monotonic() - 3.0 + completed = receiver.check_silence() + assert len(completed) == 1 + user_id, data = completed[0] + assert user_id == 42 + assert len(data) == 96000 + # Buffer should be cleared after extraction + assert len(receiver._buffers[100]) == 0 + + def test_check_silence_ignores_short_buffer(self): + receiver = self._make_receiver() + receiver.map_ssrc(100, 42) + # Too short to meet MIN_SPEECH_DURATION + receiver._buffers[100] = bytearray(b"\x00" * 100) + receiver._last_packet_time[100] = time.monotonic() - 3.0 + completed = receiver.check_silence() + assert len(completed) == 0 + + def test_check_silence_ignores_recent_audio(self): + receiver = self._make_receiver() + receiver.map_ssrc(100, 42) + receiver._buffers[100] = bytearray(b"\x00" * 96000) + receiver._last_packet_time[100] = time.monotonic() # just now + completed = receiver.check_silence() + assert len(completed) == 0 + + def test_check_silence_unknown_user_discarded(self): + receiver = self._make_receiver() + # No SSRC mapping — user_id will be 0 + receiver._buffers[100] = bytearray(b"\x00" * 96000) + receiver._last_packet_time[100] = time.monotonic() - 3.0 + completed = receiver.check_silence() + assert len(completed) == 0 + + def test_stale_buffer_discarded(self): + receiver = self._make_receiver() + # Buffer with no user mapping and very old timestamp + receiver._buffers[200] = bytearray(b"\x00" * 100) + receiver._last_packet_time[200] = time.monotonic() - 10.0 + receiver.check_silence() + # Stale buffer (> 2x threshold) should be discarded + assert 200 not in receiver._buffers + + def test_on_packet_skips_when_not_running(self): + receiver = self._make_receiver() + # Not started — _running is False + receiver._on_packet(b"\x00" * 100) + assert len(receiver._buffers) == 0 + + def test_on_packet_skips_when_paused(self): + receiver = self._make_receiver() + receiver.start() + receiver.pause() + receiver._on_packet(b"\x00" * 100) + # Paused — should not process + assert len(receiver._buffers) == 0 + + def test_on_packet_skips_short_data(self): + receiver = self._make_receiver() + receiver.start() + receiver._on_packet(b"\x00" * 10) + assert len(receiver._buffers) == 0 + + def test_on_packet_skips_non_rtp(self): + receiver = self._make_receiver() + receiver.start() + # Valid length but wrong RTP version + data = bytearray(b"\x00" * 20) + data[0] = 0x00 # version 0, not 2 + receiver._on_packet(bytes(data)) + assert len(receiver._buffers) == 0 + + +# ===================================================================== +# Gateway voice channel commands (join / leave / input) +# ===================================================================== + +class TestVoiceChannelCommands: + """Test _handle_voice_channel_join, _handle_voice_channel_leave, + _handle_voice_channel_input on the GatewayRunner.""" + + @pytest.fixture + def runner(self, tmp_path): + return _make_runner(tmp_path) + + def _make_discord_event(self, text="/voice channel", chat_id="123", + guild_id=111, user_id="user1"): + """Create event with raw_message carrying guild info.""" + source = SessionSource( + chat_id=chat_id, + user_id=user_id, + platform=MagicMock(), + ) + source.platform.value = "discord" + source.thread_id = None + event = MessageEvent(text=text, message_type=MessageType.TEXT, source=source) + event.message_id = "msg42" + event.raw_message = SimpleNamespace(guild_id=guild_id, guild=None) + return event + + # -- _handle_voice_channel_join -- + + @pytest.mark.asyncio + async def test_join_unsupported_platform(self, runner): + """Platform without join_voice_channel returns unsupported message.""" + mock_adapter = AsyncMock(spec=[]) # no join_voice_channel + event = self._make_discord_event() + runner.adapters[event.source.platform] = mock_adapter + result = await runner._handle_voice_channel_join(event) + assert "not supported" in result.lower() + + @pytest.mark.asyncio + async def test_join_no_guild_id(self, runner): + """DM context (no guild_id) returns error.""" + mock_adapter = AsyncMock() + mock_adapter.join_voice_channel = AsyncMock() + event = self._make_discord_event() + event.raw_message = None # no guild info + runner.adapters[event.source.platform] = mock_adapter + result = await runner._handle_voice_channel_join(event) + assert "discord server" in result.lower() + + @pytest.mark.asyncio + async def test_join_user_not_in_vc(self, runner): + """User not in any voice channel.""" + mock_adapter = AsyncMock() + mock_adapter.join_voice_channel = AsyncMock() + mock_adapter.get_user_voice_channel = AsyncMock(return_value=None) + event = self._make_discord_event() + runner.adapters[event.source.platform] = mock_adapter + result = await runner._handle_voice_channel_join(event) + assert "need to be in a voice channel" in result.lower() + + @pytest.mark.asyncio + async def test_join_success(self, runner): + """Successful join sets voice_mode and returns confirmation.""" + mock_channel = MagicMock() + mock_channel.name = "General" + mock_adapter = AsyncMock() + mock_adapter.join_voice_channel = AsyncMock(return_value=True) + mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel) + mock_adapter._voice_text_channels = {} + mock_adapter._voice_input_callback = None + event = self._make_discord_event() + runner.adapters[event.source.platform] = mock_adapter + result = await runner._handle_voice_channel_join(event) + assert "joined" in result.lower() + assert "General" in result + assert runner._voice_mode["123"] == "all" + + @pytest.mark.asyncio + async def test_join_failure(self, runner): + """Failed join returns permissions error.""" + mock_channel = MagicMock() + mock_channel.name = "General" + mock_adapter = AsyncMock() + mock_adapter.join_voice_channel = AsyncMock(return_value=False) + mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel) + event = self._make_discord_event() + runner.adapters[event.source.platform] = mock_adapter + result = await runner._handle_voice_channel_join(event) + assert "failed" in result.lower() + + @pytest.mark.asyncio + async def test_join_exception(self, runner): + """Exception during join is caught and reported.""" + mock_channel = MagicMock() + mock_channel.name = "General" + mock_adapter = AsyncMock() + mock_adapter.join_voice_channel = AsyncMock(side_effect=RuntimeError("No permission")) + mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel) + event = self._make_discord_event() + runner.adapters[event.source.platform] = mock_adapter + result = await runner._handle_voice_channel_join(event) + assert "failed" in result.lower() + + # -- _handle_voice_channel_leave -- + + @pytest.mark.asyncio + async def test_leave_not_in_vc(self, runner): + """Leave when not in VC returns appropriate message.""" + mock_adapter = AsyncMock() + mock_adapter.is_in_voice_channel = MagicMock(return_value=False) + event = self._make_discord_event("/voice leave") + runner.adapters[event.source.platform] = mock_adapter + result = await runner._handle_voice_channel_leave(event) + assert "not in" in result.lower() + + @pytest.mark.asyncio + async def test_leave_no_guild(self, runner): + """Leave from DM returns not in voice channel.""" + mock_adapter = AsyncMock() + event = self._make_discord_event("/voice leave") + event.raw_message = None + runner.adapters[event.source.platform] = mock_adapter + result = await runner._handle_voice_channel_leave(event) + assert "not in" in result.lower() + + @pytest.mark.asyncio + async def test_leave_success(self, runner): + """Successful leave disconnects and clears voice mode.""" + mock_adapter = AsyncMock() + mock_adapter.is_in_voice_channel = MagicMock(return_value=True) + mock_adapter.leave_voice_channel = AsyncMock() + event = self._make_discord_event("/voice leave") + runner.adapters[event.source.platform] = mock_adapter + runner._voice_mode["123"] = "all" + result = await runner._handle_voice_channel_leave(event) + assert "left" in result.lower() + assert "123" not in runner._voice_mode + mock_adapter.leave_voice_channel.assert_called_once_with(111) + + # -- _handle_voice_channel_input -- + + @pytest.mark.asyncio + async def test_input_no_adapter(self, runner): + """No Discord adapter — early return, no crash.""" + from gateway.config import Platform + # No adapters set + await runner._handle_voice_channel_input(111, 42, "Hello") + + @pytest.mark.asyncio + async def test_input_no_text_channel(self, runner): + """No text channel mapped for guild — early return.""" + from gateway.config import Platform + mock_adapter = AsyncMock() + mock_adapter._voice_text_channels = {} + mock_adapter._client = MagicMock() + runner.adapters[Platform.DISCORD] = mock_adapter + await runner._handle_voice_channel_input(111, 42, "Hello") + + @pytest.mark.asyncio + async def test_input_creates_event_and_dispatches(self, runner): + """Voice input creates synthetic event and calls handle_message.""" + from gateway.config import Platform + mock_adapter = AsyncMock() + mock_adapter._voice_text_channels = {111: 123} + mock_channel = AsyncMock() + mock_adapter._client = MagicMock() + mock_adapter._client.get_channel = MagicMock(return_value=mock_channel) + mock_adapter.handle_message = AsyncMock() + runner.adapters[Platform.DISCORD] = mock_adapter + await runner._handle_voice_channel_input(111, 42, "Hello from VC") + mock_adapter.handle_message.assert_called_once() + event = mock_adapter.handle_message.call_args[0][0] + assert event.text == "Hello from VC" + assert event.message_type == MessageType.VOICE + assert event.source.chat_id == "123" + + @pytest.mark.asyncio + async def test_input_posts_transcript_in_text_channel(self, runner): + """Voice input sends transcript message to text channel.""" + from gateway.config import Platform + mock_adapter = AsyncMock() + mock_adapter._voice_text_channels = {111: 123} + mock_channel = AsyncMock() + mock_adapter._client = MagicMock() + mock_adapter._client.get_channel = MagicMock(return_value=mock_channel) + mock_adapter.handle_message = AsyncMock() + runner.adapters[Platform.DISCORD] = mock_adapter + await runner._handle_voice_channel_input(111, 42, "Test transcript") + mock_channel.send.assert_called_once() + msg = mock_channel.send.call_args[0][0] + assert "Test transcript" in msg + assert "42" in msg # user_id in mention + + # -- _get_guild_id -- + + def test_get_guild_id_from_guild(self, runner): + event = _make_event() + mock_guild = MagicMock() + mock_guild.id = 555 + event.raw_message = SimpleNamespace(guild_id=None, guild=mock_guild) + result = runner._get_guild_id(event) + assert result == 555 + + def test_get_guild_id_from_interaction(self, runner): + event = _make_event() + event.raw_message = SimpleNamespace(guild_id=777, guild=None) + result = runner._get_guild_id(event) + assert result == 777 + + def test_get_guild_id_none(self, runner): + event = _make_event() + event.raw_message = None + result = runner._get_guild_id(event) + assert result is None + + def test_get_guild_id_dm(self, runner): + event = _make_event() + event.raw_message = SimpleNamespace(guild_id=None, guild=None) + result = runner._get_guild_id(event) + assert result is None + + +# ===================================================================== +# Discord adapter voice channel methods +# ===================================================================== + +class TestDiscordVoiceChannelMethods: + """Test DiscordAdapter voice channel methods (join, leave, play, etc.).""" + + def _make_adapter(self): + from gateway.platforms.discord import DiscordAdapter + from gateway.config import Platform, PlatformConfig + config = PlatformConfig(enabled=True, extra={}) + config.token = "fake-token" + adapter = object.__new__(DiscordAdapter) + adapter.platform = Platform.DISCORD + adapter.config = config + adapter._client = MagicMock() + adapter._voice_clients = {} + adapter._voice_text_channels = {} + adapter._voice_timeout_tasks = {} + adapter._voice_receivers = {} + adapter._voice_listen_tasks = {} + adapter._voice_input_callback = None + adapter._allowed_user_ids = set() + adapter._running = True + return adapter + + def test_is_in_voice_channel_true(self): + adapter = self._make_adapter() + mock_vc = MagicMock() + mock_vc.is_connected.return_value = True + adapter._voice_clients[111] = mock_vc + assert adapter.is_in_voice_channel(111) is True + + def test_is_in_voice_channel_false_no_client(self): + adapter = self._make_adapter() + assert adapter.is_in_voice_channel(111) is False + + def test_is_in_voice_channel_false_disconnected(self): + adapter = self._make_adapter() + mock_vc = MagicMock() + mock_vc.is_connected.return_value = False + adapter._voice_clients[111] = mock_vc + assert adapter.is_in_voice_channel(111) is False + + @pytest.mark.asyncio + async def test_leave_voice_channel_cleans_up(self): + adapter = self._make_adapter() + mock_vc = MagicMock() + mock_vc.is_connected.return_value = True + mock_vc.disconnect = AsyncMock() + adapter._voice_clients[111] = mock_vc + adapter._voice_text_channels[111] = 123 + + mock_receiver = MagicMock() + adapter._voice_receivers[111] = mock_receiver + + mock_task = MagicMock() + adapter._voice_listen_tasks[111] = mock_task + + mock_timeout = MagicMock() + adapter._voice_timeout_tasks[111] = mock_timeout + + await adapter.leave_voice_channel(111) + + mock_receiver.stop.assert_called_once() + mock_task.cancel.assert_called_once() + mock_vc.disconnect.assert_called_once() + mock_timeout.cancel.assert_called_once() + assert 111 not in adapter._voice_clients + assert 111 not in adapter._voice_text_channels + assert 111 not in adapter._voice_receivers + + @pytest.mark.asyncio + async def test_leave_voice_channel_no_connection(self): + """Leave when not connected — no crash.""" + adapter = self._make_adapter() + await adapter.leave_voice_channel(111) # should not raise + + @pytest.mark.asyncio + async def test_get_user_voice_channel_no_client(self): + adapter = self._make_adapter() + adapter._client = None + result = await adapter.get_user_voice_channel(111, "42") + assert result is None + + @pytest.mark.asyncio + async def test_get_user_voice_channel_no_guild(self): + adapter = self._make_adapter() + adapter._client.get_guild = MagicMock(return_value=None) + result = await adapter.get_user_voice_channel(111, "42") + assert result is None + + @pytest.mark.asyncio + async def test_get_user_voice_channel_user_not_in_vc(self): + adapter = self._make_adapter() + mock_guild = MagicMock() + mock_member = MagicMock() + mock_member.voice = None + mock_guild.get_member = MagicMock(return_value=mock_member) + adapter._client.get_guild = MagicMock(return_value=mock_guild) + result = await adapter.get_user_voice_channel(111, "42") + assert result is None + + @pytest.mark.asyncio + async def test_get_user_voice_channel_success(self): + adapter = self._make_adapter() + mock_vc = MagicMock() + mock_guild = MagicMock() + mock_member = MagicMock() + mock_member.voice = MagicMock() + mock_member.voice.channel = mock_vc + mock_guild.get_member = MagicMock(return_value=mock_member) + adapter._client.get_guild = MagicMock(return_value=mock_guild) + result = await adapter.get_user_voice_channel(111, "42") + assert result is mock_vc + + @pytest.mark.asyncio + async def test_play_in_voice_channel_not_connected(self): + adapter = self._make_adapter() + result = await adapter.play_in_voice_channel(111, "/tmp/test.ogg") + assert result is False + + def test_is_allowed_user_empty_list(self): + adapter = self._make_adapter() + assert adapter._is_allowed_user("42") is True + + def test_is_allowed_user_in_list(self): + adapter = self._make_adapter() + adapter._allowed_user_ids = {"42", "99"} + assert adapter._is_allowed_user("42") is True + + def test_is_allowed_user_not_in_list(self): + adapter = self._make_adapter() + adapter._allowed_user_ids = {"99"} + assert adapter._is_allowed_user("42") is False + + @pytest.mark.asyncio + async def test_process_voice_input_success(self): + """Successful voice input: PCM->WAV->STT->callback.""" + adapter = self._make_adapter() + callback = AsyncMock() + adapter._voice_input_callback = callback + adapter._allowed_user_ids = set() + + pcm_data = b"\x00" * 96000 + + with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \ + patch("tools.transcription_tools.transcribe_audio", + return_value={"success": True, "transcript": "Hello"}), \ + patch("tools.voice_mode.is_whisper_hallucination", return_value=False): + await adapter._process_voice_input(111, 42, pcm_data) + + callback.assert_called_once_with(guild_id=111, user_id=42, transcript="Hello") + + @pytest.mark.asyncio + async def test_process_voice_input_hallucination_filtered(self): + """Whisper hallucination is filtered out.""" + adapter = self._make_adapter() + callback = AsyncMock() + adapter._voice_input_callback = callback + + with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \ + patch("tools.transcription_tools.transcribe_audio", + return_value={"success": True, "transcript": "Thank you."}), \ + patch("tools.voice_mode.is_whisper_hallucination", return_value=True): + await adapter._process_voice_input(111, 42, b"\x00" * 96000) + + callback.assert_not_called() + + @pytest.mark.asyncio + async def test_process_voice_input_stt_failure(self): + """STT failure — callback not called.""" + adapter = self._make_adapter() + callback = AsyncMock() + adapter._voice_input_callback = callback + + with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \ + patch("tools.transcription_tools.transcribe_audio", + return_value={"success": False, "error": "API error"}): + await adapter._process_voice_input(111, 42, b"\x00" * 96000) + + callback.assert_not_called() + + @pytest.mark.asyncio + async def test_process_voice_input_exception_caught(self): + """Exception during processing is caught, no crash.""" + adapter = self._make_adapter() + adapter._voice_input_callback = AsyncMock() + + with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav", + side_effect=RuntimeError("ffmpeg not found")): + await adapter._process_voice_input(111, 42, b"\x00" * 96000) + # Should not raise + + +# ===================================================================== +# stream_tts_to_speaker functional tests +# ===================================================================== + +class TestStreamTtsToSpeaker: + """Functional tests for the streaming TTS pipeline.""" + + def test_none_sentinel_flushes_buffer(self): + """None sentinel causes remaining buffer to be spoken.""" + from tools.tts_tool import stream_tts_to_speaker + text_q = queue.Queue() + stop_evt = threading.Event() + done_evt = threading.Event() + spoken = [] + + def display(text): + spoken.append(text) + + text_q.put("Hello world.") + text_q.put(None) + + stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=display) + assert done_evt.is_set() + assert any("Hello" in s for s in spoken) + + def test_stop_event_aborts_early(self): + """Setting stop_event causes early exit.""" + from tools.tts_tool import stream_tts_to_speaker + text_q = queue.Queue() + stop_evt = threading.Event() + done_evt = threading.Event() + spoken = [] + + stop_evt.set() + text_q.put("Should not be spoken.") + text_q.put(None) + + stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t)) + assert done_evt.is_set() + assert len(spoken) == 0 + + def test_done_event_set_on_exception(self): + """tts_done_event is set even when an exception occurs.""" + from tools.tts_tool import stream_tts_to_speaker + text_q = queue.Queue() + stop_evt = threading.Event() + done_evt = threading.Event() + + # Put a non-string that will cause concatenation to fail + text_q.put(12345) + text_q.put(None) + + stream_tts_to_speaker(text_q, stop_evt, done_evt) + assert done_evt.is_set() + + def test_think_blocks_stripped(self): + """... content is not spoken.""" + from tools.tts_tool import stream_tts_to_speaker + text_q = queue.Queue() + stop_evt = threading.Event() + done_evt = threading.Event() + spoken = [] + + text_q.put("internal reasoning") + text_q.put("Visible response. ") + text_q.put(None) + + stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t)) + assert done_evt.is_set() + joined = " ".join(spoken) + assert "internal reasoning" not in joined + assert "Visible" in joined + + def test_sentence_splitting(self): + """Sentences are split at boundaries and spoken individually.""" + from tools.tts_tool import stream_tts_to_speaker + text_q = queue.Queue() + stop_evt = threading.Event() + done_evt = threading.Event() + spoken = [] + + # Two sentences long enough to exceed min_sentence_len (20) + text_q.put("This is the first sentence. ") + text_q.put("This is the second sentence. ") + text_q.put(None) + + stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t)) + assert done_evt.is_set() + assert len(spoken) >= 2 + + def test_markdown_stripped_in_speech(self): + """Markdown formatting is removed before display/speech.""" + from tools.tts_tool import stream_tts_to_speaker + text_q = queue.Queue() + stop_evt = threading.Event() + done_evt = threading.Event() + spoken = [] + + text_q.put("**Bold text** and `code`. ") + text_q.put(None) + + stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t)) + assert done_evt.is_set() + # Display callback gets raw text (before markdown stripping) + # But the actual TTS audio would be stripped — we verify pipeline doesn't crash + + def test_duplicate_sentences_deduped(self): + """Repeated sentences are spoken only once.""" + from tools.tts_tool import stream_tts_to_speaker + text_q = queue.Queue() + stop_evt = threading.Event() + done_evt = threading.Event() + spoken = [] + + # Same sentence twice, each long enough + text_q.put("This is a repeated sentence. ") + text_q.put("This is a repeated sentence. ") + text_q.put(None) + + stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t)) + assert done_evt.is_set() + # First occurrence is spoken, second is deduped + assert len(spoken) == 1 + + def test_no_api_key_display_only(self): + """Without ELEVENLABS_API_KEY, display callback still works.""" + from tools.tts_tool import stream_tts_to_speaker + text_q = queue.Queue() + stop_evt = threading.Event() + done_evt = threading.Event() + spoken = [] + + text_q.put("Display only text. ") + text_q.put(None) + + with patch.dict(os.environ, {"ELEVENLABS_API_KEY": ""}): + stream_tts_to_speaker(text_q, stop_evt, done_evt, + display_callback=lambda t: spoken.append(t)) + assert done_evt.is_set() + assert len(spoken) >= 1 + + def test_long_buffer_flushed_on_timeout(self): + """Buffer longer than long_flush_len is flushed on queue timeout.""" + from tools.tts_tool import stream_tts_to_speaker + text_q = queue.Queue() + stop_evt = threading.Event() + done_evt = threading.Event() + spoken = [] + + # Put a long text without sentence boundary, then None after a delay + long_text = "a" * 150 # > long_flush_len (100) + text_q.put(long_text) + + def delayed_sentinel(): + time.sleep(1.0) + text_q.put(None) + + t = threading.Thread(target=delayed_sentinel, daemon=True) + t.start() + + stream_tts_to_speaker(text_q, stop_evt, done_evt, + display_callback=lambda t: spoken.append(t)) + t.join(timeout=5) + assert done_evt.is_set() + assert len(spoken) >= 1