merge: salvage PR #327 voice mode branch

Merge contributor branch feature/voice-mode onto current main for follow-up fixes.
This commit is contained in:
teknium1
2026-03-14 06:03:07 -07:00
37 changed files with 9248 additions and 228 deletions

View File

@@ -275,12 +275,25 @@ class FakeHAServer:
affected = []
entity_id = body.get("entity_id")
if entity_id:
new_state = "on" if service == "turn_on" else "off"
for s in ENTITY_STATES:
if s["entity_id"] == entity_id:
if service == "turn_on":
s["state"] = "on"
elif service == "turn_off":
s["state"] = "off"
elif service == "set_temperature" and "temperature" in body:
s["attributes"]["temperature"] = body["temperature"]
# Keep current state or set to heat if off
if s["state"] == "off":
s["state"] = "heat"
# Simulate temperature sensor approaching the target
for ts in ENTITY_STATES:
if ts["entity_id"] == "sensor.temperature":
ts["state"] = str(body["temperature"] - 0.5)
break
affected.append({
"entity_id": entity_id,
"state": new_state,
"state": s["state"],
"attributes": s.get("attributes", {}),
})
break

View File

@@ -32,6 +32,7 @@ def _make_runner():
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.adapters = {}
runner._voice_mode = {}
runner._session_db = None
runner._reasoning_config = None
runner._provider_routing = {}

View File

@@ -29,6 +29,8 @@ def _ensure_discord_mock():
discord_mod.Embed = MagicMock
discord_mod.app_commands = SimpleNamespace(
describe=lambda **kwargs: (lambda fn: fn),
choices=lambda **kwargs: (lambda fn: fn),
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
)
ext_mod = MagicMock()

View File

@@ -0,0 +1,44 @@
"""Tests for Discord Opus codec loading — must use ctypes.util.find_library."""
import inspect
class TestOpusFindLibrary:
"""Opus loading must try ctypes.util.find_library first, with platform fallback."""
def test_uses_find_library_first(self):
"""find_library must be the primary lookup strategy."""
from gateway.platforms.discord import DiscordAdapter
source = inspect.getsource(DiscordAdapter.connect)
assert "find_library" in source, \
"Opus loading must use ctypes.util.find_library"
def test_homebrew_fallback_is_conditional(self):
"""Homebrew paths must only be tried when find_library returns None."""
from gateway.platforms.discord import DiscordAdapter
source = inspect.getsource(DiscordAdapter.connect)
# Homebrew fallback must exist
assert "/opt/homebrew" in source or "homebrew" in source, \
"Opus loading should have macOS Homebrew fallback"
# find_library must appear BEFORE any Homebrew path
fl_idx = source.index("find_library")
hb_idx = source.index("/opt/homebrew")
assert fl_idx < hb_idx, \
"find_library must be tried before Homebrew fallback paths"
# Fallback must be guarded by platform check
assert "sys.platform" in source or "darwin" in source, \
"Homebrew fallback must be guarded by macOS platform check"
def test_opus_decode_error_logged(self):
"""Opus decode failure must log the error, not silently return."""
from gateway.platforms.discord import VoiceReceiver
source = inspect.getsource(VoiceReceiver._on_packet)
assert "logger" in source, \
"_on_packet must log Opus decode errors"
# Must not have bare `except Exception:\n return`
lines = source.split("\n")
for i, line in enumerate(lines):
if "except Exception" in line and i + 1 < len(lines):
next_line = lines[i + 1].strip()
assert next_line != "return", \
f"_on_packet has bare 'except Exception: return' at line {i+1}"

View File

@@ -21,6 +21,8 @@ def _ensure_discord_mock():
discord_mod.Interaction = object
discord_mod.app_commands = SimpleNamespace(
describe=lambda **kwargs: (lambda fn: fn),
choices=lambda **kwargs: (lambda fn: fn),
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
)
ext_mod = MagicMock()

View File

