fix: extract voice reply logic and add comprehensive tests

- Fix tempfile.mktemp() TOCTOU race in Discord voice input (use NamedTemporaryFile)
- Extract voice reply decision from _handle_message into _should_send_voice_reply()
- Rewrite TestAutoVoiceReply to call real method instead of testing a copy
- Add 59 new tests: VoiceReceiver, VC commands, adapter methods, streaming TTS
This commit is contained in:
0xbyt4
2026-03-11 23:18:49 +03:00
parent 0d56b79685
commit 86ddaaee9c
3 changed files with 845 additions and 97 deletions

View File

@@ -851,7 +851,9 @@ class DiscordAdapter(BasePlatformAdapter):
"""Convert PCM -> WAV -> STT -> callback."""
from tools.voice_mode import is_whisper_hallucination
wav_path = tempfile.mktemp(suffix=".wav", prefix="vc_listen_")
tmp_f = tempfile.NamedTemporaryFile(suffix=".wav", prefix="vc_listen_", delete=False)
wav_path = tmp_f.name
tmp_f.close()
try:
await asyncio.to_thread(VoiceReceiver.pcm_to_wav, pcm_data, wav_path)

View File

@@ -1616,42 +1616,8 @@ class GatewayRunner:
)
# Auto voice reply: send TTS audio before the text response
chat_id = source.chat_id
voice_mode = self._voice_mode.get(chat_id, "off")
is_voice_input = (event.message_type == MessageType.VOICE)
should_voice_reply = (
(voice_mode == "all")
or (voice_mode == "voice_only" and is_voice_input)
)
logger.info("Voice reply check: chat_id=%s, voice_mode=%s, is_voice=%s, should_reply=%s, has_response=%s",
chat_id, voice_mode, is_voice_input, should_voice_reply, bool(response))
if should_voice_reply and response and not response.startswith("Error:"):
# Skip if agent already called TTS tool (avoid double voice)
has_agent_tts = any(
msg.get("role") == "assistant"
and any(
tc.get("function", {}).get("name") == "text_to_speech"
for tc in (msg.get("tool_calls") or [])
)
for msg in agent_messages
)
# Skip if voice input — base adapter auto-TTS in
# _process_message_background already sent audio for voice
# messages, so sending another would be double.
# Exception: Discord voice channel — the Discord play_tts
# override also skips (no-op), so the runner MUST handle it
# via play_in_voice_channel.
skip_double = is_voice_input
if skip_double:
adapter = self.adapters.get(source.platform)
guild_id = self._get_guild_id(event)
if (guild_id and adapter
and hasattr(adapter, "is_in_voice_channel")
and adapter.is_in_voice_channel(guild_id)):
skip_double = False
logger.info("Voice reply: has_agent_tts=%s, skip_double=%s, calling _send_voice_reply", has_agent_tts, skip_double)
if not has_agent_tts and not skip_double:
await self._send_voice_reply(event, response)
if self._should_send_voice_reply(event, response, agent_messages):
await self._send_voice_reply(event, response)
return response
@@ -2302,6 +2268,64 @@ class GatewayRunner:
await adapter.handle_message(event)
def _should_send_voice_reply(
self,
event: MessageEvent,
response: str,
agent_messages: list,
) -> bool:
"""Decide whether the runner should send a TTS voice reply.
Returns False when:
- voice_mode is off for this chat
- response is empty or an error
- agent already called text_to_speech tool (dedup)
- voice input and base adapter auto-TTS already handled it (skip_double)
Exception: Discord voice channel — base play_tts is a no-op there,
so the runner must handle VC playback.
"""
if not response or response.startswith("Error:"):
return False
chat_id = event.source.chat_id
voice_mode = self._voice_mode.get(chat_id, "off")
is_voice_input = (event.message_type == MessageType.VOICE)
should = (
(voice_mode == "all")
or (voice_mode == "voice_only" and is_voice_input)
)
if not should:
return False
# Dedup: agent already called TTS tool
has_agent_tts = any(
msg.get("role") == "assistant"
and any(
tc.get("function", {}).get("name") == "text_to_speech"
for tc in (msg.get("tool_calls") or [])
)
for msg in agent_messages
)
if has_agent_tts:
return False
# Dedup: base adapter auto-TTS already handles voice input.
# Exception: Discord voice channel — play_tts override is a no-op,
# so the runner must handle VC playback.
skip_double = is_voice_input
if skip_double:
adapter = self.adapters.get(event.source.platform)
guild_id = self._get_guild_id(event)
if (guild_id and adapter
and hasattr(adapter, "is_in_voice_channel")
and adapter.is_in_voice_channel(guild_id)):
skip_double = False
if skip_double:
return False
return True
async def _send_voice_reply(self, event: MessageEvent, text: str) -> None:
"""Generate TTS audio and send as a voice message before the text reply."""
try:

