feat: add Discord voice channel listening — STT transcription and agent response pipeline

Phase 2 of voice channel support: bot listens to users speaking in VC,
transcribes speech via Groq Whisper, and processes through the agent pipeline.

- Add VoiceReceiver class for RTP packet capture, NaCl/DAVE decryption, Opus decode
- Add silence detection and per-user PCM buffering
- Wire voice input callback from adapter to GatewayRunner
- Fix adapter dict key: use Platform.DISCORD enum instead of string
- Fix guild_id extraction for synthetic voice events via SimpleNamespace raw_message
- Pause/resume receiver during TTS playback to prevent echo
This commit is contained in:
0xbyt4
2026-03-11 04:34:58 +03:00
parent cc974904f8
commit c0c358d051
2 changed files with 454 additions and 17 deletions

View File

@@ -10,7 +10,13 @@ Uses discord.py library for:
import asyncio
import logging
import os
from typing import Dict, List, Optional, Any
import struct
import subprocess
import tempfile
import threading
import time
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Any
logger = logging.getLogger(__name__)
@@ -65,6 +71,294 @@ def check_discord_requirements() -> bool:
return DISCORD_AVAILABLE
class VoiceReceiver:
"""Captures and decodes voice audio from a Discord voice channel.
Attaches to a VoiceClient's socket listener, decrypts RTP packets
(NaCl transport + DAVE E2EE), decodes Opus to PCM, and buffers
per-user audio. A polling loop detects silence and delivers
completed utterances via a callback.
"""
SILENCE_THRESHOLD = 1.5 # seconds of silence → end of utterance
MIN_SPEECH_DURATION = 0.5 # minimum seconds to process (skip noise)
SAMPLE_RATE = 48000 # Discord native rate
CHANNELS = 2 # Discord sends stereo
def __init__(self, voice_client):
self._vc = voice_client
self._running = False
# Decryption
self._secret_key: Optional[bytes] = None
self._dave_session = None
self._bot_ssrc: int = 0
# SSRC -> user_id mapping (populated from SPEAKING events)
self._ssrc_to_user: Dict[int, int] = {}
self._lock = threading.Lock()
# Per-user audio buffers
self._buffers: Dict[int, bytearray] = defaultdict(bytearray)
self._last_packet_time: Dict[int, float] = {}
# Opus decoder per SSRC (each user needs own decoder state)
self._decoders: Dict[int, object] = {}
# Pause flag: don't capture while bot is playing TTS
self._paused = False
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def start(self):
"""Start listening for voice packets."""
conn = self._vc._connection
self._secret_key = bytes(conn.secret_key)
self._dave_session = conn.dave_session
self._bot_ssrc = conn.ssrc
self._install_speaking_hook(conn)
conn.add_socket_listener(self._on_packet)
self._running = True
logger.info("VoiceReceiver started (bot_ssrc=%d)", self._bot_ssrc)
def stop(self):
"""Stop listening and clean up."""
self._running = False
try:
self._vc._connection.remove_socket_listener(self._on_packet)
except Exception:
pass
self._buffers.clear()
self._last_packet_time.clear()
self._decoders.clear()
self._ssrc_to_user.clear()
logger.info("VoiceReceiver stopped")
def pause(self):
self._paused = True
def resume(self):
self._paused = False
# ------------------------------------------------------------------
# SSRC -> user_id mapping via SPEAKING opcode hook
# ------------------------------------------------------------------
def map_ssrc(self, ssrc: int, user_id: int):
with self._lock:
self._ssrc_to_user[ssrc] = user_id
def _install_speaking_hook(self, conn):
"""Wrap the voice websocket hook to capture SPEAKING events (op 5).
VoiceConnectionState stores the hook as ``conn.hook`` (public attr).
It is passed to DiscordVoiceWebSocket on each (re)connect, so we
must wrap it on the VoiceConnectionState level AND on the current
live websocket instance.
"""
original_hook = conn.hook
receiver_self = self
async def wrapped_hook(ws, msg):
if isinstance(msg, dict) and msg.get("op") == 5:
data = msg.get("d", {})
ssrc = data.get("ssrc")
user_id = data.get("user_id")
if ssrc and user_id:
logger.info("SPEAKING event: ssrc=%d -> user=%s", ssrc, user_id)
receiver_self.map_ssrc(int(ssrc), int(user_id))
if original_hook:
await original_hook(ws, msg)
# Set on connection state (for future reconnects)
conn.hook = wrapped_hook
# Set on the current live websocket (for immediate effect)
try:
from discord.utils import MISSING
if hasattr(conn, 'ws') and conn.ws is not MISSING:
conn.ws._hook = wrapped_hook
logger.info("Speaking hook installed on live websocket")
except Exception as e:
logger.warning("Could not install hook on live ws: %s", e)
# ------------------------------------------------------------------
# Packet handler (called from SocketReader thread)
# ------------------------------------------------------------------
_packet_debug_count = 0 # class-level counter for debug logging
def _on_packet(self, data: bytes):
if not self._running or self._paused:
return
# Log first few raw packets for debugging
VoiceReceiver._packet_debug_count += 1
if VoiceReceiver._packet_debug_count <= 5:
logger.info(
"Raw UDP packet: len=%d, first_bytes=%s",
len(data), data[:4].hex() if len(data) >= 4 else "short",
)
if len(data) < 16:
return
# RTP version check: top 2 bits must be 10 (version 2).
# Lower bits may vary (padding, extension, CSRC count).
# Payload type (byte 1 lower 7 bits) = 0x78 (120) for voice.
if (data[0] >> 6) != 2 or (data[1] & 0x7F) != 0x78:
if VoiceReceiver._packet_debug_count <= 5:
logger.info("Skipped non-RTP: byte0=0x%02x byte1=0x%02x", data[0], data[1])
return
first_byte = data[0]
_, _, seq, timestamp, ssrc = struct.unpack_from(">BBHII", data, 0)
# Skip bot's own audio
if ssrc == self._bot_ssrc:
return
# Calculate dynamic RTP header size (RFC 9335 / rtpsize mode)
cc = first_byte & 0x0F # CSRC count
has_extension = bool(first_byte & 0x10) # extension bit
header_size = 12 + (4 * cc) + (4 if has_extension else 0)
if len(data) < header_size + 4: # need at least header + nonce
return
# Read extension length from preamble (for skipping after decrypt)
ext_data_len = 0
if has_extension:
ext_preamble_offset = 12 + (4 * cc)
ext_words = struct.unpack_from(">H", data, ext_preamble_offset + 2)[0]
ext_data_len = ext_words * 4
if VoiceReceiver._packet_debug_count <= 10:
with self._lock:
known_user = self._ssrc_to_user.get(ssrc, "unknown")
logger.info(
"RTP packet: ssrc=%d, seq=%d, user=%s, hdr=%d, ext_data=%d",
ssrc, seq, known_user, header_size, ext_data_len,
)
header = bytes(data[:header_size])
payload_with_nonce = data[header_size:]
# --- NaCl transport decrypt (aead_xchacha20_poly1305_rtpsize) ---
if len(payload_with_nonce) < 4:
return
nonce = bytearray(24)
nonce[:4] = payload_with_nonce[-4:]
encrypted = bytes(payload_with_nonce[:-4])
try:
import nacl.secret # noqa: delayed import only in voice path
box = nacl.secret.Aead(self._secret_key)
decrypted = box.decrypt(encrypted, header, bytes(nonce))
except Exception as e:
if VoiceReceiver._packet_debug_count <= 10:
logger.warning("NaCl decrypt failed: %s (hdr=%d, enc=%d)", e, header_size, len(encrypted))
return
# Skip encrypted extension data to get the actual opus payload
if ext_data_len and len(decrypted) > ext_data_len:
decrypted = decrypted[ext_data_len:]
# --- DAVE E2EE decrypt ---
if self._dave_session:
with self._lock:
user_id = self._ssrc_to_user.get(ssrc, 0)
if user_id == 0:
if VoiceReceiver._packet_debug_count <= 10:
logger.warning("DAVE skip: unknown user for ssrc=%d", ssrc)
return # unknown user, can't DAVE-decrypt
try:
import davey
decrypted = self._dave_session.decrypt(
user_id, davey.MediaType.audio, decrypted
)
except Exception as e:
if VoiceReceiver._packet_debug_count <= 10:
logger.warning("DAVE decrypt failed for ssrc=%d: %s", ssrc, e)
return
# --- Opus decode -> PCM ---
try:
if ssrc not in self._decoders:
self._decoders[ssrc] = discord.opus.Decoder()
pcm = self._decoders[ssrc].decode(decrypted)
self._buffers[ssrc].extend(pcm)
self._last_packet_time[ssrc] = time.monotonic()
except Exception:
return
# ------------------------------------------------------------------
# Silence detection
# ------------------------------------------------------------------
def check_silence(self) -> list:
"""Return list of (user_id, pcm_bytes) for completed utterances."""
now = time.monotonic()
completed = []
with self._lock:
ssrc_user_map = dict(self._ssrc_to_user)
for ssrc in list(self._buffers.keys()):
last_time = self._last_packet_time.get(ssrc, now)
silence_duration = now - last_time
buf = self._buffers[ssrc]
# 48kHz, 16-bit, stereo = 192000 bytes/sec
buf_duration = len(buf) / (self.SAMPLE_RATE * self.CHANNELS * 2)
if silence_duration >= self.SILENCE_THRESHOLD and buf_duration >= self.MIN_SPEECH_DURATION:
user_id = ssrc_user_map.get(ssrc, 0)
if user_id:
completed.append((user_id, bytes(buf)))
self._buffers[ssrc] = bytearray()
self._last_packet_time.pop(ssrc, None)
elif silence_duration >= self.SILENCE_THRESHOLD * 2:
# Stale buffer with no valid user — discard
self._buffers.pop(ssrc, None)
self._last_packet_time.pop(ssrc, None)
return completed
# ------------------------------------------------------------------
# PCM -> WAV conversion (for Whisper STT)
# ------------------------------------------------------------------
@staticmethod
def pcm_to_wav(pcm_data: bytes, output_path: str,
src_rate: int = 48000, src_channels: int = 2):
"""Convert raw PCM to 16kHz mono WAV via ffmpeg."""
with tempfile.NamedTemporaryFile(suffix=".pcm", delete=False) as f:
f.write(pcm_data)
pcm_path = f.name
try:
subprocess.run(
[
"ffmpeg", "-y", "-loglevel", "error",
"-f", "s16le",
"-ar", str(src_rate),
"-ac", str(src_channels),
"-i", pcm_path,
"-ar", "16000",
"-ac", "1",
output_path,
],
check=True,
timeout=10,
)
finally:
try:
os.unlink(pcm_path)
except OSError:
pass
class DiscordAdapter(BasePlatformAdapter):
"""
Discord bot adapter.
@@ -94,6 +388,10 @@ class DiscordAdapter(BasePlatformAdapter):
self._voice_clients: Dict[int, Any] = {} # guild_id -> VoiceClient
self._voice_text_channels: Dict[int, int] = {} # guild_id -> text_channel_id
self._voice_timeout_tasks: Dict[int, asyncio.Task] = {} # guild_id -> timeout task
# Phase 2: voice listening
self._voice_receivers: Dict[int, VoiceReceiver] = {} # guild_id -> VoiceReceiver
self._voice_listen_tasks: Dict[int, asyncio.Task] = {} # guild_id -> listen loop
self._voice_input_callback: Optional[Callable] = None # set by run.py
async def connect(self) -> bool:
"""Connect to Discord and start receiving events."""
@@ -402,10 +700,30 @@ class DiscordAdapter(BasePlatformAdapter):
vc = await channel.connect()
self._voice_clients[guild_id] = vc
self._reset_voice_timeout(guild_id)
# Start voice receiver (Phase 2: listen to users)
try:
receiver = VoiceReceiver(vc)
receiver.start()
self._voice_receivers[guild_id] = receiver
self._voice_listen_tasks[guild_id] = asyncio.ensure_future(
self._voice_listen_loop(guild_id)
)
except Exception as e:
logger.warning("Voice receiver failed to start: %s", e)
return True
async def leave_voice_channel(self, guild_id: int) -> None:
"""Disconnect from the voice channel in a guild."""
# Stop voice receiver first
receiver = self._voice_receivers.pop(guild_id, None)
if receiver:
receiver.stop()
listen_task = self._voice_listen_tasks.pop(guild_id, None)
if listen_task:
listen_task.cancel()
vc = self._voice_clients.pop(guild_id, None)
if vc and vc.is_connected():
await vc.disconnect()
@@ -420,24 +738,33 @@ class DiscordAdapter(BasePlatformAdapter):
if not vc or not vc.is_connected():
return False
# Wait for current playback to finish
while vc.is_playing():
await asyncio.sleep(0.1)
# Pause voice receiver while playing (echo prevention)
receiver = self._voice_receivers.get(guild_id)
if receiver:
receiver.pause()
done = asyncio.Event()
loop = asyncio.get_event_loop()
try:
# Wait for current playback to finish
while vc.is_playing():
await asyncio.sleep(0.1)
def _after(error):
if error:
logger.error("Voice playback error: %s", error)
loop.call_soon_threadsafe(done.set)
done = asyncio.Event()
loop = asyncio.get_event_loop()
source = discord.FFmpegPCMAudio(audio_path)
source = discord.PCMVolumeTransformer(source, volume=1.0)
vc.play(source, after=_after)
await done.wait()
self._reset_voice_timeout(guild_id)
return True
def _after(error):
if error:
logger.error("Voice playback error: %s", error)
loop.call_soon_threadsafe(done.set)
source = discord.FFmpegPCMAudio(audio_path)
source = discord.PCMVolumeTransformer(source, volume=1.0)
vc.play(source, after=_after)
await done.wait()
self._reset_voice_timeout(guild_id)
return True
finally:
if receiver:
receiver.resume()
async def get_user_voice_channel(self, guild_id: int, user_id: str):
"""Return the voice channel the user is currently in, or None."""
@@ -481,6 +808,67 @@ class DiscordAdapter(BasePlatformAdapter):
vc = self._voice_clients.get(guild_id)
return vc is not None and vc.is_connected()
# ------------------------------------------------------------------
# Voice listening (Phase 2)
# ------------------------------------------------------------------
async def _voice_listen_loop(self, guild_id: int):
"""Periodically check for completed utterances and process them."""
receiver = self._voice_receivers.get(guild_id)
if not receiver:
return
try:
while receiver._running:
await asyncio.sleep(0.2)
completed = receiver.check_silence()
for user_id, pcm_data in completed:
if not self._is_allowed_user(str(user_id)):
continue
await self._process_voice_input(guild_id, user_id, pcm_data)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error("Voice listen loop error: %s", e, exc_info=True)
async def _process_voice_input(self, guild_id: int, user_id: int, pcm_data: bytes):
"""Convert PCM -> WAV -> STT -> callback."""
from tools.voice_mode import is_whisper_hallucination
wav_path = tempfile.mktemp(suffix=".wav", prefix="vc_listen_")
try:
await asyncio.to_thread(VoiceReceiver.pcm_to_wav, pcm_data, wav_path)
from tools.transcription_tools import transcribe_audio
result = await asyncio.to_thread(transcribe_audio, wav_path)
if not result.get("success"):
return
transcript = result.get("transcript", "").strip()
if not transcript or is_whisper_hallucination(transcript):
return
logger.info("Voice input from user %d: %s", user_id, transcript[:100])
if self._voice_input_callback:
await self._voice_input_callback(
guild_id=guild_id,
user_id=user_id,
transcript=transcript,
)
except Exception as e:
logger.warning("Voice input processing failed: %s", e, exc_info=True)
finally:
try:
os.unlink(wav_path)
except OSError:
pass
def _is_allowed_user(self, user_id: str) -> bool:
"""Check if user is in DISCORD_ALLOWED_USERS."""
if not self._allowed_user_ids:
return True
return user_id in self._allowed_user_ids
async def send_image_file(
self,
chat_id: str,

View File

@@ -1608,6 +1608,8 @@ class GatewayRunner:
(voice_mode == "all")
or (voice_mode == "voice_only" and is_voice_input)
)
logger.info("Voice reply check: chat_id=%s, voice_mode=%s, is_voice=%s, should_reply=%s, has_response=%s",
chat_id, voice_mode, is_voice_input, should_voice_reply, bool(response))
if should_voice_reply and response and not response.startswith("Error:"):
# Skip if agent already called TTS tool (avoid double voice)
has_agent_tts = any(
@@ -1618,6 +1620,7 @@ class GatewayRunner:
)
for msg in agent_messages
)
logger.info("Voice reply: has_agent_tts=%s, calling _send_voice_reply", has_agent_tts)
if not has_agent_tts:
await self._send_voice_reply(event, response)
@@ -2201,9 +2204,12 @@ class GatewayRunner:
adapter._voice_text_channels[guild_id] = int(event.source.chat_id)
self._voice_mode[event.source.chat_id] = "all"
self._save_voice_modes()
# Wire voice input callback so the adapter can deliver transcripts
if hasattr(adapter, "_voice_input_callback"):
adapter._voice_input_callback = self._handle_voice_channel_input
return (
f"Joined voice channel **{voice_channel.name}**.\n"
f"I'll speak my replies here. Use /voice leave to disconnect."
f"I'll speak my replies and listen to you. Use /voice leave to disconnect."
)
return "Failed to join voice channel. Check bot permissions (Connect + Speak)."
@@ -2223,6 +2229,49 @@ class GatewayRunner:
self._save_voice_modes()
return "Left voice channel."
async def _handle_voice_channel_input(
self, guild_id: int, user_id: int, transcript: str
):
"""Handle transcribed voice from a user in a voice channel.
Creates a synthetic MessageEvent and processes it through the
adapter's full message pipeline (session, typing, agent, TTS reply).
"""
adapter = self.adapters.get(Platform.DISCORD)
if not adapter:
return
text_ch_id = adapter._voice_text_channels.get(guild_id)
if not text_ch_id:
return
# Show transcript in text channel
try:
channel = adapter._client.get_channel(text_ch_id)
if channel:
await channel.send(f"**[Voice]** <@{user_id}>: {transcript}")
except Exception:
pass
# Build a synthetic MessageEvent and feed through the normal pipeline
source = SessionSource(
platform=Platform.DISCORD,
chat_id=str(text_ch_id),
user_id=str(user_id),
user_name=str(user_id),
)
# Use SimpleNamespace as raw_message so _get_guild_id() can extract
# guild_id and _send_voice_reply() plays audio in the voice channel.
from types import SimpleNamespace
event = MessageEvent(
source=source,
text=transcript,
message_type=MessageType.VOICE,
raw_message=SimpleNamespace(guild_id=guild_id, guild=None),
)
await adapter.handle_message(event)
async def _send_voice_reply(self, event: MessageEvent, text: str) -> None:
"""Generate TTS audio and send as a voice message before the text reply."""
try: