merge: salvage PR #327 voice mode branch
Merge contributor branch feature/voice-mode onto current main for follow-up fixes.
This commit is contained in:
@@ -351,6 +351,8 @@ class BasePlatformAdapter(ABC):
|
||||
# Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt)
|
||||
self._active_sessions: Dict[str, asyncio.Event] = {}
|
||||
self._pending_messages: Dict[str, MessageEvent] = {}
|
||||
# Chats where auto-TTS on voice input is disabled (set by /voice off)
|
||||
self._auto_tts_disabled_chats: set = set()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -537,6 +539,20 @@ class BasePlatformAdapter(ABC):
|
||||
text = f"{caption}\n{text}"
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
|
||||
async def play_tts(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Play auto-TTS audio for voice replies.
|
||||
|
||||
Override in subclasses for invisible playback (e.g. Web UI).
|
||||
Default falls back to send_voice (shows audio player).
|
||||
"""
|
||||
return await self.send_voice(chat_id=chat_id, audio_path=audio_path, **kwargs)
|
||||
|
||||
async def send_video(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -724,7 +740,43 @@ class BasePlatformAdapter(ABC):
|
||||
if images:
|
||||
logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response))
|
||||
|
||||
# Send the text portion first (if any remains after extractions)
|
||||
# Auto-TTS: if voice message, generate audio FIRST (before sending text)
|
||||
# Skipped when the chat has voice mode disabled (/voice off)
|
||||
_tts_path = None
|
||||
if (event.message_type == MessageType.VOICE
|
||||
and text_content
|
||||
and not media_files
|
||||
and event.source.chat_id not in self._auto_tts_disabled_chats):
|
||||
try:
|
||||
from tools.tts_tool import text_to_speech_tool, check_tts_requirements
|
||||
if check_tts_requirements():
|
||||
import json as _json
|
||||
speech_text = re.sub(r'[*_`#\[\]()]', '', text_content)[:4000].strip()
|
||||
if not speech_text:
|
||||
raise ValueError("Empty text after markdown cleanup")
|
||||
tts_result_str = await asyncio.to_thread(
|
||||
text_to_speech_tool, text=speech_text
|
||||
)
|
||||
tts_data = _json.loads(tts_result_str)
|
||||
_tts_path = tts_data.get("file_path")
|
||||
except Exception as tts_err:
|
||||
logger.warning("[%s] Auto-TTS failed: %s", self.name, tts_err)
|
||||
|
||||
# Play TTS audio before text (voice-first experience)
|
||||
if _tts_path and Path(_tts_path).exists():
|
||||
try:
|
||||
await self.play_tts(
|
||||
chat_id=event.source.chat_id,
|
||||
audio_path=_tts_path,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
os.remove(_tts_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Send the text portion
|
||||
if text_content:
|
||||
logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id)
|
||||
result = await self.send(
|
||||
@@ -733,7 +785,7 @@ class BasePlatformAdapter(ABC):
|
||||
reply_to=event.message_id,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
|
||||
|
||||
# Log send failures (don't raise - user already saw tool progress)
|
||||
if not result.success:
|
||||
print(f"[{self.name}] Failed to send response: {result.error}")
|
||||
@@ -746,10 +798,10 @@ class BasePlatformAdapter(ABC):
|
||||
)
|
||||
if not fallback_result.success:
|
||||
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
||||
|
||||
|
||||
# Human-like pacing delay between text and media
|
||||
human_delay = self._get_human_delay()
|
||||
|
||||
|
||||
# Send extracted images as native attachments
|
||||
if images:
|
||||
logger.info("[%s] Extracted %d image(s) to send as attachments", self.name, len(images))
|
||||
@@ -777,7 +829,7 @@ class BasePlatformAdapter(ABC):
|
||||
logger.error("[%s] Failed to send image: %s", self.name, img_result.error)
|
||||
except Exception as img_err:
|
||||
logger.error("[%s] Error sending image: %s", self.name, img_err, exc_info=True)
|
||||
|
||||
|
||||
# Send extracted media files — route by file type
|
||||
_AUDIO_EXTS = {'.ogg', '.opus', '.mp3', '.wav', '.m4a'}
|
||||
_VIDEO_EXTS = {'.mp4', '.mov', '.avi', '.mkv', '.3gp'}
|
||||
|
||||
@@ -10,7 +10,13 @@ Uses discord.py library for:
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Optional, Any
|
||||
import struct
|
||||
import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -65,6 +71,299 @@ def check_discord_requirements() -> bool:
|
||||
return DISCORD_AVAILABLE
|
||||
|
||||
|
||||
class VoiceReceiver:
|
||||
"""Captures and decodes voice audio from a Discord voice channel.
|
||||
|
||||
Attaches to a VoiceClient's socket listener, decrypts RTP packets
|
||||
(NaCl transport + DAVE E2EE), decodes Opus to PCM, and buffers
|
||||
per-user audio. A polling loop detects silence and delivers
|
||||
completed utterances via a callback.
|
||||
"""
|
||||
|
||||
SILENCE_THRESHOLD = 1.5 # seconds of silence → end of utterance
|
||||
MIN_SPEECH_DURATION = 0.5 # minimum seconds to process (skip noise)
|
||||
SAMPLE_RATE = 48000 # Discord native rate
|
||||
CHANNELS = 2 # Discord sends stereo
|
||||
|
||||
def __init__(self, voice_client):
|
||||
self._vc = voice_client
|
||||
self._running = False
|
||||
|
||||
# Decryption
|
||||
self._secret_key: Optional[bytes] = None
|
||||
self._dave_session = None
|
||||
self._bot_ssrc: int = 0
|
||||
|
||||
# SSRC -> user_id mapping (populated from SPEAKING events)
|
||||
self._ssrc_to_user: Dict[int, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Per-user audio buffers
|
||||
self._buffers: Dict[int, bytearray] = defaultdict(bytearray)
|
||||
self._last_packet_time: Dict[int, float] = {}
|
||||
|
||||
# Opus decoder per SSRC (each user needs own decoder state)
|
||||
self._decoders: Dict[int, object] = {}
|
||||
|
||||
# Pause flag: don't capture while bot is playing TTS
|
||||
self._paused = False
|
||||
|
||||
# Debug logging counter (instance-level to avoid cross-instance races)
|
||||
self._packet_debug_count = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def start(self):
|
||||
"""Start listening for voice packets."""
|
||||
conn = self._vc._connection
|
||||
self._secret_key = bytes(conn.secret_key)
|
||||
self._dave_session = conn.dave_session
|
||||
self._bot_ssrc = conn.ssrc
|
||||
|
||||
self._install_speaking_hook(conn)
|
||||
conn.add_socket_listener(self._on_packet)
|
||||
self._running = True
|
||||
logger.info("VoiceReceiver started (bot_ssrc=%d)", self._bot_ssrc)
|
||||
|
||||
def stop(self):
|
||||
"""Stop listening and clean up."""
|
||||
self._running = False
|
||||
try:
|
||||
self._vc._connection.remove_socket_listener(self._on_packet)
|
||||
except Exception:
|
||||
pass
|
||||
with self._lock:
|
||||
self._buffers.clear()
|
||||
self._last_packet_time.clear()
|
||||
self._decoders.clear()
|
||||
self._ssrc_to_user.clear()
|
||||
logger.info("VoiceReceiver stopped")
|
||||
|
||||
def pause(self):
|
||||
self._paused = True
|
||||
|
||||
def resume(self):
|
||||
self._paused = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SSRC -> user_id mapping via SPEAKING opcode hook
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def map_ssrc(self, ssrc: int, user_id: int):
|
||||
with self._lock:
|
||||
self._ssrc_to_user[ssrc] = user_id
|
||||
|
||||
def _install_speaking_hook(self, conn):
|
||||
"""Wrap the voice websocket hook to capture SPEAKING events (op 5).
|
||||
|
||||
VoiceConnectionState stores the hook as ``conn.hook`` (public attr).
|
||||
It is passed to DiscordVoiceWebSocket on each (re)connect, so we
|
||||
must wrap it on the VoiceConnectionState level AND on the current
|
||||
live websocket instance.
|
||||
"""
|
||||
original_hook = conn.hook
|
||||
receiver_self = self
|
||||
|
||||
async def wrapped_hook(ws, msg):
|
||||
if isinstance(msg, dict) and msg.get("op") == 5:
|
||||
data = msg.get("d", {})
|
||||
ssrc = data.get("ssrc")
|
||||
user_id = data.get("user_id")
|
||||
if ssrc and user_id:
|
||||
logger.info("SPEAKING event: ssrc=%d -> user=%s", ssrc, user_id)
|
||||
receiver_self.map_ssrc(int(ssrc), int(user_id))
|
||||
if original_hook:
|
||||
await original_hook(ws, msg)
|
||||
|
||||
# Set on connection state (for future reconnects)
|
||||
conn.hook = wrapped_hook
|
||||
# Set on the current live websocket (for immediate effect)
|
||||
try:
|
||||
from discord.utils import MISSING
|
||||
if hasattr(conn, 'ws') and conn.ws is not MISSING:
|
||||
conn.ws._hook = wrapped_hook
|
||||
logger.info("Speaking hook installed on live websocket")
|
||||
except Exception as e:
|
||||
logger.warning("Could not install hook on live ws: %s", e)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Packet handler (called from SocketReader thread)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _on_packet(self, data: bytes):
|
||||
if not self._running or self._paused:
|
||||
return
|
||||
|
||||
# Log first few raw packets for debugging
|
||||
self._packet_debug_count += 1
|
||||
if self._packet_debug_count <= 5:
|
||||
logger.debug(
|
||||
"Raw UDP packet: len=%d, first_bytes=%s",
|
||||
len(data), data[:4].hex() if len(data) >= 4 else "short",
|
||||
)
|
||||
|
||||
if len(data) < 16:
|
||||
return
|
||||
|
||||
# RTP version check: top 2 bits must be 10 (version 2).
|
||||
# Lower bits may vary (padding, extension, CSRC count).
|
||||
# Payload type (byte 1 lower 7 bits) = 0x78 (120) for voice.
|
||||
if (data[0] >> 6) != 2 or (data[1] & 0x7F) != 0x78:
|
||||
if self._packet_debug_count <= 5:
|
||||
logger.debug("Skipped non-RTP: byte0=0x%02x byte1=0x%02x", data[0], data[1])
|
||||
return
|
||||
|
||||
first_byte = data[0]
|
||||
_, _, seq, timestamp, ssrc = struct.unpack_from(">BBHII", data, 0)
|
||||
|
||||
# Skip bot's own audio
|
||||
if ssrc == self._bot_ssrc:
|
||||
return
|
||||
|
||||
# Calculate dynamic RTP header size (RFC 9335 / rtpsize mode)
|
||||
cc = first_byte & 0x0F # CSRC count
|
||||
has_extension = bool(first_byte & 0x10) # extension bit
|
||||
header_size = 12 + (4 * cc) + (4 if has_extension else 0)
|
||||
|
||||
if len(data) < header_size + 4: # need at least header + nonce
|
||||
return
|
||||
|
||||
# Read extension length from preamble (for skipping after decrypt)
|
||||
ext_data_len = 0
|
||||
if has_extension:
|
||||
ext_preamble_offset = 12 + (4 * cc)
|
||||
ext_words = struct.unpack_from(">H", data, ext_preamble_offset + 2)[0]
|
||||
ext_data_len = ext_words * 4
|
||||
|
||||
if self._packet_debug_count <= 10:
|
||||
with self._lock:
|
||||
known_user = self._ssrc_to_user.get(ssrc, "unknown")
|
||||
logger.debug(
|
||||
"RTP packet: ssrc=%d, seq=%d, user=%s, hdr=%d, ext_data=%d",
|
||||
ssrc, seq, known_user, header_size, ext_data_len,
|
||||
)
|
||||
|
||||
header = bytes(data[:header_size])
|
||||
payload_with_nonce = data[header_size:]
|
||||
|
||||
# --- NaCl transport decrypt (aead_xchacha20_poly1305_rtpsize) ---
|
||||
if len(payload_with_nonce) < 4:
|
||||
return
|
||||
nonce = bytearray(24)
|
||||
nonce[:4] = payload_with_nonce[-4:]
|
||||
encrypted = bytes(payload_with_nonce[:-4])
|
||||
|
||||
try:
|
||||
import nacl.secret # noqa: delayed import – only in voice path
|
||||
box = nacl.secret.Aead(self._secret_key)
|
||||
decrypted = box.decrypt(encrypted, header, bytes(nonce))
|
||||
except Exception as e:
|
||||
if self._packet_debug_count <= 10:
|
||||
logger.warning("NaCl decrypt failed: %s (hdr=%d, enc=%d)", e, header_size, len(encrypted))
|
||||
return
|
||||
|
||||
# Skip encrypted extension data to get the actual opus payload
|
||||
if ext_data_len and len(decrypted) > ext_data_len:
|
||||
decrypted = decrypted[ext_data_len:]
|
||||
|
||||
# --- DAVE E2EE decrypt ---
|
||||
if self._dave_session:
|
||||
with self._lock:
|
||||
user_id = self._ssrc_to_user.get(ssrc, 0)
|
||||
if user_id == 0:
|
||||
if self._packet_debug_count <= 10:
|
||||
logger.warning("DAVE skip: unknown user for ssrc=%d", ssrc)
|
||||
return # unknown user, can't DAVE-decrypt
|
||||
try:
|
||||
import davey
|
||||
decrypted = self._dave_session.decrypt(
|
||||
user_id, davey.MediaType.audio, decrypted
|
||||
)
|
||||
except Exception as e:
|
||||
if self._packet_debug_count <= 10:
|
||||
logger.warning("DAVE decrypt failed for ssrc=%d: %s", ssrc, e)
|
||||
return
|
||||
|
||||
# --- Opus decode -> PCM ---
|
||||
try:
|
||||
if ssrc not in self._decoders:
|
||||
self._decoders[ssrc] = discord.opus.Decoder()
|
||||
pcm = self._decoders[ssrc].decode(decrypted)
|
||||
with self._lock:
|
||||
self._buffers[ssrc].extend(pcm)
|
||||
self._last_packet_time[ssrc] = time.monotonic()
|
||||
except Exception as e:
|
||||
logger.debug("Opus decode error for SSRC %s: %s", ssrc, e)
|
||||
return
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Silence detection
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def check_silence(self) -> list:
|
||||
"""Return list of (user_id, pcm_bytes) for completed utterances."""
|
||||
now = time.monotonic()
|
||||
completed = []
|
||||
|
||||
with self._lock:
|
||||
ssrc_user_map = dict(self._ssrc_to_user)
|
||||
ssrc_list = list(self._buffers.keys())
|
||||
|
||||
for ssrc in ssrc_list:
|
||||
last_time = self._last_packet_time.get(ssrc, now)
|
||||
silence_duration = now - last_time
|
||||
buf = self._buffers[ssrc]
|
||||
# 48kHz, 16-bit, stereo = 192000 bytes/sec
|
||||
buf_duration = len(buf) / (self.SAMPLE_RATE * self.CHANNELS * 2)
|
||||
|
||||
if silence_duration >= self.SILENCE_THRESHOLD and buf_duration >= self.MIN_SPEECH_DURATION:
|
||||
user_id = ssrc_user_map.get(ssrc, 0)
|
||||
if user_id:
|
||||
completed.append((user_id, bytes(buf)))
|
||||
self._buffers[ssrc] = bytearray()
|
||||
self._last_packet_time.pop(ssrc, None)
|
||||
elif silence_duration >= self.SILENCE_THRESHOLD * 2:
|
||||
# Stale buffer with no valid user — discard
|
||||
self._buffers.pop(ssrc, None)
|
||||
self._last_packet_time.pop(ssrc, None)
|
||||
|
||||
return completed
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# PCM -> WAV conversion (for Whisper STT)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def pcm_to_wav(pcm_data: bytes, output_path: str,
|
||||
src_rate: int = 48000, src_channels: int = 2):
|
||||
"""Convert raw PCM to 16kHz mono WAV via ffmpeg."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".pcm", delete=False) as f:
|
||||
f.write(pcm_data)
|
||||
pcm_path = f.name
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"ffmpeg", "-y", "-loglevel", "error",
|
||||
"-f", "s16le",
|
||||
"-ar", str(src_rate),
|
||||
"-ac", str(src_channels),
|
||||
"-i", pcm_path,
|
||||
"-ar", "16000",
|
||||
"-ac", "1",
|
||||
output_path,
|
||||
],
|
||||
check=True,
|
||||
timeout=10,
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
os.unlink(pcm_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
class DiscordAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Discord bot adapter.
|
||||
@@ -82,17 +381,54 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# Discord message limits
|
||||
MAX_MESSAGE_LENGTH = 2000
|
||||
|
||||
# Auto-disconnect from voice channel after this many seconds of inactivity
|
||||
VOICE_TIMEOUT = 300
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.DISCORD)
|
||||
self._client: Optional[commands.Bot] = None
|
||||
self._ready_event = asyncio.Event()
|
||||
self._allowed_user_ids: set = set() # For button approval authorization
|
||||
# Voice channel state (per-guild)
|
||||
self._voice_clients: Dict[int, Any] = {} # guild_id -> VoiceClient
|
||||
self._voice_text_channels: Dict[int, int] = {} # guild_id -> text_channel_id
|
||||
self._voice_timeout_tasks: Dict[int, asyncio.Task] = {} # guild_id -> timeout task
|
||||
# Phase 2: voice listening
|
||||
self._voice_receivers: Dict[int, VoiceReceiver] = {} # guild_id -> VoiceReceiver
|
||||
self._voice_listen_tasks: Dict[int, asyncio.Task] = {} # guild_id -> listen loop
|
||||
self._voice_input_callback: Optional[Callable] = None # set by run.py
|
||||
self._on_voice_disconnect: Optional[Callable] = None # set by run.py
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Discord and start receiving events."""
|
||||
if not DISCORD_AVAILABLE:
|
||||
logger.error("[%s] discord.py not installed. Run: pip install discord.py", self.name)
|
||||
return False
|
||||
|
||||
# Load opus codec for voice channel support
|
||||
if not discord.opus.is_loaded():
|
||||
import ctypes.util
|
||||
opus_path = ctypes.util.find_library("opus")
|
||||
# ctypes.util.find_library fails on macOS with Homebrew-installed libs,
|
||||
# so fall back to known Homebrew paths if needed.
|
||||
if not opus_path:
|
||||
import sys
|
||||
_homebrew_paths = (
|
||||
"/opt/homebrew/lib/libopus.dylib", # Apple Silicon
|
||||
"/usr/local/lib/libopus.dylib", # Intel Mac
|
||||
)
|
||||
if sys.platform == "darwin":
|
||||
for _hp in _homebrew_paths:
|
||||
if os.path.isfile(_hp):
|
||||
opus_path = _hp
|
||||
break
|
||||
if opus_path:
|
||||
try:
|
||||
discord.opus.load_opus(opus_path)
|
||||
except Exception:
|
||||
logger.warning("Opus codec found at %s but failed to load", opus_path)
|
||||
if not discord.opus.is_loaded():
|
||||
logger.warning("Opus codec not found — voice channel playback disabled")
|
||||
|
||||
if not self.config.token:
|
||||
logger.error("[%s] No bot token configured", self.name)
|
||||
@@ -105,6 +441,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
intents.dm_messages = True
|
||||
intents.guild_messages = True
|
||||
intents.members = True
|
||||
intents.voice_states = True
|
||||
|
||||
# Create bot
|
||||
self._client = commands.Bot(
|
||||
@@ -158,7 +495,40 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# "all" falls through to handle_message
|
||||
|
||||
await self._handle_message(message)
|
||||
|
||||
|
||||
@self._client.event
|
||||
async def on_voice_state_update(member, before, after):
|
||||
"""Track voice channel join/leave events."""
|
||||
# Only track channels where the bot is connected
|
||||
bot_guild_ids = set(adapter_self._voice_clients.keys())
|
||||
if not bot_guild_ids:
|
||||
return
|
||||
guild_id = member.guild.id
|
||||
if guild_id not in bot_guild_ids:
|
||||
return
|
||||
# Ignore the bot itself
|
||||
if member == adapter_self._client.user:
|
||||
return
|
||||
|
||||
joined = before.channel is None and after.channel is not None
|
||||
left = before.channel is not None and after.channel is None
|
||||
switched = (
|
||||
before.channel is not None
|
||||
and after.channel is not None
|
||||
and before.channel != after.channel
|
||||
)
|
||||
|
||||
if joined or left or switched:
|
||||
logger.info(
|
||||
"Voice state: %s (%d) %s (guild %d)",
|
||||
member.display_name,
|
||||
member.id,
|
||||
"joined " + after.channel.name if joined
|
||||
else "left " + before.channel.name if left
|
||||
else f"moved {before.channel.name} -> {after.channel.name}",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
# Register slash commands
|
||||
self._register_slash_commands()
|
||||
|
||||
@@ -180,12 +550,19 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Discord."""
|
||||
# Clean up all active voice connections before closing the client
|
||||
for guild_id in list(self._voice_clients.keys()):
|
||||
try:
|
||||
await self.leave_voice_channel(guild_id)
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.debug("[%s] Error leaving voice channel %s: %s", self.name, guild_id, e)
|
||||
|
||||
if self._client:
|
||||
try:
|
||||
await self._client.close()
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.warning("[%s] Error during disconnect: %s", self.name, e, exc_info=True)
|
||||
|
||||
|
||||
self._running = False
|
||||
self._client = None
|
||||
self._ready_event.clear()
|
||||
@@ -287,6 +664,23 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
msg = await channel.send(content=caption if caption else None, file=file)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
async def play_tts(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Play auto-TTS audio.
|
||||
|
||||
When the bot is in a voice channel for this chat's guild, skip the
|
||||
file attachment — the gateway runner plays audio in the VC instead.
|
||||
"""
|
||||
for gid, text_ch_id in self._voice_text_channels.items():
|
||||
if str(text_ch_id) == str(chat_id) and self.is_in_voice_channel(gid):
|
||||
logger.debug("[%s] Skipping play_tts for %s — VC playback handled by runner", self.name, chat_id)
|
||||
return SendResult(success=True)
|
||||
return await self.send_voice(chat_id=chat_id, audio_path=audio_path, **kwargs)
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -294,16 +688,356 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send audio as a Discord file attachment."""
|
||||
try:
|
||||
return await self._send_file_attachment(chat_id, audio_path, caption)
|
||||
except FileNotFoundError:
|
||||
return SendResult(success=False, error=f"Audio file not found: {audio_path}")
|
||||
import io
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
if not os.path.exists(audio_path):
|
||||
return SendResult(success=False, error=f"Audio file not found: {audio_path}")
|
||||
|
||||
filename = os.path.basename(audio_path)
|
||||
|
||||
with open(audio_path, "rb") as f:
|
||||
file_data = f.read()
|
||||
|
||||
# Try sending as a native voice message via raw API (flags=8192).
|
||||
try:
|
||||
import base64
|
||||
|
||||
duration_secs = 5.0
|
||||
try:
|
||||
from mutagen.oggopus import OggOpus
|
||||
info = OggOpus(audio_path)
|
||||
duration_secs = info.info.length
|
||||
except Exception:
|
||||
duration_secs = max(1.0, len(file_data) / 2000.0)
|
||||
|
||||
waveform_bytes = bytes([128] * 256)
|
||||
waveform_b64 = base64.b64encode(waveform_bytes).decode()
|
||||
|
||||
import json as _json
|
||||
payload = _json.dumps({
|
||||
"flags": 8192,
|
||||
"attachments": [{
|
||||
"id": "0",
|
||||
"filename": "voice-message.ogg",
|
||||
"duration_secs": round(duration_secs, 2),
|
||||
"waveform": waveform_b64,
|
||||
}],
|
||||
})
|
||||
form = [
|
||||
{"name": "payload_json", "value": payload},
|
||||
{
|
||||
"name": "files[0]",
|
||||
"value": file_data,
|
||||
"filename": "voice-message.ogg",
|
||||
"content_type": "audio/ogg",
|
||||
},
|
||||
]
|
||||
msg_data = await self._client.http.request(
|
||||
discord.http.Route("POST", "/channels/{channel_id}/messages", channel_id=channel.id),
|
||||
form=form,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg_data["id"]))
|
||||
except Exception as voice_err:
|
||||
logger.debug("Voice message flag failed, falling back to file: %s", voice_err)
|
||||
file = discord.File(io.BytesIO(file_data), filename=filename)
|
||||
msg = await channel.send(file=file)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.error("[%s] Failed to send audio, falling back to base adapter: %s", self.name, e, exc_info=True)
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to, metadata=metadata)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Voice channel methods (join / leave / play)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def join_voice_channel(self, channel) -> bool:
|
||||
"""Join a Discord voice channel. Returns True on success."""
|
||||
if not self._client or not DISCORD_AVAILABLE:
|
||||
return False
|
||||
guild_id = channel.guild.id
|
||||
|
||||
# Already connected in this guild?
|
||||
existing = self._voice_clients.get(guild_id)
|
||||
if existing and existing.is_connected():
|
||||
if existing.channel.id == channel.id:
|
||||
self._reset_voice_timeout(guild_id)
|
||||
return True
|
||||
await existing.move_to(channel)
|
||||
self._reset_voice_timeout(guild_id)
|
||||
return True
|
||||
|
||||
vc = await channel.connect()
|
||||
self._voice_clients[guild_id] = vc
|
||||
self._reset_voice_timeout(guild_id)
|
||||
|
||||
# Start voice receiver (Phase 2: listen to users)
|
||||
try:
|
||||
receiver = VoiceReceiver(vc)
|
||||
receiver.start()
|
||||
self._voice_receivers[guild_id] = receiver
|
||||
self._voice_listen_tasks[guild_id] = asyncio.ensure_future(
|
||||
self._voice_listen_loop(guild_id)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Voice receiver failed to start: %s", e)
|
||||
|
||||
return True
|
||||
|
||||
async def leave_voice_channel(self, guild_id: int) -> None:
|
||||
"""Disconnect from the voice channel in a guild."""
|
||||
# Stop voice receiver first
|
||||
receiver = self._voice_receivers.pop(guild_id, None)
|
||||
if receiver:
|
||||
receiver.stop()
|
||||
listen_task = self._voice_listen_tasks.pop(guild_id, None)
|
||||
if listen_task:
|
||||
listen_task.cancel()
|
||||
|
||||
vc = self._voice_clients.pop(guild_id, None)
|
||||
if vc and vc.is_connected():
|
||||
await vc.disconnect()
|
||||
task = self._voice_timeout_tasks.pop(guild_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
self._voice_text_channels.pop(guild_id, None)
|
||||
|
||||
# Maximum seconds to wait for voice playback before giving up
|
||||
PLAYBACK_TIMEOUT = 120
|
||||
|
||||
async def play_in_voice_channel(self, guild_id: int, audio_path: str) -> bool:
|
||||
"""Play an audio file in the connected voice channel."""
|
||||
vc = self._voice_clients.get(guild_id)
|
||||
if not vc or not vc.is_connected():
|
||||
return False
|
||||
|
||||
# Pause voice receiver while playing (echo prevention)
|
||||
receiver = self._voice_receivers.get(guild_id)
|
||||
if receiver:
|
||||
receiver.pause()
|
||||
|
||||
try:
|
||||
# Wait for current playback to finish (with timeout)
|
||||
wait_start = time.monotonic()
|
||||
while vc.is_playing():
|
||||
if time.monotonic() - wait_start > self.PLAYBACK_TIMEOUT:
|
||||
logger.warning("Timed out waiting for previous playback to finish")
|
||||
vc.stop()
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
done = asyncio.Event()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def _after(error):
|
||||
if error:
|
||||
logger.error("Voice playback error: %s", error)
|
||||
loop.call_soon_threadsafe(done.set)
|
||||
|
||||
source = discord.FFmpegPCMAudio(audio_path)
|
||||
source = discord.PCMVolumeTransformer(source, volume=1.0)
|
||||
vc.play(source, after=_after)
|
||||
try:
|
||||
await asyncio.wait_for(done.wait(), timeout=self.PLAYBACK_TIMEOUT)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Voice playback timed out after %ds", self.PLAYBACK_TIMEOUT)
|
||||
vc.stop()
|
||||
self._reset_voice_timeout(guild_id)
|
||||
return True
|
||||
finally:
|
||||
if receiver:
|
||||
receiver.resume()
|
||||
|
||||
async def get_user_voice_channel(self, guild_id: int, user_id: str):
|
||||
"""Return the voice channel the user is currently in, or None."""
|
||||
if not self._client:
|
||||
return None
|
||||
guild = self._client.get_guild(guild_id)
|
||||
if not guild:
|
||||
return None
|
||||
member = guild.get_member(int(user_id))
|
||||
if not member or not member.voice:
|
||||
return None
|
||||
return member.voice.channel
|
||||
|
||||
def _reset_voice_timeout(self, guild_id: int) -> None:
|
||||
"""Reset the auto-disconnect inactivity timer."""
|
||||
task = self._voice_timeout_tasks.pop(guild_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
self._voice_timeout_tasks[guild_id] = asyncio.ensure_future(
|
||||
self._voice_timeout_handler(guild_id)
|
||||
)
|
||||
|
||||
async def _voice_timeout_handler(self, guild_id: int) -> None:
|
||||
"""Auto-disconnect after VOICE_TIMEOUT seconds of inactivity."""
|
||||
try:
|
||||
await asyncio.sleep(self.VOICE_TIMEOUT)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
text_ch_id = self._voice_text_channels.get(guild_id)
|
||||
await self.leave_voice_channel(guild_id)
|
||||
# Notify the runner so it can clean up voice_mode state
|
||||
if self._on_voice_disconnect and text_ch_id:
|
||||
try:
|
||||
self._on_voice_disconnect(str(text_ch_id))
|
||||
except Exception:
|
||||
pass
|
||||
if text_ch_id and self._client:
|
||||
ch = self._client.get_channel(text_ch_id)
|
||||
if ch:
|
||||
try:
|
||||
await ch.send("Left voice channel (inactivity timeout).")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def is_in_voice_channel(self, guild_id: int) -> bool:
|
||||
"""Check if the bot is connected to a voice channel in this guild."""
|
||||
vc = self._voice_clients.get(guild_id)
|
||||
return vc is not None and vc.is_connected()
|
||||
|
||||
def get_voice_channel_info(self, guild_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""Return voice channel awareness info for the given guild.
|
||||
|
||||
Returns None if the bot is not in a voice channel. Otherwise
|
||||
returns a dict with channel name, member list, count, and
|
||||
currently-speaking user IDs (from SSRC mapping).
|
||||
"""
|
||||
vc = self._voice_clients.get(guild_id)
|
||||
if not vc or not vc.is_connected():
|
||||
return None
|
||||
|
||||
channel = vc.channel
|
||||
if not channel:
|
||||
return None
|
||||
|
||||
# Members currently in the voice channel (includes bot)
|
||||
members_info = []
|
||||
bot_user = self._client.user if self._client else None
|
||||
for m in channel.members:
|
||||
if bot_user and m.id == bot_user.id:
|
||||
continue # skip the bot itself
|
||||
members_info.append({
|
||||
"user_id": m.id,
|
||||
"display_name": m.display_name,
|
||||
"is_bot": m.bot,
|
||||
})
|
||||
|
||||
# Currently speaking users (from SSRC mapping + active buffers)
|
||||
speaking_user_ids: set = set()
|
||||
receiver = self._voice_receivers.get(guild_id)
|
||||
if receiver:
|
||||
import time as _time
|
||||
now = _time.monotonic()
|
||||
with receiver._lock:
|
||||
for ssrc, last_t in receiver._last_packet_time.items():
|
||||
# Consider "speaking" if audio received within last 2 seconds
|
||||
if now - last_t < 2.0:
|
||||
uid = receiver._ssrc_to_user.get(ssrc)
|
||||
if uid:
|
||||
speaking_user_ids.add(uid)
|
||||
|
||||
# Tag speaking status on members
|
||||
for info in members_info:
|
||||
info["is_speaking"] = info["user_id"] in speaking_user_ids
|
||||
|
||||
return {
|
||||
"channel_name": channel.name,
|
||||
"member_count": len(members_info),
|
||||
"members": members_info,
|
||||
"speaking_count": len(speaking_user_ids),
|
||||
}
|
||||
|
||||
def get_voice_channel_context(self, guild_id: int) -> str:
|
||||
"""Return a human-readable voice channel context string.
|
||||
|
||||
Suitable for injection into the system/ephemeral prompt so the
|
||||
agent is always aware of voice channel state.
|
||||
"""
|
||||
info = self.get_voice_channel_info(guild_id)
|
||||
if not info:
|
||||
return ""
|
||||
|
||||
parts = [f"[Voice channel: #{info['channel_name']} — {info['member_count']} participant(s)]"]
|
||||
for m in info["members"]:
|
||||
status = " (speaking)" if m["is_speaking"] else ""
|
||||
parts.append(f" - {m['display_name']}{status}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Voice listening (Phase 2)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _voice_listen_loop(self, guild_id: int):
|
||||
"""Periodically check for completed utterances and process them."""
|
||||
receiver = self._voice_receivers.get(guild_id)
|
||||
if not receiver:
|
||||
return
|
||||
try:
|
||||
while receiver._running:
|
||||
await asyncio.sleep(0.2)
|
||||
completed = receiver.check_silence()
|
||||
for user_id, pcm_data in completed:
|
||||
if not self._is_allowed_user(str(user_id)):
|
||||
continue
|
||||
await self._process_voice_input(guild_id, user_id, pcm_data)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error("Voice listen loop error: %s", e, exc_info=True)
|
||||
|
||||
async def _process_voice_input(self, guild_id: int, user_id: int, pcm_data: bytes):
|
||||
"""Convert PCM -> WAV -> STT -> callback."""
|
||||
from tools.voice_mode import is_whisper_hallucination
|
||||
|
||||
tmp_f = tempfile.NamedTemporaryFile(suffix=".wav", prefix="vc_listen_", delete=False)
|
||||
wav_path = tmp_f.name
|
||||
tmp_f.close()
|
||||
try:
|
||||
await asyncio.to_thread(VoiceReceiver.pcm_to_wav, pcm_data, wav_path)
|
||||
|
||||
from tools.transcription_tools import transcribe_audio, get_stt_model_from_config
|
||||
stt_model = get_stt_model_from_config()
|
||||
result = await asyncio.to_thread(transcribe_audio, wav_path, model=stt_model)
|
||||
|
||||
if not result.get("success"):
|
||||
return
|
||||
transcript = result.get("transcript", "").strip()
|
||||
if not transcript or is_whisper_hallucination(transcript):
|
||||
return
|
||||
|
||||
logger.info("Voice input from user %d: %s", user_id, transcript[:100])
|
||||
|
||||
if self._voice_input_callback:
|
||||
await self._voice_input_callback(
|
||||
guild_id=guild_id,
|
||||
user_id=user_id,
|
||||
transcript=transcript,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Voice input processing failed: %s", e, exc_info=True)
|
||||
finally:
|
||||
try:
|
||||
os.unlink(wav_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _is_allowed_user(self, user_id: str) -> bool:
|
||||
"""Check if user is in DISCORD_ALLOWED_USERS."""
|
||||
if not self._allowed_user_ids:
|
||||
return True
|
||||
return user_id in self._allowed_user_ids
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -627,6 +1361,25 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
async def slash_reload_mcp(interaction: discord.Interaction):
|
||||
await self._run_simple_slash(interaction, "/reload-mcp")
|
||||
|
||||
@tree.command(name="voice", description="Toggle voice reply mode")
|
||||
@discord.app_commands.describe(mode="Voice mode: on, off, tts, channel, leave, or status")
|
||||
@discord.app_commands.choices(mode=[
|
||||
discord.app_commands.Choice(name="channel — join your voice channel", value="channel"),
|
||||
discord.app_commands.Choice(name="leave — leave voice channel", value="leave"),
|
||||
discord.app_commands.Choice(name="on — voice reply to voice messages", value="on"),
|
||||
discord.app_commands.Choice(name="tts — voice reply to all messages", value="tts"),
|
||||
discord.app_commands.Choice(name="off — text only", value="off"),
|
||||
discord.app_commands.Choice(name="status — show current mode", value="status"),
|
||||
])
|
||||
async def slash_voice(interaction: discord.Interaction, mode: str = ""):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, f"/voice {mode}".strip())
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="update", description="Update Hermes Agent to the latest version")
|
||||
async def slash_update(interaction: discord.Interaction):
|
||||
await self._run_simple_slash(interaction, "/update", "Update initiated~")
|
||||
|
||||
@@ -506,6 +506,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send an audio file to Slack."""
|
||||
try:
|
||||
|
||||
@@ -150,7 +150,10 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
# Start polling in background
|
||||
await self._app.initialize()
|
||||
await self._app.start()
|
||||
await self._app.updater.start_polling(allowed_updates=Update.ALL_TYPES)
|
||||
await self._app.updater.start_polling(
|
||||
allowed_updates=Update.ALL_TYPES,
|
||||
drop_pending_updates=True,
|
||||
)
|
||||
|
||||
# Register bot commands so Telegram shows a hint menu when users type /
|
||||
try:
|
||||
@@ -174,6 +177,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
BotCommand("insights", "Show usage insights and analytics"),
|
||||
BotCommand("update", "Update Hermes to the latest version"),
|
||||
BotCommand("reload_mcp", "Reload MCP servers from config"),
|
||||
BotCommand("voice", "Toggle voice reply mode"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
])
|
||||
except Exception as e:
|
||||
@@ -307,6 +311,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send audio as a native Telegram voice message or audio file."""
|
||||
if not self._bot:
|
||||
|
||||
400
gateway/run.py
400
gateway/run.py
@@ -14,13 +14,16 @@ Usage:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import sys
|
||||
import signal
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
@@ -281,6 +284,9 @@ class GatewayRunner:
|
||||
from gateway.hooks import HookRegistry
|
||||
self.hooks = HookRegistry()
|
||||
|
||||
# Per-chat voice reply mode: "off" | "voice_only" | "all"
|
||||
self._voice_mode: Dict[str, str] = self._load_voice_modes()
|
||||
|
||||
def _get_or_create_gateway_honcho(self, session_key: str):
|
||||
"""Return a persistent Honcho manager/config pair for this gateway session."""
|
||||
if not hasattr(self, "_honcho_managers"):
|
||||
@@ -336,6 +342,27 @@ class GatewayRunner:
|
||||
for session_key in list(managers.keys()):
|
||||
self._shutdown_gateway_honcho(session_key)
|
||||
|
||||
# -- Voice mode persistence ------------------------------------------
|
||||
|
||||
_VOICE_MODE_PATH = _hermes_home / "gateway_voice_mode.json"
|
||||
|
||||
def _load_voice_modes(self) -> Dict[str, str]:
|
||||
try:
|
||||
return json.loads(self._VOICE_MODE_PATH.read_text())
|
||||
except (FileNotFoundError, json.JSONDecodeError, OSError):
|
||||
return {}
|
||||
|
||||
def _save_voice_modes(self) -> None:
|
||||
try:
|
||||
self._VOICE_MODE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._VOICE_MODE_PATH.write_text(
|
||||
json.dumps(self._voice_mode, indent=2)
|
||||
)
|
||||
except OSError as e:
|
||||
logger.warning("Failed to save voice modes: %s", e)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
def _flush_memories_for_session(self, old_session_id: str):
|
||||
"""Prompt the agent to save memories/skills before context is lost.
|
||||
|
||||
@@ -737,7 +764,7 @@ class GatewayRunner:
|
||||
logger.info("Stopping gateway...")
|
||||
self._running = False
|
||||
|
||||
for platform, adapter in self.adapters.items():
|
||||
for platform, adapter in list(self.adapters.items()):
|
||||
try:
|
||||
await adapter.disconnect()
|
||||
logger.info("✓ %s disconnected", platform.value)
|
||||
@@ -897,7 +924,7 @@ class GatewayRunner:
|
||||
7. Return response
|
||||
"""
|
||||
source = event.source
|
||||
|
||||
|
||||
# Check if user is authorized
|
||||
if not self._is_user_authorized(source):
|
||||
logger.warning("Unauthorized user: %s (%s) on %s", source.user_id, source.user_name, source.platform.value)
|
||||
@@ -949,7 +976,7 @@ class GatewayRunner:
|
||||
"personality", "retry", "undo", "sethome", "set-home",
|
||||
"compress", "usage", "insights", "reload-mcp", "reload_mcp",
|
||||
"update", "title", "resume", "provider", "rollback",
|
||||
"background", "reasoning"}
|
||||
"background", "reasoning", "voice"}
|
||||
if command and command in _known_commands:
|
||||
await self.hooks.emit(f"command:{command}", {
|
||||
"platform": source.platform.value if source.platform else "",
|
||||
@@ -1020,7 +1047,10 @@ class GatewayRunner:
|
||||
|
||||
if command == "reasoning":
|
||||
return await self._handle_reasoning_command(event)
|
||||
|
||||
|
||||
if command == "voice":
|
||||
return await self._handle_voice_command(event)
|
||||
|
||||
# User-defined quick commands (bypass agent loop, no LLM call)
|
||||
if command:
|
||||
if isinstance(self.config, dict):
|
||||
@@ -1377,6 +1407,19 @@ class GatewayRunner:
|
||||
f"or ignore to skip."
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Voice channel awareness — inject current voice channel state
|
||||
# into context so the agent knows who is in the channel and who
|
||||
# is speaking, without needing a separate tool call.
|
||||
# -----------------------------------------------------------------
|
||||
if source.platform == Platform.DISCORD:
|
||||
adapter = self.adapters.get(Platform.DISCORD)
|
||||
guild_id = self._get_guild_id(event)
|
||||
if guild_id and adapter and hasattr(adapter, "get_voice_channel_context"):
|
||||
vc_context = adapter.get_voice_channel_context(guild_id)
|
||||
if vc_context:
|
||||
context_prompt += f"\n\n{vc_context}"
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Auto-analyze images sent by the user
|
||||
#
|
||||
@@ -1583,7 +1626,11 @@ class GatewayRunner:
|
||||
session_entry.session_key,
|
||||
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
|
||||
)
|
||||
|
||||
|
||||
# Auto voice reply: send TTS audio before the text response
|
||||
if self._should_send_voice_reply(event, response, agent_messages):
|
||||
await self._send_voice_reply(event, response)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -1692,6 +1739,7 @@ class GatewayRunner:
|
||||
"`/reasoning [level|show|hide]` — Set reasoning effort or toggle display",
|
||||
"`/rollback [number]` — List or restore filesystem checkpoints",
|
||||
"`/background <prompt>` — Run a prompt in a separate background session",
|
||||
"`/voice [on|off|tts|status]` — Toggle voice reply mode",
|
||||
"`/reload-mcp` — Reload MCP servers from config",
|
||||
"`/update` — Update Hermes Agent to the latest version",
|
||||
"`/help` — Show this message",
|
||||
@@ -2067,6 +2115,334 @@ class GatewayRunner:
|
||||
f"Cron jobs and cross-platform messages will be delivered here."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_guild_id(event: MessageEvent) -> Optional[int]:
|
||||
"""Extract Discord guild_id from the raw message object."""
|
||||
raw = getattr(event, "raw_message", None)
|
||||
if raw is None:
|
||||
return None
|
||||
# Slash command interaction
|
||||
if hasattr(raw, "guild_id") and raw.guild_id:
|
||||
return int(raw.guild_id)
|
||||
# Regular message
|
||||
if hasattr(raw, "guild") and raw.guild:
|
||||
return raw.guild.id
|
||||
return None
|
||||
|
||||
async def _handle_voice_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /voice [on|off|tts|channel|leave|status] command."""
|
||||
args = event.get_command_args().strip().lower()
|
||||
chat_id = event.source.chat_id
|
||||
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
|
||||
if args in ("on", "enable"):
|
||||
self._voice_mode[chat_id] = "voice_only"
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
adapter._auto_tts_disabled_chats.discard(chat_id)
|
||||
return (
|
||||
"Voice mode enabled.\n"
|
||||
"I'll reply with voice when you send voice messages.\n"
|
||||
"Use /voice tts to get voice replies for all messages."
|
||||
)
|
||||
elif args in ("off", "disable"):
|
||||
self._voice_mode.pop(chat_id, None)
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
adapter._auto_tts_disabled_chats.add(chat_id)
|
||||
return "Voice mode disabled. Text-only replies."
|
||||
elif args == "tts":
|
||||
self._voice_mode[chat_id] = "all"
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
adapter._auto_tts_disabled_chats.discard(chat_id)
|
||||
return (
|
||||
"Auto-TTS enabled.\n"
|
||||
"All replies will include a voice message."
|
||||
)
|
||||
elif args in ("channel", "join"):
|
||||
return await self._handle_voice_channel_join(event)
|
||||
elif args == "leave":
|
||||
return await self._handle_voice_channel_leave(event)
|
||||
elif args == "status":
|
||||
mode = self._voice_mode.get(chat_id, "off")
|
||||
labels = {
|
||||
"off": "Off (text only)",
|
||||
"voice_only": "On (voice reply to voice messages)",
|
||||
"all": "TTS (voice reply to all messages)",
|
||||
}
|
||||
# Append voice channel info if connected
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
guild_id = self._get_guild_id(event)
|
||||
if guild_id and hasattr(adapter, "get_voice_channel_info"):
|
||||
info = adapter.get_voice_channel_info(guild_id)
|
||||
if info:
|
||||
lines = [
|
||||
f"Voice mode: {labels.get(mode, mode)}",
|
||||
f"Voice channel: #{info['channel_name']}",
|
||||
f"Participants: {info['member_count']}",
|
||||
]
|
||||
for m in info["members"]:
|
||||
status = " (speaking)" if m.get("is_speaking") else ""
|
||||
lines.append(f" - {m['display_name']}{status}")
|
||||
return "\n".join(lines)
|
||||
return f"Voice mode: {labels.get(mode, mode)}"
|
||||
else:
|
||||
# Toggle: off → on, on/all → off
|
||||
current = self._voice_mode.get(chat_id, "off")
|
||||
if current == "off":
|
||||
self._voice_mode[chat_id] = "voice_only"
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
adapter._auto_tts_disabled_chats.discard(chat_id)
|
||||
return "Voice mode enabled."
|
||||
else:
|
||||
self._voice_mode.pop(chat_id, None)
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
adapter._auto_tts_disabled_chats.add(chat_id)
|
||||
return "Voice mode disabled."
|
||||
|
||||
async def _handle_voice_channel_join(self, event: MessageEvent) -> str:
|
||||
"""Join the user's current Discord voice channel."""
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
if not hasattr(adapter, "join_voice_channel"):
|
||||
return "Voice channels are not supported on this platform."
|
||||
|
||||
guild_id = self._get_guild_id(event)
|
||||
if not guild_id:
|
||||
return "This command only works in a Discord server."
|
||||
|
||||
voice_channel = await adapter.get_user_voice_channel(
|
||||
guild_id, event.source.user_id
|
||||
)
|
||||
if not voice_channel:
|
||||
return "You need to be in a voice channel first."
|
||||
|
||||
# Wire callbacks BEFORE join so voice input arriving immediately
|
||||
# after connection is not lost.
|
||||
if hasattr(adapter, "_voice_input_callback"):
|
||||
adapter._voice_input_callback = self._handle_voice_channel_input
|
||||
if hasattr(adapter, "_on_voice_disconnect"):
|
||||
adapter._on_voice_disconnect = self._handle_voice_timeout_cleanup
|
||||
|
||||
try:
|
||||
success = await adapter.join_voice_channel(voice_channel)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to join voice channel: %s", e)
|
||||
adapter._voice_input_callback = None
|
||||
return f"Failed to join voice channel: {e}"
|
||||
|
||||
if success:
|
||||
adapter._voice_text_channels[guild_id] = int(event.source.chat_id)
|
||||
self._voice_mode[event.source.chat_id] = "all"
|
||||
self._save_voice_modes()
|
||||
adapter._auto_tts_disabled_chats.discard(event.source.chat_id)
|
||||
return (
|
||||
f"Joined voice channel **{voice_channel.name}**.\n"
|
||||
f"I'll speak my replies and listen to you. Use /voice leave to disconnect."
|
||||
)
|
||||
# Join failed — clear callback
|
||||
adapter._voice_input_callback = None
|
||||
return "Failed to join voice channel. Check bot permissions (Connect + Speak)."
|
||||
|
||||
async def _handle_voice_channel_leave(self, event: MessageEvent) -> str:
|
||||
"""Leave the Discord voice channel."""
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
guild_id = self._get_guild_id(event)
|
||||
|
||||
if not guild_id or not hasattr(adapter, "leave_voice_channel"):
|
||||
return "Not in a voice channel."
|
||||
|
||||
if not hasattr(adapter, "is_in_voice_channel") or not adapter.is_in_voice_channel(guild_id):
|
||||
return "Not in a voice channel."
|
||||
|
||||
try:
|
||||
await adapter.leave_voice_channel(guild_id)
|
||||
except Exception as e:
|
||||
logger.warning("Error leaving voice channel: %s", e)
|
||||
# Always clean up state even if leave raised an exception
|
||||
self._voice_mode.pop(event.source.chat_id, None)
|
||||
self._save_voice_modes()
|
||||
if hasattr(adapter, "_voice_input_callback"):
|
||||
adapter._voice_input_callback = None
|
||||
return "Left voice channel."
|
||||
|
||||
def _handle_voice_timeout_cleanup(self, chat_id: str) -> None:
|
||||
"""Called by the adapter when a voice channel times out.
|
||||
|
||||
Cleans up runner-side voice_mode state that the adapter cannot reach.
|
||||
"""
|
||||
self._voice_mode.pop(chat_id, None)
|
||||
self._save_voice_modes()
|
||||
|
||||
async def _handle_voice_channel_input(
|
||||
self, guild_id: int, user_id: int, transcript: str
|
||||
):
|
||||
"""Handle transcribed voice from a user in a voice channel.
|
||||
|
||||
Creates a synthetic MessageEvent and processes it through the
|
||||
adapter's full message pipeline (session, typing, agent, TTS reply).
|
||||
"""
|
||||
adapter = self.adapters.get(Platform.DISCORD)
|
||||
if not adapter:
|
||||
return
|
||||
|
||||
text_ch_id = adapter._voice_text_channels.get(guild_id)
|
||||
if not text_ch_id:
|
||||
return
|
||||
|
||||
# Check authorization before processing voice input
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id=str(text_ch_id),
|
||||
user_id=str(user_id),
|
||||
user_name=str(user_id),
|
||||
chat_type="channel",
|
||||
)
|
||||
if not self._is_user_authorized(source):
|
||||
logger.debug("Unauthorized voice input from user %d, ignoring", user_id)
|
||||
return
|
||||
|
||||
# Show transcript in text channel (after auth, with mention sanitization)
|
||||
try:
|
||||
channel = adapter._client.get_channel(text_ch_id)
|
||||
if channel:
|
||||
safe_text = transcript[:2000].replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere")
|
||||
await channel.send(f"**[Voice]** <@{user_id}>: {safe_text}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Build a synthetic MessageEvent and feed through the normal pipeline
|
||||
# Use SimpleNamespace as raw_message so _get_guild_id() can extract
|
||||
# guild_id and _send_voice_reply() plays audio in the voice channel.
|
||||
from types import SimpleNamespace
|
||||
event = MessageEvent(
|
||||
source=source,
|
||||
text=transcript,
|
||||
message_type=MessageType.VOICE,
|
||||
raw_message=SimpleNamespace(guild_id=guild_id, guild=None),
|
||||
)
|
||||
|
||||
await adapter.handle_message(event)
|
||||
|
||||
def _should_send_voice_reply(
|
||||
self,
|
||||
event: MessageEvent,
|
||||
response: str,
|
||||
agent_messages: list,
|
||||
) -> bool:
|
||||
"""Decide whether the runner should send a TTS voice reply.
|
||||
|
||||
Returns False when:
|
||||
- voice_mode is off for this chat
|
||||
- response is empty or an error
|
||||
- agent already called text_to_speech tool (dedup)
|
||||
- voice input and base adapter auto-TTS already handled it (skip_double)
|
||||
Exception: Discord voice channel — base play_tts is a no-op there,
|
||||
so the runner must handle VC playback.
|
||||
"""
|
||||
if not response or response.startswith("Error:"):
|
||||
return False
|
||||
|
||||
chat_id = event.source.chat_id
|
||||
voice_mode = self._voice_mode.get(chat_id, "off")
|
||||
is_voice_input = (event.message_type == MessageType.VOICE)
|
||||
|
||||
should = (
|
||||
(voice_mode == "all")
|
||||
or (voice_mode == "voice_only" and is_voice_input)
|
||||
)
|
||||
if not should:
|
||||
return False
|
||||
|
||||
# Dedup: agent already called TTS tool
|
||||
has_agent_tts = any(
|
||||
msg.get("role") == "assistant"
|
||||
and any(
|
||||
tc.get("function", {}).get("name") == "text_to_speech"
|
||||
for tc in (msg.get("tool_calls") or [])
|
||||
)
|
||||
for msg in agent_messages
|
||||
)
|
||||
if has_agent_tts:
|
||||
return False
|
||||
|
||||
# Dedup: base adapter auto-TTS already handles voice input.
|
||||
# Exception: Discord voice channel — play_tts override is a no-op,
|
||||
# so the runner must handle VC playback.
|
||||
skip_double = is_voice_input
|
||||
if skip_double:
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
guild_id = self._get_guild_id(event)
|
||||
if (guild_id and adapter
|
||||
and hasattr(adapter, "is_in_voice_channel")
|
||||
and adapter.is_in_voice_channel(guild_id)):
|
||||
skip_double = False
|
||||
if skip_double:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _send_voice_reply(self, event: MessageEvent, text: str) -> None:
|
||||
"""Generate TTS audio and send as a voice message before the text reply."""
|
||||
import uuid as _uuid
|
||||
audio_path = None
|
||||
actual_path = None
|
||||
try:
|
||||
from tools.tts_tool import text_to_speech_tool, _strip_markdown_for_tts
|
||||
|
||||
tts_text = _strip_markdown_for_tts(text[:4000])
|
||||
if not tts_text:
|
||||
return
|
||||
|
||||
# Use .mp3 extension so edge-tts conversion to opus works correctly.
|
||||
# The TTS tool may convert to .ogg — use file_path from result.
|
||||
audio_path = os.path.join(
|
||||
tempfile.gettempdir(), "hermes_voice",
|
||||
f"tts_reply_{_uuid.uuid4().hex[:12]}.mp3",
|
||||
)
|
||||
os.makedirs(os.path.dirname(audio_path), exist_ok=True)
|
||||
|
||||
result_json = await asyncio.to_thread(
|
||||
text_to_speech_tool, text=tts_text, output_path=audio_path
|
||||
)
|
||||
result = json.loads(result_json)
|
||||
|
||||
# Use the actual file path from result (may differ after opus conversion)
|
||||
actual_path = result.get("file_path", audio_path)
|
||||
if not result.get("success") or not os.path.isfile(actual_path):
|
||||
logger.warning("Auto voice reply TTS failed: %s", result.get("error"))
|
||||
return
|
||||
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
|
||||
# If connected to a voice channel, play there instead of sending a file
|
||||
guild_id = self._get_guild_id(event)
|
||||
if (guild_id
|
||||
and hasattr(adapter, "play_in_voice_channel")
|
||||
and hasattr(adapter, "is_in_voice_channel")
|
||||
and adapter.is_in_voice_channel(guild_id)):
|
||||
await adapter.play_in_voice_channel(guild_id, actual_path)
|
||||
elif adapter and hasattr(adapter, "send_voice"):
|
||||
send_kwargs: Dict[str, Any] = {
|
||||
"chat_id": event.source.chat_id,
|
||||
"audio_path": actual_path,
|
||||
"reply_to": event.message_id,
|
||||
}
|
||||
if event.source.thread_id:
|
||||
send_kwargs["metadata"] = {"thread_id": event.source.thread_id}
|
||||
await adapter.send_voice(**send_kwargs)
|
||||
except Exception as e:
|
||||
logger.warning("Auto voice reply failed: %s", e, exc_info=True)
|
||||
finally:
|
||||
for p in {audio_path, actual_path} - {None}:
|
||||
try:
|
||||
os.unlink(p)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
async def _handle_rollback_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /rollback command — list or restore filesystem checkpoints."""
|
||||
from tools.checkpoint_manager import CheckpointManager, format_checkpoint_list
|
||||
@@ -3011,14 +3387,16 @@ class GatewayRunner:
|
||||
Returns:
|
||||
The enriched message string with transcriptions prepended.
|
||||
"""
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
from tools.transcription_tools import transcribe_audio, get_stt_model_from_config
|
||||
import asyncio
|
||||
|
||||
stt_model = get_stt_model_from_config()
|
||||
|
||||
enriched_parts = []
|
||||
for path in audio_paths:
|
||||
try:
|
||||
logger.debug("Transcribing user voice: %s", path)
|
||||
result = await asyncio.to_thread(transcribe_audio, path)
|
||||
result = await asyncio.to_thread(transcribe_audio, path, model=stt_model)
|
||||
if result["success"]:
|
||||
transcript = result["transcript"]
|
||||
enriched_parts.append(
|
||||
@@ -3027,10 +3405,10 @@ class GatewayRunner:
|
||||
)
|
||||
else:
|
||||
error = result.get("error", "unknown error")
|
||||
if "OPENAI_API_KEY" in error or "VOICE_TOOLS_OPENAI_KEY" in error:
|
||||
if "No STT provider" in error or "not set" in error:
|
||||
enriched_parts.append(
|
||||
"[The user sent a voice message but I can't listen "
|
||||
"to it right now~ VOICE_TOOLS_OPENAI_KEY isn't set up yet "
|
||||
"to it right now~ No STT provider is configured "
|
||||
"(';w;') Let them know!]"
|
||||
)
|
||||
else:
|
||||
@@ -3180,7 +3558,7 @@ class GatewayRunner:
|
||||
Platform.HOMEASSISTANT: "hermes-homeassistant",
|
||||
Platform.EMAIL: "hermes-email",
|
||||
}
|
||||
|
||||
|
||||
# Try to load platform_toolsets from config
|
||||
platform_toolsets_config = {}
|
||||
try:
|
||||
@@ -3192,7 +3570,7 @@ class GatewayRunner:
|
||||
platform_toolsets_config = user_config.get("platform_toolsets", {})
|
||||
except Exception as e:
|
||||
logger.debug("Could not load platform_toolsets config: %s", e)
|
||||
|
||||
|
||||
# Map platform enum to config key
|
||||
platform_config_key = {
|
||||
Platform.LOCAL: "cli",
|
||||
|
||||
@@ -383,7 +383,11 @@ class SessionStore:
|
||||
with open(sessions_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
for key, entry_data in data.items():
|
||||
self._entries[key] = SessionEntry.from_dict(entry_data)
|
||||
try:
|
||||
self._entries[key] = SessionEntry.from_dict(entry_data)
|
||||
except (ValueError, KeyError):
|
||||
# Skip entries with unknown/removed platform values
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: Failed to load sessions: {e}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user