View File

@@ -2,7 +2,11 @@
import json
import os
import queue
import threading
import time
import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
from gateway.platforms.base import MessageEvent, MessageType, SessionSource
@@ -126,7 +130,7 @@ class TestHandleVoiceCommand:
# =====================================================================
class TestAutoVoiceReply:
"""Test the should_voice_reply decision logic (extracted from _handle_message).
"""Test the real _should_send_voice_reply method on GatewayRunner.
The gateway has two TTS paths:
1. base adapter auto-TTS: fires for voice input in _process_message_background
@@ -138,43 +142,30 @@ class TestAutoVoiceReply:
override skip, so the runner must handle it via play_in_voice_channel.
"""
def _should_reply(self, voice_mode, message_type, agent_messages=None,
response="Hello!", in_voice_channel=False):
"""Replicate the auto voice reply decision from _handle_message."""
if not response or response.startswith("Error:"):
return False
@pytest.fixture
def runner(self, tmp_path):
return _make_runner(tmp_path)
is_voice_input = (message_type == MessageType.VOICE)
should = (
(voice_mode == "all")
or (voice_mode == "voice_only" and is_voice_input)
def _call(self, runner, voice_mode, message_type, agent_messages=None,
response="Hello!", in_voice_channel=False):
"""Call real _should_send_voice_reply on a GatewayRunner instance."""
chat_id = "123"
if voice_mode != "off":
runner._voice_mode[chat_id] = voice_mode
else:
runner._voice_mode.pop(chat_id, None)
event = _make_event(message_type=message_type)
if in_voice_channel:
mock_adapter = MagicMock()
mock_adapter.is_in_voice_channel = MagicMock(return_value=True)
event.raw_message = SimpleNamespace(guild_id=111, guild=None)
runner.adapters[event.source.platform] = mock_adapter
return runner._should_send_voice_reply(
event, response, agent_messages or []
)
if not should:
return False
# Dedup: agent already called TTS tool
if agent_messages:
has_agent_tts = any(
msg.get("role") == "assistant"
and any(
tc.get("function", {}).get("name") == "text_to_speech"
for tc in (msg.get("tool_calls") or [])
)
for msg in agent_messages
)
if has_agent_tts:
return False
# Dedup: base adapter auto-TTS already handles voice input.
# Exception: in voice channel, Discord play_tts also skips,
# so the runner must handle VC playback.
skip_double = is_voice_input
if skip_double and in_voice_channel:
skip_double = False
if skip_double:
return False
return True
# -- Full platform x input x mode matrix --------------------------------
#
@@ -204,52 +195,52 @@ class TestAutoVoiceReply:
# -- Telegram/Slack/Web: voice input, base handles ---------------------
def test_voice_input_voice_only_skipped(self):
def test_voice_input_voice_only_skipped(self, runner):
"""voice_only + voice input: base auto-TTS handles it, runner skips."""
assert self._should_reply("voice_only", MessageType.VOICE) is False
assert self._call(runner, "voice_only", MessageType.VOICE) is False
def test_voice_input_all_mode_skipped(self):
def test_voice_input_all_mode_skipped(self, runner):
"""all + voice input: base auto-TTS handles it, runner skips."""
assert self._should_reply("all", MessageType.VOICE) is False
assert self._call(runner, "all", MessageType.VOICE) is False
# -- Text input: only runner handles -----------------------------------
def test_text_input_all_mode_runner_fires(self):
def test_text_input_all_mode_runner_fires(self, runner):
"""all + text input: only runner fires (base auto-TTS only for voice)."""
assert self._should_reply("all", MessageType.TEXT) is True
assert self._call(runner, "all", MessageType.TEXT) is True
def test_text_input_voice_only_no_reply(self):
def test_text_input_voice_only_no_reply(self, runner):
"""voice_only + text input: neither fires."""
assert self._should_reply("voice_only", MessageType.TEXT) is False
assert self._call(runner, "voice_only", MessageType.TEXT) is False
# -- Mode off: nothing fires -------------------------------------------
def test_off_mode_voice(self):
assert self._should_reply("off", MessageType.VOICE) is False
def test_off_mode_voice(self, runner):
assert self._call(runner, "off", MessageType.VOICE) is False
def test_off_mode_text(self):
assert self._should_reply("off", MessageType.TEXT) is False
def test_off_mode_text(self, runner):
assert self._call(runner, "off", MessageType.TEXT) is False
# -- Discord VC exception: runner must handle --------------------------
def test_discord_vc_voice_input_runner_fires(self):
def test_discord_vc_voice_input_runner_fires(self, runner):
"""Discord VC + voice input: base play_tts skips (VC override),
so runner must handle via play_in_voice_channel."""
assert self._should_reply("all", MessageType.VOICE, in_voice_channel=True) is True
assert self._call(runner, "all", MessageType.VOICE, in_voice_channel=True) is True
def test_discord_vc_voice_only_runner_fires(self):
def test_discord_vc_voice_only_runner_fires(self, runner):
"""Discord VC + voice_only + voice: runner must handle."""
assert self._should_reply("voice_only", MessageType.VOICE, in_voice_channel=True) is True
assert self._call(runner, "voice_only", MessageType.VOICE, in_voice_channel=True) is True
# -- Edge cases --------------------------------------------------------
def test_error_response_skipped(self):
assert self._should_reply("all", MessageType.TEXT, response="Error: boom") is False
def test_error_response_skipped(self, runner):
assert self._call(runner, "all", MessageType.TEXT, response="Error: boom") is False
def test_empty_response_skipped(self):
assert self._should_reply("all", MessageType.TEXT, response="") is False
def test_empty_response_skipped(self, runner):
assert self._call(runner, "all", MessageType.TEXT, response="") is False
def test_dedup_skips_when_agent_called_tts(self):
def test_dedup_skips_when_agent_called_tts(self, runner):
messages = [{
"role": "assistant",
"tool_calls": [{
@@ -258,9 +249,9 @@ class TestAutoVoiceReply:
"function": {"name": "text_to_speech", "arguments": "{}"},
}],
}]
assert self._should_reply("all", MessageType.TEXT, agent_messages=messages) is False
assert self._call(runner, "all", MessageType.TEXT, agent_messages=messages) is False
def test_no_dedup_for_other_tools(self):
def test_no_dedup_for_other_tools(self, runner):
messages = [{
"role": "assistant",
"tool_calls": [{
@@ -269,7 +260,7 @@ class TestAutoVoiceReply:
"function": {"name": "web_search", "arguments": "{}"},
}],
}]
assert self._should_reply("all", MessageType.TEXT, agent_messages=messages) is True
assert self._call(runner, "all", MessageType.TEXT, agent_messages=messages) is True
# =====================================================================
@@ -443,3 +434,734 @@ class TestVoiceInHelp:
import inspect
source = inspect.getsource(GatewayRunner._handle_message)
assert '"voice"' in source
# =====================================================================
# VoiceReceiver unit tests
# =====================================================================
class TestVoiceReceiver:
"""Test VoiceReceiver silence detection, SSRC mapping, and lifecycle."""
def _make_receiver(self):
from gateway.platforms.discord import VoiceReceiver
mock_vc = MagicMock()
mock_vc._connection.secret_key = [0] * 32
mock_vc._connection.dave_session = None
mock_vc._connection.ssrc = 9999
mock_vc._connection.add_socket_listener = MagicMock()
mock_vc._connection.remove_socket_listener = MagicMock()
mock_vc._connection.hook = None
receiver = VoiceReceiver(mock_vc)
return receiver
def test_initial_state(self):
receiver = self._make_receiver()
assert receiver._running is False
assert receiver._paused is False
assert len(receiver._buffers) == 0
assert len(receiver._ssrc_to_user) == 0
def test_start_sets_running(self):
receiver = self._make_receiver()
receiver.start()
assert receiver._running is True
def test_stop_clears_state(self):
receiver = self._make_receiver()
receiver.start()
receiver.map_ssrc(100, 42)
receiver._buffers[100] = bytearray(b"\x00" * 1000)
receiver._last_packet_time[100] = time.monotonic()
receiver.stop()
assert receiver._running is False
assert len(receiver._buffers) == 0
assert len(receiver._ssrc_to_user) == 0
assert len(receiver._last_packet_time) == 0
def test_map_ssrc(self):
receiver = self._make_receiver()
receiver.map_ssrc(100, 42)
assert receiver._ssrc_to_user[100] == 42
def test_map_ssrc_overwrites(self):
receiver = self._make_receiver()
receiver.map_ssrc(100, 42)
receiver.map_ssrc(100, 99)
assert receiver._ssrc_to_user[100] == 99
def test_pause_resume(self):
receiver = self._make_receiver()
assert receiver._paused is False
receiver.pause()
assert receiver._paused is True
receiver.resume()
assert receiver._paused is False
def test_check_silence_empty(self):
receiver = self._make_receiver()
assert receiver.check_silence() == []
def test_check_silence_returns_completed_utterance(self):
receiver = self._make_receiver()
receiver.map_ssrc(100, 42)
# 48kHz, stereo, 16-bit = 192000 bytes/sec
# MIN_SPEECH_DURATION = 0.5s → need 96000 bytes
pcm_data = bytearray(b"\x00" * 96000)
receiver._buffers[100] = pcm_data
# Set last_packet_time far enough in the past to exceed SILENCE_THRESHOLD
receiver._last_packet_time[100] = time.monotonic() - 3.0
completed = receiver.check_silence()
assert len(completed) == 1
user_id, data = completed[0]
assert user_id == 42
assert len(data) == 96000
# Buffer should be cleared after extraction
assert len(receiver._buffers[100]) == 0
def test_check_silence_ignores_short_buffer(self):
receiver = self._make_receiver()
receiver.map_ssrc(100, 42)
# Too short to meet MIN_SPEECH_DURATION
receiver._buffers[100] = bytearray(b"\x00" * 100)
receiver._last_packet_time[100] = time.monotonic() - 3.0
completed = receiver.check_silence()
assert len(completed) == 0
def test_check_silence_ignores_recent_audio(self):
receiver = self._make_receiver()
receiver.map_ssrc(100, 42)
receiver._buffers[100] = bytearray(b"\x00" * 96000)
receiver._last_packet_time[100] = time.monotonic() # just now
completed = receiver.check_silence()
assert len(completed) == 0
def test_check_silence_unknown_user_discarded(self):
receiver = self._make_receiver()
# No SSRC mapping — user_id will be 0
receiver._buffers[100] = bytearray(b"\x00" * 96000)
receiver._last_packet_time[100] = time.monotonic() - 3.0
completed = receiver.check_silence()
assert len(completed) == 0
def test_stale_buffer_discarded(self):
receiver = self._make_receiver()
# Buffer with no user mapping and very old timestamp
receiver._buffers[200] = bytearray(b"\x00" * 100)
receiver._last_packet_time[200] = time.monotonic() - 10.0
receiver.check_silence()
# Stale buffer (> 2x threshold) should be discarded
assert 200 not in receiver._buffers
def test_on_packet_skips_when_not_running(self):
receiver = self._make_receiver()
# Not started — _running is False
receiver._on_packet(b"\x00" * 100)
assert len(receiver._buffers) == 0
def test_on_packet_skips_when_paused(self):
receiver = self._make_receiver()
receiver.start()
receiver.pause()
receiver._on_packet(b"\x00" * 100)
# Paused — should not process
assert len(receiver._buffers) == 0
def test_on_packet_skips_short_data(self):
receiver = self._make_receiver()
receiver.start()
receiver._on_packet(b"\x00" * 10)
assert len(receiver._buffers) == 0
def test_on_packet_skips_non_rtp(self):
receiver = self._make_receiver()
receiver.start()
# Valid length but wrong RTP version
data = bytearray(b"\x00" * 20)
data[0] = 0x00 # version 0, not 2
receiver._on_packet(bytes(data))
assert len(receiver._buffers) == 0
# =====================================================================
# Gateway voice channel commands (join / leave / input)
# =====================================================================
class TestVoiceChannelCommands:
"""Test _handle_voice_channel_join, _handle_voice_channel_leave,
_handle_voice_channel_input on the GatewayRunner."""
@pytest.fixture
def runner(self, tmp_path):
return _make_runner(tmp_path)
def _make_discord_event(self, text="/voice channel", chat_id="123",
guild_id=111, user_id="user1"):
"""Create event with raw_message carrying guild info."""
source = SessionSource(
chat_id=chat_id,
user_id=user_id,
platform=MagicMock(),
)
source.platform.value = "discord"
source.thread_id = None
event = MessageEvent(text=text, message_type=MessageType.TEXT, source=source)
event.message_id = "msg42"
event.raw_message = SimpleNamespace(guild_id=guild_id, guild=None)
return event
# -- _handle_voice_channel_join --
@pytest.mark.asyncio
async def test_join_unsupported_platform(self, runner):
"""Platform without join_voice_channel returns unsupported message."""
mock_adapter = AsyncMock(spec=[]) # no join_voice_channel
event = self._make_discord_event()
runner.adapters[event.source.platform] = mock_adapter
result = await runner._handle_voice_channel_join(event)
assert "not supported" in result.lower()
@pytest.mark.asyncio
async def test_join_no_guild_id(self, runner):
"""DM context (no guild_id) returns error."""
mock_adapter = AsyncMock()
mock_adapter.join_voice_channel = AsyncMock()
event = self._make_discord_event()
event.raw_message = None # no guild info
runner.adapters[event.source.platform] = mock_adapter
result = await runner._handle_voice_channel_join(event)
assert "discord server" in result.lower()
@pytest.mark.asyncio
async def test_join_user_not_in_vc(self, runner):
"""User not in any voice channel."""
mock_adapter = AsyncMock()
mock_adapter.join_voice_channel = AsyncMock()
mock_adapter.get_user_voice_channel = AsyncMock(return_value=None)
event = self._make_discord_event()
runner.adapters[event.source.platform] = mock_adapter
result = await runner._handle_voice_channel_join(event)
assert "need to be in a voice channel" in result.lower()
@pytest.mark.asyncio
async def test_join_success(self, runner):
"""Successful join sets voice_mode and returns confirmation."""
mock_channel = MagicMock()
mock_channel.name = "General"
mock_adapter = AsyncMock()
mock_adapter.join_voice_channel = AsyncMock(return_value=True)
mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel)
mock_adapter._voice_text_channels = {}
mock_adapter._voice_input_callback = None
event = self._make_discord_event()
runner.adapters[event.source.platform] = mock_adapter
result = await runner._handle_voice_channel_join(event)
assert "joined" in result.lower()
assert "General" in result
assert runner._voice_mode["123"] == "all"
@pytest.mark.asyncio
async def test_join_failure(self, runner):
"""Failed join returns permissions error."""
mock_channel = MagicMock()
mock_channel.name = "General"
mock_adapter = AsyncMock()
mock_adapter.join_voice_channel = AsyncMock(return_value=False)
mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel)
event = self._make_discord_event()
runner.adapters[event.source.platform] = mock_adapter
result = await runner._handle_voice_channel_join(event)
assert "failed" in result.lower()
@pytest.mark.asyncio
async def test_join_exception(self, runner):
"""Exception during join is caught and reported."""
mock_channel = MagicMock()
mock_channel.name = "General"
mock_adapter = AsyncMock()
mock_adapter.join_voice_channel = AsyncMock(side_effect=RuntimeError("No permission"))
mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel)
event = self._make_discord_event()
runner.adapters[event.source.platform] = mock_adapter
result = await runner._handle_voice_channel_join(event)
assert "failed" in result.lower()
# -- _handle_voice_channel_leave --
@pytest.mark.asyncio
async def test_leave_not_in_vc(self, runner):
"""Leave when not in VC returns appropriate message."""
mock_adapter = AsyncMock()
mock_adapter.is_in_voice_channel = MagicMock(return_value=False)
event = self._make_discord_event("/voice leave")
runner.adapters[event.source.platform] = mock_adapter
result = await runner._handle_voice_channel_leave(event)
assert "not in" in result.lower()
@pytest.mark.asyncio
async def test_leave_no_guild(self, runner):
"""Leave from DM returns not in voice channel."""
mock_adapter = AsyncMock()
event = self._make_discord_event("/voice leave")
event.raw_message = None
runner.adapters[event.source.platform] = mock_adapter
result = await runner._handle_voice_channel_leave(event)
assert "not in" in result.lower()
@pytest.mark.asyncio
async def test_leave_success(self, runner):
"""Successful leave disconnects and clears voice mode."""
mock_adapter = AsyncMock()
mock_adapter.is_in_voice_channel = MagicMock(return_value=True)
mock_adapter.leave_voice_channel = AsyncMock()
event = self._make_discord_event("/voice leave")
runner.adapters[event.source.platform] = mock_adapter
runner._voice_mode["123"] = "all"
result = await runner._handle_voice_channel_leave(event)
assert "left" in result.lower()
assert "123" not in runner._voice_mode
mock_adapter.leave_voice_channel.assert_called_once_with(111)
# -- _handle_voice_channel_input --
@pytest.mark.asyncio
async def test_input_no_adapter(self, runner):
"""No Discord adapter — early return, no crash."""
from gateway.config import Platform
# No adapters set
await runner._handle_voice_channel_input(111, 42, "Hello")
@pytest.mark.asyncio
async def test_input_no_text_channel(self, runner):
"""No text channel mapped for guild — early return."""
from gateway.config import Platform
mock_adapter = AsyncMock()
mock_adapter._voice_text_channels = {}
mock_adapter._client = MagicMock()
runner.adapters[Platform.DISCORD] = mock_adapter
await runner._handle_voice_channel_input(111, 42, "Hello")
@pytest.mark.asyncio
async def test_input_creates_event_and_dispatches(self, runner):
"""Voice input creates synthetic event and calls handle_message."""
from gateway.config import Platform
mock_adapter = AsyncMock()
mock_adapter._voice_text_channels = {111: 123}
mock_channel = AsyncMock()
mock_adapter._client = MagicMock()
mock_adapter._client.get_channel = MagicMock(return_value=mock_channel)
mock_adapter.handle_message = AsyncMock()
runner.adapters[Platform.DISCORD] = mock_adapter
await runner._handle_voice_channel_input(111, 42, "Hello from VC")
mock_adapter.handle_message.assert_called_once()
event = mock_adapter.handle_message.call_args[0][0]
assert event.text == "Hello from VC"
assert event.message_type == MessageType.VOICE
assert event.source.chat_id == "123"
@pytest.mark.asyncio
async def test_input_posts_transcript_in_text_channel(self, runner):
"""Voice input sends transcript message to text channel."""
from gateway.config import Platform
mock_adapter = AsyncMock()
mock_adapter._voice_text_channels = {111: 123}
mock_channel = AsyncMock()
mock_adapter._client = MagicMock()
mock_adapter._client.get_channel = MagicMock(return_value=mock_channel)
mock_adapter.handle_message = AsyncMock()
runner.adapters[Platform.DISCORD] = mock_adapter
await runner._handle_voice_channel_input(111, 42, "Test transcript")
mock_channel.send.assert_called_once()
msg = mock_channel.send.call_args[0][0]
assert "Test transcript" in msg
assert "42" in msg # user_id in mention
# -- _get_guild_id --
def test_get_guild_id_from_guild(self, runner):
event = _make_event()
mock_guild = MagicMock()
mock_guild.id = 555
event.raw_message = SimpleNamespace(guild_id=None, guild=mock_guild)
result = runner._get_guild_id(event)
assert result == 555
def test_get_guild_id_from_interaction(self, runner):
event = _make_event()
event.raw_message = SimpleNamespace(guild_id=777, guild=None)
result = runner._get_guild_id(event)
assert result == 777
def test_get_guild_id_none(self, runner):
event = _make_event()
event.raw_message = None
result = runner._get_guild_id(event)
assert result is None
def test_get_guild_id_dm(self, runner):
event = _make_event()
event.raw_message = SimpleNamespace(guild_id=None, guild=None)
result = runner._get_guild_id(event)
assert result is None
# =====================================================================
# Discord adapter voice channel methods
# =====================================================================
class TestDiscordVoiceChannelMethods:
"""Test DiscordAdapter voice channel methods (join, leave, play, etc.)."""
def _make_adapter(self):
from gateway.platforms.discord import DiscordAdapter
from gateway.config import Platform, PlatformConfig
config = PlatformConfig(enabled=True, extra={})
config.token = "fake-token"
adapter = object.__new__(DiscordAdapter)
adapter.platform = Platform.DISCORD
adapter.config = config
adapter._client = MagicMock()
adapter._voice_clients = {}
adapter._voice_text_channels = {}
adapter._voice_timeout_tasks = {}
adapter._voice_receivers = {}
adapter._voice_listen_tasks = {}
adapter._voice_input_callback = None
adapter._allowed_user_ids = set()
adapter._running = True
return adapter
def test_is_in_voice_channel_true(self):
adapter = self._make_adapter()
mock_vc = MagicMock()
mock_vc.is_connected.return_value = True
adapter._voice_clients[111] = mock_vc
assert adapter.is_in_voice_channel(111) is True
def test_is_in_voice_channel_false_no_client(self):
adapter = self._make_adapter()
assert adapter.is_in_voice_channel(111) is False
def test_is_in_voice_channel_false_disconnected(self):
adapter = self._make_adapter()
mock_vc = MagicMock()
mock_vc.is_connected.return_value = False
adapter._voice_clients[111] = mock_vc
assert adapter.is_in_voice_channel(111) is False
@pytest.mark.asyncio
async def test_leave_voice_channel_cleans_up(self):
adapter = self._make_adapter()
mock_vc = MagicMock()
mock_vc.is_connected.return_value = True
mock_vc.disconnect = AsyncMock()
adapter._voice_clients[111] = mock_vc
adapter._voice_text_channels[111] = 123
mock_receiver = MagicMock()
adapter._voice_receivers[111] = mock_receiver
mock_task = MagicMock()
adapter._voice_listen_tasks[111] = mock_task
mock_timeout = MagicMock()
adapter._voice_timeout_tasks[111] = mock_timeout
await adapter.leave_voice_channel(111)
mock_receiver.stop.assert_called_once()
mock_task.cancel.assert_called_once()
mock_vc.disconnect.assert_called_once()
mock_timeout.cancel.assert_called_once()
assert 111 not in adapter._voice_clients
assert 111 not in adapter._voice_text_channels
assert 111 not in adapter._voice_receivers
@pytest.mark.asyncio
async def test_leave_voice_channel_no_connection(self):
"""Leave when not connected — no crash."""
adapter = self._make_adapter()
await adapter.leave_voice_channel(111) # should not raise
@pytest.mark.asyncio
async def test_get_user_voice_channel_no_client(self):
adapter = self._make_adapter()
adapter._client = None
result = await adapter.get_user_voice_channel(111, "42")
assert result is None
@pytest.mark.asyncio
async def test_get_user_voice_channel_no_guild(self):
adapter = self._make_adapter()
adapter._client.get_guild = MagicMock(return_value=None)
result = await adapter.get_user_voice_channel(111, "42")
assert result is None
@pytest.mark.asyncio
async def test_get_user_voice_channel_user_not_in_vc(self):
adapter = self._make_adapter()
mock_guild = MagicMock()
mock_member = MagicMock()
mock_member.voice = None
mock_guild.get_member = MagicMock(return_value=mock_member)
adapter._client.get_guild = MagicMock(return_value=mock_guild)
result = await adapter.get_user_voice_channel(111, "42")
assert result is None
@pytest.mark.asyncio
async def test_get_user_voice_channel_success(self):
adapter = self._make_adapter()
mock_vc = MagicMock()
mock_guild = MagicMock()
mock_member = MagicMock()
mock_member.voice = MagicMock()
mock_member.voice.channel = mock_vc
mock_guild.get_member = MagicMock(return_value=mock_member)
adapter._client.get_guild = MagicMock(return_value=mock_guild)
result = await adapter.get_user_voice_channel(111, "42")
assert result is mock_vc
@pytest.mark.asyncio
async def test_play_in_voice_channel_not_connected(self):
adapter = self._make_adapter()
result = await adapter.play_in_voice_channel(111, "/tmp/test.ogg")
assert result is False
def test_is_allowed_user_empty_list(self):
adapter = self._make_adapter()
assert adapter._is_allowed_user("42") is True
def test_is_allowed_user_in_list(self):
adapter = self._make_adapter()
adapter._allowed_user_ids = {"42", "99"}
assert adapter._is_allowed_user("42") is True
def test_is_allowed_user_not_in_list(self):
adapter = self._make_adapter()
adapter._allowed_user_ids = {"99"}
assert adapter._is_allowed_user("42") is False
@pytest.mark.asyncio
async def test_process_voice_input_success(self):
"""Successful voice input: PCM->WAV->STT->callback."""
adapter = self._make_adapter()
callback = AsyncMock()
adapter._voice_input_callback = callback
adapter._allowed_user_ids = set()
pcm_data = b"\x00" * 96000
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \
patch("tools.transcription_tools.transcribe_audio",
return_value={"success": True, "transcript": "Hello"}), \
patch("tools.voice_mode.is_whisper_hallucination", return_value=False):
await adapter._process_voice_input(111, 42, pcm_data)
callback.assert_called_once_with(guild_id=111, user_id=42, transcript="Hello")
@pytest.mark.asyncio
async def test_process_voice_input_hallucination_filtered(self):
"""Whisper hallucination is filtered out."""
adapter = self._make_adapter()
callback = AsyncMock()
adapter._voice_input_callback = callback
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \
patch("tools.transcription_tools.transcribe_audio",
return_value={"success": True, "transcript": "Thank you."}), \
patch("tools.voice_mode.is_whisper_hallucination", return_value=True):
await adapter._process_voice_input(111, 42, b"\x00" * 96000)
callback.assert_not_called()
@pytest.mark.asyncio
async def test_process_voice_input_stt_failure(self):
"""STT failure — callback not called."""
adapter = self._make_adapter()
callback = AsyncMock()
adapter._voice_input_callback = callback
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \
patch("tools.transcription_tools.transcribe_audio",
return_value={"success": False, "error": "API error"}):
await adapter._process_voice_input(111, 42, b"\x00" * 96000)
callback.assert_not_called()
@pytest.mark.asyncio
async def test_process_voice_input_exception_caught(self):
"""Exception during processing is caught, no crash."""
adapter = self._make_adapter()
adapter._voice_input_callback = AsyncMock()
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav",
side_effect=RuntimeError("ffmpeg not found")):
await adapter._process_voice_input(111, 42, b"\x00" * 96000)
# Should not raise
# =====================================================================
# stream_tts_to_speaker functional tests
# =====================================================================
class TestStreamTtsToSpeaker:
"""Functional tests for the streaming TTS pipeline."""
def test_none_sentinel_flushes_buffer(self):
"""None sentinel causes remaining buffer to be spoken."""
from tools.tts_tool import stream_tts_to_speaker
text_q = queue.Queue()
stop_evt = threading.Event()
done_evt = threading.Event()
spoken = []
def display(text):
spoken.append(text)
text_q.put("Hello world.")
text_q.put(None)
stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=display)
assert done_evt.is_set()
assert any("Hello" in s for s in spoken)
def test_stop_event_aborts_early(self):
"""Setting stop_event causes early exit."""
from tools.tts_tool import stream_tts_to_speaker
text_q = queue.Queue()
stop_evt = threading.Event()
done_evt = threading.Event()
spoken = []
stop_evt.set()
text_q.put("Should not be spoken.")
text_q.put(None)
stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
assert done_evt.is_set()
assert len(spoken) == 0
def test_done_event_set_on_exception(self):
"""tts_done_event is set even when an exception occurs."""
from tools.tts_tool import stream_tts_to_speaker
text_q = queue.Queue()
stop_evt = threading.Event()
done_evt = threading.Event()
# Put a non-string that will cause concatenation to fail
text_q.put(12345)
text_q.put(None)
stream_tts_to_speaker(text_q, stop_evt, done_evt)
assert done_evt.is_set()
def test_think_blocks_stripped(self):
"""<think>...</think> content is not spoken."""
from tools.tts_tool import stream_tts_to_speaker
text_q = queue.Queue()
stop_evt = threading.Event()
done_evt = threading.Event()
spoken = []
text_q.put("<think>internal reasoning</think>")
text_q.put("Visible response. ")
text_q.put(None)
stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
assert done_evt.is_set()
joined = " ".join(spoken)
assert "internal reasoning" not in joined
assert "Visible" in joined
def test_sentence_splitting(self):
"""Sentences are split at boundaries and spoken individually."""
from tools.tts_tool import stream_tts_to_speaker
text_q = queue.Queue()
stop_evt = threading.Event()
done_evt = threading.Event()
spoken = []
# Two sentences long enough to exceed min_sentence_len (20)
text_q.put("This is the first sentence. ")
text_q.put("This is the second sentence. ")
text_q.put(None)
stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
assert done_evt.is_set()
assert len(spoken) >= 2
def test_markdown_stripped_in_speech(self):
"""Markdown formatting is removed before display/speech."""
from tools.tts_tool import stream_tts_to_speaker
text_q = queue.Queue()
stop_evt = threading.Event()
done_evt = threading.Event()
spoken = []
text_q.put("**Bold text** and `code`. ")
text_q.put(None)
stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
assert done_evt.is_set()
# Display callback gets raw text (before markdown stripping)
# But the actual TTS audio would be stripped — we verify pipeline doesn't crash
def test_duplicate_sentences_deduped(self):
"""Repeated sentences are spoken only once."""
from tools.tts_tool import stream_tts_to_speaker
text_q = queue.Queue()
stop_evt = threading.Event()
done_evt = threading.Event()
spoken = []
# Same sentence twice, each long enough
text_q.put("This is a repeated sentence. ")
text_q.put("This is a repeated sentence. ")
text_q.put(None)
stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
assert done_evt.is_set()
# First occurrence is spoken, second is deduped
assert len(spoken) == 1
def test_no_api_key_display_only(self):
"""Without ELEVENLABS_API_KEY, display callback still works."""
from tools.tts_tool import stream_tts_to_speaker
text_q = queue.Queue()
stop_evt = threading.Event()
done_evt = threading.Event()
spoken = []
text_q.put("Display only text. ")
text_q.put(None)
with patch.dict(os.environ, {"ELEVENLABS_API_KEY": ""}):
stream_tts_to_speaker(text_q, stop_evt, done_evt,
display_callback=lambda t: spoken.append(t))
assert done_evt.is_set()
assert len(spoken) >= 1
def test_long_buffer_flushed_on_timeout(self):
"""Buffer longer than long_flush_len is flushed on queue timeout."""
from tools.tts_tool import stream_tts_to_speaker
text_q = queue.Queue()
stop_evt = threading.Event()
done_evt = threading.Event()
spoken = []
# Put a long text without sentence boundary, then None after a delay
long_text = "a" * 150 # > long_flush_len (100)
text_q.put(long_text)
def delayed_sentinel():
time.sleep(1.0)
text_q.put(None)
t = threading.Thread(target=delayed_sentinel, daemon=True)
t.start()
stream_tts_to_speaker(text_q, stop_evt, done_evt,
display_callback=lambda t: spoken.append(t))
t.join(timeout=5)
assert done_evt.is_set()
assert len(spoken) >= 1