test: add double TTS prevention tests for voice reply logic
- Update TestAutoVoiceReply to include skip_double logic: voice input is handled by base adapter auto-TTS, gateway runner skips to prevent duplicate audio - Add TestDiscordPlayTtsSkip: verifies Discord adapter skips play_tts when bot is in a voice channel (VC playback handled by runner) - Add TestWebPlayTts: verifies Web adapter sends invisible play_audio instead of voice bubble
This commit is contained in:
@@ -126,7 +126,15 @@ class TestHandleVoiceCommand:
|
||||
# =====================================================================
|
||||
|
||||
class TestAutoVoiceReply:
|
||||
"""Test the should_voice_reply decision logic (extracted from _handle_message)."""
|
||||
"""Test the should_voice_reply decision logic (extracted from _handle_message).
|
||||
|
||||
The gateway has two TTS paths:
|
||||
1. base adapter auto-TTS: fires for voice input in _process_message_background
|
||||
2. gateway _send_voice_reply: fires based on voice_mode setting
|
||||
|
||||
To prevent double audio, _send_voice_reply is skipped when voice input
|
||||
already triggered base adapter auto-TTS (skip_double = is_voice_input).
|
||||
"""
|
||||
|
||||
def _should_reply(self, voice_mode, message_type, agent_messages=None, response="Hello!"):
|
||||
"""Replicate the auto voice reply decision from _handle_message."""
|
||||
@@ -141,7 +149,7 @@ class TestAutoVoiceReply:
|
||||
if not should:
|
||||
return False
|
||||
|
||||
# Dedup check
|
||||
# Dedup: agent already called TTS tool
|
||||
if agent_messages:
|
||||
has_agent_tts = any(
|
||||
msg.get("role") == "assistant"
|
||||
@@ -154,24 +162,36 @@ class TestAutoVoiceReply:
|
||||
if has_agent_tts:
|
||||
return False
|
||||
|
||||
# Dedup: base adapter auto-TTS already handles voice input
|
||||
skip_double = is_voice_input
|
||||
if skip_double:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def test_voice_only_voice_input(self):
|
||||
assert self._should_reply("voice_only", MessageType.VOICE) is True
|
||||
# -- voice_mode + message_type matrix ----------------------------------
|
||||
|
||||
def test_voice_only_voice_input_skipped_double(self):
|
||||
"""voice_only + voice input: base auto-TTS handles it, runner skips."""
|
||||
assert self._should_reply("voice_only", MessageType.VOICE) is False
|
||||
|
||||
def test_voice_only_text_input(self):
|
||||
assert self._should_reply("voice_only", MessageType.TEXT) is False
|
||||
|
||||
def test_all_mode_text_input(self):
|
||||
"""all + text input: only runner fires (base auto-TTS only for voice)."""
|
||||
assert self._should_reply("all", MessageType.TEXT) is True
|
||||
|
||||
def test_all_mode_voice_input(self):
|
||||
assert self._should_reply("all", MessageType.VOICE) is True
|
||||
def test_all_mode_voice_input_skipped_double(self):
|
||||
"""all + voice input: base auto-TTS handles it, runner skips."""
|
||||
assert self._should_reply("all", MessageType.VOICE) is False
|
||||
|
||||
def test_off_mode(self):
|
||||
assert self._should_reply("off", MessageType.VOICE) is False
|
||||
assert self._should_reply("off", MessageType.TEXT) is False
|
||||
|
||||
# -- edge cases --------------------------------------------------------
|
||||
|
||||
def test_error_response_skipped(self):
|
||||
assert self._should_reply("all", MessageType.TEXT, response="Error: boom") is False
|
||||
|
||||
@@ -266,6 +286,95 @@ class TestSendVoiceReply:
|
||||
await runner._send_voice_reply(event, "Hello")
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Discord play_tts skip when in voice channel
|
||||
# =====================================================================
|
||||
|
||||
class TestDiscordPlayTtsSkip:
|
||||
"""Discord adapter skips play_tts when bot is in a voice channel."""
|
||||
|
||||
def _make_discord_adapter(self):
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
config = PlatformConfig(enabled=True, extra={})
|
||||
config.token = "fake-token"
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
adapter.platform = Platform.DISCORD
|
||||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_timeout_tasks = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._voice_listen_tasks = {}
|
||||
adapter._client = None
|
||||
adapter._broadcast = AsyncMock()
|
||||
return adapter
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_tts_skipped_when_in_vc(self):
|
||||
adapter = self._make_discord_adapter()
|
||||
# Simulate bot in voice channel for guild 111, text channel 123
|
||||
mock_vc = MagicMock()
|
||||
mock_vc.is_connected.return_value = True
|
||||
adapter._voice_clients[111] = mock_vc
|
||||
adapter._voice_text_channels[111] = 123
|
||||
|
||||
result = await adapter.play_tts(chat_id="123", audio_path="/tmp/test.ogg")
|
||||
assert result.success is True
|
||||
# send_voice should NOT have been called (no client, would fail)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_tts_not_skipped_when_not_in_vc(self):
|
||||
adapter = self._make_discord_adapter()
|
||||
# No voice connection — play_tts falls through to send_voice
|
||||
result = await adapter.play_tts(chat_id="123", audio_path="/tmp/test.ogg")
|
||||
# send_voice will fail (no client), but play_tts should NOT return early
|
||||
assert result.success is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_tts_not_skipped_for_different_channel(self):
|
||||
adapter = self._make_discord_adapter()
|
||||
mock_vc = MagicMock()
|
||||
mock_vc.is_connected.return_value = True
|
||||
adapter._voice_clients[111] = mock_vc
|
||||
adapter._voice_text_channels[111] = 999 # different channel
|
||||
|
||||
result = await adapter.play_tts(chat_id="123", audio_path="/tmp/test.ogg")
|
||||
# Different channel — should NOT skip, falls through to send_voice (fails)
|
||||
assert result.success is False
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Web play_tts sends play_audio (not voice bubble)
|
||||
# =====================================================================
|
||||
|
||||
class TestWebPlayTts:
|
||||
"""Web adapter play_tts sends invisible play_audio, not a voice bubble."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_tts_sends_play_audio(self, tmp_path):
|
||||
from gateway.platforms.web import WebAdapter
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
config = PlatformConfig(enabled=True, extra={
|
||||
"port": 0, "host": "127.0.0.1", "token": "tok",
|
||||
})
|
||||
adapter = WebAdapter(config)
|
||||
adapter._broadcast = AsyncMock()
|
||||
adapter._media_dir = tmp_path / "media"
|
||||
adapter._media_dir.mkdir()
|
||||
|
||||
audio_file = tmp_path / "test.ogg"
|
||||
audio_file.write_bytes(b"fake audio")
|
||||
|
||||
result = await adapter.play_tts(chat_id="web", audio_path=str(audio_file))
|
||||
assert result.success is True
|
||||
|
||||
payload = adapter._broadcast.call_args[0][0]
|
||||
assert payload["type"] == "play_audio"
|
||||
assert "/media/" in payload["url"]
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Help text + known commands
|
||||
# =====================================================================
|
||||
|
||||
Reference in New Issue
Block a user