diff --git a/tests/gateway/test_voice_command.py b/tests/gateway/test_voice_command.py index 1969694d4..d0d96e6fd 100644 --- a/tests/gateway/test_voice_command.py +++ b/tests/gateway/test_voice_command.py @@ -2037,3 +2037,530 @@ class TestDisconnectVoiceCleanup: assert len(adapter._voice_receivers) == 0 assert len(adapter._voice_listen_tasks) == 0 assert len(adapter._voice_timeout_tasks) == 0 + + +# ===================================================================== +# Discord Voice Channel Flow Tests +# ===================================================================== + + +class TestVoiceReception: + """Audio reception: SSRC mapping, DAVE passthrough, buffer lifecycle.""" + + @staticmethod + def _make_receiver(allowed_ids=None, members=None, dave=False, bot_id=9999): + from gateway.platforms.discord import VoiceReceiver + vc = MagicMock() + vc._connection.secret_key = [0] * 32 + vc._connection.dave_session = MagicMock() if dave else None + vc._connection.ssrc = bot_id + vc._connection.add_socket_listener = MagicMock() + vc._connection.remove_socket_listener = MagicMock() + vc._connection.hook = None + vc.user = SimpleNamespace(id=bot_id) + vc.channel = MagicMock() + vc.channel.members = members or [] + receiver = VoiceReceiver(vc, allowed_user_ids=allowed_ids) + return receiver + + @staticmethod + def _fill_buffer(receiver, ssrc, duration_s=1.0, age_s=3.0): + """Add PCM data to buffer. 48kHz stereo 16-bit = 192000 bytes/sec.""" + size = int(192000 * duration_s) + receiver._buffers[ssrc] = bytearray(b"\x00" * size) + receiver._last_packet_time[ssrc] = time.monotonic() - age_s + + # -- Known SSRC (normal flow) -- + + def test_known_ssrc_returns_completed(self): + receiver = self._make_receiver() + receiver.start() + receiver.map_ssrc(100, 42) + self._fill_buffer(receiver, 100) + completed = receiver.check_silence() + assert len(completed) == 1 + assert completed[0][0] == 42 + assert len(receiver._buffers[100]) == 0 # cleared + + def test_known_ssrc_short_buffer_ignored(self): + receiver = self._make_receiver() + receiver.start() + receiver.map_ssrc(100, 42) + self._fill_buffer(receiver, 100, duration_s=0.1) # too short + completed = receiver.check_silence() + assert len(completed) == 0 + + def test_known_ssrc_recent_audio_waits(self): + receiver = self._make_receiver() + receiver.start() + receiver.map_ssrc(100, 42) + self._fill_buffer(receiver, 100, age_s=0.0) # just arrived + completed = receiver.check_silence() + assert len(completed) == 0 + + # -- Unknown SSRC + DAVE passthrough -- + + def test_unknown_ssrc_no_automap_no_completed(self): + """Unknown SSRC, no members to infer — buffer cleared, not returned.""" + receiver = self._make_receiver(dave=True, members=[]) + receiver.start() + self._fill_buffer(receiver, 100) + completed = receiver.check_silence() + assert len(completed) == 0 + assert len(receiver._buffers[100]) == 0 + + def test_unknown_ssrc_late_speaking_event(self): + """Audio buffered before SPEAKING → SPEAKING maps → next check returns it.""" + receiver = self._make_receiver(dave=True) + receiver.start() + self._fill_buffer(receiver, 100, age_s=0.0) # still receiving + # No user yet + assert receiver.check_silence() == [] + # SPEAKING event arrives + receiver.map_ssrc(100, 42) + # Silence kicks in + receiver._last_packet_time[100] = time.monotonic() - 3.0 + completed = receiver.check_silence() + assert len(completed) == 1 + assert completed[0][0] == 42 + + # -- SSRC auto-mapping -- + + def test_automap_single_allowed_user(self): + members = [ + SimpleNamespace(id=9999, name="Bot"), + SimpleNamespace(id=42, name="Alice"), + ] + receiver = self._make_receiver(allowed_ids={"42"}, members=members) + receiver.start() + self._fill_buffer(receiver, 100) + completed = receiver.check_silence() + assert len(completed) == 1 + assert completed[0][0] == 42 + assert receiver._ssrc_to_user[100] == 42 + + def test_automap_multiple_allowed_users_no_map(self): + members = [ + SimpleNamespace(id=9999, name="Bot"), + SimpleNamespace(id=42, name="Alice"), + SimpleNamespace(id=43, name="Bob"), + ] + receiver = self._make_receiver(allowed_ids={"42", "43"}, members=members) + receiver.start() + self._fill_buffer(receiver, 100) + completed = receiver.check_silence() + assert len(completed) == 0 + + def test_automap_no_allowlist_single_member(self): + """No allowed_user_ids → sole non-bot member inferred.""" + members = [ + SimpleNamespace(id=9999, name="Bot"), + SimpleNamespace(id=42, name="Alice"), + ] + receiver = self._make_receiver(allowed_ids=None, members=members) + receiver.start() + self._fill_buffer(receiver, 100) + completed = receiver.check_silence() + assert len(completed) == 1 + assert completed[0][0] == 42 + + def test_automap_unallowed_user_rejected(self): + """User in channel but not in allowed list — not mapped.""" + members = [ + SimpleNamespace(id=9999, name="Bot"), + SimpleNamespace(id=42, name="Alice"), + ] + receiver = self._make_receiver(allowed_ids={"99"}, members=members) + receiver.start() + self._fill_buffer(receiver, 100) + completed = receiver.check_silence() + assert len(completed) == 0 + + def test_automap_only_bot_in_channel(self): + """Only bot in channel — no one to map to.""" + members = [SimpleNamespace(id=9999, name="Bot")] + receiver = self._make_receiver(allowed_ids=None, members=members) + receiver.start() + self._fill_buffer(receiver, 100) + completed = receiver.check_silence() + assert len(completed) == 0 + + def test_automap_persists_across_calls(self): + """Auto-mapped SSRC stays mapped for subsequent checks.""" + members = [ + SimpleNamespace(id=9999, name="Bot"), + SimpleNamespace(id=42, name="Alice"), + ] + receiver = self._make_receiver(allowed_ids={"42"}, members=members) + receiver.start() + self._fill_buffer(receiver, 100) + receiver.check_silence() + assert receiver._ssrc_to_user[100] == 42 + # Second utterance — should use cached mapping + self._fill_buffer(receiver, 100) + completed = receiver.check_silence() + assert len(completed) == 1 + assert completed[0][0] == 42 + + # -- Stale buffer cleanup -- + + def test_stale_unknown_buffer_discarded(self): + """Buffer with no user and very old timestamp is discarded.""" + receiver = self._make_receiver() + receiver.start() + receiver._buffers[200] = bytearray(b"\x00" * 100) + receiver._last_packet_time[200] = time.monotonic() - 10.0 + receiver.check_silence() + assert 200 not in receiver._buffers + + # -- Pause / resume (echo prevention) -- + + def test_paused_receiver_ignores_packets(self): + receiver = self._make_receiver() + receiver.start() + receiver.pause() + receiver._on_packet(b"\x00" * 100) + assert len(receiver._buffers) == 0 + + def test_resumed_receiver_accepts_packets(self): + receiver = self._make_receiver() + receiver.start() + receiver.pause() + receiver.resume() + assert receiver._paused is False + + # -- _on_packet DAVE passthrough behavior -- + + def _make_receiver_with_nacl(self, dave_session=None, mapped_ssrcs=None): + """Create a receiver that can process _on_packet with mocked NaCl + Opus.""" + from gateway.platforms.discord import VoiceReceiver + vc = MagicMock() + vc._connection.secret_key = [0] * 32 + vc._connection.dave_session = dave_session + vc._connection.ssrc = 9999 + vc._connection.add_socket_listener = MagicMock() + vc._connection.remove_socket_listener = MagicMock() + vc._connection.hook = None + vc.user = SimpleNamespace(id=9999) + vc.channel = MagicMock() + vc.channel.members = [] + receiver = VoiceReceiver(vc) + receiver.start() + # Pre-map SSRCs if provided + if mapped_ssrcs: + for ssrc, uid in mapped_ssrcs.items(): + receiver.map_ssrc(ssrc, uid) + return receiver + + @staticmethod + def _build_rtp_packet(ssrc=100, seq=1, timestamp=960): + """Build a minimal valid RTP packet for _on_packet. + + We need: RTP header (12 bytes) + encrypted payload + 4-byte nonce. + NaCl decrypt is mocked so payload content doesn't matter. + """ + import struct + # RTP header: version=2, payload_type=0x78, no extension, no CSRC + header = struct.pack(">BBHII", 0x80, 0x78, seq, timestamp, ssrc) + # Fake encrypted payload (NaCl will be mocked) + 4 byte nonce + payload = b"\x00" * 20 + b"\x00\x00\x00\x01" + return header + payload + + def _inject_mock_decoder(self, receiver, ssrc): + """Pre-inject a mock Opus decoder for the given SSRC.""" + mock_decoder = MagicMock() + mock_decoder.decode.return_value = b"\x00" * 3840 + receiver._decoders[ssrc] = mock_decoder + return mock_decoder + + def test_on_packet_dave_known_user_decrypt_ok(self): + """Known SSRC + DAVE decrypt success → audio buffered.""" + dave = MagicMock() + dave.decrypt.return_value = b"\xf8\xff\xfe" + receiver = self._make_receiver_with_nacl( + dave_session=dave, mapped_ssrcs={100: 42} + ) + self._inject_mock_decoder(receiver, 100) + + with patch("nacl.secret.Aead") as mock_aead: + mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe" + receiver._on_packet(self._build_rtp_packet(ssrc=100)) + + assert 100 in receiver._buffers + assert len(receiver._buffers[100]) > 0 + dave.decrypt.assert_called_once() + + def test_on_packet_dave_unknown_ssrc_passthrough(self): + """Unknown SSRC + DAVE → skip DAVE, attempt Opus decode (passthrough).""" + dave = MagicMock() + receiver = self._make_receiver_with_nacl(dave_session=dave) + self._inject_mock_decoder(receiver, 100) + + with patch("nacl.secret.Aead") as mock_aead: + mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe" + receiver._on_packet(self._build_rtp_packet(ssrc=100)) + + dave.decrypt.assert_not_called() + assert 100 in receiver._buffers + assert len(receiver._buffers[100]) > 0 + + def test_on_packet_dave_unencrypted_error_passthrough(self): + """DAVE decrypt 'Unencrypted' error → use data as-is, don't drop.""" + dave = MagicMock() + dave.decrypt.side_effect = Exception( + "Failed to decrypt: DecryptionFailed(UnencryptedWhenPassthroughDisabled)" + ) + receiver = self._make_receiver_with_nacl( + dave_session=dave, mapped_ssrcs={100: 42} + ) + self._inject_mock_decoder(receiver, 100) + + with patch("nacl.secret.Aead") as mock_aead: + mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe" + receiver._on_packet(self._build_rtp_packet(ssrc=100)) + + assert 100 in receiver._buffers + assert len(receiver._buffers[100]) > 0 + + def test_on_packet_dave_other_error_drops(self): + """DAVE decrypt non-Unencrypted error → packet dropped.""" + dave = MagicMock() + dave.decrypt.side_effect = Exception("KeyRotationFailed") + receiver = self._make_receiver_with_nacl( + dave_session=dave, mapped_ssrcs={100: 42} + ) + + with patch("nacl.secret.Aead") as mock_aead: + mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe" + receiver._on_packet(self._build_rtp_packet(ssrc=100)) + + assert len(receiver._buffers.get(100, b"")) == 0 + + def test_on_packet_no_dave_direct_decode(self): + """No DAVE session → decode directly.""" + receiver = self._make_receiver_with_nacl(dave_session=None) + self._inject_mock_decoder(receiver, 100) + + with patch("nacl.secret.Aead") as mock_aead: + mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe" + receiver._on_packet(self._build_rtp_packet(ssrc=100)) + + assert 100 in receiver._buffers + assert len(receiver._buffers[100]) > 0 + + def test_on_packet_bot_own_ssrc_ignored(self): + """Bot's own SSRC → dropped (echo prevention).""" + receiver = self._make_receiver_with_nacl() + with patch("nacl.secret.Aead"): + receiver._on_packet(self._build_rtp_packet(ssrc=9999)) + assert len(receiver._buffers) == 0 + + def test_on_packet_multiple_ssrcs_separate_buffers(self): + """Different SSRCs → separate buffers.""" + receiver = self._make_receiver_with_nacl(dave_session=None) + self._inject_mock_decoder(receiver, 100) + self._inject_mock_decoder(receiver, 200) + + with patch("nacl.secret.Aead") as mock_aead: + mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe" + receiver._on_packet(self._build_rtp_packet(ssrc=100)) + receiver._on_packet(self._build_rtp_packet(ssrc=200)) + + assert 100 in receiver._buffers + assert 200 in receiver._buffers + + +class TestVoiceTTSPlayback: + """TTS playback: play_tts in VC, dedup, fallback.""" + + @staticmethod + def _make_discord_adapter(): + from gateway.platforms.discord import DiscordAdapter + from gateway.config import PlatformConfig, Platform + config = PlatformConfig(enabled=True, extra={}) + config.token = "fake-token" + adapter = object.__new__(DiscordAdapter) + adapter.platform = Platform.DISCORD + adapter.config = config + adapter._voice_clients = {} + adapter._voice_text_channels = {} + adapter._voice_receivers = {} + return adapter + + # -- play_tts behavior -- + + @pytest.mark.asyncio + async def test_play_tts_plays_in_vc(self): + """play_tts calls play_in_voice_channel when bot is in VC.""" + adapter = self._make_discord_adapter() + mock_vc = MagicMock() + mock_vc.is_connected.return_value = True + adapter._voice_clients[111] = mock_vc + adapter._voice_text_channels[111] = 123 + + played = [] + async def fake_play(gid, path): + played.append((gid, path)) + return True + adapter.play_in_voice_channel = fake_play + + result = await adapter.play_tts(chat_id="123", audio_path="/tmp/tts.ogg") + assert result.success is True + assert played == [(111, "/tmp/tts.ogg")] + + @pytest.mark.asyncio + async def test_play_tts_fallback_when_not_in_vc(self): + """play_tts sends as file attachment when bot is not in VC.""" + adapter = self._make_discord_adapter() + from gateway.platforms.base import SendResult + adapter.send_voice = AsyncMock(return_value=SendResult(success=False, error="no client")) + result = await adapter.play_tts(chat_id="123", audio_path="/tmp/tts.ogg") + assert result.success is False + adapter.send_voice.assert_called_once() + + @pytest.mark.asyncio + async def test_play_tts_wrong_channel_no_match(self): + """play_tts doesn't match if chat_id is for a different channel.""" + adapter = self._make_discord_adapter() + mock_vc = MagicMock() + mock_vc.is_connected.return_value = True + adapter._voice_clients[111] = mock_vc + adapter._voice_text_channels[111] = 123 + + from gateway.platforms.base import SendResult + adapter.send_voice = AsyncMock(return_value=SendResult(success=True)) + # Different chat_id — shouldn't match VC + result = await adapter.play_tts(chat_id="999", audio_path="/tmp/tts.ogg") + adapter.send_voice.assert_called_once() + + # -- Runner dedup -- + + @staticmethod + def _make_runner(): + from gateway.run import GatewayRunner + runner = object.__new__(GatewayRunner) + runner._voice_mode = {} + runner.adapters = {} + return runner + + def _call_should_reply(self, runner, voice_mode, msg_type, response="Hello", agent_msgs=None): + from gateway.platforms.base import MessageType, MessageEvent, SessionSource + from gateway.config import Platform + runner._voice_mode["ch1"] = voice_mode + source = SessionSource( + platform=Platform.DISCORD, chat_id="ch1", + user_id="1", user_name="test", chat_type="channel", + ) + event = MessageEvent(source=source, text="test", message_type=msg_type) + return runner._should_send_voice_reply(event, response, agent_msgs or []) + + def test_voice_input_runner_skips(self): + """Voice input: runner skips — base adapter handles via play_tts.""" + from gateway.platforms.base import MessageType + runner = self._make_runner() + assert self._call_should_reply(runner, "all", MessageType.VOICE) is False + + def test_text_input_voice_all_runner_fires(self): + """Text input + voice_mode=all: runner generates TTS.""" + from gateway.platforms.base import MessageType + runner = self._make_runner() + assert self._call_should_reply(runner, "all", MessageType.TEXT) is True + + def test_text_input_voice_off_no_tts(self): + """Text input + voice_mode=off: no TTS.""" + from gateway.platforms.base import MessageType + runner = self._make_runner() + assert self._call_should_reply(runner, "off", MessageType.TEXT) is False + + def test_text_input_voice_only_no_tts(self): + """Text input + voice_mode=voice_only: no TTS for text.""" + from gateway.platforms.base import MessageType + runner = self._make_runner() + assert self._call_should_reply(runner, "voice_only", MessageType.TEXT) is False + + def test_error_response_no_tts(self): + """Error response: no TTS regardless of voice_mode.""" + from gateway.platforms.base import MessageType + runner = self._make_runner() + assert self._call_should_reply(runner, "all", MessageType.TEXT, response="Error: boom") is False + + def test_empty_response_no_tts(self): + """Empty response: no TTS.""" + from gateway.platforms.base import MessageType + runner = self._make_runner() + assert self._call_should_reply(runner, "all", MessageType.TEXT, response="") is False + + def test_agent_tts_tool_dedup(self): + """Agent already called text_to_speech tool: runner skips.""" + from gateway.platforms.base import MessageType + runner = self._make_runner() + agent_msgs = [{"role": "assistant", "tool_calls": [ + {"id": "1", "type": "function", "function": {"name": "text_to_speech", "arguments": "{}"}} + ]}] + assert self._call_should_reply(runner, "all", MessageType.TEXT, agent_msgs=agent_msgs) is False + + +class TestUDPKeepalive: + """UDP keepalive prevents Discord from dropping the voice session.""" + + def test_keepalive_interval_is_reasonable(self): + from gateway.platforms.discord import DiscordAdapter + interval = DiscordAdapter._KEEPALIVE_INTERVAL + assert 5 <= interval <= 30, f"Keepalive interval {interval}s should be between 5-30s" + + @pytest.mark.asyncio + async def test_keepalive_sends_silence_frame(self): + """Listen loop sends silence frame via send_packet after interval.""" + from gateway.platforms.discord import DiscordAdapter + from gateway.config import PlatformConfig, Platform + + config = PlatformConfig(enabled=True, extra={}) + config.token = "fake" + adapter = object.__new__(DiscordAdapter) + adapter.platform = Platform.DISCORD + adapter.config = config + adapter._voice_clients = {} + adapter._voice_text_channels = {} + adapter._voice_receivers = {} + adapter._voice_listen_tasks = {} + + # Mock VC and receiver + mock_vc = MagicMock() + mock_vc.is_connected.return_value = True + mock_conn = MagicMock() + adapter._voice_clients[111] = mock_vc + mock_vc._connection = mock_conn + + from gateway.platforms.discord import VoiceReceiver + mock_receiver_vc = MagicMock() + mock_receiver_vc._connection.secret_key = [0] * 32 + mock_receiver_vc._connection.dave_session = None + mock_receiver_vc._connection.ssrc = 9999 + mock_receiver_vc._connection.add_socket_listener = MagicMock() + mock_receiver_vc._connection.remove_socket_listener = MagicMock() + mock_receiver_vc._connection.hook = None + receiver = VoiceReceiver(mock_receiver_vc) + receiver.start() + adapter._voice_receivers[111] = receiver + + # Set keepalive interval very short for test + original_interval = DiscordAdapter._KEEPALIVE_INTERVAL + DiscordAdapter._KEEPALIVE_INTERVAL = 0.1 + + try: + # Run listen loop briefly + import asyncio + loop_task = asyncio.create_task(adapter._voice_listen_loop(111)) + await asyncio.sleep(0.3) + receiver._running = False # stop loop + await asyncio.sleep(0.1) + loop_task.cancel() + try: + await loop_task + except asyncio.CancelledError: + pass + + # send_packet should have been called with silence frame + mock_conn.send_packet.assert_called_with(b'\xf8\xff\xfe') + finally: + DiscordAdapter._KEEPALIVE_INTERVAL = original_interval