fix: persist clean voice transcripts and /voice off state
- keep CLI voice prefixes API-local while storing the original user text - persist explicit gateway off state and restore adapter auto-TTS suppression on restart - add regression coverage for both behaviors
This commit is contained in:
@@ -3,12 +3,53 @@
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
"""Install a lightweight discord mock when discord.py isn't available."""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3)
|
||||
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
discord_mod.opus = SimpleNamespace(is_loaded=lambda: True, load_opus=lambda *_args, **_kwargs: None)
|
||||
discord_mod.FFmpegPCMAudio = MagicMock
|
||||
discord_mod.PCMVolumeTransformer = MagicMock
|
||||
discord_mod.http = SimpleNamespace(Route=MagicMock)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from gateway.platforms.base import MessageEvent, MessageType, SessionSource
|
||||
|
||||
|
||||
@@ -65,7 +106,7 @@ class TestHandleVoiceCommand:
|
||||
event = _make_event("/voice off")
|
||||
result = await runner._handle_voice_command(event)
|
||||
assert "disabled" in result.lower()
|
||||
assert "123" not in runner._voice_mode
|
||||
assert runner._voice_mode["123"] == "off"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_tts(self, runner):
|
||||
@@ -100,7 +141,7 @@ class TestHandleVoiceCommand:
|
||||
event = _make_event("/voice")
|
||||
result = await runner._handle_voice_command(event)
|
||||
assert "disabled" in result.lower()
|
||||
assert "123" not in runner._voice_mode
|
||||
assert runner._voice_mode["123"] == "off"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistence_saved(self, runner):
|
||||
@@ -116,6 +157,33 @@ class TestHandleVoiceCommand:
|
||||
loaded = runner._load_voice_modes()
|
||||
assert loaded == {"456": "all"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistence_saved_for_off(self, runner):
|
||||
event = _make_event("/voice off")
|
||||
await runner._handle_voice_command(event)
|
||||
data = json.loads(runner._VOICE_MODE_PATH.read_text())
|
||||
assert data["123"] == "off"
|
||||
|
||||
def test_sync_voice_mode_state_to_adapter_restores_off_chats(self, runner):
|
||||
runner._voice_mode = {"123": "off", "456": "all"}
|
||||
adapter = SimpleNamespace(_auto_tts_disabled_chats=set())
|
||||
|
||||
runner._sync_voice_mode_state_to_adapter(adapter)
|
||||
|
||||
assert adapter._auto_tts_disabled_chats == {"123"}
|
||||
|
||||
def test_restart_restores_voice_off_state(self, runner, tmp_path):
|
||||
runner._VOICE_MODE_PATH.write_text(json.dumps({"123": "off"}))
|
||||
|
||||
restored_runner = _make_runner(tmp_path)
|
||||
restored_runner._voice_mode = restored_runner._load_voice_modes()
|
||||
adapter = SimpleNamespace(_auto_tts_disabled_chats=set())
|
||||
|
||||
restored_runner._sync_voice_mode_state_to_adapter(adapter)
|
||||
|
||||
assert restored_runner._voice_mode["123"] == "off"
|
||||
assert adapter._auto_tts_disabled_chats == {"123"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_per_chat_isolation(self, runner):
|
||||
e1 = _make_event("/voice on", chat_id="aaa")
|
||||
@@ -693,7 +761,7 @@ class TestVoiceChannelCommands:
|
||||
runner._voice_mode["123"] = "all"
|
||||
result = await runner._handle_voice_channel_leave(event)
|
||||
assert "left" in result.lower()
|
||||
assert "123" not in runner._voice_mode
|
||||
assert runner._voice_mode["123"] == "off"
|
||||
mock_adapter.leave_voice_channel.assert_called_once_with(111)
|
||||
|
||||
# -- _handle_voice_channel_input --
|
||||
@@ -1163,7 +1231,7 @@ class TestLeaveExceptionHandling:
|
||||
|
||||
result = await runner._handle_voice_channel_leave(event)
|
||||
assert "left" in result.lower()
|
||||
assert "123" not in runner._voice_mode
|
||||
assert runner._voice_mode["123"] == "off"
|
||||
assert mock_adapter._voice_input_callback is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -1626,8 +1694,8 @@ class TestVoiceTimeoutCleansRunnerState:
|
||||
|
||||
runner._handle_voice_timeout_cleanup("999")
|
||||
|
||||
assert "999" not in runner._voice_mode, \
|
||||
"voice_mode must be removed after timeout cleanup"
|
||||
assert runner._voice_mode["999"] == "off", \
|
||||
"voice_mode must persist explicit off state after timeout cleanup"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_without_callback_does_not_crash(self, adapter):
|
||||
|
||||
@@ -2383,6 +2383,41 @@ class TestStreamCallbackNonStreamingProvider:
|
||||
assert received == ["Hello from Claude"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bugfix: API-only user message prefixes must not persist
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPersistUserMessageOverride:
|
||||
"""Synthetic API-only user prefixes should never leak into transcripts."""
|
||||
|
||||
def test_persist_session_rewrites_current_turn_user_message(self, agent):
|
||||
agent._session_db = MagicMock()
|
||||
agent.session_id = "session-123"
|
||||
agent._last_flushed_db_idx = 0
|
||||
agent._persist_user_message_idx = 0
|
||||
agent._persist_user_message_override = "Hello there"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"[Voice input — respond concisely and conversationally, "
|
||||
"2-3 sentences max. No code blocks or markdown.] Hello there"
|
||||
),
|
||||
},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
]
|
||||
|
||||
with patch.object(agent, "_save_session_log") as mock_save:
|
||||
agent._persist_session(messages, [])
|
||||
|
||||
assert messages[0]["content"] == "Hello there"
|
||||
saved_messages = mock_save.call_args.args[0]
|
||||
assert saved_messages[0]["content"] == "Hello there"
|
||||
first_db_write = agent._session_db.append_message.call_args_list[0].kwargs
|
||||
assert first_db_write["content"] == "Hello there"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bugfix: _vprint force=True on error messages during TTS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user