diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index a7fd45f6a..c2cc643fd 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -10,7 +10,13 @@ Uses discord.py library for: import asyncio import logging import os -from typing import Dict, List, Optional, Any +import struct +import subprocess +import tempfile +import threading +import time +from collections import defaultdict +from typing import Callable, Dict, List, Optional, Any logger = logging.getLogger(__name__) @@ -65,6 +71,294 @@ def check_discord_requirements() -> bool: return DISCORD_AVAILABLE +class VoiceReceiver: + """Captures and decodes voice audio from a Discord voice channel. + + Attaches to a VoiceClient's socket listener, decrypts RTP packets + (NaCl transport + DAVE E2EE), decodes Opus to PCM, and buffers + per-user audio. A polling loop detects silence and delivers + completed utterances via a callback. + """ + + SILENCE_THRESHOLD = 1.5 # seconds of silence → end of utterance + MIN_SPEECH_DURATION = 0.5 # minimum seconds to process (skip noise) + SAMPLE_RATE = 48000 # Discord native rate + CHANNELS = 2 # Discord sends stereo + + def __init__(self, voice_client): + self._vc = voice_client + self._running = False + + # Decryption + self._secret_key: Optional[bytes] = None + self._dave_session = None + self._bot_ssrc: int = 0 + + # SSRC -> user_id mapping (populated from SPEAKING events) + self._ssrc_to_user: Dict[int, int] = {} + self._lock = threading.Lock() + + # Per-user audio buffers + self._buffers: Dict[int, bytearray] = defaultdict(bytearray) + self._last_packet_time: Dict[int, float] = {} + + # Opus decoder per SSRC (each user needs own decoder state) + self._decoders: Dict[int, object] = {} + + # Pause flag: don't capture while bot is playing TTS + self._paused = False + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def start(self): + """Start listening for voice packets.""" + conn = self._vc._connection + self._secret_key = bytes(conn.secret_key) + self._dave_session = conn.dave_session + self._bot_ssrc = conn.ssrc + + self._install_speaking_hook(conn) + conn.add_socket_listener(self._on_packet) + self._running = True + logger.info("VoiceReceiver started (bot_ssrc=%d)", self._bot_ssrc) + + def stop(self): + """Stop listening and clean up.""" + self._running = False + try: + self._vc._connection.remove_socket_listener(self._on_packet) + except Exception: + pass + self._buffers.clear() + self._last_packet_time.clear() + self._decoders.clear() + self._ssrc_to_user.clear() + logger.info("VoiceReceiver stopped") + + def pause(self): + self._paused = True + + def resume(self): + self._paused = False + + # ------------------------------------------------------------------ + # SSRC -> user_id mapping via SPEAKING opcode hook + # ------------------------------------------------------------------ + + def map_ssrc(self, ssrc: int, user_id: int): + with self._lock: + self._ssrc_to_user[ssrc] = user_id + + def _install_speaking_hook(self, conn): + """Wrap the voice websocket hook to capture SPEAKING events (op 5). + + VoiceConnectionState stores the hook as ``conn.hook`` (public attr). + It is passed to DiscordVoiceWebSocket on each (re)connect, so we + must wrap it on the VoiceConnectionState level AND on the current + live websocket instance. + """ + original_hook = conn.hook + receiver_self = self + + async def wrapped_hook(ws, msg): + if isinstance(msg, dict) and msg.get("op") == 5: + data = msg.get("d", {}) + ssrc = data.get("ssrc") + user_id = data.get("user_id") + if ssrc and user_id: + logger.info("SPEAKING event: ssrc=%d -> user=%s", ssrc, user_id) + receiver_self.map_ssrc(int(ssrc), int(user_id)) + if original_hook: + await original_hook(ws, msg) + + # Set on connection state (for future reconnects) + conn.hook = wrapped_hook + # Set on the current live websocket (for immediate effect) + try: + from discord.utils import MISSING + if hasattr(conn, 'ws') and conn.ws is not MISSING: + conn.ws._hook = wrapped_hook + logger.info("Speaking hook installed on live websocket") + except Exception as e: + logger.warning("Could not install hook on live ws: %s", e) + + # ------------------------------------------------------------------ + # Packet handler (called from SocketReader thread) + # ------------------------------------------------------------------ + + _packet_debug_count = 0 # class-level counter for debug logging + + def _on_packet(self, data: bytes): + if not self._running or self._paused: + return + + # Log first few raw packets for debugging + VoiceReceiver._packet_debug_count += 1 + if VoiceReceiver._packet_debug_count <= 5: + logger.info( + "Raw UDP packet: len=%d, first_bytes=%s", + len(data), data[:4].hex() if len(data) >= 4 else "short", + ) + + if len(data) < 16: + return + + # RTP version check: top 2 bits must be 10 (version 2). + # Lower bits may vary (padding, extension, CSRC count). + # Payload type (byte 1 lower 7 bits) = 0x78 (120) for voice. + if (data[0] >> 6) != 2 or (data[1] & 0x7F) != 0x78: + if VoiceReceiver._packet_debug_count <= 5: + logger.info("Skipped non-RTP: byte0=0x%02x byte1=0x%02x", data[0], data[1]) + return + + first_byte = data[0] + _, _, seq, timestamp, ssrc = struct.unpack_from(">BBHII", data, 0) + + # Skip bot's own audio + if ssrc == self._bot_ssrc: + return + + # Calculate dynamic RTP header size (RFC 9335 / rtpsize mode) + cc = first_byte & 0x0F # CSRC count + has_extension = bool(first_byte & 0x10) # extension bit + header_size = 12 + (4 * cc) + (4 if has_extension else 0) + + if len(data) < header_size + 4: # need at least header + nonce + return + + # Read extension length from preamble (for skipping after decrypt) + ext_data_len = 0 + if has_extension: + ext_preamble_offset = 12 + (4 * cc) + ext_words = struct.unpack_from(">H", data, ext_preamble_offset + 2)[0] + ext_data_len = ext_words * 4 + + if VoiceReceiver._packet_debug_count <= 10: + with self._lock: + known_user = self._ssrc_to_user.get(ssrc, "unknown") + logger.info( + "RTP packet: ssrc=%d, seq=%d, user=%s, hdr=%d, ext_data=%d", + ssrc, seq, known_user, header_size, ext_data_len, + ) + + header = bytes(data[:header_size]) + payload_with_nonce = data[header_size:] + + # --- NaCl transport decrypt (aead_xchacha20_poly1305_rtpsize) --- + if len(payload_with_nonce) < 4: + return + nonce = bytearray(24) + nonce[:4] = payload_with_nonce[-4:] + encrypted = bytes(payload_with_nonce[:-4]) + + try: + import nacl.secret # noqa: delayed import – only in voice path + box = nacl.secret.Aead(self._secret_key) + decrypted = box.decrypt(encrypted, header, bytes(nonce)) + except Exception as e: + if VoiceReceiver._packet_debug_count <= 10: + logger.warning("NaCl decrypt failed: %s (hdr=%d, enc=%d)", e, header_size, len(encrypted)) + return + + # Skip encrypted extension data to get the actual opus payload + if ext_data_len and len(decrypted) > ext_data_len: + decrypted = decrypted[ext_data_len:] + + # --- DAVE E2EE decrypt --- + if self._dave_session: + with self._lock: + user_id = self._ssrc_to_user.get(ssrc, 0) + if user_id == 0: + if VoiceReceiver._packet_debug_count <= 10: + logger.warning("DAVE skip: unknown user for ssrc=%d", ssrc) + return # unknown user, can't DAVE-decrypt + try: + import davey + decrypted = self._dave_session.decrypt( + user_id, davey.MediaType.audio, decrypted + ) + except Exception as e: + if VoiceReceiver._packet_debug_count <= 10: + logger.warning("DAVE decrypt failed for ssrc=%d: %s", ssrc, e) + return + + # --- Opus decode -> PCM --- + try: + if ssrc not in self._decoders: + self._decoders[ssrc] = discord.opus.Decoder() + pcm = self._decoders[ssrc].decode(decrypted) + self._buffers[ssrc].extend(pcm) + self._last_packet_time[ssrc] = time.monotonic() + except Exception: + return + + # ------------------------------------------------------------------ + # Silence detection + # ------------------------------------------------------------------ + + def check_silence(self) -> list: + """Return list of (user_id, pcm_bytes) for completed utterances.""" + now = time.monotonic() + completed = [] + + with self._lock: + ssrc_user_map = dict(self._ssrc_to_user) + + for ssrc in list(self._buffers.keys()): + last_time = self._last_packet_time.get(ssrc, now) + silence_duration = now - last_time + buf = self._buffers[ssrc] + # 48kHz, 16-bit, stereo = 192000 bytes/sec + buf_duration = len(buf) / (self.SAMPLE_RATE * self.CHANNELS * 2) + + if silence_duration >= self.SILENCE_THRESHOLD and buf_duration >= self.MIN_SPEECH_DURATION: + user_id = ssrc_user_map.get(ssrc, 0) + if user_id: + completed.append((user_id, bytes(buf))) + self._buffers[ssrc] = bytearray() + self._last_packet_time.pop(ssrc, None) + elif silence_duration >= self.SILENCE_THRESHOLD * 2: + # Stale buffer with no valid user — discard + self._buffers.pop(ssrc, None) + self._last_packet_time.pop(ssrc, None) + + return completed + + # ------------------------------------------------------------------ + # PCM -> WAV conversion (for Whisper STT) + # ------------------------------------------------------------------ + + @staticmethod + def pcm_to_wav(pcm_data: bytes, output_path: str, + src_rate: int = 48000, src_channels: int = 2): + """Convert raw PCM to 16kHz mono WAV via ffmpeg.""" + with tempfile.NamedTemporaryFile(suffix=".pcm", delete=False) as f: + f.write(pcm_data) + pcm_path = f.name + try: + subprocess.run( + [ + "ffmpeg", "-y", "-loglevel", "error", + "-f", "s16le", + "-ar", str(src_rate), + "-ac", str(src_channels), + "-i", pcm_path, + "-ar", "16000", + "-ac", "1", + output_path, + ], + check=True, + timeout=10, + ) + finally: + try: + os.unlink(pcm_path) + except OSError: + pass + + class DiscordAdapter(BasePlatformAdapter): """ Discord bot adapter. @@ -94,6 +388,10 @@ class DiscordAdapter(BasePlatformAdapter): self._voice_clients: Dict[int, Any] = {} # guild_id -> VoiceClient self._voice_text_channels: Dict[int, int] = {} # guild_id -> text_channel_id self._voice_timeout_tasks: Dict[int, asyncio.Task] = {} # guild_id -> timeout task + # Phase 2: voice listening + self._voice_receivers: Dict[int, VoiceReceiver] = {} # guild_id -> VoiceReceiver + self._voice_listen_tasks: Dict[int, asyncio.Task] = {} # guild_id -> listen loop + self._voice_input_callback: Optional[Callable] = None # set by run.py async def connect(self) -> bool: """Connect to Discord and start receiving events.""" @@ -402,10 +700,30 @@ class DiscordAdapter(BasePlatformAdapter): vc = await channel.connect() self._voice_clients[guild_id] = vc self._reset_voice_timeout(guild_id) + + # Start voice receiver (Phase 2: listen to users) + try: + receiver = VoiceReceiver(vc) + receiver.start() + self._voice_receivers[guild_id] = receiver + self._voice_listen_tasks[guild_id] = asyncio.ensure_future( + self._voice_listen_loop(guild_id) + ) + except Exception as e: + logger.warning("Voice receiver failed to start: %s", e) + return True async def leave_voice_channel(self, guild_id: int) -> None: """Disconnect from the voice channel in a guild.""" + # Stop voice receiver first + receiver = self._voice_receivers.pop(guild_id, None) + if receiver: + receiver.stop() + listen_task = self._voice_listen_tasks.pop(guild_id, None) + if listen_task: + listen_task.cancel() + vc = self._voice_clients.pop(guild_id, None) if vc and vc.is_connected(): await vc.disconnect() @@ -420,24 +738,33 @@ class DiscordAdapter(BasePlatformAdapter): if not vc or not vc.is_connected(): return False - # Wait for current playback to finish - while vc.is_playing(): - await asyncio.sleep(0.1) + # Pause voice receiver while playing (echo prevention) + receiver = self._voice_receivers.get(guild_id) + if receiver: + receiver.pause() - done = asyncio.Event() - loop = asyncio.get_event_loop() + try: + # Wait for current playback to finish + while vc.is_playing(): + await asyncio.sleep(0.1) - def _after(error): - if error: - logger.error("Voice playback error: %s", error) - loop.call_soon_threadsafe(done.set) + done = asyncio.Event() + loop = asyncio.get_event_loop() - source = discord.FFmpegPCMAudio(audio_path) - source = discord.PCMVolumeTransformer(source, volume=1.0) - vc.play(source, after=_after) - await done.wait() - self._reset_voice_timeout(guild_id) - return True + def _after(error): + if error: + logger.error("Voice playback error: %s", error) + loop.call_soon_threadsafe(done.set) + + source = discord.FFmpegPCMAudio(audio_path) + source = discord.PCMVolumeTransformer(source, volume=1.0) + vc.play(source, after=_after) + await done.wait() + self._reset_voice_timeout(guild_id) + return True + finally: + if receiver: + receiver.resume() async def get_user_voice_channel(self, guild_id: int, user_id: str): """Return the voice channel the user is currently in, or None.""" @@ -481,6 +808,67 @@ class DiscordAdapter(BasePlatformAdapter): vc = self._voice_clients.get(guild_id) return vc is not None and vc.is_connected() + # ------------------------------------------------------------------ + # Voice listening (Phase 2) + # ------------------------------------------------------------------ + + async def _voice_listen_loop(self, guild_id: int): + """Periodically check for completed utterances and process them.""" + receiver = self._voice_receivers.get(guild_id) + if not receiver: + return + try: + while receiver._running: + await asyncio.sleep(0.2) + completed = receiver.check_silence() + for user_id, pcm_data in completed: + if not self._is_allowed_user(str(user_id)): + continue + await self._process_voice_input(guild_id, user_id, pcm_data) + except asyncio.CancelledError: + pass + except Exception as e: + logger.error("Voice listen loop error: %s", e, exc_info=True) + + async def _process_voice_input(self, guild_id: int, user_id: int, pcm_data: bytes): + """Convert PCM -> WAV -> STT -> callback.""" + from tools.voice_mode import is_whisper_hallucination + + wav_path = tempfile.mktemp(suffix=".wav", prefix="vc_listen_") + try: + await asyncio.to_thread(VoiceReceiver.pcm_to_wav, pcm_data, wav_path) + + from tools.transcription_tools import transcribe_audio + result = await asyncio.to_thread(transcribe_audio, wav_path) + + if not result.get("success"): + return + transcript = result.get("transcript", "").strip() + if not transcript or is_whisper_hallucination(transcript): + return + + logger.info("Voice input from user %d: %s", user_id, transcript[:100]) + + if self._voice_input_callback: + await self._voice_input_callback( + guild_id=guild_id, + user_id=user_id, + transcript=transcript, + ) + except Exception as e: + logger.warning("Voice input processing failed: %s", e, exc_info=True) + finally: + try: + os.unlink(wav_path) + except OSError: + pass + + def _is_allowed_user(self, user_id: str) -> bool: + """Check if user is in DISCORD_ALLOWED_USERS.""" + if not self._allowed_user_ids: + return True + return user_id in self._allowed_user_ids + async def send_image_file( self, chat_id: str, diff --git a/gateway/run.py b/gateway/run.py index 4674548aa..bee9b62a1 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -1608,6 +1608,8 @@ class GatewayRunner: (voice_mode == "all") or (voice_mode == "voice_only" and is_voice_input) ) + logger.info("Voice reply check: chat_id=%s, voice_mode=%s, is_voice=%s, should_reply=%s, has_response=%s", + chat_id, voice_mode, is_voice_input, should_voice_reply, bool(response)) if should_voice_reply and response and not response.startswith("Error:"): # Skip if agent already called TTS tool (avoid double voice) has_agent_tts = any( @@ -1618,6 +1620,7 @@ class GatewayRunner: ) for msg in agent_messages ) + logger.info("Voice reply: has_agent_tts=%s, calling _send_voice_reply", has_agent_tts) if not has_agent_tts: await self._send_voice_reply(event, response) @@ -2201,9 +2204,12 @@ class GatewayRunner: adapter._voice_text_channels[guild_id] = int(event.source.chat_id) self._voice_mode[event.source.chat_id] = "all" self._save_voice_modes() + # Wire voice input callback so the adapter can deliver transcripts + if hasattr(adapter, "_voice_input_callback"): + adapter._voice_input_callback = self._handle_voice_channel_input return ( f"Joined voice channel **{voice_channel.name}**.\n" - f"I'll speak my replies here. Use /voice leave to disconnect." + f"I'll speak my replies and listen to you. Use /voice leave to disconnect." ) return "Failed to join voice channel. Check bot permissions (Connect + Speak)." @@ -2223,6 +2229,49 @@ class GatewayRunner: self._save_voice_modes() return "Left voice channel." + async def _handle_voice_channel_input( + self, guild_id: int, user_id: int, transcript: str + ): + """Handle transcribed voice from a user in a voice channel. + + Creates a synthetic MessageEvent and processes it through the + adapter's full message pipeline (session, typing, agent, TTS reply). + """ + adapter = self.adapters.get(Platform.DISCORD) + if not adapter: + return + + text_ch_id = adapter._voice_text_channels.get(guild_id) + if not text_ch_id: + return + + # Show transcript in text channel + try: + channel = adapter._client.get_channel(text_ch_id) + if channel: + await channel.send(f"**[Voice]** <@{user_id}>: {transcript}") + except Exception: + pass + + # Build a synthetic MessageEvent and feed through the normal pipeline + source = SessionSource( + platform=Platform.DISCORD, + chat_id=str(text_ch_id), + user_id=str(user_id), + user_name=str(user_id), + ) + # Use SimpleNamespace as raw_message so _get_guild_id() can extract + # guild_id and _send_voice_reply() plays audio in the voice channel. + from types import SimpleNamespace + event = MessageEvent( + source=source, + text=transcript, + message_type=MessageType.VOICE, + raw_message=SimpleNamespace(guild_id=guild_id, guild=None), + ) + + await adapter.handle_message(event) + async def _send_voice_reply(self, event: MessageEvent, text: str) -> None: """Generate TTS audio and send as a voice message before the text reply.""" try: