- Provider resolution: OpenAI priority, Groq fallback, no keys - Model auto-correction: Groq corrects OpenAI models and vice versa - Success path: transcription, API errors, whitespace stripping - 12 new tests, 33 total voice-related tests
200 lines
7.4 KiB
Python
200 lines
7.4 KiB
Python
"""Tests for tools.transcription_tools -- provider resolution and model correction."""
|
|
|
|
import os
|
|
import struct
|
|
import wave
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
|
|
# ============================================================================
|
|
# Fixtures
|
|
# ============================================================================
|
|
|
|
@pytest.fixture
|
|
def sample_wav(tmp_path):
|
|
"""Create a minimal valid WAV file (1 second of silence at 16kHz)."""
|
|
wav_path = tmp_path / "test.wav"
|
|
n_frames = 16000
|
|
silence = struct.pack(f"<{n_frames}h", *([0] * n_frames))
|
|
|
|
with wave.open(str(wav_path), "wb") as wf:
|
|
wf.setnchannels(1)
|
|
wf.setsampwidth(2)
|
|
wf.setframerate(16000)
|
|
wf.writeframes(silence)
|
|
|
|
return str(wav_path)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def clean_env(monkeypatch):
|
|
"""Ensure no real API keys leak into tests."""
|
|
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
|
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
|
|
|
|
|
# ============================================================================
|
|
# _resolve_stt_provider
|
|
# ============================================================================
|
|
|
|
class TestResolveSTTProvider:
|
|
def test_openai_preferred_over_groq(self, monkeypatch):
|
|
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
|
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
|
|
|
from tools.transcription_tools import _resolve_stt_provider
|
|
key, url, provider = _resolve_stt_provider()
|
|
|
|
assert provider == "openai"
|
|
assert key == "sk-test"
|
|
assert "openai.com" in url
|
|
|
|
def test_groq_fallback(self, monkeypatch):
|
|
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
|
|
|
from tools.transcription_tools import _resolve_stt_provider
|
|
key, url, provider = _resolve_stt_provider()
|
|
|
|
assert provider == "groq"
|
|
assert key == "gsk-test"
|
|
assert "groq.com" in url
|
|
|
|
def test_no_keys_returns_none(self):
|
|
from tools.transcription_tools import _resolve_stt_provider
|
|
key, url, provider = _resolve_stt_provider()
|
|
|
|
assert provider == "none"
|
|
assert key is None
|
|
assert url is None
|
|
|
|
|
|
# ============================================================================
|
|
# transcribe_audio -- no API key
|
|
# ============================================================================
|
|
|
|
class TestTranscribeAudioNoKey:
|
|
def test_returns_error_when_no_key(self):
|
|
from tools.transcription_tools import transcribe_audio
|
|
result = transcribe_audio("/tmp/test.wav")
|
|
|
|
assert result["success"] is False
|
|
assert "No STT API key" in result["error"]
|
|
|
|
def test_returns_error_for_missing_file(self, monkeypatch):
|
|
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
|
|
|
from tools.transcription_tools import transcribe_audio
|
|
result = transcribe_audio("/nonexistent/audio.wav")
|
|
|
|
assert result["success"] is False
|
|
assert "not found" in result["error"]
|
|
|
|
|
|
# ============================================================================
|
|
# Model auto-correction
|
|
# ============================================================================
|
|
|
|
class TestModelAutoCorrection:
|
|
def test_groq_corrects_openai_model(self, monkeypatch, sample_wav):
|
|
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.audio.transcriptions.create.return_value = "hello world"
|
|
|
|
with patch("openai.OpenAI", return_value=mock_client):
|
|
from tools.transcription_tools import transcribe_audio, DEFAULT_GROQ_STT_MODEL
|
|
result = transcribe_audio(sample_wav, model="whisper-1")
|
|
|
|
assert result["success"] is True
|
|
assert result["transcript"] == "hello world"
|
|
# Verify the model was corrected to Groq default
|
|
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
|
assert call_kwargs.kwargs["model"] == DEFAULT_GROQ_STT_MODEL
|
|
|
|
def test_openai_corrects_groq_model(self, monkeypatch, sample_wav):
|
|
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.audio.transcriptions.create.return_value = "hello world"
|
|
|
|
with patch("openai.OpenAI", return_value=mock_client):
|
|
from tools.transcription_tools import transcribe_audio, DEFAULT_STT_MODEL
|
|
result = transcribe_audio(sample_wav, model="whisper-large-v3-turbo")
|
|
|
|
assert result["success"] is True
|
|
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
|
assert call_kwargs.kwargs["model"] == DEFAULT_STT_MODEL
|
|
|
|
def test_none_model_uses_provider_default(self, monkeypatch, sample_wav):
|
|
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.audio.transcriptions.create.return_value = "test"
|
|
|
|
with patch("openai.OpenAI", return_value=mock_client):
|
|
from tools.transcription_tools import transcribe_audio, DEFAULT_GROQ_STT_MODEL
|
|
transcribe_audio(sample_wav, model=None)
|
|
|
|
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
|
assert call_kwargs.kwargs["model"] == DEFAULT_GROQ_STT_MODEL
|
|
|
|
def test_compatible_model_not_overridden(self, monkeypatch, sample_wav):
|
|
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.audio.transcriptions.create.return_value = "test"
|
|
|
|
with patch("openai.OpenAI", return_value=mock_client):
|
|
from tools.transcription_tools import transcribe_audio
|
|
transcribe_audio(sample_wav, model="whisper-large-v3")
|
|
|
|
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
|
assert call_kwargs.kwargs["model"] == "whisper-large-v3"
|
|
|
|
|
|
# ============================================================================
|
|
# transcribe_audio -- success path
|
|
# ============================================================================
|
|
|
|
class TestTranscribeAudioSuccess:
|
|
def test_successful_transcription(self, monkeypatch, sample_wav):
|
|
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.audio.transcriptions.create.return_value = "hello world"
|
|
|
|
with patch("openai.OpenAI", return_value=mock_client):
|
|
from tools.transcription_tools import transcribe_audio
|
|
result = transcribe_audio(sample_wav)
|
|
|
|
assert result["success"] is True
|
|
assert result["transcript"] == "hello world"
|
|
assert result["provider"] == "groq"
|
|
|
|
def test_api_error_returns_failure(self, monkeypatch, sample_wav):
|
|
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.audio.transcriptions.create.side_effect = Exception("API error")
|
|
|
|
with patch("openai.OpenAI", return_value=mock_client):
|
|
from tools.transcription_tools import transcribe_audio
|
|
result = transcribe_audio(sample_wav)
|
|
|
|
assert result["success"] is False
|
|
assert "API error" in result["error"]
|
|
|
|
def test_whitespace_transcript_stripped(self, monkeypatch, sample_wav):
|
|
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.audio.transcriptions.create.return_value = " hello world \n"
|
|
|
|
with patch("openai.OpenAI", return_value=mock_client):
|
|
from tools.transcription_tools import transcribe_audio
|
|
result = transcribe_audio(sample_wav)
|
|
|
|
assert result["transcript"] == "hello world"
|