From 7b10881b9e2ae7b6f52d39666a25521f15ef0711 Mon Sep 17 00:00:00 2001 From: teknium1 Date: Sat, 14 Mar 2026 06:14:22 -0700 Subject: [PATCH] fix: persist clean voice transcripts and /voice off state - keep CLI voice prefixes API-local while storing the original user text - persist explicit gateway off state and restore adapter auto-TTS suppression on restart - add regression coverage for both behaviors --- cli.py | 13 ++--- gateway/run.py | 56 ++++++++++++++++---- run_agent.py | 37 ++++++++++++- tests/gateway/test_voice_command.py | 80 ++++++++++++++++++++++++++--- tests/test_run_agent.py | 35 +++++++++++++ 5 files changed, 192 insertions(+), 29 deletions(-) diff --git a/cli.py b/cli.py index 2e7ffd51a..7bd455bd0 100755 --- a/cli.py +++ b/cli.py @@ -4218,9 +4218,8 @@ class HermesCLI: text_queue.put(delta) # When voice mode is active, prepend a brief instruction so the - # model responds concisely. The prefix is API-call-local only — - # we strip it from the returned history so it never persists to - # session DB or resumed sessions. + # model responds concisely. The prefix is API-call-local only — + # run_conversation persists the original clean user message. _voice_prefix = "" if self._voice_mode and isinstance(message, str): _voice_prefix = ( @@ -4236,6 +4235,7 @@ class HermesCLI: conversation_history=self.conversation_history[:-1], # Exclude the message we just added stream_callback=stream_callback, task_id=self.session_id, + persist_user_message=message if _voice_prefix else None, ) # Start agent in background thread @@ -4302,13 +4302,6 @@ class HermesCLI: # Update history with full conversation self.conversation_history = result.get("messages", self.conversation_history) if result else self.conversation_history - # Strip voice prefix from history so it never persists - if _voice_prefix and self.conversation_history: - for msg in self.conversation_history: - if msg.get("role") == "user" and isinstance(msg.get("content"), str): - if msg["content"].startswith(_voice_prefix): - msg["content"] = msg["content"][len(_voice_prefix):] - # Get the final response response = result.get("final_response", "") if result else "" diff --git a/gateway/run.py b/gateway/run.py index fecf4cef8..6795610a8 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -348,10 +348,20 @@ class GatewayRunner: def _load_voice_modes(self) -> Dict[str, str]: try: - return json.loads(self._VOICE_MODE_PATH.read_text()) + data = json.loads(self._VOICE_MODE_PATH.read_text()) except (FileNotFoundError, json.JSONDecodeError, OSError): return {} + if not isinstance(data, dict): + return {} + + valid_modes = {"off", "voice_only", "all"} + return { + str(chat_id): mode + for chat_id, mode in data.items() + if mode in valid_modes + } + def _save_voice_modes(self) -> None: try: self._VOICE_MODE_PATH.parent.mkdir(parents=True, exist_ok=True) @@ -361,6 +371,26 @@ class GatewayRunner: except OSError as e: logger.warning("Failed to save voice modes: %s", e) + def _set_adapter_auto_tts_disabled(self, adapter, chat_id: str, disabled: bool) -> None: + """Update an adapter's in-memory auto-TTS suppression set if present.""" + disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None) + if not isinstance(disabled_chats, set): + return + if disabled: + disabled_chats.add(chat_id) + else: + disabled_chats.discard(chat_id) + + def _sync_voice_mode_state_to_adapter(self, adapter) -> None: + """Restore persisted /voice off state into a live platform adapter.""" + disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None) + if not isinstance(disabled_chats, set): + return + disabled_chats.clear() + disabled_chats.update( + chat_id for chat_id, mode in self._voice_mode.items() if mode == "off" + ) + # ----------------------------------------------------------------- def _flush_memories_for_session(self, old_session_id: str): @@ -666,6 +696,7 @@ class GatewayRunner: success = await adapter.connect() if success: self.adapters[platform] = adapter + self._sync_voice_mode_state_to_adapter(adapter) connected_count += 1 logger.info("✓ %s connected", platform.value) else: @@ -2140,23 +2171,23 @@ class GatewayRunner: self._voice_mode[chat_id] = "voice_only" self._save_voice_modes() if adapter: - adapter._auto_tts_disabled_chats.discard(chat_id) + self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False) return ( "Voice mode enabled.\n" "I'll reply with voice when you send voice messages.\n" "Use /voice tts to get voice replies for all messages." ) elif args in ("off", "disable"): - self._voice_mode.pop(chat_id, None) + self._voice_mode[chat_id] = "off" self._save_voice_modes() if adapter: - adapter._auto_tts_disabled_chats.add(chat_id) + self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True) return "Voice mode disabled. Text-only replies." elif args == "tts": self._voice_mode[chat_id] = "all" self._save_voice_modes() if adapter: - adapter._auto_tts_disabled_chats.discard(chat_id) + self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False) return ( "Auto-TTS enabled.\n" "All replies will include a voice message." @@ -2195,13 +2226,13 @@ class GatewayRunner: self._voice_mode[chat_id] = "voice_only" self._save_voice_modes() if adapter: - adapter._auto_tts_disabled_chats.discard(chat_id) + self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False) return "Voice mode enabled." else: - self._voice_mode.pop(chat_id, None) + self._voice_mode[chat_id] = "off" self._save_voice_modes() if adapter: - adapter._auto_tts_disabled_chats.add(chat_id) + self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True) return "Voice mode disabled." async def _handle_voice_channel_join(self, event: MessageEvent) -> str: @@ -2238,7 +2269,7 @@ class GatewayRunner: adapter._voice_text_channels[guild_id] = int(event.source.chat_id) self._voice_mode[event.source.chat_id] = "all" self._save_voice_modes() - adapter._auto_tts_disabled_chats.discard(event.source.chat_id) + self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=False) return ( f"Joined voice channel **{voice_channel.name}**.\n" f"I'll speak my replies and listen to you. Use /voice leave to disconnect." @@ -2263,8 +2294,9 @@ class GatewayRunner: except Exception as e: logger.warning("Error leaving voice channel: %s", e) # Always clean up state even if leave raised an exception - self._voice_mode.pop(event.source.chat_id, None) + self._voice_mode[event.source.chat_id] = "off" self._save_voice_modes() + self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=True) if hasattr(adapter, "_voice_input_callback"): adapter._voice_input_callback = None return "Left voice channel." @@ -2274,8 +2306,10 @@ class GatewayRunner: Cleans up runner-side voice_mode state that the adapter cannot reach. """ - self._voice_mode.pop(chat_id, None) + self._voice_mode[chat_id] = "off" self._save_voice_modes() + adapter = self.adapters.get(Platform.DISCORD) + self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True) async def _handle_voice_channel_input( self, guild_id: int, user_id: int, transcript: str diff --git a/run_agent.py b/run_agent.py index 405fd8e37..bdf049655 100644 --- a/run_agent.py +++ b/run_agent.py @@ -497,6 +497,12 @@ class AIAgent: # Initialized here so _vprint can reference it before run_conversation. self._stream_callback = None + # Optional current-turn user-message override used when the API-facing + # user message intentionally differs from the persisted transcript + # (e.g. CLI voice mode adds a temporary prefix for the live call only). + self._persist_user_message_idx = None + self._persist_user_message_override = None + # Initialize LLM client via centralized provider router. # The router handles auth resolution, base URL, headers, and # Codex/Anthropic wrapping for all known providers. @@ -998,11 +1004,30 @@ class AIAgent: if self.verbose_logging: logging.warning(f"Failed to cleanup browser for task {task_id}: {e}") + def _apply_persist_user_message_override(self, messages: List[Dict]) -> None: + """Rewrite the current-turn user message before persistence/return. + + Some call paths need an API-only user-message variant without letting + that synthetic text leak into persisted transcripts or resumed session + history. When an override is configured for the active turn, mutate the + in-memory messages list in place so both persistence and returned + history stay clean. + """ + idx = getattr(self, "_persist_user_message_idx", None) + override = getattr(self, "_persist_user_message_override", None) + if override is None or idx is None: + return + if 0 <= idx < len(messages): + msg = messages[idx] + if isinstance(msg, dict) and msg.get("role") == "user": + msg["content"] = override + def _persist_session(self, messages: List[Dict], conversation_history: List[Dict] = None): """Save session state to both JSON log and SQLite on any exit path. Ensures conversations are never lost, even on errors or early returns. """ + self._apply_persist_user_message_override(messages) self._session_messages = messages self._save_session_log(messages) self._flush_messages_to_session_db(messages, conversation_history) @@ -1016,6 +1041,7 @@ class AIAgent: """ if not self._session_db: return + self._apply_persist_user_message_override(messages) try: start_idx = len(conversation_history) if conversation_history else 0 flush_from = max(start_idx, self._last_flushed_db_idx) @@ -4065,6 +4091,7 @@ class AIAgent: conversation_history: List[Dict[str, Any]] = None, task_id: str = None, stream_callback: Optional[callable] = None, + persist_user_message: Optional[str] = None, ) -> Dict[str, Any]: """ Run a complete conversation with tool calling until completion. @@ -4077,6 +4104,9 @@ class AIAgent: stream_callback: Optional callback invoked with each text delta during streaming. Used by the TTS pipeline to start audio generation before the full response. When None (default), API calls use the standard non-streaming path. + persist_user_message: Optional clean user message to store in + transcripts/history when user_message contains API-only + synthetic prefixes. Returns: Dict: Complete conversation result with final response and message history @@ -4087,6 +4117,8 @@ class AIAgent: # Store stream callback for _interruptible_api_call to pick up self._stream_callback = stream_callback + self._persist_user_message_idx = None + self._persist_user_message_override = persist_user_message # Generate unique task_id if not provided to isolate VMs between concurrent tasks effective_task_id = task_id or str(uuid.uuid4()) @@ -4121,7 +4153,7 @@ class AIAgent: # Preserve the original user message before nudge injection. # Honcho should receive the actual user input, not system nudges. - original_user_message = user_message + original_user_message = persist_user_message if persist_user_message is not None else user_message # Periodic memory nudge: remind the model to consider saving memories. # Counter resets whenever the memory tool is actually used. @@ -4159,7 +4191,7 @@ class AIAgent: _recall_mode = (self._honcho_config.recall_mode if self._honcho_config else "hybrid") if self._honcho and self._honcho_session_key and _recall_mode != "tools": try: - prefetched_context = self._honcho_prefetch(user_message) + prefetched_context = self._honcho_prefetch(original_user_message) if prefetched_context: if not conversation_history: self._honcho_context = prefetched_context @@ -4172,6 +4204,7 @@ class AIAgent: user_msg = {"role": "user", "content": user_message} messages.append(user_msg) current_turn_user_idx = len(messages) - 1 + self._persist_user_message_idx = current_turn_user_idx if not self.quiet_mode: print(f"💬 Starting conversation: '{user_message[:60]}{'...' if len(user_message) > 60 else ''}'") diff --git a/tests/gateway/test_voice_command.py b/tests/gateway/test_voice_command.py index 47aef6595..545f2b28f 100644 --- a/tests/gateway/test_voice_command.py +++ b/tests/gateway/test_voice_command.py @@ -3,12 +3,53 @@ import json import os import queue +import sys import threading import time import pytest from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch + +def _ensure_discord_mock(): + """Install a lightweight discord mock when discord.py isn't available.""" + if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"): + return + + discord_mod = MagicMock() + discord_mod.Intents.default.return_value = MagicMock() + discord_mod.Client = MagicMock + discord_mod.File = MagicMock + discord_mod.DMChannel = type("DMChannel", (), {}) + discord_mod.Thread = type("Thread", (), {}) + discord_mod.ForumChannel = type("ForumChannel", (), {}) + discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object) + discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3) + discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4) + discord_mod.Interaction = object + discord_mod.Embed = MagicMock + discord_mod.app_commands = SimpleNamespace( + describe=lambda **kwargs: (lambda fn: fn), + choices=lambda **kwargs: (lambda fn: fn), + Choice=lambda **kwargs: SimpleNamespace(**kwargs), + ) + discord_mod.opus = SimpleNamespace(is_loaded=lambda: True, load_opus=lambda *_args, **_kwargs: None) + discord_mod.FFmpegPCMAudio = MagicMock + discord_mod.PCMVolumeTransformer = MagicMock + discord_mod.http = SimpleNamespace(Route=MagicMock) + + ext_mod = MagicMock() + commands_mod = MagicMock() + commands_mod.Bot = MagicMock + ext_mod.commands = commands_mod + + sys.modules.setdefault("discord", discord_mod) + sys.modules.setdefault("discord.ext", ext_mod) + sys.modules.setdefault("discord.ext.commands", commands_mod) + + +_ensure_discord_mock() + from gateway.platforms.base import MessageEvent, MessageType, SessionSource @@ -65,7 +106,7 @@ class TestHandleVoiceCommand: event = _make_event("/voice off") result = await runner._handle_voice_command(event) assert "disabled" in result.lower() - assert "123" not in runner._voice_mode + assert runner._voice_mode["123"] == "off" @pytest.mark.asyncio async def test_voice_tts(self, runner): @@ -100,7 +141,7 @@ class TestHandleVoiceCommand: event = _make_event("/voice") result = await runner._handle_voice_command(event) assert "disabled" in result.lower() - assert "123" not in runner._voice_mode + assert runner._voice_mode["123"] == "off" @pytest.mark.asyncio async def test_persistence_saved(self, runner): @@ -116,6 +157,33 @@ class TestHandleVoiceCommand: loaded = runner._load_voice_modes() assert loaded == {"456": "all"} + @pytest.mark.asyncio + async def test_persistence_saved_for_off(self, runner): + event = _make_event("/voice off") + await runner._handle_voice_command(event) + data = json.loads(runner._VOICE_MODE_PATH.read_text()) + assert data["123"] == "off" + + def test_sync_voice_mode_state_to_adapter_restores_off_chats(self, runner): + runner._voice_mode = {"123": "off", "456": "all"} + adapter = SimpleNamespace(_auto_tts_disabled_chats=set()) + + runner._sync_voice_mode_state_to_adapter(adapter) + + assert adapter._auto_tts_disabled_chats == {"123"} + + def test_restart_restores_voice_off_state(self, runner, tmp_path): + runner._VOICE_MODE_PATH.write_text(json.dumps({"123": "off"})) + + restored_runner = _make_runner(tmp_path) + restored_runner._voice_mode = restored_runner._load_voice_modes() + adapter = SimpleNamespace(_auto_tts_disabled_chats=set()) + + restored_runner._sync_voice_mode_state_to_adapter(adapter) + + assert restored_runner._voice_mode["123"] == "off" + assert adapter._auto_tts_disabled_chats == {"123"} + @pytest.mark.asyncio async def test_per_chat_isolation(self, runner): e1 = _make_event("/voice on", chat_id="aaa") @@ -693,7 +761,7 @@ class TestVoiceChannelCommands: runner._voice_mode["123"] = "all" result = await runner._handle_voice_channel_leave(event) assert "left" in result.lower() - assert "123" not in runner._voice_mode + assert runner._voice_mode["123"] == "off" mock_adapter.leave_voice_channel.assert_called_once_with(111) # -- _handle_voice_channel_input -- @@ -1163,7 +1231,7 @@ class TestLeaveExceptionHandling: result = await runner._handle_voice_channel_leave(event) assert "left" in result.lower() - assert "123" not in runner._voice_mode + assert runner._voice_mode["123"] == "off" assert mock_adapter._voice_input_callback is None @pytest.mark.asyncio @@ -1626,8 +1694,8 @@ class TestVoiceTimeoutCleansRunnerState: runner._handle_voice_timeout_cleanup("999") - assert "999" not in runner._voice_mode, \ - "voice_mode must be removed after timeout cleanup" + assert runner._voice_mode["999"] == "off", \ + "voice_mode must persist explicit off state after timeout cleanup" @pytest.mark.asyncio async def test_timeout_without_callback_does_not_crash(self, adapter): diff --git a/tests/test_run_agent.py b/tests/test_run_agent.py index dae905dd7..59c4a052a 100644 --- a/tests/test_run_agent.py +++ b/tests/test_run_agent.py @@ -2383,6 +2383,41 @@ class TestStreamCallbackNonStreamingProvider: assert received == ["Hello from Claude"] +# --------------------------------------------------------------------------- +# Bugfix: API-only user message prefixes must not persist +# --------------------------------------------------------------------------- + + +class TestPersistUserMessageOverride: + """Synthetic API-only user prefixes should never leak into transcripts.""" + + def test_persist_session_rewrites_current_turn_user_message(self, agent): + agent._session_db = MagicMock() + agent.session_id = "session-123" + agent._last_flushed_db_idx = 0 + agent._persist_user_message_idx = 0 + agent._persist_user_message_override = "Hello there" + messages = [ + { + "role": "user", + "content": ( + "[Voice input — respond concisely and conversationally, " + "2-3 sentences max. No code blocks or markdown.] Hello there" + ), + }, + {"role": "assistant", "content": "Hi!"}, + ] + + with patch.object(agent, "_save_session_log") as mock_save: + agent._persist_session(messages, []) + + assert messages[0]["content"] == "Hello there" + saved_messages = mock_save.call_args.args[0] + assert saved_messages[0]["content"] == "Hello there" + first_db_write = agent._session_db.append_message.call_args_list[0].kwargs + assert first_db_write["content"] == "Hello there" + + # --------------------------------------------------------------------------- # Bugfix: _vprint force=True on error messages during TTS # ---------------------------------------------------------------------------