merge: salvage PR #327 voice mode branch
Merge contributor branch feature/voice-mode onto current main for follow-up fixes.
This commit is contained in:
@@ -275,12 +275,25 @@ class FakeHAServer:
|
||||
affected = []
|
||||
entity_id = body.get("entity_id")
|
||||
if entity_id:
|
||||
new_state = "on" if service == "turn_on" else "off"
|
||||
for s in ENTITY_STATES:
|
||||
if s["entity_id"] == entity_id:
|
||||
if service == "turn_on":
|
||||
s["state"] = "on"
|
||||
elif service == "turn_off":
|
||||
s["state"] = "off"
|
||||
elif service == "set_temperature" and "temperature" in body:
|
||||
s["attributes"]["temperature"] = body["temperature"]
|
||||
# Keep current state or set to heat if off
|
||||
if s["state"] == "off":
|
||||
s["state"] = "heat"
|
||||
# Simulate temperature sensor approaching the target
|
||||
for ts in ENTITY_STATES:
|
||||
if ts["entity_id"] == "sensor.temperature":
|
||||
ts["state"] = str(body["temperature"] - 0.5)
|
||||
break
|
||||
affected.append({
|
||||
"entity_id": entity_id,
|
||||
"state": new_state,
|
||||
"state": s["state"],
|
||||
"attributes": s.get("attributes", {}),
|
||||
})
|
||||
break
|
||||
|
||||
@@ -32,6 +32,7 @@ def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
runner._session_db = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
|
||||
@@ -29,6 +29,8 @@ def _ensure_discord_mock():
|
||||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
|
||||
44
tests/gateway/test_discord_opus.py
Normal file
44
tests/gateway/test_discord_opus.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Tests for Discord Opus codec loading — must use ctypes.util.find_library."""
|
||||
|
||||
import inspect
|
||||
|
||||
|
||||
class TestOpusFindLibrary:
|
||||
"""Opus loading must try ctypes.util.find_library first, with platform fallback."""
|
||||
|
||||
def test_uses_find_library_first(self):
|
||||
"""find_library must be the primary lookup strategy."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
source = inspect.getsource(DiscordAdapter.connect)
|
||||
assert "find_library" in source, \
|
||||
"Opus loading must use ctypes.util.find_library"
|
||||
|
||||
def test_homebrew_fallback_is_conditional(self):
|
||||
"""Homebrew paths must only be tried when find_library returns None."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
source = inspect.getsource(DiscordAdapter.connect)
|
||||
# Homebrew fallback must exist
|
||||
assert "/opt/homebrew" in source or "homebrew" in source, \
|
||||
"Opus loading should have macOS Homebrew fallback"
|
||||
# find_library must appear BEFORE any Homebrew path
|
||||
fl_idx = source.index("find_library")
|
||||
hb_idx = source.index("/opt/homebrew")
|
||||
assert fl_idx < hb_idx, \
|
||||
"find_library must be tried before Homebrew fallback paths"
|
||||
# Fallback must be guarded by platform check
|
||||
assert "sys.platform" in source or "darwin" in source, \
|
||||
"Homebrew fallback must be guarded by macOS platform check"
|
||||
|
||||
def test_opus_decode_error_logged(self):
|
||||
"""Opus decode failure must log the error, not silently return."""
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
source = inspect.getsource(VoiceReceiver._on_packet)
|
||||
assert "logger" in source, \
|
||||
"_on_packet must log Opus decode errors"
|
||||
# Must not have bare `except Exception:\n return`
|
||||
lines = source.split("\n")
|
||||
for i, line in enumerate(lines):
|
||||
if "except Exception" in line and i + 1 < len(lines):
|
||||
next_line = lines[i + 1].strip()
|
||||
assert next_line != "return", \
|
||||
f"_on_packet has bare 'except Exception: return' at line {i+1}"
|
||||
@@ -21,6 +21,8 @@ def _ensure_discord_mock():
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
|
||||
@@ -36,6 +36,7 @@ def _make_runner(session_db=None, current_session_id="current_session_001",
|
||||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
runner._session_db = session_db
|
||||
runner._running_agents = {}
|
||||
|
||||
|
||||
@@ -77,6 +77,7 @@ def _make_runner(adapter):
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner._prefill_messages = []
|
||||
runner._ephemeral_system_prompt = ""
|
||||
runner._reasoning_config = None
|
||||
|
||||
@@ -266,6 +266,7 @@ async def test_session_hygiene_messages_stay_in_originating_topic(monkeypatch, t
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="fake-token")}
|
||||
)
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = SessionEntry(
|
||||
|
||||
@@ -31,6 +31,7 @@ def _make_runner(session_db=None):
|
||||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
runner._session_db = session_db
|
||||
|
||||
# Mock session_store that returns a session entry with a known session_id
|
||||
|
||||
@@ -33,6 +33,7 @@ def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
return runner
|
||||
|
||||
|
||||
|
||||
1965
tests/gateway/test_voice_command.py
Normal file
1965
tests/gateway/test_voice_command.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -12,7 +12,7 @@ EXPECTED_COMMANDS = {
|
||||
"/personality", "/clear", "/history", "/new", "/reset", "/retry",
|
||||
"/undo", "/save", "/config", "/cron", "/skills", "/platforms",
|
||||
"/verbose", "/reasoning", "/compress", "/title", "/usage", "/insights", "/paste",
|
||||
"/reload-mcp", "/rollback", "/background", "/skin", "/quit",
|
||||
"/reload-mcp", "/rollback", "/background", "/skin", "/voice", "/quit",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@ def _make_cli_stub():
|
||||
cli._clarify_freetext = False
|
||||
cli._command_running = False
|
||||
cli._agent_running = False
|
||||
cli._voice_recording = False
|
||||
cli._voice_processing = False
|
||||
cli._voice_mode = False
|
||||
cli._command_spinner_frame = lambda: "⟳"
|
||||
cli._tui_style_base = {
|
||||
"prompt": "#fff",
|
||||
|
||||
@@ -2083,3 +2083,332 @@ class TestAnthropicBaseUrlPassthrough:
|
||||
# No base_url provided, should be default empty string or None
|
||||
passed_url = call_args[0][1]
|
||||
assert not passed_url or passed_url is None
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# _streaming_api_call tests
|
||||
# ===================================================================
|
||||
|
||||
def _make_chunk(content=None, tool_calls=None, finish_reason=None, model="test/model"):
|
||||
"""Build a SimpleNamespace mimicking an OpenAI streaming chunk."""
|
||||
delta = SimpleNamespace(content=content, tool_calls=tool_calls)
|
||||
choice = SimpleNamespace(delta=delta, finish_reason=finish_reason)
|
||||
return SimpleNamespace(model=model, choices=[choice])
|
||||
|
||||
|
||||
def _make_tc_delta(index=0, tc_id=None, name=None, arguments=None):
|
||||
"""Build a SimpleNamespace mimicking a streaming tool_call delta."""
|
||||
func = SimpleNamespace(name=name, arguments=arguments)
|
||||
return SimpleNamespace(index=index, id=tc_id, function=func)
|
||||
|
||||
|
||||
class TestStreamingApiCall:
|
||||
"""Tests for _streaming_api_call — voice TTS streaming pipeline."""
|
||||
|
||||
def test_content_assembly(self, agent):
|
||||
chunks = [
|
||||
_make_chunk(content="Hel"),
|
||||
_make_chunk(content="lo "),
|
||||
_make_chunk(content="World"),
|
||||
_make_chunk(finish_reason="stop"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
callback = MagicMock()
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, callback)
|
||||
|
||||
assert resp.choices[0].message.content == "Hello World"
|
||||
assert resp.choices[0].finish_reason == "stop"
|
||||
assert callback.call_count == 3
|
||||
callback.assert_any_call("Hel")
|
||||
callback.assert_any_call("lo ")
|
||||
callback.assert_any_call("World")
|
||||
|
||||
def test_tool_call_accumulation(self, agent):
|
||||
chunks = [
|
||||
_make_chunk(tool_calls=[_make_tc_delta(0, "call_1", "web_", '{"q":')]),
|
||||
_make_chunk(tool_calls=[_make_tc_delta(0, None, "search", '"test"}')]),
|
||||
_make_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
tc = resp.choices[0].message.tool_calls
|
||||
assert len(tc) == 1
|
||||
assert tc[0].function.name == "web_search"
|
||||
assert tc[0].function.arguments == '{"q":"test"}'
|
||||
assert tc[0].id == "call_1"
|
||||
|
||||
def test_multiple_tool_calls(self, agent):
|
||||
chunks = [
|
||||
_make_chunk(tool_calls=[_make_tc_delta(0, "call_a", "search", '{}')]),
|
||||
_make_chunk(tool_calls=[_make_tc_delta(1, "call_b", "read", '{}')]),
|
||||
_make_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
tc = resp.choices[0].message.tool_calls
|
||||
assert len(tc) == 2
|
||||
assert tc[0].function.name == "search"
|
||||
assert tc[1].function.name == "read"
|
||||
|
||||
def test_content_and_tool_calls_together(self, agent):
|
||||
chunks = [
|
||||
_make_chunk(content="I'll search"),
|
||||
_make_chunk(tool_calls=[_make_tc_delta(0, "call_1", "search", '{}')]),
|
||||
_make_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
assert resp.choices[0].message.content == "I'll search"
|
||||
assert len(resp.choices[0].message.tool_calls) == 1
|
||||
|
||||
def test_empty_content_returns_none(self, agent):
|
||||
chunks = [_make_chunk(finish_reason="stop")]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
assert resp.choices[0].message.content is None
|
||||
assert resp.choices[0].message.tool_calls is None
|
||||
|
||||
def test_callback_exception_swallowed(self, agent):
|
||||
chunks = [
|
||||
_make_chunk(content="Hello"),
|
||||
_make_chunk(content=" World"),
|
||||
_make_chunk(finish_reason="stop"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
callback = MagicMock(side_effect=ValueError("boom"))
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, callback)
|
||||
|
||||
assert resp.choices[0].message.content == "Hello World"
|
||||
|
||||
def test_model_name_captured(self, agent):
|
||||
chunks = [
|
||||
_make_chunk(content="Hi", model="gpt-4o"),
|
||||
_make_chunk(finish_reason="stop", model="gpt-4o"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
assert resp.model == "gpt-4o"
|
||||
|
||||
def test_stream_kwarg_injected(self, agent):
|
||||
chunks = [_make_chunk(content="x"), _make_chunk(finish_reason="stop")]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
agent._streaming_api_call({"messages": [], "model": "test"}, MagicMock())
|
||||
|
||||
call_kwargs = agent.client.chat.completions.create.call_args
|
||||
assert call_kwargs[1].get("stream") is True or call_kwargs.kwargs.get("stream") is True
|
||||
|
||||
def test_api_exception_propagated(self, agent):
|
||||
agent.client.chat.completions.create.side_effect = ConnectionError("fail")
|
||||
|
||||
with pytest.raises(ConnectionError, match="fail"):
|
||||
agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
def test_response_has_uuid_id(self, agent):
|
||||
chunks = [_make_chunk(content="x"), _make_chunk(finish_reason="stop")]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
assert resp.id.startswith("stream-")
|
||||
assert len(resp.id) > len("stream-")
|
||||
|
||||
def test_empty_choices_chunk_skipped(self, agent):
|
||||
empty_chunk = SimpleNamespace(model="gpt-4", choices=[])
|
||||
chunks = [
|
||||
empty_chunk,
|
||||
_make_chunk(content="Hello", model="gpt-4"),
|
||||
_make_chunk(finish_reason="stop", model="gpt-4"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
assert resp.choices[0].message.content == "Hello"
|
||||
assert resp.model == "gpt-4"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Interrupt _vprint force=True verification
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestInterruptVprintForceTrue:
|
||||
"""All interrupt _vprint calls must use force=True so they are always visible."""
|
||||
|
||||
def test_all_interrupt_vprint_have_force_true(self):
|
||||
"""Scan source for _vprint calls containing 'Interrupt' — each must have force=True."""
|
||||
import inspect
|
||||
source = inspect.getsource(AIAgent)
|
||||
lines = source.split("\n")
|
||||
violations = []
|
||||
for i, line in enumerate(lines, 1):
|
||||
stripped = line.strip()
|
||||
if "_vprint(" in stripped and "Interrupt" in stripped:
|
||||
if "force=True" not in stripped:
|
||||
violations.append(f"line {i}: {stripped}")
|
||||
assert not violations, (
|
||||
f"Interrupt _vprint calls missing force=True:\n"
|
||||
+ "\n".join(violations)
|
||||
)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Anthropic interrupt handler in _interruptible_api_call
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestAnthropicInterruptHandler:
|
||||
"""_interruptible_api_call must handle Anthropic mode when interrupted."""
|
||||
|
||||
def test_interruptible_has_anthropic_branch(self):
|
||||
"""The interrupt handler must check api_mode == 'anthropic_messages'."""
|
||||
import inspect
|
||||
source = inspect.getsource(AIAgent._interruptible_api_call)
|
||||
assert "anthropic_messages" in source, \
|
||||
"_interruptible_api_call must handle Anthropic interrupt (api_mode check)"
|
||||
|
||||
def test_interruptible_rebuilds_anthropic_client(self):
|
||||
"""After interrupting, the Anthropic client should be rebuilt."""
|
||||
import inspect
|
||||
source = inspect.getsource(AIAgent._interruptible_api_call)
|
||||
assert "build_anthropic_client" in source, \
|
||||
"_interruptible_api_call must rebuild Anthropic client after interrupt"
|
||||
|
||||
def test_streaming_has_anthropic_branch(self):
|
||||
"""_streaming_api_call must also handle Anthropic interrupt."""
|
||||
import inspect
|
||||
source = inspect.getsource(AIAgent._streaming_api_call)
|
||||
assert "anthropic_messages" in source, \
|
||||
"_streaming_api_call must handle Anthropic interrupt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bugfix: stream_callback forwarding for non-streaming providers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStreamCallbackNonStreamingProvider:
|
||||
"""When api_mode != chat_completions, stream_callback must still receive
|
||||
the response content so TTS works (batch delivery)."""
|
||||
|
||||
def test_callback_receives_chat_completions_response(self, agent):
|
||||
"""For chat_completions-shaped responses, callback gets content."""
|
||||
agent.api_mode = "anthropic_messages"
|
||||
mock_response = SimpleNamespace(
|
||||
choices=[SimpleNamespace(
|
||||
message=SimpleNamespace(content="Hello", tool_calls=None, reasoning_content=None),
|
||||
finish_reason="stop", index=0,
|
||||
)],
|
||||
usage=None, model="test", id="test-id",
|
||||
)
|
||||
agent._interruptible_api_call = MagicMock(return_value=mock_response)
|
||||
|
||||
received = []
|
||||
cb = lambda delta: received.append(delta)
|
||||
agent._stream_callback = cb
|
||||
|
||||
_cb = getattr(agent, "_stream_callback", None)
|
||||
response = agent._interruptible_api_call({})
|
||||
if _cb is not None and response:
|
||||
try:
|
||||
if agent.api_mode == "anthropic_messages":
|
||||
text_parts = [
|
||||
block.text for block in getattr(response, "content", [])
|
||||
if getattr(block, "type", None) == "text" and getattr(block, "text", None)
|
||||
]
|
||||
content = " ".join(text_parts) if text_parts else None
|
||||
else:
|
||||
content = response.choices[0].message.content
|
||||
if content:
|
||||
_cb(content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Anthropic format not matched above; fallback via except
|
||||
# Test the actual code path by checking chat_completions branch
|
||||
received2 = []
|
||||
agent.api_mode = "some_other_mode"
|
||||
agent._stream_callback = lambda d: received2.append(d)
|
||||
_cb2 = agent._stream_callback
|
||||
if _cb2 is not None and mock_response:
|
||||
try:
|
||||
content = mock_response.choices[0].message.content
|
||||
if content:
|
||||
_cb2(content)
|
||||
except Exception:
|
||||
pass
|
||||
assert received2 == ["Hello"]
|
||||
|
||||
def test_callback_receives_anthropic_content(self, agent):
|
||||
"""For Anthropic responses, text blocks are extracted and forwarded."""
|
||||
agent.api_mode = "anthropic_messages"
|
||||
mock_response = SimpleNamespace(
|
||||
content=[SimpleNamespace(type="text", text="Hello from Claude")],
|
||||
stop_reason="end_turn",
|
||||
)
|
||||
|
||||
received = []
|
||||
cb = lambda d: received.append(d)
|
||||
agent._stream_callback = cb
|
||||
_cb = agent._stream_callback
|
||||
|
||||
if _cb is not None and mock_response:
|
||||
try:
|
||||
if agent.api_mode == "anthropic_messages":
|
||||
text_parts = [
|
||||
block.text for block in getattr(mock_response, "content", [])
|
||||
if getattr(block, "type", None) == "text" and getattr(block, "text", None)
|
||||
]
|
||||
content = " ".join(text_parts) if text_parts else None
|
||||
else:
|
||||
content = mock_response.choices[0].message.content
|
||||
if content:
|
||||
_cb(content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
assert received == ["Hello from Claude"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bugfix: _vprint force=True on error messages during TTS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVprintForceOnErrors:
|
||||
"""Error/warning messages must be visible during streaming TTS."""
|
||||
|
||||
def test_forced_message_shown_during_tts(self, agent):
|
||||
agent._stream_callback = lambda x: None
|
||||
printed = []
|
||||
with patch("builtins.print", side_effect=lambda *a, **kw: printed.append(a)):
|
||||
agent._vprint("error msg", force=True)
|
||||
assert len(printed) == 1
|
||||
|
||||
def test_non_forced_suppressed_during_tts(self, agent):
|
||||
agent._stream_callback = lambda x: None
|
||||
printed = []
|
||||
with patch("builtins.print", side_effect=lambda *a, **kw: printed.append(a)):
|
||||
agent._vprint("debug info")
|
||||
assert len(printed) == 0
|
||||
|
||||
def test_all_shown_without_tts(self, agent):
|
||||
agent._stream_callback = None
|
||||
printed = []
|
||||
with patch("builtins.print", side_effect=lambda *a, **kw: printed.append(a)):
|
||||
agent._vprint("debug")
|
||||
agent._vprint("error", force=True)
|
||||
assert len(printed) == 2
|
||||
|
||||
@@ -28,6 +28,7 @@ class TestGetProvider:
|
||||
|
||||
def test_local_fallback_to_openai(self, monkeypatch):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
@@ -124,7 +125,7 @@ class TestTranscribeLocal:
|
||||
mock_model.transcribe.return_value = ([mock_segment], mock_info)
|
||||
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("tools.transcription_tools.WhisperModel", return_value=mock_model), \
|
||||
patch("faster_whisper.WhisperModel", return_value=mock_model), \
|
||||
patch("tools.transcription_tools._local_model", None):
|
||||
from tools.transcription_tools import _transcribe_local
|
||||
result = _transcribe_local(str(audio_file), "base")
|
||||
@@ -163,7 +164,7 @@ class TestTranscribeOpenAI:
|
||||
mock_client.audio.transcriptions.create.return_value = "Hello from OpenAI"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools.OpenAI", return_value=mock_client):
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai
|
||||
result = _transcribe_openai(str(audio_file), "whisper-1")
|
||||
|
||||
|
||||
716
tests/tools/test_transcription_tools.py
Normal file
716
tests/tools/test_transcription_tools.py
Normal file
@@ -0,0 +1,716 @@
|
||||
"""Tests for tools.transcription_tools — three-provider STT pipeline.
|
||||
|
||||
Covers the full provider matrix (local, groq, openai), fallback chains,
|
||||
model auto-correction, config loading, validation edge cases, and
|
||||
end-to-end dispatch. All external dependencies are mocked.
|
||||
"""
|
||||
|
||||
import os
|
||||
import struct
|
||||
import wave
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def sample_wav(tmp_path):
|
||||
"""Create a minimal valid WAV file (1 second of silence at 16kHz)."""
|
||||
wav_path = tmp_path / "test.wav"
|
||||
n_frames = 16000
|
||||
silence = struct.pack(f"<{n_frames}h", *([0] * n_frames))
|
||||
|
||||
with wave.open(str(wav_path), "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(16000)
|
||||
wf.writeframes(silence)
|
||||
|
||||
return str(wav_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ogg(tmp_path):
|
||||
"""Create a fake OGG file for validation tests."""
|
||||
ogg_path = tmp_path / "test.ogg"
|
||||
ogg_path.write_bytes(b"fake audio data")
|
||||
return str(ogg_path)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_env(monkeypatch):
|
||||
"""Ensure no real API keys leak into tests."""
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _get_provider — full permutation matrix
|
||||
# ============================================================================
|
||||
|
||||
class TestGetProviderGroq:
|
||||
"""Groq-specific provider selection tests."""
|
||||
|
||||
def test_groq_when_key_set(self, monkeypatch):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools._HAS_FASTER_WHISPER", False):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "groq"}) == "groq"
|
||||
|
||||
def test_groq_fallback_to_local(self, monkeypatch):
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "groq"}) == "local"
|
||||
|
||||
def test_groq_fallback_to_openai(self, monkeypatch):
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "groq"}) == "openai"
|
||||
|
||||
def test_groq_nothing_available(self, monkeypatch):
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", False):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "groq"}) == "none"
|
||||
|
||||
|
||||
class TestGetProviderFallbackPriority:
|
||||
"""Cross-provider fallback priority tests."""
|
||||
|
||||
def test_local_fallback_prefers_groq_over_openai(self, monkeypatch):
|
||||
"""When local unavailable, groq (free) is preferred over openai (paid)."""
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "local"}) == "groq"
|
||||
|
||||
def test_local_fallback_to_groq_only(self, monkeypatch):
|
||||
"""When only groq key available, falls back to groq."""
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "local"}) == "groq"
|
||||
|
||||
def test_openai_fallback_to_groq(self, monkeypatch):
|
||||
"""When openai key missing but groq available, falls back to groq."""
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "openai"}) == "groq"
|
||||
|
||||
def test_openai_nothing_available(self, monkeypatch):
|
||||
"""When no openai key and no local, returns none."""
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "openai"}) == "none"
|
||||
|
||||
def test_unknown_provider_passed_through(self):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "custom-endpoint"}) == "custom-endpoint"
|
||||
|
||||
def test_empty_config_defaults_to_local(self):
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({}) == "local"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _transcribe_groq
|
||||
# ============================================================================
|
||||
|
||||
class TestTranscribeGroq:
|
||||
def test_no_key(self, monkeypatch):
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
result = _transcribe_groq("/tmp/test.ogg", "whisper-large-v3-turbo")
|
||||
assert result["success"] is False
|
||||
assert "GROQ_API_KEY" in result["error"]
|
||||
|
||||
def test_openai_package_not_installed(self, monkeypatch):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", False):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
result = _transcribe_groq("/tmp/test.ogg", "whisper-large-v3-turbo")
|
||||
assert result["success"] is False
|
||||
assert "openai package" in result["error"]
|
||||
|
||||
def test_successful_transcription(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "hello world"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == "hello world"
|
||||
assert result["provider"] == "groq"
|
||||
|
||||
def test_whitespace_stripped(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = " hello world \n"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo")
|
||||
|
||||
assert result["transcript"] == "hello world"
|
||||
|
||||
def test_uses_groq_base_url(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client) as mock_openai_cls:
|
||||
from tools.transcription_tools import _transcribe_groq, GROQ_BASE_URL
|
||||
_transcribe_groq(sample_wav, "whisper-large-v3-turbo")
|
||||
|
||||
call_kwargs = mock_openai_cls.call_args
|
||||
assert call_kwargs.kwargs["base_url"] == GROQ_BASE_URL
|
||||
|
||||
def test_api_error_returns_failure(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.side_effect = Exception("API error")
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "API error" in result["error"]
|
||||
|
||||
def test_permission_error(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.side_effect = PermissionError("denied")
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Permission denied" in result["error"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _transcribe_openai — additional tests
|
||||
# ============================================================================
|
||||
|
||||
class TestTranscribeOpenAIExtended:
|
||||
def test_openai_package_not_installed(self, monkeypatch):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", False):
|
||||
from tools.transcription_tools import _transcribe_openai
|
||||
result = _transcribe_openai("/tmp/test.ogg", "whisper-1")
|
||||
assert result["success"] is False
|
||||
assert "openai package" in result["error"]
|
||||
|
||||
def test_uses_openai_base_url(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client) as mock_openai_cls:
|
||||
from tools.transcription_tools import _transcribe_openai, OPENAI_BASE_URL
|
||||
_transcribe_openai(sample_wav, "whisper-1")
|
||||
|
||||
call_kwargs = mock_openai_cls.call_args
|
||||
assert call_kwargs.kwargs["base_url"] == OPENAI_BASE_URL
|
||||
|
||||
def test_whitespace_stripped(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = " hello \n"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai
|
||||
result = _transcribe_openai(sample_wav, "whisper-1")
|
||||
|
||||
assert result["transcript"] == "hello"
|
||||
|
||||
def test_permission_error(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.side_effect = PermissionError("denied")
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai
|
||||
result = _transcribe_openai(sample_wav, "whisper-1")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Permission denied" in result["error"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _transcribe_local — additional tests
|
||||
# ============================================================================
|
||||
|
||||
class TestTranscribeLocalExtended:
|
||||
def test_model_reuse_on_second_call(self, tmp_path):
|
||||
"""Second call with same model should NOT reload the model."""
|
||||
audio = tmp_path / "test.ogg"
|
||||
audio.write_bytes(b"fake")
|
||||
|
||||
mock_segment = MagicMock()
|
||||
mock_segment.text = "hi"
|
||||
mock_info = MagicMock()
|
||||
mock_info.language = "en"
|
||||
mock_info.duration = 1.0
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.transcribe.return_value = ([mock_segment], mock_info)
|
||||
mock_whisper_cls = MagicMock(return_value=mock_model)
|
||||
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("faster_whisper.WhisperModel", mock_whisper_cls), \
|
||||
patch("tools.transcription_tools._local_model", None), \
|
||||
patch("tools.transcription_tools._local_model_name", None):
|
||||
from tools.transcription_tools import _transcribe_local
|
||||
_transcribe_local(str(audio), "base")
|
||||
_transcribe_local(str(audio), "base")
|
||||
|
||||
# WhisperModel should be created only once
|
||||
assert mock_whisper_cls.call_count == 1
|
||||
|
||||
def test_model_reloaded_on_change(self, tmp_path):
|
||||
"""Switching model name should reload the model."""
|
||||
audio = tmp_path / "test.ogg"
|
||||
audio.write_bytes(b"fake")
|
||||
|
||||
mock_segment = MagicMock()
|
||||
mock_segment.text = "hi"
|
||||
mock_info = MagicMock()
|
||||
mock_info.language = "en"
|
||||
mock_info.duration = 1.0
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.transcribe.return_value = ([mock_segment], mock_info)
|
||||
mock_whisper_cls = MagicMock(return_value=mock_model)
|
||||
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("faster_whisper.WhisperModel", mock_whisper_cls), \
|
||||
patch("tools.transcription_tools._local_model", None), \
|
||||
patch("tools.transcription_tools._local_model_name", None):
|
||||
from tools.transcription_tools import _transcribe_local
|
||||
_transcribe_local(str(audio), "base")
|
||||
_transcribe_local(str(audio), "small")
|
||||
|
||||
assert mock_whisper_cls.call_count == 2
|
||||
|
||||
def test_exception_returns_failure(self, tmp_path):
|
||||
audio = tmp_path / "test.ogg"
|
||||
audio.write_bytes(b"fake")
|
||||
|
||||
mock_whisper_cls = MagicMock(side_effect=RuntimeError("CUDA out of memory"))
|
||||
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("faster_whisper.WhisperModel", mock_whisper_cls), \
|
||||
patch("tools.transcription_tools._local_model", None):
|
||||
from tools.transcription_tools import _transcribe_local
|
||||
result = _transcribe_local(str(audio), "large-v3")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "CUDA out of memory" in result["error"]
|
||||
|
||||
def test_multiple_segments_joined(self, tmp_path):
|
||||
audio = tmp_path / "test.ogg"
|
||||
audio.write_bytes(b"fake")
|
||||
|
||||
seg1 = MagicMock()
|
||||
seg1.text = "Hello"
|
||||
seg2 = MagicMock()
|
||||
seg2.text = " world"
|
||||
mock_info = MagicMock()
|
||||
mock_info.language = "en"
|
||||
mock_info.duration = 3.0
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.transcribe.return_value = ([seg1, seg2], mock_info)
|
||||
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("faster_whisper.WhisperModel", return_value=mock_model), \
|
||||
patch("tools.transcription_tools._local_model", None):
|
||||
from tools.transcription_tools import _transcribe_local
|
||||
result = _transcribe_local(str(audio), "base")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == "Hello world"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Model auto-correction
|
||||
# ============================================================================
|
||||
|
||||
class TestModelAutoCorrection:
|
||||
def test_groq_corrects_openai_model(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "hello world"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq, DEFAULT_GROQ_STT_MODEL
|
||||
_transcribe_groq(sample_wav, "whisper-1")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == DEFAULT_GROQ_STT_MODEL
|
||||
|
||||
def test_groq_corrects_gpt4o_transcribe(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq, DEFAULT_GROQ_STT_MODEL
|
||||
_transcribe_groq(sample_wav, "gpt-4o-transcribe")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == DEFAULT_GROQ_STT_MODEL
|
||||
|
||||
def test_openai_corrects_groq_model(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "hello world"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai, DEFAULT_STT_MODEL
|
||||
_transcribe_openai(sample_wav, "whisper-large-v3-turbo")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == DEFAULT_STT_MODEL
|
||||
|
||||
def test_openai_corrects_distil_whisper(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai, DEFAULT_STT_MODEL
|
||||
_transcribe_openai(sample_wav, "distil-whisper-large-v3-en")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == DEFAULT_STT_MODEL
|
||||
|
||||
def test_compatible_groq_model_not_overridden(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
_transcribe_groq(sample_wav, "whisper-large-v3")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == "whisper-large-v3"
|
||||
|
||||
def test_compatible_openai_model_not_overridden(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai
|
||||
_transcribe_openai(sample_wav, "gpt-4o-mini-transcribe")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == "gpt-4o-mini-transcribe"
|
||||
|
||||
def test_unknown_model_passes_through_groq(self, monkeypatch, sample_wav):
|
||||
"""A model not in either known set should not be overridden."""
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
_transcribe_groq(sample_wav, "my-custom-model")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == "my-custom-model"
|
||||
|
||||
def test_unknown_model_passes_through_openai(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai
|
||||
_transcribe_openai(sample_wav, "my-custom-model")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == "my-custom-model"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _load_stt_config
|
||||
# ============================================================================
|
||||
|
||||
class TestLoadSttConfig:
|
||||
def test_returns_dict_when_import_fails(self):
|
||||
with patch("tools.transcription_tools._load_stt_config") as mock_load:
|
||||
mock_load.return_value = {}
|
||||
from tools.transcription_tools import _load_stt_config
|
||||
assert _load_stt_config() == {}
|
||||
|
||||
def test_real_load_returns_dict(self):
|
||||
"""_load_stt_config should always return a dict, even on import error."""
|
||||
with patch.dict("sys.modules", {"hermes_cli": None, "hermes_cli.config": None}):
|
||||
from tools.transcription_tools import _load_stt_config
|
||||
result = _load_stt_config()
|
||||
assert isinstance(result, dict)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _validate_audio_file — edge cases
|
||||
# ============================================================================
|
||||
|
||||
class TestValidateAudioFileEdgeCases:
|
||||
def test_directory_is_not_a_file(self, tmp_path):
|
||||
from tools.transcription_tools import _validate_audio_file
|
||||
# tmp_path itself is a directory with an .ogg-ish name? No.
|
||||
# Create a directory with a valid audio extension
|
||||
d = tmp_path / "audio.ogg"
|
||||
d.mkdir()
|
||||
result = _validate_audio_file(str(d))
|
||||
assert result is not None
|
||||
assert "not a file" in result["error"]
|
||||
|
||||
def test_stat_oserror(self, tmp_path):
|
||||
f = tmp_path / "test.ogg"
|
||||
f.write_bytes(b"data")
|
||||
from tools.transcription_tools import _validate_audio_file
|
||||
real_stat = f.stat()
|
||||
call_count = 0
|
||||
|
||||
def stat_side_effect(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# First calls are from exists() and is_file(), let them pass
|
||||
if call_count <= 2:
|
||||
return real_stat
|
||||
raise OSError("disk error")
|
||||
|
||||
with patch("pathlib.Path.stat", side_effect=stat_side_effect):
|
||||
result = _validate_audio_file(str(f))
|
||||
assert result is not None
|
||||
assert "Failed to access" in result["error"]
|
||||
|
||||
def test_all_supported_formats_accepted(self, tmp_path):
|
||||
from tools.transcription_tools import _validate_audio_file, SUPPORTED_FORMATS
|
||||
for fmt in SUPPORTED_FORMATS:
|
||||
f = tmp_path / f"test{fmt}"
|
||||
f.write_bytes(b"data")
|
||||
assert _validate_audio_file(str(f)) is None, f"Format {fmt} should be accepted"
|
||||
|
||||
def test_case_insensitive_extension(self, tmp_path):
|
||||
from tools.transcription_tools import _validate_audio_file
|
||||
f = tmp_path / "test.MP3"
|
||||
f.write_bytes(b"data")
|
||||
assert _validate_audio_file(str(f)) is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# transcribe_audio — end-to-end dispatch
|
||||
# ============================================================================
|
||||
|
||||
class TestTranscribeAudioDispatch:
|
||||
def test_dispatches_to_groq(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={"provider": "groq"}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="groq"), \
|
||||
patch("tools.transcription_tools._transcribe_groq",
|
||||
return_value={"success": True, "transcript": "hi", "provider": "groq"}) as mock_groq:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio(sample_ogg)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["provider"] == "groq"
|
||||
mock_groq.assert_called_once()
|
||||
|
||||
def test_dispatches_to_local(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="local"), \
|
||||
patch("tools.transcription_tools._transcribe_local",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_local:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio(sample_ogg)
|
||||
|
||||
assert result["success"] is True
|
||||
mock_local.assert_called_once()
|
||||
|
||||
def test_dispatches_to_openai(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={"provider": "openai"}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="openai"), \
|
||||
patch("tools.transcription_tools._transcribe_openai",
|
||||
return_value={"success": True, "transcript": "hi", "provider": "openai"}) as mock_openai:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio(sample_ogg)
|
||||
|
||||
assert result["success"] is True
|
||||
mock_openai.assert_called_once()
|
||||
|
||||
def test_no_provider_returns_error(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="none"):
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio(sample_ogg)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "No STT provider" in result["error"]
|
||||
assert "faster-whisper" in result["error"]
|
||||
assert "GROQ_API_KEY" in result["error"]
|
||||
|
||||
def test_invalid_file_short_circuits(self):
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio("/nonexistent/audio.wav")
|
||||
assert result["success"] is False
|
||||
assert "not found" in result["error"]
|
||||
|
||||
def test_model_override_passed_to_groq(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="groq"), \
|
||||
patch("tools.transcription_tools._transcribe_groq",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_groq:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
transcribe_audio(sample_ogg, model="whisper-large-v3")
|
||||
|
||||
_, kwargs = mock_groq.call_args
|
||||
assert kwargs.get("model_name") or mock_groq.call_args[0][1] == "whisper-large-v3"
|
||||
|
||||
def test_model_override_passed_to_local(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="local"), \
|
||||
patch("tools.transcription_tools._transcribe_local",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_local:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
transcribe_audio(sample_ogg, model="large-v3")
|
||||
|
||||
assert mock_local.call_args[0][1] == "large-v3"
|
||||
|
||||
def test_default_model_used_when_none(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="groq"), \
|
||||
patch("tools.transcription_tools._transcribe_groq",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_groq:
|
||||
from tools.transcription_tools import transcribe_audio, DEFAULT_GROQ_STT_MODEL
|
||||
transcribe_audio(sample_ogg, model=None)
|
||||
|
||||
assert mock_groq.call_args[0][1] == DEFAULT_GROQ_STT_MODEL
|
||||
|
||||
def test_config_local_model_used(self, sample_ogg):
|
||||
config = {"local": {"model": "small"}}
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value=config), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="local"), \
|
||||
patch("tools.transcription_tools._transcribe_local",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_local:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
transcribe_audio(sample_ogg, model=None)
|
||||
|
||||
assert mock_local.call_args[0][1] == "small"
|
||||
|
||||
def test_config_openai_model_used(self, sample_ogg):
|
||||
config = {"openai": {"model": "gpt-4o-transcribe"}}
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value=config), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="openai"), \
|
||||
patch("tools.transcription_tools._transcribe_openai",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_openai:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
transcribe_audio(sample_ogg, model=None)
|
||||
|
||||
assert mock_openai.call_args[0][1] == "gpt-4o-transcribe"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# get_stt_model_from_config
|
||||
# ============================================================================
|
||||
|
||||
class TestGetSttModelFromConfig:
|
||||
def test_returns_model_from_config(self, tmp_path, monkeypatch):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text("stt:\n model: whisper-large-v3\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
from tools.transcription_tools import get_stt_model_from_config
|
||||
assert get_stt_model_from_config() == "whisper-large-v3"
|
||||
|
||||
def test_returns_none_when_no_stt_section(self, tmp_path, monkeypatch):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text("tts:\n provider: edge\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
from tools.transcription_tools import get_stt_model_from_config
|
||||
assert get_stt_model_from_config() is None
|
||||
|
||||
def test_returns_none_when_no_config_file(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
from tools.transcription_tools import get_stt_model_from_config
|
||||
assert get_stt_model_from_config() is None
|
||||
|
||||
def test_returns_none_on_invalid_yaml(self, tmp_path, monkeypatch):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text(": : :\n bad yaml [[[")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
from tools.transcription_tools import get_stt_model_from_config
|
||||
assert get_stt_model_from_config() is None
|
||||
|
||||
def test_returns_none_when_model_key_missing(self, tmp_path, monkeypatch):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text("stt:\n enabled: true\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
from tools.transcription_tools import get_stt_model_from_config
|
||||
assert get_stt_model_from_config() is None
|
||||
1233
tests/tools/test_voice_cli_integration.py
Normal file
1233
tests/tools/test_voice_cli_integration.py
Normal file
File diff suppressed because it is too large
Load Diff
938
tests/tools/test_voice_mode.py
Normal file
938
tests/tools/test_voice_mode.py
Normal file
@@ -0,0 +1,938 @@
|
||||
"""Tests for tools.voice_mode -- all mocked, no real microphone or API calls."""
|
||||
|
||||
import os
|
||||
import struct
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def sample_wav(tmp_path):
|
||||
"""Create a minimal valid WAV file (1 second of silence at 16kHz)."""
|
||||
wav_path = tmp_path / "test.wav"
|
||||
n_frames = 16000 # 1 second at 16kHz
|
||||
silence = struct.pack(f"<{n_frames}h", *([0] * n_frames))
|
||||
|
||||
with wave.open(str(wav_path), "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(16000)
|
||||
wf.writeframes(silence)
|
||||
|
||||
return str(wav_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_voice_dir(tmp_path, monkeypatch):
|
||||
"""Redirect _TEMP_DIR to a temporary path."""
|
||||
voice_dir = tmp_path / "hermes_voice"
|
||||
voice_dir.mkdir()
|
||||
monkeypatch.setattr("tools.voice_mode._TEMP_DIR", str(voice_dir))
|
||||
return voice_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sd(monkeypatch):
|
||||
"""Mock _import_audio to return (mock_sd, real_np) so lazy imports work."""
|
||||
mock = MagicMock()
|
||||
try:
|
||||
import numpy as real_np
|
||||
except ImportError:
|
||||
real_np = MagicMock()
|
||||
|
||||
def _fake_import_audio():
|
||||
return mock, real_np
|
||||
|
||||
monkeypatch.setattr("tools.voice_mode._import_audio", _fake_import_audio)
|
||||
monkeypatch.setattr("tools.voice_mode._audio_available", lambda: True)
|
||||
return mock
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# check_voice_requirements
|
||||
# ============================================================================
|
||||
|
||||
class TestCheckVoiceRequirements:
|
||||
def test_all_requirements_met(self, monkeypatch):
|
||||
monkeypatch.setattr("tools.voice_mode._audio_available", lambda: True)
|
||||
monkeypatch.setattr("tools.voice_mode.detect_audio_environment",
|
||||
lambda: {"available": True, "warnings": []})
|
||||
monkeypatch.setattr("tools.transcription_tools._get_provider", lambda cfg: "openai")
|
||||
|
||||
from tools.voice_mode import check_voice_requirements
|
||||
|
||||
result = check_voice_requirements()
|
||||
assert result["available"] is True
|
||||
assert result["audio_available"] is True
|
||||
assert result["stt_available"] is True
|
||||
assert result["missing_packages"] == []
|
||||
|
||||
def test_missing_audio_packages(self, monkeypatch):
|
||||
monkeypatch.setattr("tools.voice_mode._audio_available", lambda: False)
|
||||
monkeypatch.setattr("tools.voice_mode.detect_audio_environment",
|
||||
lambda: {"available": False, "warnings": ["Audio libraries not installed"]})
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test-key")
|
||||
|
||||
from tools.voice_mode import check_voice_requirements
|
||||
|
||||
result = check_voice_requirements()
|
||||
assert result["available"] is False
|
||||
assert result["audio_available"] is False
|
||||
assert "sounddevice" in result["missing_packages"]
|
||||
assert "numpy" in result["missing_packages"]
|
||||
|
||||
def test_missing_stt_provider(self, monkeypatch):
|
||||
monkeypatch.setattr("tools.voice_mode._audio_available", lambda: True)
|
||||
monkeypatch.setattr("tools.voice_mode.detect_audio_environment",
|
||||
lambda: {"available": True, "warnings": []})
|
||||
monkeypatch.setattr("tools.transcription_tools._get_provider", lambda cfg: "none")
|
||||
|
||||
from tools.voice_mode import check_voice_requirements
|
||||
|
||||
result = check_voice_requirements()
|
||||
assert result["available"] is False
|
||||
assert result["stt_available"] is False
|
||||
assert "STT provider: MISSING" in result["details"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# AudioRecorder
|
||||
# ============================================================================
|
||||
|
||||
class TestAudioRecorderStart:
|
||||
def test_start_raises_without_audio(self, monkeypatch):
|
||||
def _fail_import():
|
||||
raise ImportError("no sounddevice")
|
||||
monkeypatch.setattr("tools.voice_mode._import_audio", _fail_import)
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
with pytest.raises(RuntimeError, match="sounddevice and numpy"):
|
||||
recorder.start()
|
||||
|
||||
def test_start_creates_and_starts_stream(self, mock_sd):
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start()
|
||||
|
||||
assert recorder.is_recording is True
|
||||
mock_sd.InputStream.assert_called_once()
|
||||
mock_stream.start.assert_called_once()
|
||||
|
||||
def test_double_start_is_noop(self, mock_sd):
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start()
|
||||
recorder.start() # second call should be noop
|
||||
|
||||
assert mock_sd.InputStream.call_count == 1
|
||||
|
||||
|
||||
class TestAudioRecorderStop:
|
||||
def test_stop_returns_none_when_not_recording(self):
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
assert recorder.stop() is None
|
||||
|
||||
def test_stop_writes_wav_file(self, mock_sd, temp_voice_dir):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder, SAMPLE_RATE
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start()
|
||||
|
||||
# Simulate captured audio frames (1 second of loud audio above RMS threshold)
|
||||
frame = np.full((SAMPLE_RATE, 1), 1000, dtype="int16")
|
||||
recorder._frames = [frame]
|
||||
recorder._peak_rms = 1000 # Peak RMS above threshold
|
||||
|
||||
wav_path = recorder.stop()
|
||||
|
||||
assert wav_path is not None
|
||||
assert os.path.isfile(wav_path)
|
||||
assert wav_path.endswith(".wav")
|
||||
assert recorder.is_recording is False
|
||||
|
||||
# Verify it is a valid WAV
|
||||
with wave.open(wav_path, "rb") as wf:
|
||||
assert wf.getnchannels() == 1
|
||||
assert wf.getsampwidth() == 2
|
||||
assert wf.getframerate() == SAMPLE_RATE
|
||||
|
||||
def test_stop_returns_none_for_very_short_recording(self, mock_sd, temp_voice_dir):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start()
|
||||
|
||||
# Very short recording (100 samples = ~6ms at 16kHz)
|
||||
frame = np.zeros((100, 1), dtype="int16")
|
||||
recorder._frames = [frame]
|
||||
|
||||
wav_path = recorder.stop()
|
||||
assert wav_path is None
|
||||
|
||||
def test_stop_returns_none_for_silent_recording(self, mock_sd, temp_voice_dir):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder, SAMPLE_RATE
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start()
|
||||
|
||||
# 1 second of near-silence (RMS well below threshold)
|
||||
frame = np.full((SAMPLE_RATE, 1), 10, dtype="int16")
|
||||
recorder._frames = [frame]
|
||||
recorder._peak_rms = 10 # Peak RMS also below threshold
|
||||
|
||||
wav_path = recorder.stop()
|
||||
assert wav_path is None
|
||||
|
||||
|
||||
class TestAudioRecorderCancel:
|
||||
def test_cancel_discards_frames(self, mock_sd):
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start()
|
||||
recorder._frames = [MagicMock()] # simulate captured data
|
||||
|
||||
recorder.cancel()
|
||||
|
||||
assert recorder.is_recording is False
|
||||
assert recorder._frames == []
|
||||
# Stream is kept alive (persistent) — cancel() does NOT close it.
|
||||
mock_stream.stop.assert_not_called()
|
||||
mock_stream.close.assert_not_called()
|
||||
|
||||
def test_cancel_when_not_recording_is_safe(self):
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.cancel() # should not raise
|
||||
assert recorder.is_recording is False
|
||||
|
||||
|
||||
class TestAudioRecorderProperties:
|
||||
def test_elapsed_seconds_when_not_recording(self):
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
assert recorder.elapsed_seconds == 0.0
|
||||
|
||||
def test_elapsed_seconds_when_recording(self, mock_sd):
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start()
|
||||
|
||||
# Force start time to 1 second ago
|
||||
recorder._start_time = time.monotonic() - 1.0
|
||||
elapsed = recorder.elapsed_seconds
|
||||
assert 0.9 < elapsed < 2.0
|
||||
|
||||
recorder.cancel()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# transcribe_recording
|
||||
# ============================================================================
|
||||
|
||||
class TestTranscribeRecording:
|
||||
def test_delegates_to_transcribe_audio(self):
|
||||
mock_transcribe = MagicMock(return_value={
|
||||
"success": True,
|
||||
"transcript": "hello world",
|
||||
})
|
||||
|
||||
with patch("tools.transcription_tools.transcribe_audio", mock_transcribe):
|
||||
from tools.voice_mode import transcribe_recording
|
||||
result = transcribe_recording("/tmp/test.wav", model="whisper-1")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == "hello world"
|
||||
mock_transcribe.assert_called_once_with("/tmp/test.wav", model="whisper-1")
|
||||
|
||||
def test_filters_whisper_hallucination(self):
|
||||
mock_transcribe = MagicMock(return_value={
|
||||
"success": True,
|
||||
"transcript": "Thank you.",
|
||||
})
|
||||
|
||||
with patch("tools.transcription_tools.transcribe_audio", mock_transcribe):
|
||||
from tools.voice_mode import transcribe_recording
|
||||
result = transcribe_recording("/tmp/test.wav")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == ""
|
||||
assert result["filtered"] is True
|
||||
|
||||
def test_does_not_filter_real_speech(self):
|
||||
mock_transcribe = MagicMock(return_value={
|
||||
"success": True,
|
||||
"transcript": "Thank you for helping me with this code.",
|
||||
})
|
||||
|
||||
with patch("tools.transcription_tools.transcribe_audio", mock_transcribe):
|
||||
from tools.voice_mode import transcribe_recording
|
||||
result = transcribe_recording("/tmp/test.wav")
|
||||
|
||||
assert result["transcript"] == "Thank you for helping me with this code."
|
||||
assert "filtered" not in result
|
||||
|
||||
|
||||
class TestWhisperHallucinationFilter:
|
||||
def test_known_hallucinations(self):
|
||||
from tools.voice_mode import is_whisper_hallucination
|
||||
|
||||
assert is_whisper_hallucination("Thank you.") is True
|
||||
assert is_whisper_hallucination("thank you") is True
|
||||
assert is_whisper_hallucination("Thanks for watching.") is True
|
||||
assert is_whisper_hallucination("Bye.") is True
|
||||
assert is_whisper_hallucination(" Thank you. ") is True # with whitespace
|
||||
assert is_whisper_hallucination("you") is True
|
||||
|
||||
def test_real_speech_not_filtered(self):
|
||||
from tools.voice_mode import is_whisper_hallucination
|
||||
|
||||
assert is_whisper_hallucination("Hello, how are you?") is False
|
||||
assert is_whisper_hallucination("Thank you for your help with the project.") is False
|
||||
assert is_whisper_hallucination("Can you explain this code?") is False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# play_audio_file
|
||||
# ============================================================================
|
||||
|
||||
class TestPlayAudioFile:
|
||||
def test_play_wav_via_sounddevice(self, monkeypatch, sample_wav):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_sd_obj = MagicMock()
|
||||
# Simulate stream completing immediately (get_stream().active = False)
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.active = False
|
||||
mock_sd_obj.get_stream.return_value = mock_stream
|
||||
|
||||
def _fake_import():
|
||||
return mock_sd_obj, np
|
||||
|
||||
monkeypatch.setattr("tools.voice_mode._import_audio", _fake_import)
|
||||
|
||||
from tools.voice_mode import play_audio_file
|
||||
|
||||
result = play_audio_file(sample_wav)
|
||||
|
||||
assert result is True
|
||||
mock_sd_obj.play.assert_called_once()
|
||||
mock_sd_obj.stop.assert_called_once()
|
||||
|
||||
def test_returns_false_when_no_player(self, monkeypatch, sample_wav):
|
||||
def _fail_import():
|
||||
raise ImportError("no sounddevice")
|
||||
monkeypatch.setattr("tools.voice_mode._import_audio", _fail_import)
|
||||
monkeypatch.setattr("shutil.which", lambda _: None)
|
||||
|
||||
from tools.voice_mode import play_audio_file
|
||||
|
||||
result = play_audio_file(sample_wav)
|
||||
assert result is False
|
||||
|
||||
def test_returns_false_for_missing_file(self):
|
||||
from tools.voice_mode import play_audio_file
|
||||
|
||||
result = play_audio_file("/nonexistent/file.wav")
|
||||
assert result is False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# cleanup_temp_recordings
|
||||
# ============================================================================
|
||||
|
||||
class TestCleanupTempRecordings:
|
||||
def test_old_files_deleted(self, temp_voice_dir):
|
||||
# Create an "old" file
|
||||
old_file = temp_voice_dir / "recording_20240101_000000.wav"
|
||||
old_file.write_bytes(b"\x00" * 100)
|
||||
# Set mtime to 2 hours ago
|
||||
old_mtime = time.time() - 7200
|
||||
os.utime(str(old_file), (old_mtime, old_mtime))
|
||||
|
||||
from tools.voice_mode import cleanup_temp_recordings
|
||||
|
||||
deleted = cleanup_temp_recordings(max_age_seconds=3600)
|
||||
assert deleted == 1
|
||||
assert not old_file.exists()
|
||||
|
||||
def test_recent_files_preserved(self, temp_voice_dir):
|
||||
# Create a "recent" file
|
||||
recent_file = temp_voice_dir / "recording_20260303_120000.wav"
|
||||
recent_file.write_bytes(b"\x00" * 100)
|
||||
|
||||
from tools.voice_mode import cleanup_temp_recordings
|
||||
|
||||
deleted = cleanup_temp_recordings(max_age_seconds=3600)
|
||||
assert deleted == 0
|
||||
assert recent_file.exists()
|
||||
|
||||
def test_nonexistent_dir_returns_zero(self, monkeypatch):
|
||||
monkeypatch.setattr("tools.voice_mode._TEMP_DIR", "/nonexistent/dir")
|
||||
|
||||
from tools.voice_mode import cleanup_temp_recordings
|
||||
|
||||
assert cleanup_temp_recordings() == 0
|
||||
|
||||
def test_non_recording_files_ignored(self, temp_voice_dir):
|
||||
# Create a file that doesn't match the pattern
|
||||
other_file = temp_voice_dir / "other_file.txt"
|
||||
other_file.write_bytes(b"\x00" * 100)
|
||||
old_mtime = time.time() - 7200
|
||||
os.utime(str(other_file), (old_mtime, old_mtime))
|
||||
|
||||
from tools.voice_mode import cleanup_temp_recordings
|
||||
|
||||
deleted = cleanup_temp_recordings(max_age_seconds=3600)
|
||||
assert deleted == 0
|
||||
assert other_file.exists()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# play_beep
|
||||
# ============================================================================
|
||||
|
||||
class TestPlayBeep:
|
||||
def test_beep_calls_sounddevice_play(self, mock_sd):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
from tools.voice_mode import play_beep
|
||||
|
||||
# play_beep uses polling (get_stream) + sd.stop() instead of sd.wait()
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.active = False
|
||||
mock_sd.get_stream.return_value = mock_stream
|
||||
|
||||
play_beep(frequency=880, duration=0.1, count=1)
|
||||
|
||||
mock_sd.play.assert_called_once()
|
||||
mock_sd.stop.assert_called()
|
||||
# Verify audio data is int16 numpy array
|
||||
audio_arg = mock_sd.play.call_args[0][0]
|
||||
assert audio_arg.dtype == np.int16
|
||||
assert len(audio_arg) > 0
|
||||
|
||||
def test_beep_double_produces_longer_audio(self, mock_sd):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
from tools.voice_mode import play_beep
|
||||
|
||||
play_beep(frequency=660, duration=0.1, count=2)
|
||||
|
||||
audio_arg = mock_sd.play.call_args[0][0]
|
||||
single_beep_samples = int(16000 * 0.1)
|
||||
# Double beep should be longer than a single beep
|
||||
assert len(audio_arg) > single_beep_samples
|
||||
|
||||
def test_beep_noop_without_audio(self, monkeypatch):
|
||||
def _fail_import():
|
||||
raise ImportError("no sounddevice")
|
||||
monkeypatch.setattr("tools.voice_mode._import_audio", _fail_import)
|
||||
|
||||
from tools.voice_mode import play_beep
|
||||
|
||||
# Should not raise
|
||||
play_beep()
|
||||
|
||||
def test_beep_handles_playback_error(self, mock_sd):
|
||||
mock_sd.play.side_effect = Exception("device error")
|
||||
|
||||
from tools.voice_mode import play_beep
|
||||
|
||||
# Should not raise
|
||||
play_beep()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Silence detection
|
||||
# ============================================================================
|
||||
|
||||
class TestSilenceDetection:
|
||||
def test_silence_callback_fires_after_speech_then_silence(self, mock_sd):
|
||||
np = pytest.importorskip("numpy")
|
||||
import threading
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder, SAMPLE_RATE
|
||||
|
||||
recorder = AudioRecorder()
|
||||
# Use very short durations for testing
|
||||
recorder._silence_duration = 0.05
|
||||
recorder._min_speech_duration = 0.05
|
||||
|
||||
fired = threading.Event()
|
||||
|
||||
def on_silence():
|
||||
fired.set()
|
||||
|
||||
recorder.start(on_silence_stop=on_silence)
|
||||
|
||||
# Get the callback function from InputStream constructor
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
|
||||
# Simulate sustained speech (multiple loud chunks to exceed min_speech_duration)
|
||||
loud_frame = np.full((1600, 1), 5000, dtype="int16")
|
||||
callback(loud_frame, 1600, None, None)
|
||||
time.sleep(0.06)
|
||||
callback(loud_frame, 1600, None, None)
|
||||
assert recorder._has_spoken is True
|
||||
|
||||
# Simulate silence
|
||||
silent_frame = np.zeros((1600, 1), dtype="int16")
|
||||
callback(silent_frame, 1600, None, None)
|
||||
|
||||
# Wait a bit past the silence duration, then send another silent frame
|
||||
time.sleep(0.06)
|
||||
callback(silent_frame, 1600, None, None)
|
||||
|
||||
# The callback should have been fired
|
||||
assert fired.wait(timeout=1.0) is True
|
||||
|
||||
recorder.cancel()
|
||||
|
||||
def test_silence_without_speech_does_not_fire(self, mock_sd):
|
||||
np = pytest.importorskip("numpy")
|
||||
import threading
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder._silence_duration = 0.02
|
||||
|
||||
fired = threading.Event()
|
||||
recorder.start(on_silence_stop=lambda: fired.set())
|
||||
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
|
||||
# Only silence -- no speech detected, so callback should NOT fire
|
||||
silent_frame = np.zeros((1600, 1), dtype="int16")
|
||||
for _ in range(5):
|
||||
callback(silent_frame, 1600, None, None)
|
||||
time.sleep(0.01)
|
||||
|
||||
assert fired.wait(timeout=0.2) is False
|
||||
|
||||
recorder.cancel()
|
||||
|
||||
def test_micro_pause_tolerance_during_speech(self, mock_sd):
|
||||
"""Brief dips below threshold during speech should NOT reset speech tracking."""
|
||||
np = pytest.importorskip("numpy")
|
||||
import threading
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder._silence_duration = 0.05
|
||||
recorder._min_speech_duration = 0.15
|
||||
recorder._max_dip_tolerance = 0.1
|
||||
|
||||
fired = threading.Event()
|
||||
recorder.start(on_silence_stop=lambda: fired.set())
|
||||
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
|
||||
loud_frame = np.full((1600, 1), 5000, dtype="int16")
|
||||
quiet_frame = np.full((1600, 1), 50, dtype="int16")
|
||||
|
||||
# Speech chunk 1
|
||||
callback(loud_frame, 1600, None, None)
|
||||
time.sleep(0.05)
|
||||
# Brief micro-pause (dip < max_dip_tolerance)
|
||||
callback(quiet_frame, 1600, None, None)
|
||||
time.sleep(0.05)
|
||||
# Speech resumes -- speech_start should NOT have been reset
|
||||
callback(loud_frame, 1600, None, None)
|
||||
assert recorder._speech_start > 0, "Speech start should be preserved across brief dips"
|
||||
time.sleep(0.06)
|
||||
# Another speech chunk to exceed min_speech_duration
|
||||
callback(loud_frame, 1600, None, None)
|
||||
assert recorder._has_spoken is True, "Speech should be confirmed after tolerating micro-pause"
|
||||
|
||||
recorder.cancel()
|
||||
|
||||
def test_no_callback_means_no_silence_detection(self, mock_sd):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start() # no on_silence_stop
|
||||
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
|
||||
# Even with speech then silence, nothing should happen
|
||||
loud_frame = np.full((1600, 1), 5000, dtype="int16")
|
||||
silent_frame = np.zeros((1600, 1), dtype="int16")
|
||||
callback(loud_frame, 1600, None, None)
|
||||
callback(silent_frame, 1600, None, None)
|
||||
|
||||
# No crash, no callback
|
||||
assert recorder._on_silence_stop is None
|
||||
recorder.cancel()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Playback interrupt
|
||||
# ============================================================================
|
||||
|
||||
class TestPlaybackInterrupt:
|
||||
"""Verify that TTS playback can be interrupted."""
|
||||
|
||||
def test_stop_playback_terminates_process(self):
|
||||
from tools.voice_mode import stop_playback, _playback_lock
|
||||
import tools.voice_mode as vm
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.poll.return_value = None # process is running
|
||||
|
||||
with _playback_lock:
|
||||
vm._active_playback = mock_proc
|
||||
|
||||
stop_playback()
|
||||
|
||||
mock_proc.terminate.assert_called_once()
|
||||
|
||||
with _playback_lock:
|
||||
assert vm._active_playback is None
|
||||
|
||||
def test_stop_playback_noop_when_nothing_playing(self):
|
||||
import tools.voice_mode as vm
|
||||
|
||||
with vm._playback_lock:
|
||||
vm._active_playback = None
|
||||
|
||||
vm.stop_playback()
|
||||
|
||||
def test_play_audio_file_sets_active_playback(self, monkeypatch, sample_wav):
|
||||
import tools.voice_mode as vm
|
||||
|
||||
def _fail_import():
|
||||
raise ImportError("no sounddevice")
|
||||
monkeypatch.setattr("tools.voice_mode._import_audio", _fail_import)
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait.return_value = 0
|
||||
|
||||
mock_popen = MagicMock(return_value=mock_proc)
|
||||
monkeypatch.setattr("subprocess.Popen", mock_popen)
|
||||
monkeypatch.setattr("shutil.which", lambda cmd: "/usr/bin/" + cmd)
|
||||
|
||||
vm.play_audio_file(sample_wav)
|
||||
|
||||
assert mock_popen.called
|
||||
with vm._playback_lock:
|
||||
assert vm._active_playback is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Continuous mode flow
|
||||
# ============================================================================
|
||||
|
||||
class TestContinuousModeFlow:
|
||||
"""Verify continuous mode: auto-restart after transcription or silence."""
|
||||
|
||||
def test_continuous_restart_on_no_speech(self, mock_sd, temp_voice_dir):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
|
||||
# First recording: only silence -> stop returns None
|
||||
recorder.start()
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
|
||||
for _ in range(10):
|
||||
silence = np.full((1600, 1), 10, dtype="int16")
|
||||
callback(silence, 1600, None, None)
|
||||
|
||||
wav_path = recorder.stop()
|
||||
assert wav_path is None
|
||||
|
||||
# Simulate continuous mode restart
|
||||
recorder.start()
|
||||
assert recorder.is_recording is True
|
||||
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
|
||||
for _ in range(10):
|
||||
speech = np.full((1600, 1), 5000, dtype="int16")
|
||||
callback(speech, 1600, None, None)
|
||||
|
||||
wav_path = recorder.stop()
|
||||
assert wav_path is not None
|
||||
|
||||
recorder.cancel()
|
||||
|
||||
def test_recorder_reusable_after_stop(self, mock_sd, temp_voice_dir):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
results = []
|
||||
|
||||
for i in range(3):
|
||||
recorder.start()
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
loud = np.full((1600, 1), 5000, dtype="int16")
|
||||
for _ in range(10):
|
||||
callback(loud, 1600, None, None)
|
||||
wav_path = recorder.stop()
|
||||
results.append(wav_path)
|
||||
|
||||
assert all(r is not None for r in results)
|
||||
assert os.path.isfile(results[-1])
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Audio level indicator
|
||||
# ============================================================================
|
||||
|
||||
class TestAudioLevelIndicator:
|
||||
"""Verify current_rms property updates in real-time for UI feedback."""
|
||||
|
||||
def test_rms_updates_with_audio_chunks(self, mock_sd):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start()
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
|
||||
assert recorder.current_rms == 0
|
||||
|
||||
loud = np.full((1600, 1), 5000, dtype="int16")
|
||||
callback(loud, 1600, None, None)
|
||||
assert recorder.current_rms == 5000
|
||||
|
||||
quiet = np.full((1600, 1), 100, dtype="int16")
|
||||
callback(quiet, 1600, None, None)
|
||||
assert recorder.current_rms == 100
|
||||
|
||||
recorder.cancel()
|
||||
|
||||
def test_peak_rms_tracks_maximum(self, mock_sd):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start()
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
|
||||
frames = [
|
||||
np.full((1600, 1), 100, dtype="int16"),
|
||||
np.full((1600, 1), 8000, dtype="int16"),
|
||||
np.full((1600, 1), 500, dtype="int16"),
|
||||
np.full((1600, 1), 3000, dtype="int16"),
|
||||
]
|
||||
for frame in frames:
|
||||
callback(frame, 1600, None, None)
|
||||
|
||||
assert recorder._peak_rms == 8000
|
||||
assert recorder.current_rms == 3000
|
||||
|
||||
recorder.cancel()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configurable silence parameters
|
||||
# ============================================================================
|
||||
|
||||
class TestConfigurableSilenceParams:
|
||||
"""Verify that silence detection params can be configured."""
|
||||
|
||||
def test_custom_threshold_and_duration(self, mock_sd):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
import threading
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder._silence_threshold = 5000
|
||||
recorder._silence_duration = 0.05
|
||||
recorder._min_speech_duration = 0.05
|
||||
|
||||
fired = threading.Event()
|
||||
recorder.start(on_silence_stop=lambda: fired.set())
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
|
||||
# Audio at RMS 1000 -- below custom threshold (5000)
|
||||
moderate = np.full((1600, 1), 1000, dtype="int16")
|
||||
for _ in range(5):
|
||||
callback(moderate, 1600, None, None)
|
||||
time.sleep(0.02)
|
||||
|
||||
assert recorder._has_spoken is False
|
||||
assert fired.wait(timeout=0.2) is False
|
||||
|
||||
# Now send really loud audio (above 5000 threshold)
|
||||
very_loud = np.full((1600, 1), 8000, dtype="int16")
|
||||
callback(very_loud, 1600, None, None)
|
||||
time.sleep(0.06)
|
||||
callback(very_loud, 1600, None, None)
|
||||
assert recorder._has_spoken is True
|
||||
|
||||
recorder.cancel()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Bugfix regression tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestSubprocessTimeoutKill:
|
||||
"""Bug: proc.wait(timeout) raised TimeoutExpired but process was not killed."""
|
||||
|
||||
def test_timeout_kills_process(self):
|
||||
import subprocess, os
|
||||
proc = subprocess.Popen(["sleep", "600"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
pid = proc.pid
|
||||
assert proc.poll() is None
|
||||
|
||||
try:
|
||||
proc.wait(timeout=0.1)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait()
|
||||
|
||||
assert proc.poll() is not None
|
||||
assert proc.returncode is not None
|
||||
|
||||
|
||||
class TestStreamLeakOnStartFailure:
|
||||
"""Bug: stream.start() failure left stream unclosed."""
|
||||
|
||||
def test_stream_closed_on_start_failure(self, mock_sd):
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.start.side_effect = OSError("Audio device busy")
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
recorder = AudioRecorder()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Failed to open audio input stream"):
|
||||
recorder._ensure_stream()
|
||||
|
||||
mock_stream.close.assert_called_once()
|
||||
|
||||
|
||||
class TestSilenceCallbackLock:
|
||||
"""Bug: _on_silence_stop was read/written without lock in audio callback."""
|
||||
|
||||
def test_fire_block_acquires_lock(self):
|
||||
import inspect
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
source = inspect.getsource(AudioRecorder._ensure_stream)
|
||||
# Verify lock is used before reading _on_silence_stop in fire block
|
||||
assert "with self._lock:" in source
|
||||
assert "cb = self._on_silence_stop" in source
|
||||
lock_pos = source.index("with self._lock:")
|
||||
cb_pos = source.index("cb = self._on_silence_stop")
|
||||
assert lock_pos < cb_pos
|
||||
|
||||
def test_cancel_clears_callback_under_lock(self, mock_sd):
|
||||
from tools.voice_mode import AudioRecorder
|
||||
recorder = AudioRecorder()
|
||||
mock_sd.InputStream.return_value = MagicMock()
|
||||
|
||||
cb = lambda: None
|
||||
recorder.start(on_silence_stop=cb)
|
||||
assert recorder._on_silence_stop is cb
|
||||
|
||||
recorder.cancel()
|
||||
with recorder._lock:
|
||||
assert recorder._on_silence_stop is None
|
||||
Reference in New Issue
Block a user