From 86ddaaee9c24b80920221b64e0a6064b10f883af Mon Sep 17 00:00:00 2001
From: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com>
Date: Wed, 11 Mar 2026 23:18:49 +0300
Subject: [PATCH] fix: extract voice reply logic and add comprehensive tests
- Fix tempfile.mktemp() TOCTOU race in Discord voice input (use NamedTemporaryFile)
- Extract voice reply decision from _handle_message into _should_send_voice_reply()
- Rewrite TestAutoVoiceReply to call real method instead of testing a copy
- Add 59 new tests: VoiceReceiver, VC commands, adapter methods, streaming TTS
---
gateway/platforms/discord.py | 4 +-
gateway/run.py | 96 ++--
tests/gateway/test_voice_command.py | 842 ++++++++++++++++++++++++++--
3 files changed, 845 insertions(+), 97 deletions(-)
diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py
index d10375e3c..cf57b37d9 100644
--- a/gateway/platforms/discord.py
+++ b/gateway/platforms/discord.py
@@ -851,7 +851,9 @@ class DiscordAdapter(BasePlatformAdapter):
"""Convert PCM -> WAV -> STT -> callback."""
from tools.voice_mode import is_whisper_hallucination
- wav_path = tempfile.mktemp(suffix=".wav", prefix="vc_listen_")
+ tmp_f = tempfile.NamedTemporaryFile(suffix=".wav", prefix="vc_listen_", delete=False)
+ wav_path = tmp_f.name
+ tmp_f.close()
try:
await asyncio.to_thread(VoiceReceiver.pcm_to_wav, pcm_data, wav_path)
diff --git a/gateway/run.py b/gateway/run.py
index b672ac48c..5deb093af 100644
--- a/gateway/run.py
+++ b/gateway/run.py
@@ -1616,42 +1616,8 @@ class GatewayRunner:
)
# Auto voice reply: send TTS audio before the text response
- chat_id = source.chat_id
- voice_mode = self._voice_mode.get(chat_id, "off")
- is_voice_input = (event.message_type == MessageType.VOICE)
- should_voice_reply = (
- (voice_mode == "all")
- or (voice_mode == "voice_only" and is_voice_input)
- )
- logger.info("Voice reply check: chat_id=%s, voice_mode=%s, is_voice=%s, should_reply=%s, has_response=%s",
- chat_id, voice_mode, is_voice_input, should_voice_reply, bool(response))
- if should_voice_reply and response and not response.startswith("Error:"):
- # Skip if agent already called TTS tool (avoid double voice)
- has_agent_tts = any(
- msg.get("role") == "assistant"
- and any(
- tc.get("function", {}).get("name") == "text_to_speech"
- for tc in (msg.get("tool_calls") or [])
- )
- for msg in agent_messages
- )
- # Skip if voice input — base adapter auto-TTS in
- # _process_message_background already sent audio for voice
- # messages, so sending another would be double.
- # Exception: Discord voice channel — the Discord play_tts
- # override also skips (no-op), so the runner MUST handle it
- # via play_in_voice_channel.
- skip_double = is_voice_input
- if skip_double:
- adapter = self.adapters.get(source.platform)
- guild_id = self._get_guild_id(event)
- if (guild_id and adapter
- and hasattr(adapter, "is_in_voice_channel")
- and adapter.is_in_voice_channel(guild_id)):
- skip_double = False
- logger.info("Voice reply: has_agent_tts=%s, skip_double=%s, calling _send_voice_reply", has_agent_tts, skip_double)
- if not has_agent_tts and not skip_double:
- await self._send_voice_reply(event, response)
+ if self._should_send_voice_reply(event, response, agent_messages):
+ await self._send_voice_reply(event, response)
return response
@@ -2302,6 +2268,64 @@ class GatewayRunner:
await adapter.handle_message(event)
+ def _should_send_voice_reply(
+ self,
+ event: MessageEvent,
+ response: str,
+ agent_messages: list,
+ ) -> bool:
+ """Decide whether the runner should send a TTS voice reply.
+
+ Returns False when:
+ - voice_mode is off for this chat
+ - response is empty or an error
+ - agent already called text_to_speech tool (dedup)
+ - voice input and base adapter auto-TTS already handled it (skip_double)
+ Exception: Discord voice channel — base play_tts is a no-op there,
+ so the runner must handle VC playback.
+ """
+ if not response or response.startswith("Error:"):
+ return False
+
+ chat_id = event.source.chat_id
+ voice_mode = self._voice_mode.get(chat_id, "off")
+ is_voice_input = (event.message_type == MessageType.VOICE)
+
+ should = (
+ (voice_mode == "all")
+ or (voice_mode == "voice_only" and is_voice_input)
+ )
+ if not should:
+ return False
+
+ # Dedup: agent already called TTS tool
+ has_agent_tts = any(
+ msg.get("role") == "assistant"
+ and any(
+ tc.get("function", {}).get("name") == "text_to_speech"
+ for tc in (msg.get("tool_calls") or [])
+ )
+ for msg in agent_messages
+ )
+ if has_agent_tts:
+ return False
+
+ # Dedup: base adapter auto-TTS already handles voice input.
+ # Exception: Discord voice channel — play_tts override is a no-op,
+ # so the runner must handle VC playback.
+ skip_double = is_voice_input
+ if skip_double:
+ adapter = self.adapters.get(event.source.platform)
+ guild_id = self._get_guild_id(event)
+ if (guild_id and adapter
+ and hasattr(adapter, "is_in_voice_channel")
+ and adapter.is_in_voice_channel(guild_id)):
+ skip_double = False
+ if skip_double:
+ return False
+
+ return True
+
async def _send_voice_reply(self, event: MessageEvent, text: str) -> None:
"""Generate TTS audio and send as a voice message before the text reply."""
try:
diff --git a/tests/gateway/test_voice_command.py b/tests/gateway/test_voice_command.py
index 1e4250960..f8ec52423 100644
--- a/tests/gateway/test_voice_command.py
+++ b/tests/gateway/test_voice_command.py
@@ -2,7 +2,11 @@
import json
import os
+import queue
+import threading
+import time
import pytest
+from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
from gateway.platforms.base import MessageEvent, MessageType, SessionSource
@@ -126,7 +130,7 @@ class TestHandleVoiceCommand:
# =====================================================================
class TestAutoVoiceReply:
- """Test the should_voice_reply decision logic (extracted from _handle_message).
+ """Test the real _should_send_voice_reply method on GatewayRunner.
The gateway has two TTS paths:
1. base adapter auto-TTS: fires for voice input in _process_message_background
@@ -138,43 +142,30 @@ class TestAutoVoiceReply:
override skip, so the runner must handle it via play_in_voice_channel.
"""
- def _should_reply(self, voice_mode, message_type, agent_messages=None,
- response="Hello!", in_voice_channel=False):
- """Replicate the auto voice reply decision from _handle_message."""
- if not response or response.startswith("Error:"):
- return False
+ @pytest.fixture
+ def runner(self, tmp_path):
+ return _make_runner(tmp_path)
- is_voice_input = (message_type == MessageType.VOICE)
- should = (
- (voice_mode == "all")
- or (voice_mode == "voice_only" and is_voice_input)
+ def _call(self, runner, voice_mode, message_type, agent_messages=None,
+ response="Hello!", in_voice_channel=False):
+ """Call real _should_send_voice_reply on a GatewayRunner instance."""
+ chat_id = "123"
+ if voice_mode != "off":
+ runner._voice_mode[chat_id] = voice_mode
+ else:
+ runner._voice_mode.pop(chat_id, None)
+
+ event = _make_event(message_type=message_type)
+
+ if in_voice_channel:
+ mock_adapter = MagicMock()
+ mock_adapter.is_in_voice_channel = MagicMock(return_value=True)
+ event.raw_message = SimpleNamespace(guild_id=111, guild=None)
+ runner.adapters[event.source.platform] = mock_adapter
+
+ return runner._should_send_voice_reply(
+ event, response, agent_messages or []
)
- if not should:
- return False
-
- # Dedup: agent already called TTS tool
- if agent_messages:
- has_agent_tts = any(
- msg.get("role") == "assistant"
- and any(
- tc.get("function", {}).get("name") == "text_to_speech"
- for tc in (msg.get("tool_calls") or [])
- )
- for msg in agent_messages
- )
- if has_agent_tts:
- return False
-
- # Dedup: base adapter auto-TTS already handles voice input.
- # Exception: in voice channel, Discord play_tts also skips,
- # so the runner must handle VC playback.
- skip_double = is_voice_input
- if skip_double and in_voice_channel:
- skip_double = False
- if skip_double:
- return False
-
- return True
# -- Full platform x input x mode matrix --------------------------------
#
@@ -204,52 +195,52 @@ class TestAutoVoiceReply:
# -- Telegram/Slack/Web: voice input, base handles ---------------------
- def test_voice_input_voice_only_skipped(self):
+ def test_voice_input_voice_only_skipped(self, runner):
"""voice_only + voice input: base auto-TTS handles it, runner skips."""
- assert self._should_reply("voice_only", MessageType.VOICE) is False
+ assert self._call(runner, "voice_only", MessageType.VOICE) is False
- def test_voice_input_all_mode_skipped(self):
+ def test_voice_input_all_mode_skipped(self, runner):
"""all + voice input: base auto-TTS handles it, runner skips."""
- assert self._should_reply("all", MessageType.VOICE) is False
+ assert self._call(runner, "all", MessageType.VOICE) is False
# -- Text input: only runner handles -----------------------------------
- def test_text_input_all_mode_runner_fires(self):
+ def test_text_input_all_mode_runner_fires(self, runner):
"""all + text input: only runner fires (base auto-TTS only for voice)."""
- assert self._should_reply("all", MessageType.TEXT) is True
+ assert self._call(runner, "all", MessageType.TEXT) is True
- def test_text_input_voice_only_no_reply(self):
+ def test_text_input_voice_only_no_reply(self, runner):
"""voice_only + text input: neither fires."""
- assert self._should_reply("voice_only", MessageType.TEXT) is False
+ assert self._call(runner, "voice_only", MessageType.TEXT) is False
# -- Mode off: nothing fires -------------------------------------------
- def test_off_mode_voice(self):
- assert self._should_reply("off", MessageType.VOICE) is False
+ def test_off_mode_voice(self, runner):
+ assert self._call(runner, "off", MessageType.VOICE) is False
- def test_off_mode_text(self):
- assert self._should_reply("off", MessageType.TEXT) is False
+ def test_off_mode_text(self, runner):
+ assert self._call(runner, "off", MessageType.TEXT) is False
# -- Discord VC exception: runner must handle --------------------------
- def test_discord_vc_voice_input_runner_fires(self):
+ def test_discord_vc_voice_input_runner_fires(self, runner):
"""Discord VC + voice input: base play_tts skips (VC override),
so runner must handle via play_in_voice_channel."""
- assert self._should_reply("all", MessageType.VOICE, in_voice_channel=True) is True
+ assert self._call(runner, "all", MessageType.VOICE, in_voice_channel=True) is True
- def test_discord_vc_voice_only_runner_fires(self):
+ def test_discord_vc_voice_only_runner_fires(self, runner):
"""Discord VC + voice_only + voice: runner must handle."""
- assert self._should_reply("voice_only", MessageType.VOICE, in_voice_channel=True) is True
+ assert self._call(runner, "voice_only", MessageType.VOICE, in_voice_channel=True) is True
# -- Edge cases --------------------------------------------------------
- def test_error_response_skipped(self):
- assert self._should_reply("all", MessageType.TEXT, response="Error: boom") is False
+ def test_error_response_skipped(self, runner):
+ assert self._call(runner, "all", MessageType.TEXT, response="Error: boom") is False
- def test_empty_response_skipped(self):
- assert self._should_reply("all", MessageType.TEXT, response="") is False
+ def test_empty_response_skipped(self, runner):
+ assert self._call(runner, "all", MessageType.TEXT, response="") is False
- def test_dedup_skips_when_agent_called_tts(self):
+ def test_dedup_skips_when_agent_called_tts(self, runner):
messages = [{
"role": "assistant",
"tool_calls": [{
@@ -258,9 +249,9 @@ class TestAutoVoiceReply:
"function": {"name": "text_to_speech", "arguments": "{}"},
}],
}]
- assert self._should_reply("all", MessageType.TEXT, agent_messages=messages) is False
+ assert self._call(runner, "all", MessageType.TEXT, agent_messages=messages) is False
- def test_no_dedup_for_other_tools(self):
+ def test_no_dedup_for_other_tools(self, runner):
messages = [{
"role": "assistant",
"tool_calls": [{
@@ -269,7 +260,7 @@ class TestAutoVoiceReply:
"function": {"name": "web_search", "arguments": "{}"},
}],
}]
- assert self._should_reply("all", MessageType.TEXT, agent_messages=messages) is True
+ assert self._call(runner, "all", MessageType.TEXT, agent_messages=messages) is True
# =====================================================================
@@ -443,3 +434,734 @@ class TestVoiceInHelp:
import inspect
source = inspect.getsource(GatewayRunner._handle_message)
assert '"voice"' in source
+
+
+# =====================================================================
+# VoiceReceiver unit tests
+# =====================================================================
+
+class TestVoiceReceiver:
+ """Test VoiceReceiver silence detection, SSRC mapping, and lifecycle."""
+
+ def _make_receiver(self):
+ from gateway.platforms.discord import VoiceReceiver
+ mock_vc = MagicMock()
+ mock_vc._connection.secret_key = [0] * 32
+ mock_vc._connection.dave_session = None
+ mock_vc._connection.ssrc = 9999
+ mock_vc._connection.add_socket_listener = MagicMock()
+ mock_vc._connection.remove_socket_listener = MagicMock()
+ mock_vc._connection.hook = None
+ receiver = VoiceReceiver(mock_vc)
+ return receiver
+
+ def test_initial_state(self):
+ receiver = self._make_receiver()
+ assert receiver._running is False
+ assert receiver._paused is False
+ assert len(receiver._buffers) == 0
+ assert len(receiver._ssrc_to_user) == 0
+
+ def test_start_sets_running(self):
+ receiver = self._make_receiver()
+ receiver.start()
+ assert receiver._running is True
+
+ def test_stop_clears_state(self):
+ receiver = self._make_receiver()
+ receiver.start()
+ receiver.map_ssrc(100, 42)
+ receiver._buffers[100] = bytearray(b"\x00" * 1000)
+ receiver._last_packet_time[100] = time.monotonic()
+ receiver.stop()
+ assert receiver._running is False
+ assert len(receiver._buffers) == 0
+ assert len(receiver._ssrc_to_user) == 0
+ assert len(receiver._last_packet_time) == 0
+
+ def test_map_ssrc(self):
+ receiver = self._make_receiver()
+ receiver.map_ssrc(100, 42)
+ assert receiver._ssrc_to_user[100] == 42
+
+ def test_map_ssrc_overwrites(self):
+ receiver = self._make_receiver()
+ receiver.map_ssrc(100, 42)
+ receiver.map_ssrc(100, 99)
+ assert receiver._ssrc_to_user[100] == 99
+
+ def test_pause_resume(self):
+ receiver = self._make_receiver()
+ assert receiver._paused is False
+ receiver.pause()
+ assert receiver._paused is True
+ receiver.resume()
+ assert receiver._paused is False
+
+ def test_check_silence_empty(self):
+ receiver = self._make_receiver()
+ assert receiver.check_silence() == []
+
+ def test_check_silence_returns_completed_utterance(self):
+ receiver = self._make_receiver()
+ receiver.map_ssrc(100, 42)
+ # 48kHz, stereo, 16-bit = 192000 bytes/sec
+ # MIN_SPEECH_DURATION = 0.5s → need 96000 bytes
+ pcm_data = bytearray(b"\x00" * 96000)
+ receiver._buffers[100] = pcm_data
+ # Set last_packet_time far enough in the past to exceed SILENCE_THRESHOLD
+ receiver._last_packet_time[100] = time.monotonic() - 3.0
+ completed = receiver.check_silence()
+ assert len(completed) == 1
+ user_id, data = completed[0]
+ assert user_id == 42
+ assert len(data) == 96000
+ # Buffer should be cleared after extraction
+ assert len(receiver._buffers[100]) == 0
+
+ def test_check_silence_ignores_short_buffer(self):
+ receiver = self._make_receiver()
+ receiver.map_ssrc(100, 42)
+ # Too short to meet MIN_SPEECH_DURATION
+ receiver._buffers[100] = bytearray(b"\x00" * 100)
+ receiver._last_packet_time[100] = time.monotonic() - 3.0
+ completed = receiver.check_silence()
+ assert len(completed) == 0
+
+ def test_check_silence_ignores_recent_audio(self):
+ receiver = self._make_receiver()
+ receiver.map_ssrc(100, 42)
+ receiver._buffers[100] = bytearray(b"\x00" * 96000)
+ receiver._last_packet_time[100] = time.monotonic() # just now
+ completed = receiver.check_silence()
+ assert len(completed) == 0
+
+ def test_check_silence_unknown_user_discarded(self):
+ receiver = self._make_receiver()
+ # No SSRC mapping — user_id will be 0
+ receiver._buffers[100] = bytearray(b"\x00" * 96000)
+ receiver._last_packet_time[100] = time.monotonic() - 3.0
+ completed = receiver.check_silence()
+ assert len(completed) == 0
+
+ def test_stale_buffer_discarded(self):
+ receiver = self._make_receiver()
+ # Buffer with no user mapping and very old timestamp
+ receiver._buffers[200] = bytearray(b"\x00" * 100)
+ receiver._last_packet_time[200] = time.monotonic() - 10.0
+ receiver.check_silence()
+ # Stale buffer (> 2x threshold) should be discarded
+ assert 200 not in receiver._buffers
+
+ def test_on_packet_skips_when_not_running(self):
+ receiver = self._make_receiver()
+ # Not started — _running is False
+ receiver._on_packet(b"\x00" * 100)
+ assert len(receiver._buffers) == 0
+
+ def test_on_packet_skips_when_paused(self):
+ receiver = self._make_receiver()
+ receiver.start()
+ receiver.pause()
+ receiver._on_packet(b"\x00" * 100)
+ # Paused — should not process
+ assert len(receiver._buffers) == 0
+
+ def test_on_packet_skips_short_data(self):
+ receiver = self._make_receiver()
+ receiver.start()
+ receiver._on_packet(b"\x00" * 10)
+ assert len(receiver._buffers) == 0
+
+ def test_on_packet_skips_non_rtp(self):
+ receiver = self._make_receiver()
+ receiver.start()
+ # Valid length but wrong RTP version
+ data = bytearray(b"\x00" * 20)
+ data[0] = 0x00 # version 0, not 2
+ receiver._on_packet(bytes(data))
+ assert len(receiver._buffers) == 0
+
+
+# =====================================================================
+# Gateway voice channel commands (join / leave / input)
+# =====================================================================
+
+class TestVoiceChannelCommands:
+ """Test _handle_voice_channel_join, _handle_voice_channel_leave,
+ _handle_voice_channel_input on the GatewayRunner."""
+
+ @pytest.fixture
+ def runner(self, tmp_path):
+ return _make_runner(tmp_path)
+
+ def _make_discord_event(self, text="/voice channel", chat_id="123",
+ guild_id=111, user_id="user1"):
+ """Create event with raw_message carrying guild info."""
+ source = SessionSource(
+ chat_id=chat_id,
+ user_id=user_id,
+ platform=MagicMock(),
+ )
+ source.platform.value = "discord"
+ source.thread_id = None
+ event = MessageEvent(text=text, message_type=MessageType.TEXT, source=source)
+ event.message_id = "msg42"
+ event.raw_message = SimpleNamespace(guild_id=guild_id, guild=None)
+ return event
+
+ # -- _handle_voice_channel_join --
+
+ @pytest.mark.asyncio
+ async def test_join_unsupported_platform(self, runner):
+ """Platform without join_voice_channel returns unsupported message."""
+ mock_adapter = AsyncMock(spec=[]) # no join_voice_channel
+ event = self._make_discord_event()
+ runner.adapters[event.source.platform] = mock_adapter
+ result = await runner._handle_voice_channel_join(event)
+ assert "not supported" in result.lower()
+
+ @pytest.mark.asyncio
+ async def test_join_no_guild_id(self, runner):
+ """DM context (no guild_id) returns error."""
+ mock_adapter = AsyncMock()
+ mock_adapter.join_voice_channel = AsyncMock()
+ event = self._make_discord_event()
+ event.raw_message = None # no guild info
+ runner.adapters[event.source.platform] = mock_adapter
+ result = await runner._handle_voice_channel_join(event)
+ assert "discord server" in result.lower()
+
+ @pytest.mark.asyncio
+ async def test_join_user_not_in_vc(self, runner):
+ """User not in any voice channel."""
+ mock_adapter = AsyncMock()
+ mock_adapter.join_voice_channel = AsyncMock()
+ mock_adapter.get_user_voice_channel = AsyncMock(return_value=None)
+ event = self._make_discord_event()
+ runner.adapters[event.source.platform] = mock_adapter
+ result = await runner._handle_voice_channel_join(event)
+ assert "need to be in a voice channel" in result.lower()
+
+ @pytest.mark.asyncio
+ async def test_join_success(self, runner):
+ """Successful join sets voice_mode and returns confirmation."""
+ mock_channel = MagicMock()
+ mock_channel.name = "General"
+ mock_adapter = AsyncMock()
+ mock_adapter.join_voice_channel = AsyncMock(return_value=True)
+ mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel)
+ mock_adapter._voice_text_channels = {}
+ mock_adapter._voice_input_callback = None
+ event = self._make_discord_event()
+ runner.adapters[event.source.platform] = mock_adapter
+ result = await runner._handle_voice_channel_join(event)
+ assert "joined" in result.lower()
+ assert "General" in result
+ assert runner._voice_mode["123"] == "all"
+
+ @pytest.mark.asyncio
+ async def test_join_failure(self, runner):
+ """Failed join returns permissions error."""
+ mock_channel = MagicMock()
+ mock_channel.name = "General"
+ mock_adapter = AsyncMock()
+ mock_adapter.join_voice_channel = AsyncMock(return_value=False)
+ mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel)
+ event = self._make_discord_event()
+ runner.adapters[event.source.platform] = mock_adapter
+ result = await runner._handle_voice_channel_join(event)
+ assert "failed" in result.lower()
+
+ @pytest.mark.asyncio
+ async def test_join_exception(self, runner):
+ """Exception during join is caught and reported."""
+ mock_channel = MagicMock()
+ mock_channel.name = "General"
+ mock_adapter = AsyncMock()
+ mock_adapter.join_voice_channel = AsyncMock(side_effect=RuntimeError("No permission"))
+ mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel)
+ event = self._make_discord_event()
+ runner.adapters[event.source.platform] = mock_adapter
+ result = await runner._handle_voice_channel_join(event)
+ assert "failed" in result.lower()
+
+ # -- _handle_voice_channel_leave --
+
+ @pytest.mark.asyncio
+ async def test_leave_not_in_vc(self, runner):
+ """Leave when not in VC returns appropriate message."""
+ mock_adapter = AsyncMock()
+ mock_adapter.is_in_voice_channel = MagicMock(return_value=False)
+ event = self._make_discord_event("/voice leave")
+ runner.adapters[event.source.platform] = mock_adapter
+ result = await runner._handle_voice_channel_leave(event)
+ assert "not in" in result.lower()
+
+ @pytest.mark.asyncio
+ async def test_leave_no_guild(self, runner):
+ """Leave from DM returns not in voice channel."""
+ mock_adapter = AsyncMock()
+ event = self._make_discord_event("/voice leave")
+ event.raw_message = None
+ runner.adapters[event.source.platform] = mock_adapter
+ result = await runner._handle_voice_channel_leave(event)
+ assert "not in" in result.lower()
+
+ @pytest.mark.asyncio
+ async def test_leave_success(self, runner):
+ """Successful leave disconnects and clears voice mode."""
+ mock_adapter = AsyncMock()
+ mock_adapter.is_in_voice_channel = MagicMock(return_value=True)
+ mock_adapter.leave_voice_channel = AsyncMock()
+ event = self._make_discord_event("/voice leave")
+ runner.adapters[event.source.platform] = mock_adapter
+ runner._voice_mode["123"] = "all"
+ result = await runner._handle_voice_channel_leave(event)
+ assert "left" in result.lower()
+ assert "123" not in runner._voice_mode
+ mock_adapter.leave_voice_channel.assert_called_once_with(111)
+
+ # -- _handle_voice_channel_input --
+
+ @pytest.mark.asyncio
+ async def test_input_no_adapter(self, runner):
+ """No Discord adapter — early return, no crash."""
+ from gateway.config import Platform
+ # No adapters set
+ await runner._handle_voice_channel_input(111, 42, "Hello")
+
+ @pytest.mark.asyncio
+ async def test_input_no_text_channel(self, runner):
+ """No text channel mapped for guild — early return."""
+ from gateway.config import Platform
+ mock_adapter = AsyncMock()
+ mock_adapter._voice_text_channels = {}
+ mock_adapter._client = MagicMock()
+ runner.adapters[Platform.DISCORD] = mock_adapter
+ await runner._handle_voice_channel_input(111, 42, "Hello")
+
+ @pytest.mark.asyncio
+ async def test_input_creates_event_and_dispatches(self, runner):
+ """Voice input creates synthetic event and calls handle_message."""
+ from gateway.config import Platform
+ mock_adapter = AsyncMock()
+ mock_adapter._voice_text_channels = {111: 123}
+ mock_channel = AsyncMock()
+ mock_adapter._client = MagicMock()
+ mock_adapter._client.get_channel = MagicMock(return_value=mock_channel)
+ mock_adapter.handle_message = AsyncMock()
+ runner.adapters[Platform.DISCORD] = mock_adapter
+ await runner._handle_voice_channel_input(111, 42, "Hello from VC")
+ mock_adapter.handle_message.assert_called_once()
+ event = mock_adapter.handle_message.call_args[0][0]
+ assert event.text == "Hello from VC"
+ assert event.message_type == MessageType.VOICE
+ assert event.source.chat_id == "123"
+
+ @pytest.mark.asyncio
+ async def test_input_posts_transcript_in_text_channel(self, runner):
+ """Voice input sends transcript message to text channel."""
+ from gateway.config import Platform
+ mock_adapter = AsyncMock()
+ mock_adapter._voice_text_channels = {111: 123}
+ mock_channel = AsyncMock()
+ mock_adapter._client = MagicMock()
+ mock_adapter._client.get_channel = MagicMock(return_value=mock_channel)
+ mock_adapter.handle_message = AsyncMock()
+ runner.adapters[Platform.DISCORD] = mock_adapter
+ await runner._handle_voice_channel_input(111, 42, "Test transcript")
+ mock_channel.send.assert_called_once()
+ msg = mock_channel.send.call_args[0][0]
+ assert "Test transcript" in msg
+ assert "42" in msg # user_id in mention
+
+ # -- _get_guild_id --
+
+ def test_get_guild_id_from_guild(self, runner):
+ event = _make_event()
+ mock_guild = MagicMock()
+ mock_guild.id = 555
+ event.raw_message = SimpleNamespace(guild_id=None, guild=mock_guild)
+ result = runner._get_guild_id(event)
+ assert result == 555
+
+ def test_get_guild_id_from_interaction(self, runner):
+ event = _make_event()
+ event.raw_message = SimpleNamespace(guild_id=777, guild=None)
+ result = runner._get_guild_id(event)
+ assert result == 777
+
+ def test_get_guild_id_none(self, runner):
+ event = _make_event()
+ event.raw_message = None
+ result = runner._get_guild_id(event)
+ assert result is None
+
+ def test_get_guild_id_dm(self, runner):
+ event = _make_event()
+ event.raw_message = SimpleNamespace(guild_id=None, guild=None)
+ result = runner._get_guild_id(event)
+ assert result is None
+
+
+# =====================================================================
+# Discord adapter voice channel methods
+# =====================================================================
+
+class TestDiscordVoiceChannelMethods:
+ """Test DiscordAdapter voice channel methods (join, leave, play, etc.)."""
+
+ def _make_adapter(self):
+ from gateway.platforms.discord import DiscordAdapter
+ from gateway.config import Platform, PlatformConfig
+ config = PlatformConfig(enabled=True, extra={})
+ config.token = "fake-token"
+ adapter = object.__new__(DiscordAdapter)
+ adapter.platform = Platform.DISCORD
+ adapter.config = config
+ adapter._client = MagicMock()
+ adapter._voice_clients = {}
+ adapter._voice_text_channels = {}
+ adapter._voice_timeout_tasks = {}
+ adapter._voice_receivers = {}
+ adapter._voice_listen_tasks = {}
+ adapter._voice_input_callback = None
+ adapter._allowed_user_ids = set()
+ adapter._running = True
+ return adapter
+
+ def test_is_in_voice_channel_true(self):
+ adapter = self._make_adapter()
+ mock_vc = MagicMock()
+ mock_vc.is_connected.return_value = True
+ adapter._voice_clients[111] = mock_vc
+ assert adapter.is_in_voice_channel(111) is True
+
+ def test_is_in_voice_channel_false_no_client(self):
+ adapter = self._make_adapter()
+ assert adapter.is_in_voice_channel(111) is False
+
+ def test_is_in_voice_channel_false_disconnected(self):
+ adapter = self._make_adapter()
+ mock_vc = MagicMock()
+ mock_vc.is_connected.return_value = False
+ adapter._voice_clients[111] = mock_vc
+ assert adapter.is_in_voice_channel(111) is False
+
+ @pytest.mark.asyncio
+ async def test_leave_voice_channel_cleans_up(self):
+ adapter = self._make_adapter()
+ mock_vc = MagicMock()
+ mock_vc.is_connected.return_value = True
+ mock_vc.disconnect = AsyncMock()
+ adapter._voice_clients[111] = mock_vc
+ adapter._voice_text_channels[111] = 123
+
+ mock_receiver = MagicMock()
+ adapter._voice_receivers[111] = mock_receiver
+
+ mock_task = MagicMock()
+ adapter._voice_listen_tasks[111] = mock_task
+
+ mock_timeout = MagicMock()
+ adapter._voice_timeout_tasks[111] = mock_timeout
+
+ await adapter.leave_voice_channel(111)
+
+ mock_receiver.stop.assert_called_once()
+ mock_task.cancel.assert_called_once()
+ mock_vc.disconnect.assert_called_once()
+ mock_timeout.cancel.assert_called_once()
+ assert 111 not in adapter._voice_clients
+ assert 111 not in adapter._voice_text_channels
+ assert 111 not in adapter._voice_receivers
+
+ @pytest.mark.asyncio
+ async def test_leave_voice_channel_no_connection(self):
+ """Leave when not connected — no crash."""
+ adapter = self._make_adapter()
+ await adapter.leave_voice_channel(111) # should not raise
+
+ @pytest.mark.asyncio
+ async def test_get_user_voice_channel_no_client(self):
+ adapter = self._make_adapter()
+ adapter._client = None
+ result = await adapter.get_user_voice_channel(111, "42")
+ assert result is None
+
+ @pytest.mark.asyncio
+ async def test_get_user_voice_channel_no_guild(self):
+ adapter = self._make_adapter()
+ adapter._client.get_guild = MagicMock(return_value=None)
+ result = await adapter.get_user_voice_channel(111, "42")
+ assert result is None
+
+ @pytest.mark.asyncio
+ async def test_get_user_voice_channel_user_not_in_vc(self):
+ adapter = self._make_adapter()
+ mock_guild = MagicMock()
+ mock_member = MagicMock()
+ mock_member.voice = None
+ mock_guild.get_member = MagicMock(return_value=mock_member)
+ adapter._client.get_guild = MagicMock(return_value=mock_guild)
+ result = await adapter.get_user_voice_channel(111, "42")
+ assert result is None
+
+ @pytest.mark.asyncio
+ async def test_get_user_voice_channel_success(self):
+ adapter = self._make_adapter()
+ mock_vc = MagicMock()
+ mock_guild = MagicMock()
+ mock_member = MagicMock()
+ mock_member.voice = MagicMock()
+ mock_member.voice.channel = mock_vc
+ mock_guild.get_member = MagicMock(return_value=mock_member)
+ adapter._client.get_guild = MagicMock(return_value=mock_guild)
+ result = await adapter.get_user_voice_channel(111, "42")
+ assert result is mock_vc
+
+ @pytest.mark.asyncio
+ async def test_play_in_voice_channel_not_connected(self):
+ adapter = self._make_adapter()
+ result = await adapter.play_in_voice_channel(111, "/tmp/test.ogg")
+ assert result is False
+
+ def test_is_allowed_user_empty_list(self):
+ adapter = self._make_adapter()
+ assert adapter._is_allowed_user("42") is True
+
+ def test_is_allowed_user_in_list(self):
+ adapter = self._make_adapter()
+ adapter._allowed_user_ids = {"42", "99"}
+ assert adapter._is_allowed_user("42") is True
+
+ def test_is_allowed_user_not_in_list(self):
+ adapter = self._make_adapter()
+ adapter._allowed_user_ids = {"99"}
+ assert adapter._is_allowed_user("42") is False
+
+ @pytest.mark.asyncio
+ async def test_process_voice_input_success(self):
+ """Successful voice input: PCM->WAV->STT->callback."""
+ adapter = self._make_adapter()
+ callback = AsyncMock()
+ adapter._voice_input_callback = callback
+ adapter._allowed_user_ids = set()
+
+ pcm_data = b"\x00" * 96000
+
+ with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \
+ patch("tools.transcription_tools.transcribe_audio",
+ return_value={"success": True, "transcript": "Hello"}), \
+ patch("tools.voice_mode.is_whisper_hallucination", return_value=False):
+ await adapter._process_voice_input(111, 42, pcm_data)
+
+ callback.assert_called_once_with(guild_id=111, user_id=42, transcript="Hello")
+
+ @pytest.mark.asyncio
+ async def test_process_voice_input_hallucination_filtered(self):
+ """Whisper hallucination is filtered out."""
+ adapter = self._make_adapter()
+ callback = AsyncMock()
+ adapter._voice_input_callback = callback
+
+ with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \
+ patch("tools.transcription_tools.transcribe_audio",
+ return_value={"success": True, "transcript": "Thank you."}), \
+ patch("tools.voice_mode.is_whisper_hallucination", return_value=True):
+ await adapter._process_voice_input(111, 42, b"\x00" * 96000)
+
+ callback.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_process_voice_input_stt_failure(self):
+ """STT failure — callback not called."""
+ adapter = self._make_adapter()
+ callback = AsyncMock()
+ adapter._voice_input_callback = callback
+
+ with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \
+ patch("tools.transcription_tools.transcribe_audio",
+ return_value={"success": False, "error": "API error"}):
+ await adapter._process_voice_input(111, 42, b"\x00" * 96000)
+
+ callback.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_process_voice_input_exception_caught(self):
+ """Exception during processing is caught, no crash."""
+ adapter = self._make_adapter()
+ adapter._voice_input_callback = AsyncMock()
+
+ with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav",
+ side_effect=RuntimeError("ffmpeg not found")):
+ await adapter._process_voice_input(111, 42, b"\x00" * 96000)
+ # Should not raise
+
+
+# =====================================================================
+# stream_tts_to_speaker functional tests
+# =====================================================================
+
+class TestStreamTtsToSpeaker:
+ """Functional tests for the streaming TTS pipeline."""
+
+ def test_none_sentinel_flushes_buffer(self):
+ """None sentinel causes remaining buffer to be spoken."""
+ from tools.tts_tool import stream_tts_to_speaker
+ text_q = queue.Queue()
+ stop_evt = threading.Event()
+ done_evt = threading.Event()
+ spoken = []
+
+ def display(text):
+ spoken.append(text)
+
+ text_q.put("Hello world.")
+ text_q.put(None)
+
+ stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=display)
+ assert done_evt.is_set()
+ assert any("Hello" in s for s in spoken)
+
+ def test_stop_event_aborts_early(self):
+ """Setting stop_event causes early exit."""
+ from tools.tts_tool import stream_tts_to_speaker
+ text_q = queue.Queue()
+ stop_evt = threading.Event()
+ done_evt = threading.Event()
+ spoken = []
+
+ stop_evt.set()
+ text_q.put("Should not be spoken.")
+ text_q.put(None)
+
+ stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
+ assert done_evt.is_set()
+ assert len(spoken) == 0
+
+ def test_done_event_set_on_exception(self):
+ """tts_done_event is set even when an exception occurs."""
+ from tools.tts_tool import stream_tts_to_speaker
+ text_q = queue.Queue()
+ stop_evt = threading.Event()
+ done_evt = threading.Event()
+
+ # Put a non-string that will cause concatenation to fail
+ text_q.put(12345)
+ text_q.put(None)
+
+ stream_tts_to_speaker(text_q, stop_evt, done_evt)
+ assert done_evt.is_set()
+
+ def test_think_blocks_stripped(self):
+ """... content is not spoken."""
+ from tools.tts_tool import stream_tts_to_speaker
+ text_q = queue.Queue()
+ stop_evt = threading.Event()
+ done_evt = threading.Event()
+ spoken = []
+
+ text_q.put("internal reasoning")
+ text_q.put("Visible response. ")
+ text_q.put(None)
+
+ stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
+ assert done_evt.is_set()
+ joined = " ".join(spoken)
+ assert "internal reasoning" not in joined
+ assert "Visible" in joined
+
+ def test_sentence_splitting(self):
+ """Sentences are split at boundaries and spoken individually."""
+ from tools.tts_tool import stream_tts_to_speaker
+ text_q = queue.Queue()
+ stop_evt = threading.Event()
+ done_evt = threading.Event()
+ spoken = []
+
+ # Two sentences long enough to exceed min_sentence_len (20)
+ text_q.put("This is the first sentence. ")
+ text_q.put("This is the second sentence. ")
+ text_q.put(None)
+
+ stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
+ assert done_evt.is_set()
+ assert len(spoken) >= 2
+
+ def test_markdown_stripped_in_speech(self):
+ """Markdown formatting is removed before display/speech."""
+ from tools.tts_tool import stream_tts_to_speaker
+ text_q = queue.Queue()
+ stop_evt = threading.Event()
+ done_evt = threading.Event()
+ spoken = []
+
+ text_q.put("**Bold text** and `code`. ")
+ text_q.put(None)
+
+ stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
+ assert done_evt.is_set()
+ # Display callback gets raw text (before markdown stripping)
+ # But the actual TTS audio would be stripped — we verify pipeline doesn't crash
+
+ def test_duplicate_sentences_deduped(self):
+ """Repeated sentences are spoken only once."""
+ from tools.tts_tool import stream_tts_to_speaker
+ text_q = queue.Queue()
+ stop_evt = threading.Event()
+ done_evt = threading.Event()
+ spoken = []
+
+ # Same sentence twice, each long enough
+ text_q.put("This is a repeated sentence. ")
+ text_q.put("This is a repeated sentence. ")
+ text_q.put(None)
+
+ stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
+ assert done_evt.is_set()
+ # First occurrence is spoken, second is deduped
+ assert len(spoken) == 1
+
+ def test_no_api_key_display_only(self):
+ """Without ELEVENLABS_API_KEY, display callback still works."""
+ from tools.tts_tool import stream_tts_to_speaker
+ text_q = queue.Queue()
+ stop_evt = threading.Event()
+ done_evt = threading.Event()
+ spoken = []
+
+ text_q.put("Display only text. ")
+ text_q.put(None)
+
+ with patch.dict(os.environ, {"ELEVENLABS_API_KEY": ""}):
+ stream_tts_to_speaker(text_q, stop_evt, done_evt,
+ display_callback=lambda t: spoken.append(t))
+ assert done_evt.is_set()
+ assert len(spoken) >= 1
+
+ def test_long_buffer_flushed_on_timeout(self):
+ """Buffer longer than long_flush_len is flushed on queue timeout."""
+ from tools.tts_tool import stream_tts_to_speaker
+ text_q = queue.Queue()
+ stop_evt = threading.Event()
+ done_evt = threading.Event()
+ spoken = []
+
+ # Put a long text without sentence boundary, then None after a delay
+ long_text = "a" * 150 # > long_flush_len (100)
+ text_q.put(long_text)
+
+ def delayed_sentinel():
+ time.sleep(1.0)
+ text_q.put(None)
+
+ t = threading.Thread(target=delayed_sentinel, daemon=True)
+ t.start()
+
+ stream_tts_to_speaker(text_q, stop_evt, done_evt,
+ display_callback=lambda t: spoken.append(t))
+ t.join(timeout=5)
+ assert done_evt.is_set()
+ assert len(spoken) >= 1