@@ -36,6 +36,7 @@ def _make_runner(session_db=None, current_session_id="current_session_001",
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.adapters = {}
runner._voice_mode = {}
runner._session_db = session_db
runner._running_agents = {}

View File

@@ -77,6 +77,7 @@ def _make_runner(adapter):
runner = object.__new__(GatewayRunner)
runner.adapters = {Platform.TELEGRAM: adapter}
runner._voice_mode = {}
runner._prefill_messages = []
runner._ephemeral_system_prompt = ""
runner._reasoning_config = None

View File

@@ -266,6 +266,7 @@ async def test_session_hygiene_messages_stay_in_originating_topic(monkeypatch, t
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="fake-token")}
)
runner.adapters = {Platform.TELEGRAM: adapter}
runner._voice_mode = {}
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
runner.session_store = MagicMock()
runner.session_store.get_or_create_session.return_value = SessionEntry(

View File

@@ -31,6 +31,7 @@ def _make_runner(session_db=None):
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.adapters = {}
runner._voice_mode = {}
runner._session_db = session_db
# Mock session_store that returns a session entry with a known session_id

View File

@@ -33,6 +33,7 @@ def _make_runner():
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.adapters = {}
runner._voice_mode = {}
return runner

File diff suppressed because it is too large Load Diff

View File

@@ -12,7 +12,7 @@ EXPECTED_COMMANDS = {
"/personality", "/clear", "/history", "/new", "/reset", "/retry",
"/undo", "/save", "/config", "/cron", "/skills", "/platforms",
"/verbose", "/reasoning", "/compress", "/title", "/usage", "/insights", "/paste",
"/reload-mcp", "/rollback", "/background", "/skin", "/quit",
"/reload-mcp", "/rollback", "/background", "/skin", "/voice", "/quit",
}

View File

@@ -14,6 +14,9 @@ def _make_cli_stub():
cli._clarify_freetext = False
cli._command_running = False
cli._agent_running = False
cli._voice_recording = False
cli._voice_processing = False
cli._voice_mode = False
cli._command_spinner_frame = lambda: ""
cli._tui_style_base = {
"prompt": "#fff",

View File

@@ -2083,3 +2083,332 @@ class TestAnthropicBaseUrlPassthrough:
# No base_url provided, should be default empty string or None
passed_url = call_args[0][1]
assert not passed_url or passed_url is None
# ===================================================================
# _streaming_api_call tests
# ===================================================================
def _make_chunk(content=None, tool_calls=None, finish_reason=None, model="test/model"):
"""Build a SimpleNamespace mimicking an OpenAI streaming chunk."""
delta = SimpleNamespace(content=content, tool_calls=tool_calls)
choice = SimpleNamespace(delta=delta, finish_reason=finish_reason)
return SimpleNamespace(model=model, choices=[choice])
def _make_tc_delta(index=0, tc_id=None, name=None, arguments=None):
"""Build a SimpleNamespace mimicking a streaming tool_call delta."""
func = SimpleNamespace(name=name, arguments=arguments)
return SimpleNamespace(index=index, id=tc_id, function=func)
class TestStreamingApiCall:
"""Tests for _streaming_api_call — voice TTS streaming pipeline."""
def test_content_assembly(self, agent):
chunks = [
_make_chunk(content="Hel"),
_make_chunk(content="lo "),
_make_chunk(content="World"),
_make_chunk(finish_reason="stop"),
]
agent.client.chat.completions.create.return_value = iter(chunks)
callback = MagicMock()
resp = agent._streaming_api_call({"messages": []}, callback)
assert resp.choices[0].message.content == "Hello World"
assert resp.choices[0].finish_reason == "stop"
assert callback.call_count == 3
callback.assert_any_call("Hel")
callback.assert_any_call("lo ")
callback.assert_any_call("World")
def test_tool_call_accumulation(self, agent):
chunks = [
_make_chunk(tool_calls=[_make_tc_delta(0, "call_1", "web_", '{"q":')]),
_make_chunk(tool_calls=[_make_tc_delta(0, None, "search", '"test"}')]),
_make_chunk(finish_reason="tool_calls"),
]
agent.client.chat.completions.create.return_value = iter(chunks)
resp = agent._streaming_api_call({"messages": []}, MagicMock())
tc = resp.choices[0].message.tool_calls
assert len(tc) == 1
assert tc[0].function.name == "web_search"
assert tc[0].function.arguments == '{"q":"test"}'
assert tc[0].id == "call_1"
def test_multiple_tool_calls(self, agent):
chunks = [
_make_chunk(tool_calls=[_make_tc_delta(0, "call_a", "search", '{}')]),
_make_chunk(tool_calls=[_make_tc_delta(1, "call_b", "read", '{}')]),
_make_chunk(finish_reason="tool_calls"),
]
agent.client.chat.completions.create.return_value = iter(chunks)
resp = agent._streaming_api_call({"messages": []}, MagicMock())
tc = resp.choices[0].message.tool_calls
assert len(tc) == 2
assert tc[0].function.name == "search"
assert tc[1].function.name == "read"
def test_content_and_tool_calls_together(self, agent):
chunks = [
_make_chunk(content="I'll search"),
_make_chunk(tool_calls=[_make_tc_delta(0, "call_1", "search", '{}')]),
_make_chunk(finish_reason="tool_calls"),
]
agent.client.chat.completions.create.return_value = iter(chunks)
resp = agent._streaming_api_call({"messages": []}, MagicMock())
assert resp.choices[0].message.content == "I'll search"
assert len(resp.choices[0].message.tool_calls) == 1
def test_empty_content_returns_none(self, agent):
chunks = [_make_chunk(finish_reason="stop")]
agent.client.chat.completions.create.return_value = iter(chunks)
resp = agent._streaming_api_call({"messages": []}, MagicMock())
assert resp.choices[0].message.content is None
assert resp.choices[0].message.tool_calls is None
def test_callback_exception_swallowed(self, agent):
chunks = [
_make_chunk(content="Hello"),
_make_chunk(content=" World"),
_make_chunk(finish_reason="stop"),
]
agent.client.chat.completions.create.return_value = iter(chunks)
callback = MagicMock(side_effect=ValueError("boom"))
resp = agent._streaming_api_call({"messages": []}, callback)
assert resp.choices[0].message.content == "Hello World"
def test_model_name_captured(self, agent):
chunks = [
_make_chunk(content="Hi", model="gpt-4o"),
_make_chunk(finish_reason="stop", model="gpt-4o"),
]
agent.client.chat.completions.create.return_value = iter(chunks)
resp = agent._streaming_api_call({"messages": []}, MagicMock())
assert resp.model == "gpt-4o"
def test_stream_kwarg_injected(self, agent):
chunks = [_make_chunk(content="x"), _make_chunk(finish_reason="stop")]
agent.client.chat.completions.create.return_value = iter(chunks)
agent._streaming_api_call({"messages": [], "model": "test"}, MagicMock())
call_kwargs = agent.client.chat.completions.create.call_args
assert call_kwargs[1].get("stream") is True or call_kwargs.kwargs.get("stream") is True
def test_api_exception_propagated(self, agent):
agent.client.chat.completions.create.side_effect = ConnectionError("fail")
with pytest.raises(ConnectionError, match="fail"):
agent._streaming_api_call({"messages": []}, MagicMock())
def test_response_has_uuid_id(self, agent):
chunks = [_make_chunk(content="x"), _make_chunk(finish_reason="stop")]
agent.client.chat.completions.create.return_value = iter(chunks)
resp = agent._streaming_api_call({"messages": []}, MagicMock())
assert resp.id.startswith("stream-")
assert len(resp.id) > len("stream-")
def test_empty_choices_chunk_skipped(self, agent):
empty_chunk = SimpleNamespace(model="gpt-4", choices=[])
chunks = [
empty_chunk,
_make_chunk(content="Hello", model="gpt-4"),
_make_chunk(finish_reason="stop", model="gpt-4"),
]
agent.client.chat.completions.create.return_value = iter(chunks)
resp = agent._streaming_api_call({"messages": []}, MagicMock())
assert resp.choices[0].message.content == "Hello"
assert resp.model == "gpt-4"
# ===================================================================
# Interrupt _vprint force=True verification
# ===================================================================
class TestInterruptVprintForceTrue:
"""All interrupt _vprint calls must use force=True so they are always visible."""
def test_all_interrupt_vprint_have_force_true(self):
"""Scan source for _vprint calls containing 'Interrupt' — each must have force=True."""
import inspect
source = inspect.getsource(AIAgent)
lines = source.split("\n")
violations = []
for i, line in enumerate(lines, 1):
stripped = line.strip()
if "_vprint(" in stripped and "Interrupt" in stripped:
if "force=True" not in stripped:
violations.append(f"line {i}: {stripped}")
assert not violations, (
f"Interrupt _vprint calls missing force=True:\n"
+ "\n".join(violations)
)
# ===================================================================
# Anthropic interrupt handler in _interruptible_api_call
# ===================================================================
class TestAnthropicInterruptHandler:
"""_interruptible_api_call must handle Anthropic mode when interrupted."""
def test_interruptible_has_anthropic_branch(self):
"""The interrupt handler must check api_mode == 'anthropic_messages'."""
import inspect
source = inspect.getsource(AIAgent._interruptible_api_call)
assert "anthropic_messages" in source, \
"_interruptible_api_call must handle Anthropic interrupt (api_mode check)"
def test_interruptible_rebuilds_anthropic_client(self):
"""After interrupting, the Anthropic client should be rebuilt."""
import inspect
source = inspect.getsource(AIAgent._interruptible_api_call)
assert "build_anthropic_client" in source, \
"_interruptible_api_call must rebuild Anthropic client after interrupt"
def test_streaming_has_anthropic_branch(self):
"""_streaming_api_call must also handle Anthropic interrupt."""
import inspect
source = inspect.getsource(AIAgent._streaming_api_call)
assert "anthropic_messages" in source, \
"_streaming_api_call must handle Anthropic interrupt"
# ---------------------------------------------------------------------------
# Bugfix: stream_callback forwarding for non-streaming providers
# ---------------------------------------------------------------------------
class TestStreamCallbackNonStreamingProvider:
"""When api_mode != chat_completions, stream_callback must still receive
the response content so TTS works (batch delivery)."""
def test_callback_receives_chat_completions_response(self, agent):
"""For chat_completions-shaped responses, callback gets content."""
agent.api_mode = "anthropic_messages"
mock_response = SimpleNamespace(
choices=[SimpleNamespace(
message=SimpleNamespace(content="Hello", tool_calls=None, reasoning_content=None),
finish_reason="stop", index=0,
)],
usage=None, model="test", id="test-id",
)
agent._interruptible_api_call = MagicMock(return_value=mock_response)
received = []
cb = lambda delta: received.append(delta)
agent._stream_callback = cb
_cb = getattr(agent, "_stream_callback", None)
response = agent._interruptible_api_call({})
if _cb is not None and response:
try:
if agent.api_mode == "anthropic_messages":
text_parts = [
block.text for block in getattr(response, "content", [])
if getattr(block, "type", None) == "text" and getattr(block, "text", None)
]
content = " ".join(text_parts) if text_parts else None
else:
content = response.choices[0].message.content
if content:
_cb(content)
except Exception:
pass
# Anthropic format not matched above; fallback via except
# Test the actual code path by checking chat_completions branch
received2 = []
agent.api_mode = "some_other_mode"
agent._stream_callback = lambda d: received2.append(d)
_cb2 = agent._stream_callback
if _cb2 is not None and mock_response:
try:
content = mock_response.choices[0].message.content
if content:
_cb2(content)
except Exception:
pass
assert received2 == ["Hello"]
def test_callback_receives_anthropic_content(self, agent):
"""For Anthropic responses, text blocks are extracted and forwarded."""
agent.api_mode = "anthropic_messages"
mock_response = SimpleNamespace(
content=[SimpleNamespace(type="text", text="Hello from Claude")],
stop_reason="end_turn",
)
received = []
cb = lambda d: received.append(d)
agent._stream_callback = cb
_cb = agent._stream_callback
if _cb is not None and mock_response:
try:
if agent.api_mode == "anthropic_messages":
text_parts = [
block.text for block in getattr(mock_response, "content", [])
if getattr(block, "type", None) == "text" and getattr(block, "text", None)
]
content = " ".join(text_parts) if text_parts else None
else:
content = mock_response.choices[0].message.content
if content:
_cb(content)
except Exception:
pass
assert received == ["Hello from Claude"]
# ---------------------------------------------------------------------------
# Bugfix: _vprint force=True on error messages during TTS
# ---------------------------------------------------------------------------
class TestVprintForceOnErrors:
"""Error/warning messages must be visible during streaming TTS."""
def test_forced_message_shown_during_tts(self, agent):
agent._stream_callback = lambda x: None
printed = []
with patch("builtins.print", side_effect=lambda *a, **kw: printed.append(a)):
agent._vprint("error msg", force=True)
assert len(printed) == 1
def test_non_forced_suppressed_during_tts(self, agent):
agent._stream_callback = lambda x: None
printed = []
with patch("builtins.print", side_effect=lambda *a, **kw: printed.append(a)):
agent._vprint("debug info")
assert len(printed) == 0
def test_all_shown_without_tts(self, agent):
agent._stream_callback = None
printed = []
with patch("builtins.print", side_effect=lambda *a, **kw: printed.append(a)):
agent._vprint("debug")
agent._vprint("error", force=True)
assert len(printed) == 2

View File

@@ -28,6 +28,7 @@ class TestGetProvider:
def test_local_fallback_to_openai(self, monkeypatch):
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
monkeypatch.delenv("GROQ_API_KEY", raising=False)
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
patch("tools.transcription_tools._HAS_OPENAI", True):
from tools.transcription_tools import _get_provider
@@ -124,7 +125,7 @@ class TestTranscribeLocal:
mock_model.transcribe.return_value = ([mock_segment], mock_info)
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
patch("tools.transcription_tools.WhisperModel", return_value=mock_model), \
patch("faster_whisper.WhisperModel", return_value=mock_model), \
patch("tools.transcription_tools._local_model", None):
from tools.transcription_tools import _transcribe_local
result = _transcribe_local(str(audio_file), "base")
@@ -163,7 +164,7 @@ class TestTranscribeOpenAI:
mock_client.audio.transcriptions.create.return_value = "Hello from OpenAI"
with patch("tools.transcription_tools._HAS_OPENAI", True), \
patch("tools.transcription_tools.OpenAI", return_value=mock_client):
patch("openai.OpenAI", return_value=mock_client):
from tools.transcription_tools import _transcribe_openai
result = _transcribe_openai(str(audio_file), "whisper-1")

View File

@@ -0,0 +1,716 @@
"""Tests for tools.transcription_tools — three-provider STT pipeline.
Covers the full provider matrix (local, groq, openai), fallback chains,
model auto-correction, config loading, validation edge cases, and
end-to-end dispatch. All external dependencies are mocked.
"""
import os
import struct
import wave
from unittest.mock import MagicMock, patch
import pytest
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def sample_wav(tmp_path):
"""Create a minimal valid WAV file (1 second of silence at 16kHz)."""
wav_path = tmp_path / "test.wav"
n_frames = 16000
silence = struct.pack(f"<{n_frames}h", *([0] * n_frames))
with wave.open(str(wav_path), "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(16000)
wf.writeframes(silence)
return str(wav_path)
@pytest.fixture
def sample_ogg(tmp_path):
"""Create a fake OGG file for validation tests."""
ogg_path = tmp_path / "test.ogg"
ogg_path.write_bytes(b"fake audio data")
return str(ogg_path)
@pytest.fixture(autouse=True)
def clean_env(monkeypatch):
"""Ensure no real API keys leak into tests."""
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
monkeypatch.delenv("GROQ_API_KEY", raising=False)
# ============================================================================
# _get_provider — full permutation matrix
# ============================================================================
class TestGetProviderGroq:
"""Groq-specific provider selection tests."""
def test_groq_when_key_set(self, monkeypatch):
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
with patch("tools.transcription_tools._HAS_OPENAI", True), \
patch("tools.transcription_tools._HAS_FASTER_WHISPER", False):
from tools.transcription_tools import _get_provider
assert _get_provider({"provider": "groq"}) == "groq"
def test_groq_fallback_to_local(self, monkeypatch):
monkeypatch.delenv("GROQ_API_KEY", raising=False)
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True):
from tools.transcription_tools import _get_provider
assert _get_provider({"provider": "groq"}) == "local"
def test_groq_fallback_to_openai(self, monkeypatch):
monkeypatch.delenv("GROQ_API_KEY", raising=False)
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
patch("tools.transcription_tools._HAS_OPENAI", True):
from tools.transcription_tools import _get_provider
assert _get_provider({"provider": "groq"}) == "openai"
def test_groq_nothing_available(self, monkeypatch):
monkeypatch.delenv("GROQ_API_KEY", raising=False)
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
patch("tools.transcription_tools._HAS_OPENAI", False):
from tools.transcription_tools import _get_provider
assert _get_provider({"provider": "groq"}) == "none"
class TestGetProviderFallbackPriority:
"""Cross-provider fallback priority tests."""
def test_local_fallback_prefers_groq_over_openai(self, monkeypatch):
"""When local unavailable, groq (free) is preferred over openai (paid)."""
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
patch("tools.transcription_tools._HAS_OPENAI", True):
from tools.transcription_tools import _get_provider
assert _get_provider({"provider": "local"}) == "groq"
def test_local_fallback_to_groq_only(self, monkeypatch):
"""When only groq key available, falls back to groq."""
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
patch("tools.transcription_tools._HAS_OPENAI", True):
from tools.transcription_tools import _get_provider
assert _get_provider({"provider": "local"}) == "groq"
def test_openai_fallback_to_groq(self, monkeypatch):
"""When openai key missing but groq available, falls back to groq."""
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
patch("tools.transcription_tools._HAS_OPENAI", True):
from tools.transcription_tools import _get_provider
assert _get_provider({"provider": "openai"}) == "groq"
def test_openai_nothing_available(self, monkeypatch):
"""When no openai key and no local, returns none."""
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
monkeypatch.delenv("GROQ_API_KEY", raising=False)
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
patch("tools.transcription_tools._HAS_OPENAI", True):
from tools.transcription_tools import _get_provider
assert _get_provider({"provider": "openai"}) == "none"
def test_unknown_provider_passed_through(self):
from tools.transcription_tools import _get_provider
assert _get_provider({"provider": "custom-endpoint"}) == "custom-endpoint"
def test_empty_config_defaults_to_local(self):
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True):
from tools.transcription_tools import _get_provider
assert _get_provider({}) == "local"
# ============================================================================
# _transcribe_groq
# ============================================================================
class TestTranscribeGroq:
def test_no_key(self, monkeypatch):
monkeypatch.delenv("GROQ_API_KEY", raising=False)
from tools.transcription_tools import _transcribe_groq
result = _transcribe_groq("/tmp/test.ogg", "whisper-large-v3-turbo")
assert result["success"] is False
assert "GROQ_API_KEY" in result["error"]
def test_openai_package_not_installed(self, monkeypatch):
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
with patch("tools.transcription_tools._HAS_OPENAI", False):
from tools.transcription_tools import _transcribe_groq
result = _transcribe_groq("/tmp/test.ogg", "whisper-large-v3-turbo")
assert result["success"] is False
assert "openai package" in result["error"]
def test_successful_transcription(self, monkeypatch, sample_wav):
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
mock_client = MagicMock()
mock_client.audio.transcriptions.create.return_value = "hello world"
with patch("tools.transcription_tools._HAS_OPENAI", True), \
patch("openai.OpenAI", return_value=mock_client):
from tools.transcription_tools import _transcribe_groq
result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo")
assert result["success"] is True
assert result["transcript"] == "hello world"
assert result["provider"] == "groq"
def test_whitespace_stripped(self, monkeypatch, sample_wav):
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
mock_client = MagicMock()
mock_client.audio.transcriptions.create.return_value = " hello world \n"
with patch("tools.transcription_tools._HAS_OPENAI", True), \
patch("openai.OpenAI", return_value=mock_client):
from tools.transcription_tools import _transcribe_groq
result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo")
assert result["transcript"] == "hello world"
def test_uses_groq_base_url(self, monkeypatch, sample_wav):
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
mock_client = MagicMock()
mock_client.audio.transcriptions.create.return_value = "test"
with patch("tools.transcription_tools._HAS_OPENAI", True), \
patch("openai.OpenAI", return_value=mock_client) as mock_openai_cls:
from tools.transcription_tools import _transcribe_groq, GROQ_BASE_URL
_transcribe_groq(sample_wav, "whisper-large-v3-turbo")
call_kwargs = mock_openai_cls.call_args
assert call_kwargs.kwargs["base_url"] == GROQ_BASE_URL
def test_api_error_returns_failure(self, monkeypatch, sample_wav):
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
mock_client = MagicMock()
mock_client.audio.transcriptions.create.side_effect = Exception("API error")
with patch("tools.transcription_tools._HAS_OPENAI", True), \
patch("openai.OpenAI", return_value=mock_client):
from tools.transcription_tools import _transcribe_groq
result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo")
assert result["success"] is False
assert "API error" in result["error"]
def test_permission_error(self, monkeypatch, sample_wav):
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
mock_client = MagicMock()
mock_client.audio.transcriptions.create.side_effect = PermissionError("denied")
with patch("tools.transcription_tools._HAS_OPENAI", True), \
patch("openai.OpenAI", return_value=mock_client):
from tools.transcription_tools import _transcribe_groq
result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo")
assert result["success"] is False
assert "Permission denied" in result["error"]
# ============================================================================
# _transcribe_openai — additional tests
# ============================================================================
class TestTranscribeOpenAIExtended:
def test_openai_package_not_installed(self, monkeypatch):
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
with patch("tools.transcription_tools._HAS_OPENAI", False):
from tools.transcription_tools import _transcribe_openai
result = _transcribe_openai("/tmp/test.ogg", "whisper-1")
assert result["success"] is False
assert "openai package" in result["error"]
def test_uses_openai_base_url(self, monkeypatch, sample_wav):
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
mock_client = MagicMock()
mock_client.audio.transcriptions.create.return_value = "test"
with patch("tools.transcription_tools._HAS_OPENAI", True), \
patch("openai.OpenAI", return_value=mock_client) as mock_openai_cls:
from tools.transcription_tools import _transcribe_openai, OPENAI_BASE_URL
_transcribe_openai(sample_wav, "whisper-1")
call_kwargs = mock_openai_cls.call_args
assert call_kwargs.kwargs["base_url"] == OPENAI_BASE_URL
def test_whitespace_stripped(self, monkeypatch, sample_wav):
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
mock_client = MagicMock()
mock_client.audio.transcriptions.create.return_value = " hello \n"
with patch("tools.transcription_tools._HAS_OPENAI", True), \
patch("openai.OpenAI", return_value=mock_client):
from tools.transcription_tools import _transcribe_openai
result = _transcribe_openai(sample_wav, "whisper-1")
assert result["transcript"] == "hello"
def test_permission_error(self, monkeypatch, sample_wav):
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
mock_client = MagicMock()
mock_client.audio.transcriptions.create.side_effect = PermissionError("denied")
with patch("tools.transcription_tools._HAS_OPENAI", True), \
patch("openai.OpenAI", return_value=mock_client):
from tools.transcription_tools import _transcribe_openai
result = _transcribe_openai(sample_wav, "whisper-1")
assert result["success"] is False
assert "Permission denied" in result["error"]
# ============================================================================
# _transcribe_local — additional tests
# ============================================================================
class TestTranscribeLocalExtended:
def test_model_reuse_on_second_call(self, tmp_path):
"""Second call with same model should NOT reload the model."""
audio = tmp_path / "test.ogg"
audio.write_bytes(b"fake")
mock_segment = MagicMock()
mock_segment.text = "hi"
mock_info = MagicMock()
mock_info.language = "en"
mock_info.duration = 1.0
mock_model = MagicMock()
mock_model.transcribe.return_value = ([mock_segment], mock_info)
mock_whisper_cls = MagicMock(return_value=mock_model)
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
patch("faster_whisper.WhisperModel", mock_whisper_cls), \
patch("tools.transcription_tools._local_model", None), \
patch("tools.transcription_tools._local_model_name", None):
from tools.transcription_tools import _transcribe_local
_transcribe_local(str(audio), "base")
_transcribe_local(str(audio), "base")
# WhisperModel should be created only once
assert mock_whisper_cls.call_count == 1
def test_model_reloaded_on_change(self, tmp_path):
"""Switching model name should reload the model."""
audio = tmp_path / "test.ogg"
audio.write_bytes(b"fake")
mock_segment = MagicMock()
mock_segment.text = "hi"
mock_info = MagicMock()
mock_info.language = "en"
mock_info.duration = 1.0
mock_model = MagicMock()
mock_model.transcribe.return_value = ([mock_segment], mock_info)
mock_whisper_cls = MagicMock(return_value=mock_model)
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
patch("faster_whisper.WhisperModel", mock_whisper_cls), \
patch("tools.transcription_tools._local_model", None), \
patch("tools.transcription_tools._local_model_name", None):
from tools.transcription_tools import _transcribe_local
_transcribe_local(str(audio), "base")
_transcribe_local(str(audio), "small")
assert mock_whisper_cls.call_count == 2
def test_exception_returns_failure(self, tmp_path):
audio = tmp_path / "test.ogg"
audio.write_bytes(b"fake")
mock_whisper_cls = MagicMock(side_effect=RuntimeError("CUDA out of memory"))
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
patch("faster_whisper.WhisperModel", mock_whisper_cls), \
patch("tools.transcription_tools._local_model", None):
from tools.transcription_tools import _transcribe_local
result = _transcribe_local(str(audio), "large-v3")
assert result["success"] is False
assert "CUDA out of memory" in result["error"]
def test_multiple_segments_joined(self, tmp_path):
audio = tmp_path / "test.ogg"
audio.write_bytes(b"fake")
seg1 = MagicMock()
seg1.text = "Hello"
seg2 = MagicMock()
seg2.text = " world"
mock_info = MagicMock()
mock_info.language = "en"
mock_info.duration = 3.0
mock_model = MagicMock()
mock_model.transcribe.return_value = ([seg1, seg2], mock_info)
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
patch("faster_whisper.WhisperModel", return_value=mock_model), \
patch("tools.transcription_tools._local_model", None):
from tools.transcription_tools import _transcribe_local
result = _transcribe_local(str(audio), "base")
assert result["success"] is True
assert result["transcript"] == "Hello world"
# ============================================================================
# Model auto-correction
# ============================================================================
class TestModelAutoCorrection:
def test_groq_corrects_openai_model(self, monkeypatch, sample_wav):
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
mock_client = MagicMock()
mock_client.audio.transcriptions.create.return_value = "hello world"
with patch("tools.transcription_tools._HAS_OPENAI", True), \
patch("openai.OpenAI", return_value=mock_client):
from tools.transcription_tools import _transcribe_groq, DEFAULT_GROQ_STT_MODEL
_transcribe_groq(sample_wav, "whisper-1")
call_kwargs = mock_client.audio.transcriptions.create.call_args
assert call_kwargs.kwargs["model"] == DEFAULT_GROQ_STT_MODEL
def test_groq_corrects_gpt4o_transcribe(self, monkeypatch, sample_wav):
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
mock_client = MagicMock()
mock_client.audio.transcriptions.create.return_value = "test"
with patch("tools.transcription_tools._HAS_OPENAI", True), \
patch("openai.OpenAI", return_value=mock_client):
from tools.transcription_tools import _transcribe_groq, DEFAULT_GROQ_STT_MODEL
_transcribe_groq(sample_wav, "gpt-4o-transcribe")
call_kwargs = mock_client.audio.transcriptions.create.call_args
assert call_kwargs.kwargs["model"] == DEFAULT_GROQ_STT_MODEL
def test_openai_corrects_groq_model(self, monkeypatch, sample_wav):
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
mock_client = MagicMock()
mock_client.audio.transcriptions.create.return_value = "hello world"
with patch("tools.transcription_tools._HAS_OPENAI", True), \
patch("openai.OpenAI", return_value=mock_client):
from tools.transcription_tools import _transcribe_openai, DEFAULT_STT_MODEL
_transcribe_openai(sample_wav, "whisper-large-v3-turbo")
call_kwargs = mock_client.audio.transcriptions.create.call_args
assert call_kwargs.kwargs["model"] == DEFAULT_STT_MODEL
def test_openai_corrects_distil_whisper(self, monkeypatch, sample_wav):
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
mock_client = MagicMock()
mock_client.audio.transcriptions.create.return_value = "test"
with patch("tools.transcription_tools._HAS_OPENAI", True), \
patch("openai.OpenAI", return_value=mock_client):
from tools.transcription_tools import _transcribe_openai, DEFAULT_STT_MODEL
_transcribe_openai(sample_wav, "distil-whisper-large-v3-en")
call_kwargs = mock_client.audio.transcriptions.create.call_args
assert call_kwargs.kwargs["model"] == DEFAULT_STT_MODEL
def test_compatible_groq_model_not_overridden(self, monkeypatch, sample_wav):
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
mock_client = MagicMock()
mock_client.audio.transcriptions.create.return_value = "test"
with patch("tools.transcription_tools._HAS_OPENAI", True), \
patch("openai.OpenAI", return_value=mock_client):
from tools.transcription_tools import _transcribe_groq
_transcribe_groq(sample_wav, "whisper-large-v3")
call_kwargs = mock_client.audio.transcriptions.create.call_args
assert call_kwargs.kwargs["model"] == "whisper-large-v3"
def test_compatible_openai_model_not_overridden(self, monkeypatch, sample_wav):
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
mock_client = MagicMock()
mock_client.audio.transcriptions.create.return_value = "test"
with patch("tools.transcription_tools._HAS_OPENAI", True), \
patch("openai.OpenAI", return_value=mock_client):
from tools.transcription_tools import _transcribe_openai
_transcribe_openai(sample_wav, "gpt-4o-mini-transcribe")
call_kwargs = mock_client.audio.transcriptions.create.call_args
assert call_kwargs.kwargs["model"] == "gpt-4o-mini-transcribe"
def test_unknown_model_passes_through_groq(self, monkeypatch, sample_wav):
"""A model not in either known set should not be overridden."""
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
mock_client = MagicMock()
mock_client.audio.transcriptions.create.return_value = "test"
with patch("tools.transcription_tools._HAS_OPENAI", True), \
patch("openai.OpenAI", return_value=mock_client):
from tools.transcription_tools import _transcribe_groq
_transcribe_groq(sample_wav, "my-custom-model")
call_kwargs = mock_client.audio.transcriptions.create.call_args
assert call_kwargs.kwargs["model"] == "my-custom-model"
def test_unknown_model_passes_through_openai(self, monkeypatch, sample_wav):
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
mock_client = MagicMock()
mock_client.audio.transcriptions.create.return_value = "test"
with patch("tools.transcription_tools._HAS_OPENAI", True), \
patch("openai.OpenAI", return_value=mock_client):
from tools.transcription_tools import _transcribe_openai
_transcribe_openai(sample_wav, "my-custom-model")
call_kwargs = mock_client.audio.transcriptions.create.call_args
assert call_kwargs.kwargs["model"] == "my-custom-model"
# ============================================================================
# _load_stt_config
# ============================================================================
class TestLoadSttConfig:
def test_returns_dict_when_import_fails(self):
with patch("tools.transcription_tools._load_stt_config") as mock_load:
mock_load.return_value = {}
from tools.transcription_tools import _load_stt_config
assert _load_stt_config() == {}
def test_real_load_returns_dict(self):
"""_load_stt_config should always return a dict, even on import error."""
with patch.dict("sys.modules", {"hermes_cli": None, "hermes_cli.config": None}):
from tools.transcription_tools import _load_stt_config
result = _load_stt_config()
assert isinstance(result, dict)
# ============================================================================
# _validate_audio_file — edge cases
# ============================================================================
class TestValidateAudioFileEdgeCases:
def test_directory_is_not_a_file(self, tmp_path):
from tools.transcription_tools import _validate_audio_file
# tmp_path itself is a directory with an .ogg-ish name? No.
# Create a directory with a valid audio extension
d = tmp_path / "audio.ogg"
d.mkdir()
result = _validate_audio_file(str(d))
assert result is not None
assert "not a file" in result["error"]
def test_stat_oserror(self, tmp_path):
f = tmp_path / "test.ogg"
f.write_bytes(b"data")
from tools.transcription_tools import _validate_audio_file
real_stat = f.stat()
call_count = 0
def stat_side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
# First calls are from exists() and is_file(), let them pass
if call_count <= 2:
return real_stat
raise OSError("disk error")
with patch("pathlib.Path.stat", side_effect=stat_side_effect):
result = _validate_audio_file(str(f))
assert result is not None
assert "Failed to access" in result["error"]
def test_all_supported_formats_accepted(self, tmp_path):
from tools.transcription_tools import _validate_audio_file, SUPPORTED_FORMATS
for fmt in SUPPORTED_FORMATS:
f = tmp_path / f"test{fmt}"
f.write_bytes(b"data")
assert _validate_audio_file(str(f)) is None, f"Format {fmt} should be accepted"
def test_case_insensitive_extension(self, tmp_path):
from tools.transcription_tools import _validate_audio_file
f = tmp_path / "test.MP3"
f.write_bytes(b"data")
assert _validate_audio_file(str(f)) is None
# ============================================================================
# transcribe_audio — end-to-end dispatch
# ============================================================================
class TestTranscribeAudioDispatch:
def test_dispatches_to_groq(self, sample_ogg):
with patch("tools.transcription_tools._load_stt_config", return_value={"provider": "groq"}), \
patch("tools.transcription_tools._get_provider", return_value="groq"), \
patch("tools.transcription_tools._transcribe_groq",
return_value={"success": True, "transcript": "hi", "provider": "groq"}) as mock_groq:
from tools.transcription_tools import transcribe_audio
result = transcribe_audio(sample_ogg)
assert result["success"] is True
assert result["provider"] == "groq"
mock_groq.assert_called_once()
def test_dispatches_to_local(self, sample_ogg):
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
patch("tools.transcription_tools._get_provider", return_value="local"), \
patch("tools.transcription_tools._transcribe_local",
return_value={"success": True, "transcript": "hi"}) as mock_local:
from tools.transcription_tools import transcribe_audio
result = transcribe_audio(sample_ogg)
assert result["success"] is True
mock_local.assert_called_once()
def test_dispatches_to_openai(self, sample_ogg):
with patch("tools.transcription_tools._load_stt_config", return_value={"provider": "openai"}), \
patch("tools.transcription_tools._get_provider", return_value="openai"), \
patch("tools.transcription_tools._transcribe_openai",
return_value={"success": True, "transcript": "hi", "provider": "openai"}) as mock_openai:
from tools.transcription_tools import transcribe_audio
result = transcribe_audio(sample_ogg)
assert result["success"] is True
mock_openai.assert_called_once()
def test_no_provider_returns_error(self, sample_ogg):
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
patch("tools.transcription_tools._get_provider", return_value="none"):
from tools.transcription_tools import transcribe_audio
result = transcribe_audio(sample_ogg)
assert result["success"] is False
assert "No STT provider" in result["error"]
assert "faster-whisper" in result["error"]
assert "GROQ_API_KEY" in result["error"]
def test_invalid_file_short_circuits(self):
from tools.transcription_tools import transcribe_audio
result = transcribe_audio("/nonexistent/audio.wav")
assert result["success"] is False
assert "not found" in result["error"]
def test_model_override_passed_to_groq(self, sample_ogg):
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
patch("tools.transcription_tools._get_provider", return_value="groq"), \
patch("tools.transcription_tools._transcribe_groq",
return_value={"success": True, "transcript": "hi"}) as mock_groq:
from tools.transcription_tools import transcribe_audio
transcribe_audio(sample_ogg, model="whisper-large-v3")
_, kwargs = mock_groq.call_args
assert kwargs.get("model_name") or mock_groq.call_args[0][1] == "whisper-large-v3"
def test_model_override_passed_to_local(self, sample_ogg):
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
patch("tools.transcription_tools._get_provider", return_value="local"), \
patch("tools.transcription_tools._transcribe_local",
return_value={"success": True, "transcript": "hi"}) as mock_local:
from tools.transcription_tools import transcribe_audio
transcribe_audio(sample_ogg, model="large-v3")
assert mock_local.call_args[0][1] == "large-v3"
def test_default_model_used_when_none(self, sample_ogg):
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
patch("tools.transcription_tools._get_provider", return_value="groq"), \
patch("tools.transcription_tools._transcribe_groq",
return_value={"success": True, "transcript": "hi"}) as mock_groq:
from tools.transcription_tools import transcribe_audio, DEFAULT_GROQ_STT_MODEL
transcribe_audio(sample_ogg, model=None)
assert mock_groq.call_args[0][1] == DEFAULT_GROQ_STT_MODEL
def test_config_local_model_used(self, sample_ogg):
config = {"local": {"model": "small"}}
with patch("tools.transcription_tools._load_stt_config", return_value=config), \
patch("tools.transcription_tools._get_provider", return_value="local"), \
patch("tools.transcription_tools._transcribe_local",
return_value={"success": True, "transcript": "hi"}) as mock_local:
from tools.transcription_tools import transcribe_audio
transcribe_audio(sample_ogg, model=None)
assert mock_local.call_args[0][1] == "small"
def test_config_openai_model_used(self, sample_ogg):
config = {"openai": {"model": "gpt-4o-transcribe"}}
with patch("tools.transcription_tools._load_stt_config", return_value=config), \
patch("tools.transcription_tools._get_provider", return_value="openai"), \
patch("tools.transcription_tools._transcribe_openai",
return_value={"success": True, "transcript": "hi"}) as mock_openai:
from tools.transcription_tools import transcribe_audio
transcribe_audio(sample_ogg, model=None)
assert mock_openai.call_args[0][1] == "gpt-4o-transcribe"
# ============================================================================
# get_stt_model_from_config
# ============================================================================
class TestGetSttModelFromConfig:
def test_returns_model_from_config(self, tmp_path, monkeypatch):
cfg = tmp_path / "config.yaml"
cfg.write_text("stt:\n model: whisper-large-v3\n")
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.transcription_tools import get_stt_model_from_config
assert get_stt_model_from_config() == "whisper-large-v3"
def test_returns_none_when_no_stt_section(self, tmp_path, monkeypatch):
cfg = tmp_path / "config.yaml"
cfg.write_text("tts:\n provider: edge\n")
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.transcription_tools import get_stt_model_from_config
assert get_stt_model_from_config() is None
def test_returns_none_when_no_config_file(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.transcription_tools import get_stt_model_from_config
assert get_stt_model_from_config() is None
def test_returns_none_on_invalid_yaml(self, tmp_path, monkeypatch):
cfg = tmp_path / "config.yaml"
cfg.write_text(": : :\n bad yaml [[[")
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.transcription_tools import get_stt_model_from_config
assert get_stt_model_from_config() is None
def test_returns_none_when_model_key_missing(self, tmp_path, monkeypatch):
cfg = tmp_path / "config.yaml"
cfg.write_text("stt:\n enabled: true\n")
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.transcription_tools import get_stt_model_from_config
assert get_stt_model_from_config() is None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,938 @@
"""Tests for tools.voice_mode -- all mocked, no real microphone or API calls."""
import os
import struct
import time
import wave
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def sample_wav(tmp_path):
"""Create a minimal valid WAV file (1 second of silence at 16kHz)."""
wav_path = tmp_path / "test.wav"
n_frames = 16000 # 1 second at 16kHz
silence = struct.pack(f"<{n_frames}h", *([0] * n_frames))
with wave.open(str(wav_path), "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(16000)
wf.writeframes(silence)
return str(wav_path)
@pytest.fixture
def temp_voice_dir(tmp_path, monkeypatch):
"""Redirect _TEMP_DIR to a temporary path."""
voice_dir = tmp_path / "hermes_voice"
voice_dir.mkdir()
monkeypatch.setattr("tools.voice_mode._TEMP_DIR", str(voice_dir))
return voice_dir
@pytest.fixture
def mock_sd(monkeypatch):
"""Mock _import_audio to return (mock_sd, real_np) so lazy imports work."""
mock = MagicMock()
try:
import numpy as real_np
except ImportError:
real_np = MagicMock()
def _fake_import_audio():
return mock, real_np
monkeypatch.setattr("tools.voice_mode._import_audio", _fake_import_audio)
monkeypatch.setattr("tools.voice_mode._audio_available", lambda: True)
return mock
# ============================================================================
# check_voice_requirements
# ============================================================================
class TestCheckVoiceRequirements:
def test_all_requirements_met(self, monkeypatch):
monkeypatch.setattr("tools.voice_mode._audio_available", lambda: True)
monkeypatch.setattr("tools.voice_mode.detect_audio_environment",
lambda: {"available": True, "warnings": []})
monkeypatch.setattr("tools.transcription_tools._get_provider", lambda cfg: "openai")
from tools.voice_mode import check_voice_requirements
result = check_voice_requirements()
assert result["available"] is True
assert result["audio_available"] is True
assert result["stt_available"] is True
assert result["missing_packages"] == []
def test_missing_audio_packages(self, monkeypatch):
monkeypatch.setattr("tools.voice_mode._audio_available", lambda: False)
monkeypatch.setattr("tools.voice_mode.detect_audio_environment",
lambda: {"available": False, "warnings": ["Audio libraries not installed"]})
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test-key")
from tools.voice_mode import check_voice_requirements
result = check_voice_requirements()
assert result["available"] is False
assert result["audio_available"] is False
assert "sounddevice" in result["missing_packages"]
assert "numpy" in result["missing_packages"]
def test_missing_stt_provider(self, monkeypatch):
monkeypatch.setattr("tools.voice_mode._audio_available", lambda: True)
monkeypatch.setattr("tools.voice_mode.detect_audio_environment",
lambda: {"available": True, "warnings": []})
monkeypatch.setattr("tools.transcription_tools._get_provider", lambda cfg: "none")
from tools.voice_mode import check_voice_requirements
result = check_voice_requirements()
assert result["available"] is False
assert result["stt_available"] is False
assert "STT provider: MISSING" in result["details"]
# ============================================================================
# AudioRecorder
# ============================================================================
class TestAudioRecorderStart:
def test_start_raises_without_audio(self, monkeypatch):
def _fail_import():
raise ImportError("no sounddevice")
monkeypatch.setattr("tools.voice_mode._import_audio", _fail_import)
from tools.voice_mode import AudioRecorder
recorder = AudioRecorder()
with pytest.raises(RuntimeError, match="sounddevice and numpy"):
recorder.start()
def test_start_creates_and_starts_stream(self, mock_sd):
mock_stream = MagicMock()
mock_sd.InputStream.return_value = mock_stream
from tools.voice_mode import AudioRecorder
recorder = AudioRecorder()
recorder.start()
assert recorder.is_recording is True
mock_sd.InputStream.assert_called_once()
mock_stream.start.assert_called_once()
def test_double_start_is_noop(self, mock_sd):
mock_stream = MagicMock()
mock_sd.InputStream.return_value = mock_stream
from tools.voice_mode import AudioRecorder
recorder = AudioRecorder()
recorder.start()
recorder.start() # second call should be noop
assert mock_sd.InputStream.call_count == 1
class TestAudioRecorderStop:
def test_stop_returns_none_when_not_recording(self):
from tools.voice_mode import AudioRecorder
recorder = AudioRecorder()
assert recorder.stop() is None
def test_stop_writes_wav_file(self, mock_sd, temp_voice_dir):
np = pytest.importorskip("numpy")
mock_stream = MagicMock()
mock_sd.InputStream.return_value = mock_stream
from tools.voice_mode import AudioRecorder, SAMPLE_RATE
recorder = AudioRecorder()
recorder.start()
# Simulate captured audio frames (1 second of loud audio above RMS threshold)
frame = np.full((SAMPLE_RATE, 1), 1000, dtype="int16")
recorder._frames = [frame]
recorder._peak_rms = 1000 # Peak RMS above threshold
wav_path = recorder.stop()
assert wav_path is not None
assert os.path.isfile(wav_path)
assert wav_path.endswith(".wav")
assert recorder.is_recording is False
# Verify it is a valid WAV
with wave.open(wav_path, "rb") as wf:
assert wf.getnchannels() == 1
assert wf.getsampwidth() == 2
assert wf.getframerate() == SAMPLE_RATE
def test_stop_returns_none_for_very_short_recording(self, mock_sd, temp_voice_dir):
np = pytest.importorskip("numpy")
mock_stream = MagicMock()
mock_sd.InputStream.return_value = mock_stream
from tools.voice_mode import AudioRecorder
recorder = AudioRecorder()
recorder.start()
# Very short recording (100 samples = ~6ms at 16kHz)
frame = np.zeros((100, 1), dtype="int16")
recorder._frames = [frame]
wav_path = recorder.stop()
assert wav_path is None
def test_stop_returns_none_for_silent_recording(self, mock_sd, temp_voice_dir):
np = pytest.importorskip("numpy")
mock_stream = MagicMock()
mock_sd.InputStream.return_value = mock_stream
from tools.voice_mode import AudioRecorder, SAMPLE_RATE
recorder = AudioRecorder()
recorder.start()
# 1 second of near-silence (RMS well below threshold)
frame = np.full((SAMPLE_RATE, 1), 10, dtype="int16")
recorder._frames = [frame]
recorder._peak_rms = 10 # Peak RMS also below threshold
wav_path = recorder.stop()
assert wav_path is None
class TestAudioRecorderCancel:
def test_cancel_discards_frames(self, mock_sd):
mock_stream = MagicMock()
mock_sd.InputStream.return_value = mock_stream
from tools.voice_mode import AudioRecorder
recorder = AudioRecorder()
recorder.start()
recorder._frames = [MagicMock()] # simulate captured data
recorder.cancel()
assert recorder.is_recording is False
assert recorder._frames == []
# Stream is kept alive (persistent) — cancel() does NOT close it.
mock_stream.stop.assert_not_called()
mock_stream.close.assert_not_called()
def test_cancel_when_not_recording_is_safe(self):
from tools.voice_mode import AudioRecorder
recorder = AudioRecorder()
recorder.cancel() # should not raise
assert recorder.is_recording is False
class TestAudioRecorderProperties:
def test_elapsed_seconds_when_not_recording(self):
from tools.voice_mode import AudioRecorder
recorder = AudioRecorder()
assert recorder.elapsed_seconds == 0.0
def test_elapsed_seconds_when_recording(self, mock_sd):
mock_stream = MagicMock()
mock_sd.InputStream.return_value = mock_stream
from tools.voice_mode import AudioRecorder
recorder = AudioRecorder()
recorder.start()
# Force start time to 1 second ago
recorder._start_time = time.monotonic() - 1.0
elapsed = recorder.elapsed_seconds
assert 0.9 < elapsed < 2.0
recorder.cancel()
# ============================================================================
# transcribe_recording
# ============================================================================
class TestTranscribeRecording:
def test_delegates_to_transcribe_audio(self):
mock_transcribe = MagicMock(return_value={
"success": True,
"transcript": "hello world",
})
with patch("tools.transcription_tools.transcribe_audio", mock_transcribe):
from tools.voice_mode import transcribe_recording
result = transcribe_recording("/tmp/test.wav", model="whisper-1")
assert result["success"] is True
assert result["transcript"] == "hello world"
mock_transcribe.assert_called_once_with("/tmp/test.wav", model="whisper-1")
def test_filters_whisper_hallucination(self):
mock_transcribe = MagicMock(return_value={
"success": True,
"transcript": "Thank you.",
})
with patch("tools.transcription_tools.transcribe_audio", mock_transcribe):
from tools.voice_mode import transcribe_recording
result = transcribe_recording("/tmp/test.wav")
assert result["success"] is True
assert result["transcript"] == ""
assert result["filtered"] is True
def test_does_not_filter_real_speech(self):
mock_transcribe = MagicMock(return_value={
"success": True,
"transcript": "Thank you for helping me with this code.",
})
with patch("tools.transcription_tools.transcribe_audio", mock_transcribe):
from tools.voice_mode import transcribe_recording
result = transcribe_recording("/tmp/test.wav")
assert result["transcript"] == "Thank you for helping me with this code."
assert "filtered" not in result
class TestWhisperHallucinationFilter:
def test_known_hallucinations(self):
from tools.voice_mode import is_whisper_hallucination
assert is_whisper_hallucination("Thank you.") is True
assert is_whisper_hallucination("thank you") is True
assert is_whisper_hallucination("Thanks for watching.") is True
assert is_whisper_hallucination("Bye.") is True
assert is_whisper_hallucination(" Thank you. ") is True # with whitespace
assert is_whisper_hallucination("you") is True
def test_real_speech_not_filtered(self):
from tools.voice_mode import is_whisper_hallucination
assert is_whisper_hallucination("Hello, how are you?") is False
assert is_whisper_hallucination("Thank you for your help with the project.") is False
assert is_whisper_hallucination("Can you explain this code?") is False
# ============================================================================
# play_audio_file
# ============================================================================
class TestPlayAudioFile:
def test_play_wav_via_sounddevice(self, monkeypatch, sample_wav):
np = pytest.importorskip("numpy")
mock_sd_obj = MagicMock()
# Simulate stream completing immediately (get_stream().active = False)
mock_stream = MagicMock()
mock_stream.active = False
mock_sd_obj.get_stream.return_value = mock_stream
def _fake_import():
return mock_sd_obj, np
monkeypatch.setattr("tools.voice_mode._import_audio", _fake_import)
from tools.voice_mode import play_audio_file
result = play_audio_file(sample_wav)
assert result is True
mock_sd_obj.play.assert_called_once()
mock_sd_obj.stop.assert_called_once()
def test_returns_false_when_no_player(self, monkeypatch, sample_wav):
def _fail_import():
raise ImportError("no sounddevice")
monkeypatch.setattr("tools.voice_mode._import_audio", _fail_import)
monkeypatch.setattr("shutil.which", lambda _: None)
from tools.voice_mode import play_audio_file
result = play_audio_file(sample_wav)
assert result is False
def test_returns_false_for_missing_file(self):
from tools.voice_mode import play_audio_file
result = play_audio_file("/nonexistent/file.wav")
assert result is False
# ============================================================================
# cleanup_temp_recordings
# ============================================================================
class TestCleanupTempRecordings:
def test_old_files_deleted(self, temp_voice_dir):
# Create an "old" file
old_file = temp_voice_dir / "recording_20240101_000000.wav"
old_file.write_bytes(b"\x00" * 100)
# Set mtime to 2 hours ago
old_mtime = time.time() - 7200
os.utime(str(old_file), (old_mtime, old_mtime))
from tools.voice_mode import cleanup_temp_recordings
deleted = cleanup_temp_recordings(max_age_seconds=3600)
assert deleted == 1
assert not old_file.exists()
def test_recent_files_preserved(self, temp_voice_dir):
# Create a "recent" file
recent_file = temp_voice_dir / "recording_20260303_120000.wav"
recent_file.write_bytes(b"\x00" * 100)
from tools.voice_mode import cleanup_temp_recordings
deleted = cleanup_temp_recordings(max_age_seconds=3600)
assert deleted == 0
assert recent_file.exists()
def test_nonexistent_dir_returns_zero(self, monkeypatch):
monkeypatch.setattr("tools.voice_mode._TEMP_DIR", "/nonexistent/dir")
from tools.voice_mode import cleanup_temp_recordings
assert cleanup_temp_recordings() == 0
def test_non_recording_files_ignored(self, temp_voice_dir):
# Create a file that doesn't match the pattern
other_file = temp_voice_dir / "other_file.txt"
other_file.write_bytes(b"\x00" * 100)
old_mtime = time.time() - 7200
os.utime(str(other_file), (old_mtime, old_mtime))
from tools.voice_mode import cleanup_temp_recordings
deleted = cleanup_temp_recordings(max_age_seconds=3600)
assert deleted == 0
assert other_file.exists()
# ============================================================================
# play_beep
# ============================================================================
class TestPlayBeep:
def test_beep_calls_sounddevice_play(self, mock_sd):
np = pytest.importorskip("numpy")
from tools.voice_mode import play_beep
# play_beep uses polling (get_stream) + sd.stop() instead of sd.wait()
mock_stream = MagicMock()
mock_stream.active = False
mock_sd.get_stream.return_value = mock_stream
play_beep(frequency=880, duration=0.1, count=1)
mock_sd.play.assert_called_once()
mock_sd.stop.assert_called()
# Verify audio data is int16 numpy array
audio_arg = mock_sd.play.call_args[0][0]
assert audio_arg.dtype == np.int16
assert len(audio_arg) > 0
def test_beep_double_produces_longer_audio(self, mock_sd):
np = pytest.importorskip("numpy")
from tools.voice_mode import play_beep
play_beep(frequency=660, duration=0.1, count=2)
audio_arg = mock_sd.play.call_args[0][0]
single_beep_samples = int(16000 * 0.1)
# Double beep should be longer than a single beep
assert len(audio_arg) > single_beep_samples
def test_beep_noop_without_audio(self, monkeypatch):
def _fail_import():
raise ImportError("no sounddevice")
monkeypatch.setattr("tools.voice_mode._import_audio", _fail_import)
from tools.voice_mode import play_beep
# Should not raise
play_beep()
def test_beep_handles_playback_error(self, mock_sd):
mock_sd.play.side_effect = Exception("device error")
from tools.voice_mode import play_beep
# Should not raise
play_beep()
# ============================================================================
# Silence detection
# ============================================================================
class TestSilenceDetection:
def test_silence_callback_fires_after_speech_then_silence(self, mock_sd):
np = pytest.importorskip("numpy")
import threading
mock_stream = MagicMock()
mock_sd.InputStream.return_value = mock_stream
from tools.voice_mode import AudioRecorder, SAMPLE_RATE
recorder = AudioRecorder()
# Use very short durations for testing
recorder._silence_duration = 0.05
recorder._min_speech_duration = 0.05
fired = threading.Event()
def on_silence():
fired.set()
recorder.start(on_silence_stop=on_silence)
# Get the callback function from InputStream constructor
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
if callback is None:
callback = mock_sd.InputStream.call_args[1]["callback"]
# Simulate sustained speech (multiple loud chunks to exceed min_speech_duration)
loud_frame = np.full((1600, 1), 5000, dtype="int16")
callback(loud_frame, 1600, None, None)
time.sleep(0.06)
callback(loud_frame, 1600, None, None)
assert recorder._has_spoken is True
# Simulate silence
silent_frame = np.zeros((1600, 1), dtype="int16")
callback(silent_frame, 1600, None, None)
# Wait a bit past the silence duration, then send another silent frame
time.sleep(0.06)
callback(silent_frame, 1600, None, None)
# The callback should have been fired
assert fired.wait(timeout=1.0) is True
recorder.cancel()
def test_silence_without_speech_does_not_fire(self, mock_sd):
np = pytest.importorskip("numpy")
import threading
mock_stream = MagicMock()
mock_sd.InputStream.return_value = mock_stream
from tools.voice_mode import AudioRecorder
recorder = AudioRecorder()
recorder._silence_duration = 0.02
fired = threading.Event()
recorder.start(on_silence_stop=lambda: fired.set())
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
if callback is None:
callback = mock_sd.InputStream.call_args[1]["callback"]
# Only silence -- no speech detected, so callback should NOT fire
silent_frame = np.zeros((1600, 1), dtype="int16")
for _ in range(5):
callback(silent_frame, 1600, None, None)
time.sleep(0.01)
assert fired.wait(timeout=0.2) is False
recorder.cancel()
def test_micro_pause_tolerance_during_speech(self, mock_sd):
"""Brief dips below threshold during speech should NOT reset speech tracking."""
np = pytest.importorskip("numpy")
import threading
mock_stream = MagicMock()
mock_sd.InputStream.return_value = mock_stream
from tools.voice_mode import AudioRecorder
recorder = AudioRecorder()
recorder._silence_duration = 0.05
recorder._min_speech_duration = 0.15
recorder._max_dip_tolerance = 0.1
fired = threading.Event()
recorder.start(on_silence_stop=lambda: fired.set())
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
if callback is None:
callback = mock_sd.InputStream.call_args[1]["callback"]
loud_frame = np.full((1600, 1), 5000, dtype="int16")
quiet_frame = np.full((1600, 1), 50, dtype="int16")
# Speech chunk 1
callback(loud_frame, 1600, None, None)
time.sleep(0.05)
# Brief micro-pause (dip < max_dip_tolerance)
callback(quiet_frame, 1600, None, None)
time.sleep(0.05)
# Speech resumes -- speech_start should NOT have been reset
callback(loud_frame, 1600, None, None)
assert recorder._speech_start > 0, "Speech start should be preserved across brief dips"
time.sleep(0.06)
# Another speech chunk to exceed min_speech_duration
callback(loud_frame, 1600, None, None)
assert recorder._has_spoken is True, "Speech should be confirmed after tolerating micro-pause"
recorder.cancel()
def test_no_callback_means_no_silence_detection(self, mock_sd):
np = pytest.importorskip("numpy")
mock_stream = MagicMock()
mock_sd.InputStream.return_value = mock_stream
from tools.voice_mode import AudioRecorder
recorder = AudioRecorder()
recorder.start() # no on_silence_stop
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
if callback is None:
callback = mock_sd.InputStream.call_args[1]["callback"]
# Even with speech then silence, nothing should happen
loud_frame = np.full((1600, 1), 5000, dtype="int16")
silent_frame = np.zeros((1600, 1), dtype="int16")
callback(loud_frame, 1600, None, None)
callback(silent_frame, 1600, None, None)
# No crash, no callback
assert recorder._on_silence_stop is None
recorder.cancel()
# ============================================================================
# Playback interrupt
# ============================================================================
class TestPlaybackInterrupt:
"""Verify that TTS playback can be interrupted."""
def test_stop_playback_terminates_process(self):
from tools.voice_mode import stop_playback, _playback_lock
import tools.voice_mode as vm
mock_proc = MagicMock()
mock_proc.poll.return_value = None # process is running
with _playback_lock:
vm._active_playback = mock_proc
stop_playback()
mock_proc.terminate.assert_called_once()
with _playback_lock:
assert vm._active_playback is None
def test_stop_playback_noop_when_nothing_playing(self):
import tools.voice_mode as vm
with vm._playback_lock:
vm._active_playback = None
vm.stop_playback()
def test_play_audio_file_sets_active_playback(self, monkeypatch, sample_wav):
import tools.voice_mode as vm
def _fail_import():
raise ImportError("no sounddevice")
monkeypatch.setattr("tools.voice_mode._import_audio", _fail_import)
mock_proc = MagicMock()
mock_proc.wait.return_value = 0
mock_popen = MagicMock(return_value=mock_proc)
monkeypatch.setattr("subprocess.Popen", mock_popen)
monkeypatch.setattr("shutil.which", lambda cmd: "/usr/bin/" + cmd)
vm.play_audio_file(sample_wav)
assert mock_popen.called
with vm._playback_lock:
assert vm._active_playback is None
# ============================================================================
# Continuous mode flow
# ============================================================================
class TestContinuousModeFlow:
"""Verify continuous mode: auto-restart after transcription or silence."""
def test_continuous_restart_on_no_speech(self, mock_sd, temp_voice_dir):
np = pytest.importorskip("numpy")
mock_stream = MagicMock()
mock_sd.InputStream.return_value = mock_stream
from tools.voice_mode import AudioRecorder
recorder = AudioRecorder()
# First recording: only silence -> stop returns None
recorder.start()
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
if callback is None:
callback = mock_sd.InputStream.call_args[1]["callback"]
for _ in range(10):
silence = np.full((1600, 1), 10, dtype="int16")
callback(silence, 1600, None, None)
wav_path = recorder.stop()
assert wav_path is None
# Simulate continuous mode restart
recorder.start()
assert recorder.is_recording is True
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
if callback is None:
callback = mock_sd.InputStream.call_args[1]["callback"]
for _ in range(10):
speech = np.full((1600, 1), 5000, dtype="int16")
callback(speech, 1600, None, None)
wav_path = recorder.stop()
assert wav_path is not None
recorder.cancel()
def test_recorder_reusable_after_stop(self, mock_sd, temp_voice_dir):
np = pytest.importorskip("numpy")
mock_stream = MagicMock()
mock_sd.InputStream.return_value = mock_stream
from tools.voice_mode import AudioRecorder
recorder = AudioRecorder()
results = []
for i in range(3):
recorder.start()
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
if callback is None:
callback = mock_sd.InputStream.call_args[1]["callback"]
loud = np.full((1600, 1), 5000, dtype="int16")
for _ in range(10):
callback(loud, 1600, None, None)
wav_path = recorder.stop()
results.append(wav_path)
assert all(r is not None for r in results)
assert os.path.isfile(results[-1])
# ============================================================================
# Audio level indicator
# ============================================================================
class TestAudioLevelIndicator:
"""Verify current_rms property updates in real-time for UI feedback."""
def test_rms_updates_with_audio_chunks(self, mock_sd):
np = pytest.importorskip("numpy")
mock_stream = MagicMock()
mock_sd.InputStream.return_value = mock_stream
from tools.voice_mode import AudioRecorder
recorder = AudioRecorder()
recorder.start()
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
if callback is None:
callback = mock_sd.InputStream.call_args[1]["callback"]
assert recorder.current_rms == 0
loud = np.full((1600, 1), 5000, dtype="int16")
callback(loud, 1600, None, None)
assert recorder.current_rms == 5000
quiet = np.full((1600, 1), 100, dtype="int16")
callback(quiet, 1600, None, None)
assert recorder.current_rms == 100
recorder.cancel()
def test_peak_rms_tracks_maximum(self, mock_sd):
np = pytest.importorskip("numpy")
mock_stream = MagicMock()
mock_sd.InputStream.return_value = mock_stream
from tools.voice_mode import AudioRecorder
recorder = AudioRecorder()
recorder.start()
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
if callback is None:
callback = mock_sd.InputStream.call_args[1]["callback"]
frames = [
np.full((1600, 1), 100, dtype="int16"),
np.full((1600, 1), 8000, dtype="int16"),
np.full((1600, 1), 500, dtype="int16"),
np.full((1600, 1), 3000, dtype="int16"),
]
for frame in frames:
callback(frame, 1600, None, None)
assert recorder._peak_rms == 8000
assert recorder.current_rms == 3000
recorder.cancel()
# ============================================================================
# Configurable silence parameters
# ============================================================================
class TestConfigurableSilenceParams:
"""Verify that silence detection params can be configured."""
def test_custom_threshold_and_duration(self, mock_sd):
np = pytest.importorskip("numpy")
mock_stream = MagicMock()
mock_sd.InputStream.return_value = mock_stream
from tools.voice_mode import AudioRecorder
import threading
recorder = AudioRecorder()
recorder._silence_threshold = 5000
recorder._silence_duration = 0.05
recorder._min_speech_duration = 0.05
fired = threading.Event()
recorder.start(on_silence_stop=lambda: fired.set())
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
if callback is None:
callback = mock_sd.InputStream.call_args[1]["callback"]
# Audio at RMS 1000 -- below custom threshold (5000)
moderate = np.full((1600, 1), 1000, dtype="int16")
for _ in range(5):
callback(moderate, 1600, None, None)
time.sleep(0.02)
assert recorder._has_spoken is False
assert fired.wait(timeout=0.2) is False
# Now send really loud audio (above 5000 threshold)
very_loud = np.full((1600, 1), 8000, dtype="int16")
callback(very_loud, 1600, None, None)
time.sleep(0.06)
callback(very_loud, 1600, None, None)
assert recorder._has_spoken is True
recorder.cancel()
# ============================================================================
# Bugfix regression tests
# ============================================================================
class TestSubprocessTimeoutKill:
"""Bug: proc.wait(timeout) raised TimeoutExpired but process was not killed."""
def test_timeout_kills_process(self):
import subprocess, os
proc = subprocess.Popen(["sleep", "600"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
pid = proc.pid
assert proc.poll() is None
try:
proc.wait(timeout=0.1)
except subprocess.TimeoutExpired:
proc.kill()
proc.wait()
assert proc.poll() is not None
assert proc.returncode is not None
class TestStreamLeakOnStartFailure:
"""Bug: stream.start() failure left stream unclosed."""
def test_stream_closed_on_start_failure(self, mock_sd):
mock_stream = MagicMock()
mock_stream.start.side_effect = OSError("Audio device busy")
mock_sd.InputStream.return_value = mock_stream
from tools.voice_mode import AudioRecorder
recorder = AudioRecorder()
with pytest.raises(RuntimeError, match="Failed to open audio input stream"):
recorder._ensure_stream()
mock_stream.close.assert_called_once()
class TestSilenceCallbackLock:
"""Bug: _on_silence_stop was read/written without lock in audio callback."""
def test_fire_block_acquires_lock(self):
import inspect
from tools.voice_mode import AudioRecorder
source = inspect.getsource(AudioRecorder._ensure_stream)
# Verify lock is used before reading _on_silence_stop in fire block
assert "with self._lock:" in source
assert "cb = self._on_silence_stop" in source
lock_pos = source.index("with self._lock:")
cb_pos = source.index("cb = self._on_silence_stop")
assert lock_pos < cb_pos
def test_cancel_clears_callback_under_lock(self, mock_sd):
from tools.voice_mode import AudioRecorder
recorder = AudioRecorder()
mock_sd.InputStream.return_value = MagicMock()
cb = lambda: None
recorder.start(on_silence_stop=cb)
assert recorder._on_silence_stop is cb
recorder.cancel()
with recorder._lock:
assert recorder._on_silence_stop is None