fix: persist clean voice transcripts and /voice off state

- keep CLI voice prefixes API-local while storing the original user text
- persist explicit gateway off state and restore adapter auto-TTS suppression on restart
- add regression coverage for both behaviors
This commit is contained in:
teknium1
2026-03-14 06:14:22 -07:00
parent 523a1b6faf
commit 7b10881b9e
5 changed files with 192 additions and 29 deletions

View File

@@ -3,12 +3,53 @@
import json
import os
import queue
import sys
import threading
import time
import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
def _ensure_discord_mock():
"""Install a lightweight discord mock when discord.py isn't available."""
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
return
discord_mod = MagicMock()
discord_mod.Intents.default.return_value = MagicMock()
discord_mod.Client = MagicMock
discord_mod.File = MagicMock
discord_mod.DMChannel = type("DMChannel", (), {})
discord_mod.Thread = type("Thread", (), {})
discord_mod.ForumChannel = type("ForumChannel", (), {})
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3)
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4)
discord_mod.Interaction = object
discord_mod.Embed = MagicMock
discord_mod.app_commands = SimpleNamespace(
describe=lambda **kwargs: (lambda fn: fn),
choices=lambda **kwargs: (lambda fn: fn),
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
)
discord_mod.opus = SimpleNamespace(is_loaded=lambda: True, load_opus=lambda *_args, **_kwargs: None)
discord_mod.FFmpegPCMAudio = MagicMock
discord_mod.PCMVolumeTransformer = MagicMock
discord_mod.http = SimpleNamespace(Route=MagicMock)
ext_mod = MagicMock()
commands_mod = MagicMock()
commands_mod.Bot = MagicMock
ext_mod.commands = commands_mod
sys.modules.setdefault("discord", discord_mod)
sys.modules.setdefault("discord.ext", ext_mod)
sys.modules.setdefault("discord.ext.commands", commands_mod)
_ensure_discord_mock()
from gateway.platforms.base import MessageEvent, MessageType, SessionSource
@@ -65,7 +106,7 @@ class TestHandleVoiceCommand:
event = _make_event("/voice off")
result = await runner._handle_voice_command(event)
assert "disabled" in result.lower()
assert "123" not in runner._voice_mode
assert runner._voice_mode["123"] == "off"
@pytest.mark.asyncio
async def test_voice_tts(self, runner):
@@ -100,7 +141,7 @@ class TestHandleVoiceCommand:
event = _make_event("/voice")
result = await runner._handle_voice_command(event)
assert "disabled" in result.lower()
assert "123" not in runner._voice_mode
assert runner._voice_mode["123"] == "off"
@pytest.mark.asyncio
async def test_persistence_saved(self, runner):
@@ -116,6 +157,33 @@ class TestHandleVoiceCommand:
loaded = runner._load_voice_modes()
assert loaded == {"456": "all"}
@pytest.mark.asyncio
async def test_persistence_saved_for_off(self, runner):
event = _make_event("/voice off")
await runner._handle_voice_command(event)
data = json.loads(runner._VOICE_MODE_PATH.read_text())
assert data["123"] == "off"
def test_sync_voice_mode_state_to_adapter_restores_off_chats(self, runner):
runner._voice_mode = {"123": "off", "456": "all"}
adapter = SimpleNamespace(_auto_tts_disabled_chats=set())
runner._sync_voice_mode_state_to_adapter(adapter)
assert adapter._auto_tts_disabled_chats == {"123"}
def test_restart_restores_voice_off_state(self, runner, tmp_path):
runner._VOICE_MODE_PATH.write_text(json.dumps({"123": "off"}))
restored_runner = _make_runner(tmp_path)
restored_runner._voice_mode = restored_runner._load_voice_modes()
adapter = SimpleNamespace(_auto_tts_disabled_chats=set())
restored_runner._sync_voice_mode_state_to_adapter(adapter)
assert restored_runner._voice_mode["123"] == "off"
assert adapter._auto_tts_disabled_chats == {"123"}
@pytest.mark.asyncio
async def test_per_chat_isolation(self, runner):
e1 = _make_event("/voice on", chat_id="aaa")
@@ -693,7 +761,7 @@ class TestVoiceChannelCommands:
runner._voice_mode["123"] = "all"
result = await runner._handle_voice_channel_leave(event)
assert "left" in result.lower()
assert "123" not in runner._voice_mode
assert runner._voice_mode["123"] == "off"
mock_adapter.leave_voice_channel.assert_called_once_with(111)
# -- _handle_voice_channel_input --
@@ -1163,7 +1231,7 @@ class TestLeaveExceptionHandling:
result = await runner._handle_voice_channel_leave(event)
assert "left" in result.lower()
assert "123" not in runner._voice_mode
assert runner._voice_mode["123"] == "off"
assert mock_adapter._voice_input_callback is None
@pytest.mark.asyncio
@@ -1626,8 +1694,8 @@ class TestVoiceTimeoutCleansRunnerState:
runner._handle_voice_timeout_cleanup("999")
assert "999" not in runner._voice_mode, \
"voice_mode must be removed after timeout cleanup"
assert runner._voice_mode["999"] == "off", \
"voice_mode must persist explicit off state after timeout cleanup"
@pytest.mark.asyncio
async def test_timeout_without_callback_does_not_crash(self, adapter):

View File

@@ -2383,6 +2383,41 @@ class TestStreamCallbackNonStreamingProvider:
assert received == ["Hello from Claude"]
# ---------------------------------------------------------------------------
# Bugfix: API-only user message prefixes must not persist
# ---------------------------------------------------------------------------
class TestPersistUserMessageOverride:
"""Synthetic API-only user prefixes should never leak into transcripts."""
def test_persist_session_rewrites_current_turn_user_message(self, agent):
agent._session_db = MagicMock()
agent.session_id = "session-123"
agent._last_flushed_db_idx = 0
agent._persist_user_message_idx = 0
agent._persist_user_message_override = "Hello there"
messages = [
{
"role": "user",
"content": (
"[Voice input — respond concisely and conversationally, "
"2-3 sentences max. No code blocks or markdown.] Hello there"
),
},
{"role": "assistant", "content": "Hi!"},
]
with patch.object(agent, "_save_session_log") as mock_save:
agent._persist_session(messages, [])
assert messages[0]["content"] == "Hello there"
saved_messages = mock_save.call_args.args[0]
assert saved_messages[0]["content"] == "Hello there"
first_db_write = agent._session_db.append_message.call_args_list[0].kwargs
assert first_db_write["content"] == "Hello there"
# ---------------------------------------------------------------------------
# Bugfix: _vprint force=True on error messages during TTS
# ---------------------------------------------------------------------------