fix: extract voice reply logic and add comprehensive tests
- Fix tempfile.mktemp() TOCTOU race in Discord voice input (use NamedTemporaryFile) - Extract voice reply decision from _handle_message into _should_send_voice_reply() - Rewrite TestAutoVoiceReply to call real method instead of testing a copy - Add 59 new tests: VoiceReceiver, VC commands, adapter methods, streaming TTS
This commit is contained in:
@@ -851,7 +851,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
"""Convert PCM -> WAV -> STT -> callback."""
|
||||
from tools.voice_mode import is_whisper_hallucination
|
||||
|
||||
wav_path = tempfile.mktemp(suffix=".wav", prefix="vc_listen_")
|
||||
tmp_f = tempfile.NamedTemporaryFile(suffix=".wav", prefix="vc_listen_", delete=False)
|
||||
wav_path = tmp_f.name
|
||||
tmp_f.close()
|
||||
try:
|
||||
await asyncio.to_thread(VoiceReceiver.pcm_to_wav, pcm_data, wav_path)
|
||||
|
||||
|
||||
@@ -1616,42 +1616,8 @@ class GatewayRunner:
|
||||
)
|
||||
|
||||
# Auto voice reply: send TTS audio before the text response
|
||||
chat_id = source.chat_id
|
||||
voice_mode = self._voice_mode.get(chat_id, "off")
|
||||
is_voice_input = (event.message_type == MessageType.VOICE)
|
||||
should_voice_reply = (
|
||||
(voice_mode == "all")
|
||||
or (voice_mode == "voice_only" and is_voice_input)
|
||||
)
|
||||
logger.info("Voice reply check: chat_id=%s, voice_mode=%s, is_voice=%s, should_reply=%s, has_response=%s",
|
||||
chat_id, voice_mode, is_voice_input, should_voice_reply, bool(response))
|
||||
if should_voice_reply and response and not response.startswith("Error:"):
|
||||
# Skip if agent already called TTS tool (avoid double voice)
|
||||
has_agent_tts = any(
|
||||
msg.get("role") == "assistant"
|
||||
and any(
|
||||
tc.get("function", {}).get("name") == "text_to_speech"
|
||||
for tc in (msg.get("tool_calls") or [])
|
||||
)
|
||||
for msg in agent_messages
|
||||
)
|
||||
# Skip if voice input — base adapter auto-TTS in
|
||||
# _process_message_background already sent audio for voice
|
||||
# messages, so sending another would be double.
|
||||
# Exception: Discord voice channel — the Discord play_tts
|
||||
# override also skips (no-op), so the runner MUST handle it
|
||||
# via play_in_voice_channel.
|
||||
skip_double = is_voice_input
|
||||
if skip_double:
|
||||
adapter = self.adapters.get(source.platform)
|
||||
guild_id = self._get_guild_id(event)
|
||||
if (guild_id and adapter
|
||||
and hasattr(adapter, "is_in_voice_channel")
|
||||
and adapter.is_in_voice_channel(guild_id)):
|
||||
skip_double = False
|
||||
logger.info("Voice reply: has_agent_tts=%s, skip_double=%s, calling _send_voice_reply", has_agent_tts, skip_double)
|
||||
if not has_agent_tts and not skip_double:
|
||||
await self._send_voice_reply(event, response)
|
||||
if self._should_send_voice_reply(event, response, agent_messages):
|
||||
await self._send_voice_reply(event, response)
|
||||
|
||||
return response
|
||||
|
||||
@@ -2302,6 +2268,64 @@ class GatewayRunner:
|
||||
|
||||
await adapter.handle_message(event)
|
||||
|
||||
def _should_send_voice_reply(
|
||||
self,
|
||||
event: MessageEvent,
|
||||
response: str,
|
||||
agent_messages: list,
|
||||
) -> bool:
|
||||
"""Decide whether the runner should send a TTS voice reply.
|
||||
|
||||
Returns False when:
|
||||
- voice_mode is off for this chat
|
||||
- response is empty or an error
|
||||
- agent already called text_to_speech tool (dedup)
|
||||
- voice input and base adapter auto-TTS already handled it (skip_double)
|
||||
Exception: Discord voice channel — base play_tts is a no-op there,
|
||||
so the runner must handle VC playback.
|
||||
"""
|
||||
if not response or response.startswith("Error:"):
|
||||
return False
|
||||
|
||||
chat_id = event.source.chat_id
|
||||
voice_mode = self._voice_mode.get(chat_id, "off")
|
||||
is_voice_input = (event.message_type == MessageType.VOICE)
|
||||
|
||||
should = (
|
||||
(voice_mode == "all")
|
||||
or (voice_mode == "voice_only" and is_voice_input)
|
||||
)
|
||||
if not should:
|
||||
return False
|
||||
|
||||
# Dedup: agent already called TTS tool
|
||||
has_agent_tts = any(
|
||||
msg.get("role") == "assistant"
|
||||
and any(
|
||||
tc.get("function", {}).get("name") == "text_to_speech"
|
||||
for tc in (msg.get("tool_calls") or [])
|
||||
)
|
||||
for msg in agent_messages
|
||||
)
|
||||
if has_agent_tts:
|
||||
return False
|
||||
|
||||
# Dedup: base adapter auto-TTS already handles voice input.
|
||||
# Exception: Discord voice channel — play_tts override is a no-op,
|
||||
# so the runner must handle VC playback.
|
||||
skip_double = is_voice_input
|
||||
if skip_double:
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
guild_id = self._get_guild_id(event)
|
||||
if (guild_id and adapter
|
||||
and hasattr(adapter, "is_in_voice_channel")
|
||||
and adapter.is_in_voice_channel(guild_id)):
|
||||
skip_double = False
|
||||
if skip_double:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _send_voice_reply(self, event: MessageEvent, text: str) -> None:
|
||||
"""Generate TTS audio and send as a voice message before the text reply."""
|
||||
try:
|
||||
|
||||
@@ -2,7 +2,11 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from gateway.platforms.base import MessageEvent, MessageType, SessionSource
|
||||
@@ -126,7 +130,7 @@ class TestHandleVoiceCommand:
|
||||
# =====================================================================
|
||||
|
||||
class TestAutoVoiceReply:
|
||||
"""Test the should_voice_reply decision logic (extracted from _handle_message).
|
||||
"""Test the real _should_send_voice_reply method on GatewayRunner.
|
||||
|
||||
The gateway has two TTS paths:
|
||||
1. base adapter auto-TTS: fires for voice input in _process_message_background
|
||||
@@ -138,43 +142,30 @@ class TestAutoVoiceReply:
|
||||
override skip, so the runner must handle it via play_in_voice_channel.
|
||||
"""
|
||||
|
||||
def _should_reply(self, voice_mode, message_type, agent_messages=None,
|
||||
response="Hello!", in_voice_channel=False):
|
||||
"""Replicate the auto voice reply decision from _handle_message."""
|
||||
if not response or response.startswith("Error:"):
|
||||
return False
|
||||
@pytest.fixture
|
||||
def runner(self, tmp_path):
|
||||
return _make_runner(tmp_path)
|
||||
|
||||
is_voice_input = (message_type == MessageType.VOICE)
|
||||
should = (
|
||||
(voice_mode == "all")
|
||||
or (voice_mode == "voice_only" and is_voice_input)
|
||||
def _call(self, runner, voice_mode, message_type, agent_messages=None,
|
||||
response="Hello!", in_voice_channel=False):
|
||||
"""Call real _should_send_voice_reply on a GatewayRunner instance."""
|
||||
chat_id = "123"
|
||||
if voice_mode != "off":
|
||||
runner._voice_mode[chat_id] = voice_mode
|
||||
else:
|
||||
runner._voice_mode.pop(chat_id, None)
|
||||
|
||||
event = _make_event(message_type=message_type)
|
||||
|
||||
if in_voice_channel:
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.is_in_voice_channel = MagicMock(return_value=True)
|
||||
event.raw_message = SimpleNamespace(guild_id=111, guild=None)
|
||||
runner.adapters[event.source.platform] = mock_adapter
|
||||
|
||||
return runner._should_send_voice_reply(
|
||||
event, response, agent_messages or []
|
||||
)
|
||||
if not should:
|
||||
return False
|
||||
|
||||
# Dedup: agent already called TTS tool
|
||||
if agent_messages:
|
||||
has_agent_tts = any(
|
||||
msg.get("role") == "assistant"
|
||||
and any(
|
||||
tc.get("function", {}).get("name") == "text_to_speech"
|
||||
for tc in (msg.get("tool_calls") or [])
|
||||
)
|
||||
for msg in agent_messages
|
||||
)
|
||||
if has_agent_tts:
|
||||
return False
|
||||
|
||||
# Dedup: base adapter auto-TTS already handles voice input.
|
||||
# Exception: in voice channel, Discord play_tts also skips,
|
||||
# so the runner must handle VC playback.
|
||||
skip_double = is_voice_input
|
||||
if skip_double and in_voice_channel:
|
||||
skip_double = False
|
||||
if skip_double:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# -- Full platform x input x mode matrix --------------------------------
|
||||
#
|
||||
@@ -204,52 +195,52 @@ class TestAutoVoiceReply:
|
||||
|
||||
# -- Telegram/Slack/Web: voice input, base handles ---------------------
|
||||
|
||||
def test_voice_input_voice_only_skipped(self):
|
||||
def test_voice_input_voice_only_skipped(self, runner):
|
||||
"""voice_only + voice input: base auto-TTS handles it, runner skips."""
|
||||
assert self._should_reply("voice_only", MessageType.VOICE) is False
|
||||
assert self._call(runner, "voice_only", MessageType.VOICE) is False
|
||||
|
||||
def test_voice_input_all_mode_skipped(self):
|
||||
def test_voice_input_all_mode_skipped(self, runner):
|
||||
"""all + voice input: base auto-TTS handles it, runner skips."""
|
||||
assert self._should_reply("all", MessageType.VOICE) is False
|
||||
assert self._call(runner, "all", MessageType.VOICE) is False
|
||||
|
||||
# -- Text input: only runner handles -----------------------------------
|
||||
|
||||
def test_text_input_all_mode_runner_fires(self):
|
||||
def test_text_input_all_mode_runner_fires(self, runner):
|
||||
"""all + text input: only runner fires (base auto-TTS only for voice)."""
|
||||
assert self._should_reply("all", MessageType.TEXT) is True
|
||||
assert self._call(runner, "all", MessageType.TEXT) is True
|
||||
|
||||
def test_text_input_voice_only_no_reply(self):
|
||||
def test_text_input_voice_only_no_reply(self, runner):
|
||||
"""voice_only + text input: neither fires."""
|
||||
assert self._should_reply("voice_only", MessageType.TEXT) is False
|
||||
assert self._call(runner, "voice_only", MessageType.TEXT) is False
|
||||
|
||||
# -- Mode off: nothing fires -------------------------------------------
|
||||
|
||||
def test_off_mode_voice(self):
|
||||
assert self._should_reply("off", MessageType.VOICE) is False
|
||||
def test_off_mode_voice(self, runner):
|
||||
assert self._call(runner, "off", MessageType.VOICE) is False
|
||||
|
||||
def test_off_mode_text(self):
|
||||
assert self._should_reply("off", MessageType.TEXT) is False
|
||||
def test_off_mode_text(self, runner):
|
||||
assert self._call(runner, "off", MessageType.TEXT) is False
|
||||
|
||||
# -- Discord VC exception: runner must handle --------------------------
|
||||
|
||||
def test_discord_vc_voice_input_runner_fires(self):
|
||||
def test_discord_vc_voice_input_runner_fires(self, runner):
|
||||
"""Discord VC + voice input: base play_tts skips (VC override),
|
||||
so runner must handle via play_in_voice_channel."""
|
||||
assert self._should_reply("all", MessageType.VOICE, in_voice_channel=True) is True
|
||||
assert self._call(runner, "all", MessageType.VOICE, in_voice_channel=True) is True
|
||||
|
||||
def test_discord_vc_voice_only_runner_fires(self):
|
||||
def test_discord_vc_voice_only_runner_fires(self, runner):
|
||||
"""Discord VC + voice_only + voice: runner must handle."""
|
||||
assert self._should_reply("voice_only", MessageType.VOICE, in_voice_channel=True) is True
|
||||
assert self._call(runner, "voice_only", MessageType.VOICE, in_voice_channel=True) is True
|
||||
|
||||
# -- Edge cases --------------------------------------------------------
|
||||
|
||||
def test_error_response_skipped(self):
|
||||
assert self._should_reply("all", MessageType.TEXT, response="Error: boom") is False
|
||||
def test_error_response_skipped(self, runner):
|
||||
assert self._call(runner, "all", MessageType.TEXT, response="Error: boom") is False
|
||||
|
||||
def test_empty_response_skipped(self):
|
||||
assert self._should_reply("all", MessageType.TEXT, response="") is False
|
||||
def test_empty_response_skipped(self, runner):
|
||||
assert self._call(runner, "all", MessageType.TEXT, response="") is False
|
||||
|
||||
def test_dedup_skips_when_agent_called_tts(self):
|
||||
def test_dedup_skips_when_agent_called_tts(self, runner):
|
||||
messages = [{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
@@ -258,9 +249,9 @@ class TestAutoVoiceReply:
|
||||
"function": {"name": "text_to_speech", "arguments": "{}"},
|
||||
}],
|
||||
}]
|
||||
assert self._should_reply("all", MessageType.TEXT, agent_messages=messages) is False
|
||||
assert self._call(runner, "all", MessageType.TEXT, agent_messages=messages) is False
|
||||
|
||||
def test_no_dedup_for_other_tools(self):
|
||||
def test_no_dedup_for_other_tools(self, runner):
|
||||
messages = [{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
@@ -269,7 +260,7 @@ class TestAutoVoiceReply:
|
||||
"function": {"name": "web_search", "arguments": "{}"},
|
||||
}],
|
||||
}]
|
||||
assert self._should_reply("all", MessageType.TEXT, agent_messages=messages) is True
|
||||
assert self._call(runner, "all", MessageType.TEXT, agent_messages=messages) is True
|
||||
|
||||
|
||||
# =====================================================================
|
||||
@@ -443,3 +434,734 @@ class TestVoiceInHelp:
|
||||
import inspect
|
||||
source = inspect.getsource(GatewayRunner._handle_message)
|
||||
assert '"voice"' in source
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# VoiceReceiver unit tests
|
||||
# =====================================================================
|
||||
|
||||
class TestVoiceReceiver:
|
||||
"""Test VoiceReceiver silence detection, SSRC mapping, and lifecycle."""
|
||||
|
||||
def _make_receiver(self):
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
mock_vc = MagicMock()
|
||||
mock_vc._connection.secret_key = [0] * 32
|
||||
mock_vc._connection.dave_session = None
|
||||
mock_vc._connection.ssrc = 9999
|
||||
mock_vc._connection.add_socket_listener = MagicMock()
|
||||
mock_vc._connection.remove_socket_listener = MagicMock()
|
||||
mock_vc._connection.hook = None
|
||||
receiver = VoiceReceiver(mock_vc)
|
||||
return receiver
|
||||
|
||||
def test_initial_state(self):
|
||||
receiver = self._make_receiver()
|
||||
assert receiver._running is False
|
||||
assert receiver._paused is False
|
||||
assert len(receiver._buffers) == 0
|
||||
assert len(receiver._ssrc_to_user) == 0
|
||||
|
||||
def test_start_sets_running(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.start()
|
||||
assert receiver._running is True
|
||||
|
||||
def test_stop_clears_state(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.start()
|
||||
receiver.map_ssrc(100, 42)
|
||||
receiver._buffers[100] = bytearray(b"\x00" * 1000)
|
||||
receiver._last_packet_time[100] = time.monotonic()
|
||||
receiver.stop()
|
||||
assert receiver._running is False
|
||||
assert len(receiver._buffers) == 0
|
||||
assert len(receiver._ssrc_to_user) == 0
|
||||
assert len(receiver._last_packet_time) == 0
|
||||
|
||||
def test_map_ssrc(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.map_ssrc(100, 42)
|
||||
assert receiver._ssrc_to_user[100] == 42
|
||||
|
||||
def test_map_ssrc_overwrites(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.map_ssrc(100, 42)
|
||||
receiver.map_ssrc(100, 99)
|
||||
assert receiver._ssrc_to_user[100] == 99
|
||||
|
||||
def test_pause_resume(self):
|
||||
receiver = self._make_receiver()
|
||||
assert receiver._paused is False
|
||||
receiver.pause()
|
||||
assert receiver._paused is True
|
||||
receiver.resume()
|
||||
assert receiver._paused is False
|
||||
|
||||
def test_check_silence_empty(self):
|
||||
receiver = self._make_receiver()
|
||||
assert receiver.check_silence() == []
|
||||
|
||||
def test_check_silence_returns_completed_utterance(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.map_ssrc(100, 42)
|
||||
# 48kHz, stereo, 16-bit = 192000 bytes/sec
|
||||
# MIN_SPEECH_DURATION = 0.5s → need 96000 bytes
|
||||
pcm_data = bytearray(b"\x00" * 96000)
|
||||
receiver._buffers[100] = pcm_data
|
||||
# Set last_packet_time far enough in the past to exceed SILENCE_THRESHOLD
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 1
|
||||
user_id, data = completed[0]
|
||||
assert user_id == 42
|
||||
assert len(data) == 96000
|
||||
# Buffer should be cleared after extraction
|
||||
assert len(receiver._buffers[100]) == 0
|
||||
|
||||
def test_check_silence_ignores_short_buffer(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.map_ssrc(100, 42)
|
||||
# Too short to meet MIN_SPEECH_DURATION
|
||||
receiver._buffers[100] = bytearray(b"\x00" * 100)
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 0
|
||||
|
||||
def test_check_silence_ignores_recent_audio(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.map_ssrc(100, 42)
|
||||
receiver._buffers[100] = bytearray(b"\x00" * 96000)
|
||||
receiver._last_packet_time[100] = time.monotonic() # just now
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 0
|
||||
|
||||
def test_check_silence_unknown_user_discarded(self):
|
||||
receiver = self._make_receiver()
|
||||
# No SSRC mapping — user_id will be 0
|
||||
receiver._buffers[100] = bytearray(b"\x00" * 96000)
|
||||
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||
completed = receiver.check_silence()
|
||||
assert len(completed) == 0
|
||||
|
||||
def test_stale_buffer_discarded(self):
|
||||
receiver = self._make_receiver()
|
||||
# Buffer with no user mapping and very old timestamp
|
||||
receiver._buffers[200] = bytearray(b"\x00" * 100)
|
||||
receiver._last_packet_time[200] = time.monotonic() - 10.0
|
||||
receiver.check_silence()
|
||||
# Stale buffer (> 2x threshold) should be discarded
|
||||
assert 200 not in receiver._buffers
|
||||
|
||||
def test_on_packet_skips_when_not_running(self):
|
||||
receiver = self._make_receiver()
|
||||
# Not started — _running is False
|
||||
receiver._on_packet(b"\x00" * 100)
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
def test_on_packet_skips_when_paused(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.start()
|
||||
receiver.pause()
|
||||
receiver._on_packet(b"\x00" * 100)
|
||||
# Paused — should not process
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
def test_on_packet_skips_short_data(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.start()
|
||||
receiver._on_packet(b"\x00" * 10)
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
def test_on_packet_skips_non_rtp(self):
|
||||
receiver = self._make_receiver()
|
||||
receiver.start()
|
||||
# Valid length but wrong RTP version
|
||||
data = bytearray(b"\x00" * 20)
|
||||
data[0] = 0x00 # version 0, not 2
|
||||
receiver._on_packet(bytes(data))
|
||||
assert len(receiver._buffers) == 0
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Gateway voice channel commands (join / leave / input)
|
||||
# =====================================================================
|
||||
|
||||
class TestVoiceChannelCommands:
|
||||
"""Test _handle_voice_channel_join, _handle_voice_channel_leave,
|
||||
_handle_voice_channel_input on the GatewayRunner."""
|
||||
|
||||
@pytest.fixture
|
||||
def runner(self, tmp_path):
|
||||
return _make_runner(tmp_path)
|
||||
|
||||
def _make_discord_event(self, text="/voice channel", chat_id="123",
|
||||
guild_id=111, user_id="user1"):
|
||||
"""Create event with raw_message carrying guild info."""
|
||||
source = SessionSource(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
platform=MagicMock(),
|
||||
)
|
||||
source.platform.value = "discord"
|
||||
source.thread_id = None
|
||||
event = MessageEvent(text=text, message_type=MessageType.TEXT, source=source)
|
||||
event.message_id = "msg42"
|
||||
event.raw_message = SimpleNamespace(guild_id=guild_id, guild=None)
|
||||
return event
|
||||
|
||||
# -- _handle_voice_channel_join --
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_unsupported_platform(self, runner):
|
||||
"""Platform without join_voice_channel returns unsupported message."""
|
||||
mock_adapter = AsyncMock(spec=[]) # no join_voice_channel
|
||||
event = self._make_discord_event()
|
||||
runner.adapters[event.source.platform] = mock_adapter
|
||||
result = await runner._handle_voice_channel_join(event)
|
||||
assert "not supported" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_no_guild_id(self, runner):
|
||||
"""DM context (no guild_id) returns error."""
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter.join_voice_channel = AsyncMock()
|
||||
event = self._make_discord_event()
|
||||
event.raw_message = None # no guild info
|
||||
runner.adapters[event.source.platform] = mock_adapter
|
||||
result = await runner._handle_voice_channel_join(event)
|
||||
assert "discord server" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_user_not_in_vc(self, runner):
|
||||
"""User not in any voice channel."""
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter.join_voice_channel = AsyncMock()
|
||||
mock_adapter.get_user_voice_channel = AsyncMock(return_value=None)
|
||||
event = self._make_discord_event()
|
||||
runner.adapters[event.source.platform] = mock_adapter
|
||||
result = await runner._handle_voice_channel_join(event)
|
||||
assert "need to be in a voice channel" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_success(self, runner):
|
||||
"""Successful join sets voice_mode and returns confirmation."""
|
||||
mock_channel = MagicMock()
|
||||
mock_channel.name = "General"
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter.join_voice_channel = AsyncMock(return_value=True)
|
||||
mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel)
|
||||
mock_adapter._voice_text_channels = {}
|
||||
mock_adapter._voice_input_callback = None
|
||||
event = self._make_discord_event()
|
||||
runner.adapters[event.source.platform] = mock_adapter
|
||||
result = await runner._handle_voice_channel_join(event)
|
||||
assert "joined" in result.lower()
|
||||
assert "General" in result
|
||||
assert runner._voice_mode["123"] == "all"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_failure(self, runner):
|
||||
"""Failed join returns permissions error."""
|
||||
mock_channel = MagicMock()
|
||||
mock_channel.name = "General"
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter.join_voice_channel = AsyncMock(return_value=False)
|
||||
mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel)
|
||||
event = self._make_discord_event()
|
||||
runner.adapters[event.source.platform] = mock_adapter
|
||||
result = await runner._handle_voice_channel_join(event)
|
||||
assert "failed" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_exception(self, runner):
|
||||
"""Exception during join is caught and reported."""
|
||||
mock_channel = MagicMock()
|
||||
mock_channel.name = "General"
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter.join_voice_channel = AsyncMock(side_effect=RuntimeError("No permission"))
|
||||
mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel)
|
||||
event = self._make_discord_event()
|
||||
runner.adapters[event.source.platform] = mock_adapter
|
||||
result = await runner._handle_voice_channel_join(event)
|
||||
assert "failed" in result.lower()
|
||||
|
||||
# -- _handle_voice_channel_leave --
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_not_in_vc(self, runner):
|
||||
"""Leave when not in VC returns appropriate message."""
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter.is_in_voice_channel = MagicMock(return_value=False)
|
||||
event = self._make_discord_event("/voice leave")
|
||||
runner.adapters[event.source.platform] = mock_adapter
|
||||
result = await runner._handle_voice_channel_leave(event)
|
||||
assert "not in" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_no_guild(self, runner):
|
||||
"""Leave from DM returns not in voice channel."""
|
||||
mock_adapter = AsyncMock()
|
||||
event = self._make_discord_event("/voice leave")
|
||||
event.raw_message = None
|
||||
runner.adapters[event.source.platform] = mock_adapter
|
||||
result = await runner._handle_voice_channel_leave(event)
|
||||
assert "not in" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_success(self, runner):
|
||||
"""Successful leave disconnects and clears voice mode."""
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter.is_in_voice_channel = MagicMock(return_value=True)
|
||||
mock_adapter.leave_voice_channel = AsyncMock()
|
||||
event = self._make_discord_event("/voice leave")
|
||||
runner.adapters[event.source.platform] = mock_adapter
|
||||
runner._voice_mode["123"] = "all"
|
||||
result = await runner._handle_voice_channel_leave(event)
|
||||
assert "left" in result.lower()
|
||||
assert "123" not in runner._voice_mode
|
||||
mock_adapter.leave_voice_channel.assert_called_once_with(111)
|
||||
|
||||
# -- _handle_voice_channel_input --
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_no_adapter(self, runner):
|
||||
"""No Discord adapter — early return, no crash."""
|
||||
from gateway.config import Platform
|
||||
# No adapters set
|
||||
await runner._handle_voice_channel_input(111, 42, "Hello")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_no_text_channel(self, runner):
|
||||
"""No text channel mapped for guild — early return."""
|
||||
from gateway.config import Platform
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter._voice_text_channels = {}
|
||||
mock_adapter._client = MagicMock()
|
||||
runner.adapters[Platform.DISCORD] = mock_adapter
|
||||
await runner._handle_voice_channel_input(111, 42, "Hello")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_creates_event_and_dispatches(self, runner):
|
||||
"""Voice input creates synthetic event and calls handle_message."""
|
||||
from gateway.config import Platform
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter._voice_text_channels = {111: 123}
|
||||
mock_channel = AsyncMock()
|
||||
mock_adapter._client = MagicMock()
|
||||
mock_adapter._client.get_channel = MagicMock(return_value=mock_channel)
|
||||
mock_adapter.handle_message = AsyncMock()
|
||||
runner.adapters[Platform.DISCORD] = mock_adapter
|
||||
await runner._handle_voice_channel_input(111, 42, "Hello from VC")
|
||||
mock_adapter.handle_message.assert_called_once()
|
||||
event = mock_adapter.handle_message.call_args[0][0]
|
||||
assert event.text == "Hello from VC"
|
||||
assert event.message_type == MessageType.VOICE
|
||||
assert event.source.chat_id == "123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_posts_transcript_in_text_channel(self, runner):
|
||||
"""Voice input sends transcript message to text channel."""
|
||||
from gateway.config import Platform
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter._voice_text_channels = {111: 123}
|
||||
mock_channel = AsyncMock()
|
||||
mock_adapter._client = MagicMock()
|
||||
mock_adapter._client.get_channel = MagicMock(return_value=mock_channel)
|
||||
mock_adapter.handle_message = AsyncMock()
|
||||
runner.adapters[Platform.DISCORD] = mock_adapter
|
||||
await runner._handle_voice_channel_input(111, 42, "Test transcript")
|
||||
mock_channel.send.assert_called_once()
|
||||
msg = mock_channel.send.call_args[0][0]
|
||||
assert "Test transcript" in msg
|
||||
assert "42" in msg # user_id in mention
|
||||
|
||||
# -- _get_guild_id --
|
||||
|
||||
def test_get_guild_id_from_guild(self, runner):
|
||||
event = _make_event()
|
||||
mock_guild = MagicMock()
|
||||
mock_guild.id = 555
|
||||
event.raw_message = SimpleNamespace(guild_id=None, guild=mock_guild)
|
||||
result = runner._get_guild_id(event)
|
||||
assert result == 555
|
||||
|
||||
def test_get_guild_id_from_interaction(self, runner):
|
||||
event = _make_event()
|
||||
event.raw_message = SimpleNamespace(guild_id=777, guild=None)
|
||||
result = runner._get_guild_id(event)
|
||||
assert result == 777
|
||||
|
||||
def test_get_guild_id_none(self, runner):
|
||||
event = _make_event()
|
||||
event.raw_message = None
|
||||
result = runner._get_guild_id(event)
|
||||
assert result is None
|
||||
|
||||
def test_get_guild_id_dm(self, runner):
|
||||
event = _make_event()
|
||||
event.raw_message = SimpleNamespace(guild_id=None, guild=None)
|
||||
result = runner._get_guild_id(event)
|
||||
assert result is None
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Discord adapter voice channel methods
|
||||
# =====================================================================
|
||||
|
||||
class TestDiscordVoiceChannelMethods:
|
||||
"""Test DiscordAdapter voice channel methods (join, leave, play, etc.)."""
|
||||
|
||||
def _make_adapter(self):
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
config = PlatformConfig(enabled=True, extra={})
|
||||
config.token = "fake-token"
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
adapter.platform = Platform.DISCORD
|
||||
adapter.config = config
|
||||
adapter._client = MagicMock()
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_timeout_tasks = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._voice_listen_tasks = {}
|
||||
adapter._voice_input_callback = None
|
||||
adapter._allowed_user_ids = set()
|
||||
adapter._running = True
|
||||
return adapter
|
||||
|
||||
def test_is_in_voice_channel_true(self):
|
||||
adapter = self._make_adapter()
|
||||
mock_vc = MagicMock()
|
||||
mock_vc.is_connected.return_value = True
|
||||
adapter._voice_clients[111] = mock_vc
|
||||
assert adapter.is_in_voice_channel(111) is True
|
||||
|
||||
def test_is_in_voice_channel_false_no_client(self):
|
||||
adapter = self._make_adapter()
|
||||
assert adapter.is_in_voice_channel(111) is False
|
||||
|
||||
def test_is_in_voice_channel_false_disconnected(self):
|
||||
adapter = self._make_adapter()
|
||||
mock_vc = MagicMock()
|
||||
mock_vc.is_connected.return_value = False
|
||||
adapter._voice_clients[111] = mock_vc
|
||||
assert adapter.is_in_voice_channel(111) is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_voice_channel_cleans_up(self):
|
||||
adapter = self._make_adapter()
|
||||
mock_vc = MagicMock()
|
||||
mock_vc.is_connected.return_value = True
|
||||
mock_vc.disconnect = AsyncMock()
|
||||
adapter._voice_clients[111] = mock_vc
|
||||
adapter._voice_text_channels[111] = 123
|
||||
|
||||
mock_receiver = MagicMock()
|
||||
adapter._voice_receivers[111] = mock_receiver
|
||||
|
||||
mock_task = MagicMock()
|
||||
adapter._voice_listen_tasks[111] = mock_task
|
||||
|
||||
mock_timeout = MagicMock()
|
||||
adapter._voice_timeout_tasks[111] = mock_timeout
|
||||
|
||||
await adapter.leave_voice_channel(111)
|
||||
|
||||
mock_receiver.stop.assert_called_once()
|
||||
mock_task.cancel.assert_called_once()
|
||||
mock_vc.disconnect.assert_called_once()
|
||||
mock_timeout.cancel.assert_called_once()
|
||||
assert 111 not in adapter._voice_clients
|
||||
assert 111 not in adapter._voice_text_channels
|
||||
assert 111 not in adapter._voice_receivers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_voice_channel_no_connection(self):
|
||||
"""Leave when not connected — no crash."""
|
||||
adapter = self._make_adapter()
|
||||
await adapter.leave_voice_channel(111) # should not raise
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_voice_channel_no_client(self):
|
||||
adapter = self._make_adapter()
|
||||
adapter._client = None
|
||||
result = await adapter.get_user_voice_channel(111, "42")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_voice_channel_no_guild(self):
|
||||
adapter = self._make_adapter()
|
||||
adapter._client.get_guild = MagicMock(return_value=None)
|
||||
result = await adapter.get_user_voice_channel(111, "42")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_voice_channel_user_not_in_vc(self):
|
||||
adapter = self._make_adapter()
|
||||
mock_guild = MagicMock()
|
||||
mock_member = MagicMock()
|
||||
mock_member.voice = None
|
||||
mock_guild.get_member = MagicMock(return_value=mock_member)
|
||||
adapter._client.get_guild = MagicMock(return_value=mock_guild)
|
||||
result = await adapter.get_user_voice_channel(111, "42")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_voice_channel_success(self):
|
||||
adapter = self._make_adapter()
|
||||
mock_vc = MagicMock()
|
||||
mock_guild = MagicMock()
|
||||
mock_member = MagicMock()
|
||||
mock_member.voice = MagicMock()
|
||||
mock_member.voice.channel = mock_vc
|
||||
mock_guild.get_member = MagicMock(return_value=mock_member)
|
||||
adapter._client.get_guild = MagicMock(return_value=mock_guild)
|
||||
result = await adapter.get_user_voice_channel(111, "42")
|
||||
assert result is mock_vc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_in_voice_channel_not_connected(self):
|
||||
adapter = self._make_adapter()
|
||||
result = await adapter.play_in_voice_channel(111, "/tmp/test.ogg")
|
||||
assert result is False
|
||||
|
||||
def test_is_allowed_user_empty_list(self):
|
||||
adapter = self._make_adapter()
|
||||
assert adapter._is_allowed_user("42") is True
|
||||
|
||||
def test_is_allowed_user_in_list(self):
|
||||
adapter = self._make_adapter()
|
||||
adapter._allowed_user_ids = {"42", "99"}
|
||||
assert adapter._is_allowed_user("42") is True
|
||||
|
||||
def test_is_allowed_user_not_in_list(self):
|
||||
adapter = self._make_adapter()
|
||||
adapter._allowed_user_ids = {"99"}
|
||||
assert adapter._is_allowed_user("42") is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_voice_input_success(self):
|
||||
"""Successful voice input: PCM->WAV->STT->callback."""
|
||||
adapter = self._make_adapter()
|
||||
callback = AsyncMock()
|
||||
adapter._voice_input_callback = callback
|
||||
adapter._allowed_user_ids = set()
|
||||
|
||||
pcm_data = b"\x00" * 96000
|
||||
|
||||
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \
|
||||
patch("tools.transcription_tools.transcribe_audio",
|
||||
return_value={"success": True, "transcript": "Hello"}), \
|
||||
patch("tools.voice_mode.is_whisper_hallucination", return_value=False):
|
||||
await adapter._process_voice_input(111, 42, pcm_data)
|
||||
|
||||
callback.assert_called_once_with(guild_id=111, user_id=42, transcript="Hello")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_voice_input_hallucination_filtered(self):
|
||||
"""Whisper hallucination is filtered out."""
|
||||
adapter = self._make_adapter()
|
||||
callback = AsyncMock()
|
||||
adapter._voice_input_callback = callback
|
||||
|
||||
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \
|
||||
patch("tools.transcription_tools.transcribe_audio",
|
||||
return_value={"success": True, "transcript": "Thank you."}), \
|
||||
patch("tools.voice_mode.is_whisper_hallucination", return_value=True):
|
||||
await adapter._process_voice_input(111, 42, b"\x00" * 96000)
|
||||
|
||||
callback.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_voice_input_stt_failure(self):
|
||||
"""STT failure — callback not called."""
|
||||
adapter = self._make_adapter()
|
||||
callback = AsyncMock()
|
||||
adapter._voice_input_callback = callback
|
||||
|
||||
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \
|
||||
patch("tools.transcription_tools.transcribe_audio",
|
||||
return_value={"success": False, "error": "API error"}):
|
||||
await adapter._process_voice_input(111, 42, b"\x00" * 96000)
|
||||
|
||||
callback.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_voice_input_exception_caught(self):
|
||||
"""Exception during processing is caught, no crash."""
|
||||
adapter = self._make_adapter()
|
||||
adapter._voice_input_callback = AsyncMock()
|
||||
|
||||
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav",
|
||||
side_effect=RuntimeError("ffmpeg not found")):
|
||||
await adapter._process_voice_input(111, 42, b"\x00" * 96000)
|
||||
# Should not raise
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# stream_tts_to_speaker functional tests
|
||||
# =====================================================================
|
||||
|
||||
class TestStreamTtsToSpeaker:
|
||||
"""Functional tests for the streaming TTS pipeline."""
|
||||
|
||||
def test_none_sentinel_flushes_buffer(self):
|
||||
"""None sentinel causes remaining buffer to be spoken."""
|
||||
from tools.tts_tool import stream_tts_to_speaker
|
||||
text_q = queue.Queue()
|
||||
stop_evt = threading.Event()
|
||||
done_evt = threading.Event()
|
||||
spoken = []
|
||||
|
||||
def display(text):
|
||||
spoken.append(text)
|
||||
|
||||
text_q.put("Hello world.")
|
||||
text_q.put(None)
|
||||
|
||||
stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=display)
|
||||
assert done_evt.is_set()
|
||||
assert any("Hello" in s for s in spoken)
|
||||
|
||||
def test_stop_event_aborts_early(self):
|
||||
"""Setting stop_event causes early exit."""
|
||||
from tools.tts_tool import stream_tts_to_speaker
|
||||
text_q = queue.Queue()
|
||||
stop_evt = threading.Event()
|
||||
done_evt = threading.Event()
|
||||
spoken = []
|
||||
|
||||
stop_evt.set()
|
||||
text_q.put("Should not be spoken.")
|
||||
text_q.put(None)
|
||||
|
||||
stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
|
||||
assert done_evt.is_set()
|
||||
assert len(spoken) == 0
|
||||
|
||||
def test_done_event_set_on_exception(self):
|
||||
"""tts_done_event is set even when an exception occurs."""
|
||||
from tools.tts_tool import stream_tts_to_speaker
|
||||
text_q = queue.Queue()
|
||||
stop_evt = threading.Event()
|
||||
done_evt = threading.Event()
|
||||
|
||||
# Put a non-string that will cause concatenation to fail
|
||||
text_q.put(12345)
|
||||
text_q.put(None)
|
||||
|
||||
stream_tts_to_speaker(text_q, stop_evt, done_evt)
|
||||
assert done_evt.is_set()
|
||||
|
||||
def test_think_blocks_stripped(self):
|
||||
"""<think>...</think> content is not spoken."""
|
||||
from tools.tts_tool import stream_tts_to_speaker
|
||||
text_q = queue.Queue()
|
||||
stop_evt = threading.Event()
|
||||
done_evt = threading.Event()
|
||||
spoken = []
|
||||
|
||||
text_q.put("<think>internal reasoning</think>")
|
||||
text_q.put("Visible response. ")
|
||||
text_q.put(None)
|
||||
|
||||
stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
|
||||
assert done_evt.is_set()
|
||||
joined = " ".join(spoken)
|
||||
assert "internal reasoning" not in joined
|
||||
assert "Visible" in joined
|
||||
|
||||
def test_sentence_splitting(self):
|
||||
"""Sentences are split at boundaries and spoken individually."""
|
||||
from tools.tts_tool import stream_tts_to_speaker
|
||||
text_q = queue.Queue()
|
||||
stop_evt = threading.Event()
|
||||
done_evt = threading.Event()
|
||||
spoken = []
|
||||
|
||||
# Two sentences long enough to exceed min_sentence_len (20)
|
||||
text_q.put("This is the first sentence. ")
|
||||
text_q.put("This is the second sentence. ")
|
||||
text_q.put(None)
|
||||
|
||||
stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
|
||||
assert done_evt.is_set()
|
||||
assert len(spoken) >= 2
|
||||
|
||||
def test_markdown_stripped_in_speech(self):
|
||||
"""Markdown formatting is removed before display/speech."""
|
||||
from tools.tts_tool import stream_tts_to_speaker
|
||||
text_q = queue.Queue()
|
||||
stop_evt = threading.Event()
|
||||
done_evt = threading.Event()
|
||||
spoken = []
|
||||
|
||||
text_q.put("**Bold text** and `code`. ")
|
||||
text_q.put(None)
|
||||
|
||||
stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
|
||||
assert done_evt.is_set()
|
||||
# Display callback gets raw text (before markdown stripping)
|
||||
# But the actual TTS audio would be stripped — we verify pipeline doesn't crash
|
||||
|
||||
def test_duplicate_sentences_deduped(self):
|
||||
"""Repeated sentences are spoken only once."""
|
||||
from tools.tts_tool import stream_tts_to_speaker
|
||||
text_q = queue.Queue()
|
||||
stop_evt = threading.Event()
|
||||
done_evt = threading.Event()
|
||||
spoken = []
|
||||
|
||||
# Same sentence twice, each long enough
|
||||
text_q.put("This is a repeated sentence. ")
|
||||
text_q.put("This is a repeated sentence. ")
|
||||
text_q.put(None)
|
||||
|
||||
stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
|
||||
assert done_evt.is_set()
|
||||
# First occurrence is spoken, second is deduped
|
||||
assert len(spoken) == 1
|
||||
|
||||
def test_no_api_key_display_only(self):
|
||||
"""Without ELEVENLABS_API_KEY, display callback still works."""
|
||||
from tools.tts_tool import stream_tts_to_speaker
|
||||
text_q = queue.Queue()
|
||||
stop_evt = threading.Event()
|
||||
done_evt = threading.Event()
|
||||
spoken = []
|
||||
|
||||
text_q.put("Display only text. ")
|
||||
text_q.put(None)
|
||||
|
||||
with patch.dict(os.environ, {"ELEVENLABS_API_KEY": ""}):
|
||||
stream_tts_to_speaker(text_q, stop_evt, done_evt,
|
||||
display_callback=lambda t: spoken.append(t))
|
||||
assert done_evt.is_set()
|
||||
assert len(spoken) >= 1
|
||||
|
||||
def test_long_buffer_flushed_on_timeout(self):
|
||||
"""Buffer longer than long_flush_len is flushed on queue timeout."""
|
||||
from tools.tts_tool import stream_tts_to_speaker
|
||||
text_q = queue.Queue()
|
||||
stop_evt = threading.Event()
|
||||
done_evt = threading.Event()
|
||||
spoken = []
|
||||
|
||||
# Put a long text without sentence boundary, then None after a delay
|
||||
long_text = "a" * 150 # > long_flush_len (100)
|
||||
text_q.put(long_text)
|
||||
|
||||
def delayed_sentinel():
|
||||
time.sleep(1.0)
|
||||
text_q.put(None)
|
||||
|
||||
t = threading.Thread(target=delayed_sentinel, daemon=True)
|
||||
t.start()
|
||||
|
||||
stream_tts_to_speaker(text_q, stop_evt, done_evt,
|
||||
display_callback=lambda t: spoken.append(t))
|
||||
t.join(timeout=5)
|
||||
assert done_evt.is_set()
|
||||
assert len(spoken) >= 1
|
||||
|
||||
Reference in New Issue
Block a user