feat: integrate faster-whisper local STT with three-provider fallback
Merge main's faster-whisper (local, free) with our Groq support into a unified three-provider STT pipeline: local > groq > openai. Provider priority ensures free options are tried first. Each provider has its own transcriber function with model auto-correction, env- overridable endpoints, and proper error handling. 74 tests cover the full provider matrix, fallback chains, model correction, config loading, validation edge cases, and dispatch.
This commit is contained in:
10
.env.example
10
.env.example
@@ -286,6 +286,16 @@ WANDB_API_KEY=
|
||||
# Groq API key (free tier — used for Whisper STT in voice mode)
|
||||
# GROQ_API_KEY=
|
||||
|
||||
# =============================================================================
|
||||
# STT PROVIDER SELECTION
|
||||
# =============================================================================
|
||||
# Default STT provider is "local" (faster-whisper) — runs on your machine, no API key needed.
|
||||
# Install with: pip install faster-whisper
|
||||
# Model downloads automatically on first use (~150 MB for "base").
|
||||
# To use cloud providers instead, set GROQ_API_KEY or VOICE_TOOLS_OPENAI_KEY above.
|
||||
# Provider priority: local > groq > openai
|
||||
# Configure in config.yaml: stt.provider: local | groq | openai
|
||||
|
||||
# =============================================================================
|
||||
# STT ADVANCED OVERRIDES (optional)
|
||||
# =============================================================================
|
||||
|
||||
@@ -438,11 +438,14 @@ class TestTranscriptionGroqFallback:
|
||||
assert "not found" in result.get("error", "").lower()
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_no_key_returns_error(self):
|
||||
def test_no_key_returns_error(self, tmp_path):
|
||||
audio_file = tmp_path / "test.ogg"
|
||||
audio_file.write_bytes(b"fake audio data")
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio("/nonexistent/audio.mp3")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False):
|
||||
result = transcribe_audio(str(audio_file))
|
||||
assert result["success"] is False
|
||||
assert "not set" in result.get("error", "").lower() or "GROQ" in result.get("error", "")
|
||||
assert "no stt provider" in result.get("error", "").lower()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
|
||||
@@ -28,6 +28,7 @@ class TestGetProvider:
|
||||
|
||||
def test_local_fallback_to_openai(self, monkeypatch):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
"""Tests for tools.transcription_tools -- provider resolution, model correction, config helper."""
|
||||
"""Tests for tools.transcription_tools — three-provider STT pipeline.
|
||||
|
||||
Covers the full provider matrix (local, groq, openai), fallback chains,
|
||||
model auto-correction, config loading, validation edge cases, and
|
||||
end-to-end dispatch. All external dependencies are mocked.
|
||||
"""
|
||||
|
||||
import os
|
||||
import struct
|
||||
@@ -28,6 +33,14 @@ def sample_wav(tmp_path):
|
||||
return str(wav_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ogg(tmp_path):
|
||||
"""Create a fake OGG file for validation tests."""
|
||||
ogg_path = tmp_path / "test.ogg"
|
||||
ogg_path.write_bytes(b"fake audio data")
|
||||
return str(ogg_path)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_env(monkeypatch):
|
||||
"""Ensure no real API keys leak into tests."""
|
||||
@@ -36,60 +49,330 @@ def clean_env(monkeypatch):
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _resolve_stt_provider
|
||||
# _get_provider — full permutation matrix
|
||||
# ============================================================================
|
||||
|
||||
class TestResolveSTTProvider:
|
||||
def test_openai_preferred_over_groq(self, monkeypatch):
|
||||
class TestGetProviderGroq:
|
||||
"""Groq-specific provider selection tests."""
|
||||
|
||||
def test_groq_when_key_set(self, monkeypatch):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools._HAS_FASTER_WHISPER", False):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "groq"}) == "groq"
|
||||
|
||||
def test_groq_fallback_to_local(self, monkeypatch):
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "groq"}) == "local"
|
||||
|
||||
def test_groq_fallback_to_openai(self, monkeypatch):
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "groq"}) == "openai"
|
||||
|
||||
def test_groq_nothing_available(self, monkeypatch):
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", False):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "groq"}) == "none"
|
||||
|
||||
|
||||
class TestGetProviderFallbackPriority:
|
||||
"""Cross-provider fallback priority tests."""
|
||||
|
||||
def test_local_fallback_prefers_groq_over_openai(self, monkeypatch):
|
||||
"""When local unavailable, groq (free) is preferred over openai (paid)."""
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "local"}) == "groq"
|
||||
|
||||
from tools.transcription_tools import _resolve_stt_provider
|
||||
key, url, provider = _resolve_stt_provider()
|
||||
|
||||
assert provider == "openai"
|
||||
assert key == "sk-test"
|
||||
assert "openai.com" in url
|
||||
|
||||
def test_groq_fallback(self, monkeypatch):
|
||||
def test_local_fallback_to_groq_only(self, monkeypatch):
|
||||
"""When only groq key available, falls back to groq."""
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "local"}) == "groq"
|
||||
|
||||
from tools.transcription_tools import _resolve_stt_provider
|
||||
key, url, provider = _resolve_stt_provider()
|
||||
def test_openai_fallback_to_groq(self, monkeypatch):
|
||||
"""When openai key missing but groq available, falls back to groq."""
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "openai"}) == "groq"
|
||||
|
||||
assert provider == "groq"
|
||||
assert key == "gsk-test"
|
||||
assert "groq.com" in url
|
||||
def test_openai_nothing_available(self, monkeypatch):
|
||||
"""When no openai key and no local, returns none."""
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "openai"}) == "none"
|
||||
|
||||
def test_no_keys_returns_none(self):
|
||||
from tools.transcription_tools import _resolve_stt_provider
|
||||
key, url, provider = _resolve_stt_provider()
|
||||
def test_unknown_provider_passed_through(self):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "custom-endpoint"}) == "custom-endpoint"
|
||||
|
||||
assert provider == "none"
|
||||
assert key is None
|
||||
assert url is None
|
||||
def test_empty_config_defaults_to_local(self):
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({}) == "local"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# transcribe_audio -- no API key
|
||||
# _transcribe_groq
|
||||
# ============================================================================
|
||||
|
||||
class TestTranscribeAudioNoKey:
|
||||
def test_returns_error_when_no_key(self):
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio("/tmp/test.wav")
|
||||
|
||||
class TestTranscribeGroq:
|
||||
def test_no_key(self, monkeypatch):
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
result = _transcribe_groq("/tmp/test.ogg", "whisper-large-v3-turbo")
|
||||
assert result["success"] is False
|
||||
assert "No STT API key" in result["error"]
|
||||
assert "GROQ_API_KEY" in result["error"]
|
||||
|
||||
def test_returns_error_for_missing_file(self, monkeypatch):
|
||||
def test_openai_package_not_installed(self, monkeypatch):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", False):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
result = _transcribe_groq("/tmp/test.ogg", "whisper-large-v3-turbo")
|
||||
assert result["success"] is False
|
||||
assert "openai package" in result["error"]
|
||||
|
||||
def test_successful_transcription(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio("/nonexistent/audio.wav")
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "hello world"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == "hello world"
|
||||
assert result["provider"] == "groq"
|
||||
|
||||
def test_whitespace_stripped(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = " hello world \n"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo")
|
||||
|
||||
assert result["transcript"] == "hello world"
|
||||
|
||||
def test_uses_groq_base_url(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools.OpenAI", return_value=mock_client) as mock_openai_cls:
|
||||
from tools.transcription_tools import _transcribe_groq, GROQ_BASE_URL
|
||||
_transcribe_groq(sample_wav, "whisper-large-v3-turbo")
|
||||
|
||||
call_kwargs = mock_openai_cls.call_args
|
||||
assert call_kwargs.kwargs["base_url"] == GROQ_BASE_URL
|
||||
|
||||
def test_api_error_returns_failure(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.side_effect = Exception("API error")
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "not found" in result["error"]
|
||||
assert "API error" in result["error"]
|
||||
|
||||
def test_permission_error(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.side_effect = PermissionError("denied")
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Permission denied" in result["error"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _transcribe_openai — additional tests
|
||||
# ============================================================================
|
||||
|
||||
class TestTranscribeOpenAIExtended:
|
||||
def test_openai_package_not_installed(self, monkeypatch):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", False):
|
||||
from tools.transcription_tools import _transcribe_openai
|
||||
result = _transcribe_openai("/tmp/test.ogg", "whisper-1")
|
||||
assert result["success"] is False
|
||||
assert "openai package" in result["error"]
|
||||
|
||||
def test_uses_openai_base_url(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools.OpenAI", return_value=mock_client) as mock_openai_cls:
|
||||
from tools.transcription_tools import _transcribe_openai, OPENAI_BASE_URL
|
||||
_transcribe_openai(sample_wav, "whisper-1")
|
||||
|
||||
call_kwargs = mock_openai_cls.call_args
|
||||
assert call_kwargs.kwargs["base_url"] == OPENAI_BASE_URL
|
||||
|
||||
def test_whitespace_stripped(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = " hello \n"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai
|
||||
result = _transcribe_openai(sample_wav, "whisper-1")
|
||||
|
||||
assert result["transcript"] == "hello"
|
||||
|
||||
def test_permission_error(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.side_effect = PermissionError("denied")
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai
|
||||
result = _transcribe_openai(sample_wav, "whisper-1")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Permission denied" in result["error"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _transcribe_local — additional tests
|
||||
# ============================================================================
|
||||
|
||||
class TestTranscribeLocalExtended:
|
||||
def test_model_reuse_on_second_call(self, tmp_path):
|
||||
"""Second call with same model should NOT reload the model."""
|
||||
audio = tmp_path / "test.ogg"
|
||||
audio.write_bytes(b"fake")
|
||||
|
||||
mock_segment = MagicMock()
|
||||
mock_segment.text = "hi"
|
||||
mock_info = MagicMock()
|
||||
mock_info.language = "en"
|
||||
mock_info.duration = 1.0
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.transcribe.return_value = ([mock_segment], mock_info)
|
||||
mock_whisper_cls = MagicMock(return_value=mock_model)
|
||||
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("tools.transcription_tools.WhisperModel", mock_whisper_cls), \
|
||||
patch("tools.transcription_tools._local_model", None), \
|
||||
patch("tools.transcription_tools._local_model_name", None):
|
||||
from tools.transcription_tools import _transcribe_local
|
||||
_transcribe_local(str(audio), "base")
|
||||
_transcribe_local(str(audio), "base")
|
||||
|
||||
# WhisperModel should be created only once
|
||||
assert mock_whisper_cls.call_count == 1
|
||||
|
||||
def test_model_reloaded_on_change(self, tmp_path):
|
||||
"""Switching model name should reload the model."""
|
||||
audio = tmp_path / "test.ogg"
|
||||
audio.write_bytes(b"fake")
|
||||
|
||||
mock_segment = MagicMock()
|
||||
mock_segment.text = "hi"
|
||||
mock_info = MagicMock()
|
||||
mock_info.language = "en"
|
||||
mock_info.duration = 1.0
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.transcribe.return_value = ([mock_segment], mock_info)
|
||||
mock_whisper_cls = MagicMock(return_value=mock_model)
|
||||
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("tools.transcription_tools.WhisperModel", mock_whisper_cls), \
|
||||
patch("tools.transcription_tools._local_model", None), \
|
||||
patch("tools.transcription_tools._local_model_name", None):
|
||||
from tools.transcription_tools import _transcribe_local
|
||||
_transcribe_local(str(audio), "base")
|
||||
_transcribe_local(str(audio), "small")
|
||||
|
||||
assert mock_whisper_cls.call_count == 2
|
||||
|
||||
def test_exception_returns_failure(self, tmp_path):
|
||||
audio = tmp_path / "test.ogg"
|
||||
audio.write_bytes(b"fake")
|
||||
|
||||
mock_whisper_cls = MagicMock(side_effect=RuntimeError("CUDA out of memory"))
|
||||
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("tools.transcription_tools.WhisperModel", mock_whisper_cls), \
|
||||
patch("tools.transcription_tools._local_model", None):
|
||||
from tools.transcription_tools import _transcribe_local
|
||||
result = _transcribe_local(str(audio), "large-v3")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "CUDA out of memory" in result["error"]
|
||||
|
||||
def test_multiple_segments_joined(self, tmp_path):
|
||||
audio = tmp_path / "test.ogg"
|
||||
audio.write_bytes(b"fake")
|
||||
|
||||
seg1 = MagicMock()
|
||||
seg1.text = "Hello"
|
||||
seg2 = MagicMock()
|
||||
seg2.text = " world"
|
||||
mock_info = MagicMock()
|
||||
mock_info.language = "en"
|
||||
mock_info.duration = 3.0
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.transcribe.return_value = ([seg1, seg2], mock_info)
|
||||
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("tools.transcription_tools.WhisperModel", return_value=mock_model), \
|
||||
patch("tools.transcription_tools._local_model", None):
|
||||
from tools.transcription_tools import _transcribe_local
|
||||
result = _transcribe_local(str(audio), "base")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == "Hello world"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -103,13 +386,25 @@ class TestModelAutoCorrection:
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "hello world"
|
||||
|
||||
with patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import transcribe_audio, DEFAULT_GROQ_STT_MODEL
|
||||
result = transcribe_audio(sample_wav, model="whisper-1")
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq, DEFAULT_GROQ_STT_MODEL
|
||||
_transcribe_groq(sample_wav, "whisper-1")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == DEFAULT_GROQ_STT_MODEL
|
||||
|
||||
def test_groq_corrects_gpt4o_transcribe(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq, DEFAULT_GROQ_STT_MODEL
|
||||
_transcribe_groq(sample_wav, "gpt-4o-transcribe")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == "hello world"
|
||||
# Verify the model was corrected to Groq default
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == DEFAULT_GROQ_STT_MODEL
|
||||
|
||||
@@ -119,84 +414,262 @@ class TestModelAutoCorrection:
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "hello world"
|
||||
|
||||
with patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import transcribe_audio, DEFAULT_STT_MODEL
|
||||
result = transcribe_audio(sample_wav, model="whisper-large-v3-turbo")
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai, DEFAULT_STT_MODEL
|
||||
_transcribe_openai(sample_wav, "whisper-large-v3-turbo")
|
||||
|
||||
assert result["success"] is True
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == DEFAULT_STT_MODEL
|
||||
|
||||
def test_none_model_uses_provider_default(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
def test_openai_corrects_distil_whisper(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import transcribe_audio, DEFAULT_GROQ_STT_MODEL
|
||||
transcribe_audio(sample_wav, model=None)
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai, DEFAULT_STT_MODEL
|
||||
_transcribe_openai(sample_wav, "distil-whisper-large-v3-en")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == DEFAULT_GROQ_STT_MODEL
|
||||
assert call_kwargs.kwargs["model"] == DEFAULT_STT_MODEL
|
||||
|
||||
def test_compatible_model_not_overridden(self, monkeypatch, sample_wav):
|
||||
def test_compatible_groq_model_not_overridden(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
transcribe_audio(sample_wav, model="whisper-large-v3")
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
_transcribe_groq(sample_wav, "whisper-large-v3")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == "whisper-large-v3"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# transcribe_audio -- success path
|
||||
# ============================================================================
|
||||
|
||||
class TestTranscribeAudioSuccess:
|
||||
def test_successful_transcription(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "hello world"
|
||||
|
||||
with patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio(sample_wav)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == "hello world"
|
||||
assert result["provider"] == "groq"
|
||||
|
||||
def test_api_error_returns_failure(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.side_effect = Exception("API error")
|
||||
|
||||
with patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio(sample_wav)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "API error" in result["error"]
|
||||
|
||||
def test_whitespace_transcript_stripped(self, monkeypatch, sample_wav):
|
||||
def test_compatible_openai_model_not_overridden(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = " hello world \n"
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("openai.OpenAI", return_value=mock_client):
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai
|
||||
_transcribe_openai(sample_wav, "gpt-4o-mini-transcribe")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == "gpt-4o-mini-transcribe"
|
||||
|
||||
def test_unknown_model_passes_through_groq(self, monkeypatch, sample_wav):
|
||||
"""A model not in either known set should not be overridden."""
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
_transcribe_groq(sample_wav, "my-custom-model")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == "my-custom-model"
|
||||
|
||||
def test_unknown_model_passes_through_openai(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai
|
||||
_transcribe_openai(sample_wav, "my-custom-model")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == "my-custom-model"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _load_stt_config
|
||||
# ============================================================================
|
||||
|
||||
class TestLoadSttConfig:
|
||||
def test_returns_dict_when_import_fails(self):
|
||||
with patch("tools.transcription_tools._load_stt_config") as mock_load:
|
||||
mock_load.return_value = {}
|
||||
from tools.transcription_tools import _load_stt_config
|
||||
assert _load_stt_config() == {}
|
||||
|
||||
def test_real_load_returns_dict(self):
|
||||
"""_load_stt_config should always return a dict, even on import error."""
|
||||
with patch.dict("sys.modules", {"hermes_cli": None, "hermes_cli.config": None}):
|
||||
from tools.transcription_tools import _load_stt_config
|
||||
result = _load_stt_config()
|
||||
assert isinstance(result, dict)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _validate_audio_file — edge cases
|
||||
# ============================================================================
|
||||
|
||||
class TestValidateAudioFileEdgeCases:
|
||||
def test_directory_is_not_a_file(self, tmp_path):
|
||||
from tools.transcription_tools import _validate_audio_file
|
||||
# tmp_path itself is a directory with an .ogg-ish name? No.
|
||||
# Create a directory with a valid audio extension
|
||||
d = tmp_path / "audio.ogg"
|
||||
d.mkdir()
|
||||
result = _validate_audio_file(str(d))
|
||||
assert result is not None
|
||||
assert "not a file" in result["error"]
|
||||
|
||||
def test_stat_oserror(self, tmp_path):
|
||||
f = tmp_path / "test.ogg"
|
||||
f.write_bytes(b"data")
|
||||
from tools.transcription_tools import _validate_audio_file
|
||||
real_stat = f.stat()
|
||||
call_count = 0
|
||||
|
||||
def stat_side_effect(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# First calls are from exists() and is_file(), let them pass
|
||||
if call_count <= 2:
|
||||
return real_stat
|
||||
raise OSError("disk error")
|
||||
|
||||
with patch("pathlib.Path.stat", side_effect=stat_side_effect):
|
||||
result = _validate_audio_file(str(f))
|
||||
assert result is not None
|
||||
assert "Failed to access" in result["error"]
|
||||
|
||||
def test_all_supported_formats_accepted(self, tmp_path):
|
||||
from tools.transcription_tools import _validate_audio_file, SUPPORTED_FORMATS
|
||||
for fmt in SUPPORTED_FORMATS:
|
||||
f = tmp_path / f"test{fmt}"
|
||||
f.write_bytes(b"data")
|
||||
assert _validate_audio_file(str(f)) is None, f"Format {fmt} should be accepted"
|
||||
|
||||
def test_case_insensitive_extension(self, tmp_path):
|
||||
from tools.transcription_tools import _validate_audio_file
|
||||
f = tmp_path / "test.MP3"
|
||||
f.write_bytes(b"data")
|
||||
assert _validate_audio_file(str(f)) is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# transcribe_audio — end-to-end dispatch
|
||||
# ============================================================================
|
||||
|
||||
class TestTranscribeAudioDispatch:
|
||||
def test_dispatches_to_groq(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={"provider": "groq"}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="groq"), \
|
||||
patch("tools.transcription_tools._transcribe_groq",
|
||||
return_value={"success": True, "transcript": "hi", "provider": "groq"}) as mock_groq:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio(sample_wav)
|
||||
result = transcribe_audio(sample_ogg)
|
||||
|
||||
assert result["transcript"] == "hello world"
|
||||
assert result["success"] is True
|
||||
assert result["provider"] == "groq"
|
||||
mock_groq.assert_called_once()
|
||||
|
||||
def test_dispatches_to_local(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="local"), \
|
||||
patch("tools.transcription_tools._transcribe_local",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_local:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio(sample_ogg)
|
||||
|
||||
assert result["success"] is True
|
||||
mock_local.assert_called_once()
|
||||
|
||||
def test_dispatches_to_openai(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={"provider": "openai"}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="openai"), \
|
||||
patch("tools.transcription_tools._transcribe_openai",
|
||||
return_value={"success": True, "transcript": "hi", "provider": "openai"}) as mock_openai:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio(sample_ogg)
|
||||
|
||||
assert result["success"] is True
|
||||
mock_openai.assert_called_once()
|
||||
|
||||
def test_no_provider_returns_error(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="none"):
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio(sample_ogg)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "No STT provider" in result["error"]
|
||||
assert "faster-whisper" in result["error"]
|
||||
assert "GROQ_API_KEY" in result["error"]
|
||||
|
||||
def test_invalid_file_short_circuits(self):
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio("/nonexistent/audio.wav")
|
||||
assert result["success"] is False
|
||||
assert "not found" in result["error"]
|
||||
|
||||
def test_model_override_passed_to_groq(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="groq"), \
|
||||
patch("tools.transcription_tools._transcribe_groq",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_groq:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
transcribe_audio(sample_ogg, model="whisper-large-v3")
|
||||
|
||||
_, kwargs = mock_groq.call_args
|
||||
assert kwargs.get("model_name") or mock_groq.call_args[0][1] == "whisper-large-v3"
|
||||
|
||||
def test_model_override_passed_to_local(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="local"), \
|
||||
patch("tools.transcription_tools._transcribe_local",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_local:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
transcribe_audio(sample_ogg, model="large-v3")
|
||||
|
||||
assert mock_local.call_args[0][1] == "large-v3"
|
||||
|
||||
def test_default_model_used_when_none(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="groq"), \
|
||||
patch("tools.transcription_tools._transcribe_groq",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_groq:
|
||||
from tools.transcription_tools import transcribe_audio, DEFAULT_GROQ_STT_MODEL
|
||||
transcribe_audio(sample_ogg, model=None)
|
||||
|
||||
assert mock_groq.call_args[0][1] == DEFAULT_GROQ_STT_MODEL
|
||||
|
||||
def test_config_local_model_used(self, sample_ogg):
|
||||
config = {"local": {"model": "small"}}
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value=config), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="local"), \
|
||||
patch("tools.transcription_tools._transcribe_local",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_local:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
transcribe_audio(sample_ogg, model=None)
|
||||
|
||||
assert mock_local.call_args[0][1] == "small"
|
||||
|
||||
def test_config_openai_model_used(self, sample_ogg):
|
||||
config = {"openai": {"model": "gpt-4o-transcribe"}}
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value=config), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="openai"), \
|
||||
patch("tools.transcription_tools._transcribe_openai",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_openai:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
transcribe_audio(sample_ogg, model=None)
|
||||
|
||||
assert mock_openai.call_args[0][1] == "gpt-4o-transcribe"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
@@ -2,21 +2,20 @@
|
||||
"""
|
||||
Transcription Tools Module
|
||||
|
||||
Provides speech-to-text transcription using OpenAI-compatible Whisper APIs.
|
||||
Supports multiple providers with automatic fallback:
|
||||
1. OpenAI (VOICE_TOOLS_OPENAI_KEY) -- paid
|
||||
2. Groq (GROQ_API_KEY) -- free tier available
|
||||
Provides speech-to-text transcription with three providers:
|
||||
|
||||
- **local** (default, free) — faster-whisper running locally, no API key needed.
|
||||
Auto-downloads the model (~150 MB for ``base``) on first use.
|
||||
- **groq** (free tier) — Groq Whisper API, requires ``GROQ_API_KEY``.
|
||||
- **openai** (paid) — OpenAI Whisper API, requires ``VOICE_TOOLS_OPENAI_KEY``.
|
||||
|
||||
Used by the messaging gateway to automatically transcribe voice messages
|
||||
sent by users on Telegram, Discord, WhatsApp, and Slack.
|
||||
|
||||
Supported models:
|
||||
OpenAI: whisper-1, gpt-4o-mini-transcribe, gpt-4o-transcribe
|
||||
Groq: whisper-large-v3, whisper-large-v3-turbo, distil-whisper-large-v3-en
|
||||
sent by users on Telegram, Discord, WhatsApp, Slack, and Signal.
|
||||
|
||||
Supported input formats: mp3, mp4, mpeg, mpga, m4a, wav, webm, ogg
|
||||
|
||||
Usage:
|
||||
Usage::
|
||||
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
|
||||
result = transcribe_audio("/path/to/audio.ogg")
|
||||
@@ -27,19 +26,54 @@ Usage:
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Optional imports — graceful degradation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Default STT models per provider (overridable via env)
|
||||
try:
|
||||
from faster_whisper import WhisperModel
|
||||
_HAS_FASTER_WHISPER = True
|
||||
except ImportError:
|
||||
_HAS_FASTER_WHISPER = False
|
||||
WhisperModel = None # type: ignore[assignment,misc]
|
||||
|
||||
try:
|
||||
from openai import OpenAI, APIError, APIConnectionError, APITimeoutError
|
||||
_HAS_OPENAI = True
|
||||
except ImportError:
|
||||
_HAS_OPENAI = False
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_PROVIDER = "local"
|
||||
DEFAULT_LOCAL_MODEL = "base"
|
||||
DEFAULT_STT_MODEL = os.getenv("STT_OPENAI_MODEL", "whisper-1")
|
||||
DEFAULT_GROQ_STT_MODEL = os.getenv("STT_GROQ_MODEL", "whisper-large-v3-turbo")
|
||||
|
||||
# Provider endpoints (overridable via env for proxies / self-hosted)
|
||||
GROQ_BASE_URL = os.getenv("GROQ_BASE_URL", "https://api.groq.com/openai/v1")
|
||||
OPENAI_BASE_URL = os.getenv("STT_OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
|
||||
SUPPORTED_FORMATS = {".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm", ".ogg"}
|
||||
MAX_FILE_SIZE = 25 * 1024 * 1024 # 25 MB
|
||||
|
||||
# Known model sets for auto-correction
|
||||
OPENAI_MODELS = {"whisper-1", "gpt-4o-mini-transcribe", "gpt-4o-transcribe"}
|
||||
GROQ_MODELS = {"whisper-large-v3", "whisper-large-v3-turbo", "distil-whisper-large-v3-en"}
|
||||
|
||||
# Singleton for the local model — loaded once, reused across calls
|
||||
_local_model: Optional["WhisperModel"] = None
|
||||
_local_model_name: Optional[str] = None
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_stt_model_from_config() -> Optional[str]:
|
||||
"""Read the STT model name from ~/.hermes/config.yaml.
|
||||
@@ -59,40 +93,238 @@ def get_stt_model_from_config() -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_stt_provider() -> Tuple[Optional[str], Optional[str], str]:
|
||||
"""Resolve which STT provider to use based on available API keys.
|
||||
def _load_stt_config() -> dict:
|
||||
"""Load the ``stt`` section from user config, falling back to defaults."""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
return load_config().get("stt", {})
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
Returns:
|
||||
Tuple of (api_key, base_url, provider_name).
|
||||
api_key is None if no provider is available.
|
||||
|
||||
def _get_provider(stt_config: dict) -> str:
|
||||
"""Determine which STT provider to use.
|
||||
|
||||
Priority:
|
||||
1. Explicit config value (``stt.provider``)
|
||||
2. Auto-detect: local > groq (free) > openai (paid)
|
||||
3. Disabled (returns "none")
|
||||
"""
|
||||
openai_key = os.getenv("VOICE_TOOLS_OPENAI_KEY")
|
||||
if openai_key:
|
||||
return openai_key, OPENAI_BASE_URL, "openai"
|
||||
provider = stt_config.get("provider", DEFAULT_PROVIDER)
|
||||
|
||||
groq_key = os.getenv("GROQ_API_KEY")
|
||||
if groq_key:
|
||||
return groq_key, GROQ_BASE_URL, "groq"
|
||||
if provider == "local":
|
||||
if _HAS_FASTER_WHISPER:
|
||||
return "local"
|
||||
# Local requested but not available — fall back to groq, then openai
|
||||
if _HAS_OPENAI and os.getenv("GROQ_API_KEY"):
|
||||
logger.info("faster-whisper not installed, falling back to Groq Whisper API")
|
||||
return "groq"
|
||||
if _HAS_OPENAI and os.getenv("VOICE_TOOLS_OPENAI_KEY"):
|
||||
logger.info("faster-whisper not installed, falling back to OpenAI Whisper API")
|
||||
return "openai"
|
||||
return "none"
|
||||
|
||||
return None, None, "none"
|
||||
if provider == "groq":
|
||||
if _HAS_OPENAI and os.getenv("GROQ_API_KEY"):
|
||||
return "groq"
|
||||
# Groq requested but no key — fall back
|
||||
if _HAS_FASTER_WHISPER:
|
||||
logger.info("GROQ_API_KEY not set, falling back to local faster-whisper")
|
||||
return "local"
|
||||
if _HAS_OPENAI and os.getenv("VOICE_TOOLS_OPENAI_KEY"):
|
||||
logger.info("GROQ_API_KEY not set, falling back to OpenAI Whisper API")
|
||||
return "openai"
|
||||
return "none"
|
||||
|
||||
# Supported audio formats
|
||||
SUPPORTED_FORMATS = {".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm", ".ogg"}
|
||||
if provider == "openai":
|
||||
if _HAS_OPENAI and os.getenv("VOICE_TOOLS_OPENAI_KEY"):
|
||||
return "openai"
|
||||
# OpenAI requested but no key — fall back
|
||||
if _HAS_FASTER_WHISPER:
|
||||
logger.info("VOICE_TOOLS_OPENAI_KEY not set, falling back to local faster-whisper")
|
||||
return "local"
|
||||
if _HAS_OPENAI and os.getenv("GROQ_API_KEY"):
|
||||
logger.info("VOICE_TOOLS_OPENAI_KEY not set, falling back to Groq Whisper API")
|
||||
return "groq"
|
||||
return "none"
|
||||
|
||||
# Maximum file size (25MB - OpenAI limit)
|
||||
MAX_FILE_SIZE = 25 * 1024 * 1024
|
||||
return provider # Unknown — let it fail downstream
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _validate_audio_file(file_path: str) -> Optional[Dict[str, Any]]:
|
||||
"""Validate the audio file. Returns an error dict or None if OK."""
|
||||
audio_path = Path(file_path)
|
||||
|
||||
if not audio_path.exists():
|
||||
return {"success": False, "transcript": "", "error": f"Audio file not found: {file_path}"}
|
||||
if not audio_path.is_file():
|
||||
return {"success": False, "transcript": "", "error": f"Path is not a file: {file_path}"}
|
||||
if audio_path.suffix.lower() not in SUPPORTED_FORMATS:
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Unsupported format: {audio_path.suffix}. Supported: {', '.join(sorted(SUPPORTED_FORMATS))}",
|
||||
}
|
||||
try:
|
||||
file_size = audio_path.stat().st_size
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"File too large: {file_size / (1024*1024):.1f}MB (max {MAX_FILE_SIZE / (1024*1024):.0f}MB)",
|
||||
}
|
||||
except OSError as e:
|
||||
return {"success": False, "transcript": "", "error": f"Failed to access file: {e}"}
|
||||
|
||||
return None
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider: local (faster-whisper)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _transcribe_local(file_path: str, model_name: str) -> Dict[str, Any]:
|
||||
"""Transcribe using faster-whisper (local, free)."""
|
||||
global _local_model, _local_model_name
|
||||
|
||||
if not _HAS_FASTER_WHISPER:
|
||||
return {"success": False, "transcript": "", "error": "faster-whisper not installed"}
|
||||
|
||||
try:
|
||||
# Lazy-load the model (downloads on first use, ~150 MB for 'base')
|
||||
if _local_model is None or _local_model_name != model_name:
|
||||
logger.info("Loading faster-whisper model '%s' (first load downloads the model)...", model_name)
|
||||
_local_model = WhisperModel(model_name, device="auto", compute_type="auto")
|
||||
_local_model_name = model_name
|
||||
|
||||
segments, info = _local_model.transcribe(file_path, beam_size=5)
|
||||
transcript = " ".join(segment.text.strip() for segment in segments)
|
||||
|
||||
logger.info(
|
||||
"Transcribed %s via local whisper (%s, lang=%s, %.1fs audio)",
|
||||
Path(file_path).name, model_name, info.language, info.duration,
|
||||
)
|
||||
|
||||
return {"success": True, "transcript": transcript}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Local transcription failed: %s", e, exc_info=True)
|
||||
return {"success": False, "transcript": "", "error": f"Local transcription failed: {e}"}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider: groq (Whisper API — free tier)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _transcribe_groq(file_path: str, model_name: str) -> Dict[str, Any]:
|
||||
"""Transcribe using Groq Whisper API (free tier available)."""
|
||||
api_key = os.getenv("GROQ_API_KEY")
|
||||
if not api_key:
|
||||
return {"success": False, "transcript": "", "error": "GROQ_API_KEY not set"}
|
||||
|
||||
if not _HAS_OPENAI:
|
||||
return {"success": False, "transcript": "", "error": "openai package not installed"}
|
||||
|
||||
# Auto-correct model if caller passed an OpenAI-only model
|
||||
if model_name in OPENAI_MODELS:
|
||||
logger.info("Model %s not available on Groq, using %s", model_name, DEFAULT_GROQ_STT_MODEL)
|
||||
model_name = DEFAULT_GROQ_STT_MODEL
|
||||
|
||||
try:
|
||||
client = OpenAI(api_key=api_key, base_url=GROQ_BASE_URL, timeout=30, max_retries=0)
|
||||
|
||||
with open(file_path, "rb") as audio_file:
|
||||
transcription = client.audio.transcriptions.create(
|
||||
model=model_name,
|
||||
file=audio_file,
|
||||
response_format="text",
|
||||
)
|
||||
|
||||
transcript_text = str(transcription).strip()
|
||||
logger.info("Transcribed %s via Groq API (%s, %d chars)",
|
||||
Path(file_path).name, model_name, len(transcript_text))
|
||||
|
||||
return {"success": True, "transcript": transcript_text, "provider": "groq"}
|
||||
|
||||
except PermissionError:
|
||||
return {"success": False, "transcript": "", "error": f"Permission denied: {file_path}"}
|
||||
except APIConnectionError as e:
|
||||
return {"success": False, "transcript": "", "error": f"Connection error: {e}"}
|
||||
except APITimeoutError as e:
|
||||
return {"success": False, "transcript": "", "error": f"Request timeout: {e}"}
|
||||
except APIError as e:
|
||||
return {"success": False, "transcript": "", "error": f"API error: {e}"}
|
||||
except Exception as e:
|
||||
logger.error("Groq transcription failed: %s", e, exc_info=True)
|
||||
return {"success": False, "transcript": "", "error": f"Transcription failed: {e}"}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider: openai (Whisper API)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _transcribe_openai(file_path: str, model_name: str) -> Dict[str, Any]:
|
||||
"""Transcribe using OpenAI Whisper API (paid)."""
|
||||
api_key = os.getenv("VOICE_TOOLS_OPENAI_KEY")
|
||||
if not api_key:
|
||||
return {"success": False, "transcript": "", "error": "VOICE_TOOLS_OPENAI_KEY not set"}
|
||||
|
||||
if not _HAS_OPENAI:
|
||||
return {"success": False, "transcript": "", "error": "openai package not installed"}
|
||||
|
||||
# Auto-correct model if caller passed a Groq-only model
|
||||
if model_name in GROQ_MODELS:
|
||||
logger.info("Model %s not available on OpenAI, using %s", model_name, DEFAULT_STT_MODEL)
|
||||
model_name = DEFAULT_STT_MODEL
|
||||
|
||||
try:
|
||||
client = OpenAI(api_key=api_key, base_url=OPENAI_BASE_URL, timeout=30, max_retries=0)
|
||||
|
||||
with open(file_path, "rb") as audio_file:
|
||||
transcription = client.audio.transcriptions.create(
|
||||
model=model_name,
|
||||
file=audio_file,
|
||||
response_format="text",
|
||||
)
|
||||
|
||||
transcript_text = str(transcription).strip()
|
||||
logger.info("Transcribed %s via OpenAI API (%s, %d chars)",
|
||||
Path(file_path).name, model_name, len(transcript_text))
|
||||
|
||||
return {"success": True, "transcript": transcript_text, "provider": "openai"}
|
||||
|
||||
except PermissionError:
|
||||
return {"success": False, "transcript": "", "error": f"Permission denied: {file_path}"}
|
||||
except APIConnectionError as e:
|
||||
return {"success": False, "transcript": "", "error": f"Connection error: {e}"}
|
||||
except APITimeoutError as e:
|
||||
return {"success": False, "transcript": "", "error": f"Request timeout: {e}"}
|
||||
except APIError as e:
|
||||
return {"success": False, "transcript": "", "error": f"API error: {e}"}
|
||||
except Exception as e:
|
||||
logger.error("OpenAI transcription failed: %s", e, exc_info=True)
|
||||
return {"success": False, "transcript": "", "error": f"Transcription failed: {e}"}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Transcribe an audio file using an OpenAI-compatible Whisper API.
|
||||
Transcribe an audio file using the configured STT provider.
|
||||
|
||||
Automatically selects the provider based on available API keys:
|
||||
VOICE_TOOLS_OPENAI_KEY (OpenAI) > GROQ_API_KEY (Groq).
|
||||
Provider priority:
|
||||
1. User config (``stt.provider`` in config.yaml)
|
||||
2. Auto-detect: local faster-whisper (free) > Groq (free tier) > OpenAI (paid)
|
||||
|
||||
Args:
|
||||
file_path: Absolute path to the audio file to transcribe.
|
||||
model: Whisper model to use. Defaults per provider if not specified.
|
||||
model: Override the model. If None, uses config or provider default.
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
@@ -101,125 +333,36 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
|
||||
- "error" (str, optional): Error message if success is False
|
||||
- "provider" (str, optional): Which provider was used
|
||||
"""
|
||||
api_key, base_url, provider = _resolve_stt_provider()
|
||||
if not api_key:
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": "No STT API key set. Set VOICE_TOOLS_OPENAI_KEY or GROQ_API_KEY.",
|
||||
}
|
||||
# Validate input
|
||||
error = _validate_audio_file(file_path)
|
||||
if error:
|
||||
return error
|
||||
|
||||
audio_path = Path(file_path)
|
||||
|
||||
# Validate file exists
|
||||
if not audio_path.exists():
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Audio file not found: {file_path}",
|
||||
}
|
||||
|
||||
if not audio_path.is_file():
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Path is not a file: {file_path}",
|
||||
}
|
||||
|
||||
# Validate file extension
|
||||
if audio_path.suffix.lower() not in SUPPORTED_FORMATS:
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Unsupported file format: {audio_path.suffix}. Supported formats: {', '.join(sorted(SUPPORTED_FORMATS))}",
|
||||
}
|
||||
|
||||
# Validate file size
|
||||
try:
|
||||
file_size = audio_path.stat().st_size
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"File too large: {file_size / (1024*1024):.1f}MB (max {MAX_FILE_SIZE / (1024*1024)}MB)",
|
||||
}
|
||||
except OSError as e:
|
||||
logger.error("Failed to get file size for %s: %s", file_path, e, exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Failed to access file: {e}",
|
||||
}
|
||||
# Load config and determine provider
|
||||
stt_config = _load_stt_config()
|
||||
provider = _get_provider(stt_config)
|
||||
|
||||
# Use provided model, or fall back to provider default.
|
||||
# If the caller passed an OpenAI-only model but we resolved to Groq, override it.
|
||||
OPENAI_MODELS = {"whisper-1", "gpt-4o-mini-transcribe", "gpt-4o-transcribe"}
|
||||
GROQ_MODELS = {"whisper-large-v3", "whisper-large-v3-turbo", "distil-whisper-large-v3-en"}
|
||||
if provider == "local":
|
||||
local_cfg = stt_config.get("local", {})
|
||||
model_name = model or local_cfg.get("model", DEFAULT_LOCAL_MODEL)
|
||||
return _transcribe_local(file_path, model_name)
|
||||
|
||||
if model is None:
|
||||
model = DEFAULT_GROQ_STT_MODEL if provider == "groq" else DEFAULT_STT_MODEL
|
||||
elif provider == "groq" and model in OPENAI_MODELS:
|
||||
logger.info("Model %s not available on Groq, using %s", model, DEFAULT_GROQ_STT_MODEL)
|
||||
model = DEFAULT_GROQ_STT_MODEL
|
||||
elif provider == "openai" and model in GROQ_MODELS:
|
||||
logger.info("Model %s not available on OpenAI, using %s", model, DEFAULT_STT_MODEL)
|
||||
model = DEFAULT_STT_MODEL
|
||||
if provider == "groq":
|
||||
model_name = model or DEFAULT_GROQ_STT_MODEL
|
||||
return _transcribe_groq(file_path, model_name)
|
||||
|
||||
try:
|
||||
from openai import OpenAI, APIError, APIConnectionError, APITimeoutError
|
||||
if provider == "openai":
|
||||
openai_cfg = stt_config.get("openai", {})
|
||||
model_name = model or openai_cfg.get("model", DEFAULT_STT_MODEL)
|
||||
return _transcribe_openai(file_path, model_name)
|
||||
|
||||
client = OpenAI(api_key=api_key, base_url=base_url, timeout=30, max_retries=0)
|
||||
|
||||
with open(file_path, "rb") as audio_file:
|
||||
transcription = client.audio.transcriptions.create(
|
||||
model=model,
|
||||
file=audio_file,
|
||||
response_format="text",
|
||||
)
|
||||
|
||||
# The response is a plain string when response_format="text"
|
||||
transcript_text = str(transcription).strip()
|
||||
|
||||
logger.info("Transcribed %s (%d chars, provider=%s)", audio_path.name, len(transcript_text), provider)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"transcript": transcript_text,
|
||||
"provider": provider,
|
||||
}
|
||||
|
||||
except PermissionError:
|
||||
logger.error("Permission denied accessing file: %s", file_path, exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Permission denied: {file_path}",
|
||||
}
|
||||
except APIConnectionError as e:
|
||||
logger.error("API connection error during transcription: %s", e, exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Connection error: {e}",
|
||||
}
|
||||
except APITimeoutError as e:
|
||||
logger.error("API timeout during transcription: %s", e, exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Request timeout: {e}",
|
||||
}
|
||||
except APIError as e:
|
||||
logger.error("OpenAI API error during transcription: %s", e, exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"API error: {e}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error during transcription: %s", e, exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Transcription failed: {e}",
|
||||
}
|
||||
# No provider available
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": (
|
||||
"No STT provider available. Install faster-whisper for free local "
|
||||
"transcription, set GROQ_API_KEY for free Groq Whisper, "
|
||||
"or set VOICE_TOOLS_OPENAI_KEY for the OpenAI Whisper API."
|
||||
),
|
||||
}
|
||||
|
||||
@@ -77,14 +77,19 @@ sudo apt install portaudio19-dev ffmpeg libopus0
|
||||
Add to `~/.hermes/.env`:
|
||||
|
||||
```bash
|
||||
# Speech-to-Text (at least one required)
|
||||
GROQ_API_KEY=your-key # Groq Whisper — fast, free tier (recommended for most users)
|
||||
VOICE_TOOLS_OPENAI_KEY=your-key # OpenAI Whisper — used first if both keys are set
|
||||
# Speech-to-Text — local provider needs NO key at all
|
||||
# pip install faster-whisper # Free, runs locally, recommended
|
||||
GROQ_API_KEY=your-key # Groq Whisper — fast, free tier (cloud)
|
||||
VOICE_TOOLS_OPENAI_KEY=your-key # OpenAI Whisper — paid (cloud)
|
||||
|
||||
# Text-to-Speech (optional — Edge TTS works without any key)
|
||||
ELEVENLABS_API_KEY=your-key # ElevenLabs — premium quality
|
||||
ELEVENLABS_API_KEY=your-key # ElevenLabs — premium quality
|
||||
```
|
||||
|
||||
:::tip
|
||||
If `faster-whisper` is installed, voice mode works with **zero API keys** for STT. The model (~150 MB for `base`) downloads automatically on first use.
|
||||
:::
|
||||
|
||||
---
|
||||
|
||||
## CLI Voice Mode
|
||||
@@ -293,8 +298,8 @@ The bot auto-loads the codec from:
|
||||
DISCORD_BOT_TOKEN=your-bot-token
|
||||
DISCORD_ALLOWED_USERS=your-user-id
|
||||
|
||||
# STT — at least one required for voice channel listening
|
||||
GROQ_API_KEY=your-key # Recommended (fast, free tier)
|
||||
# STT — local provider needs no key (pip install faster-whisper)
|
||||
# GROQ_API_KEY=your-key # Alternative: cloud-based, fast, free tier
|
||||
|
||||
# TTS — optional, Edge TTS (free) is the default
|
||||
# ELEVENLABS_API_KEY=your-key # Premium quality
|
||||
@@ -329,7 +334,7 @@ When the bot joins a voice channel, it:
|
||||
|
||||
1. **Listens** to each user's audio stream independently
|
||||
2. **Detects silence** — 1.5s of silence after at least 0.5s of speech triggers processing
|
||||
3. **Transcribes** the audio via Whisper STT (Groq or OpenAI)
|
||||
3. **Transcribes** the audio via Whisper STT (local, Groq, or OpenAI)
|
||||
4. **Processes** through the full agent pipeline (session, tools, memory)
|
||||
5. **Speaks** the reply back in the voice channel via TTS
|
||||
|
||||
@@ -371,8 +376,10 @@ voice:
|
||||
|
||||
# Speech-to-Text
|
||||
stt:
|
||||
enabled: true
|
||||
model: "whisper-1" # Or: whisper-large-v3-turbo (Groq)
|
||||
provider: "local" # "local" (free) | "groq" | "openai"
|
||||
local:
|
||||
model: "base" # tiny, base, small, medium, large-v3
|
||||
# model: "whisper-1" # Legacy: used when provider is not set
|
||||
|
||||
# Text-to-Speech
|
||||
tts:
|
||||
@@ -390,9 +397,10 @@ tts:
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# Speech-to-Text providers
|
||||
GROQ_API_KEY=... # Groq Whisper (recommended — fast, free tier)
|
||||
VOICE_TOOLS_OPENAI_KEY=... # OpenAI Whisper (used first if both set)
|
||||
# Speech-to-Text providers (local needs no key)
|
||||
# pip install faster-whisper # Free local STT — no API key needed
|
||||
GROQ_API_KEY=... # Groq Whisper (fast, free tier)
|
||||
VOICE_TOOLS_OPENAI_KEY=... # OpenAI Whisper (paid)
|
||||
|
||||
# STT advanced overrides (optional)
|
||||
STT_GROQ_MODEL=whisper-large-v3-turbo # Override default Groq STT model
|
||||
@@ -411,12 +419,17 @@ DISCORD_ALLOWED_USERS=...
|
||||
|
||||
### STT Provider Comparison
|
||||
|
||||
| Provider | Model | Speed | Quality | Cost |
|
||||
|----------|-------|-------|---------|------|
|
||||
| **Groq** | `whisper-large-v3-turbo` | Very fast (~0.5s) | Good | Free tier |
|
||||
| **Groq** | `whisper-large-v3` | Fast (~1s) | Better | Free tier |
|
||||
| **OpenAI** | `whisper-1` | Fast (~1s) | Good | Low |
|
||||
| **OpenAI** | `gpt-4o-transcribe` | Medium (~2s) | Best | Higher |
|
||||
| Provider | Model | Speed | Quality | Cost | API Key |
|
||||
|----------|-------|-------|---------|------|---------|
|
||||
| **Local** | `base` | Fast (depends on CPU/GPU) | Good | Free | No |
|
||||
| **Local** | `small` | Medium | Better | Free | No |
|
||||
| **Local** | `large-v3` | Slow | Best | Free | No |
|
||||
| **Groq** | `whisper-large-v3-turbo` | Very fast (~0.5s) | Good | Free tier | Yes |
|
||||
| **Groq** | `whisper-large-v3` | Fast (~1s) | Better | Free tier | Yes |
|
||||
| **OpenAI** | `whisper-1` | Fast (~1s) | Good | Paid | Yes |
|
||||
| **OpenAI** | `gpt-4o-transcribe` | Medium (~2s) | Best | Paid | Yes |
|
||||
|
||||
Provider priority (automatic fallback): **local** > **groq** > **openai**
|
||||
|
||||
### TTS Provider Comparison
|
||||
|
||||
@@ -455,7 +468,7 @@ The bot requires an @mention by default in server channels. Make sure you:
|
||||
|
||||
### Bot hears me but doesn't respond
|
||||
|
||||
- Verify STT key is set (`GROQ_API_KEY` or `VOICE_TOOLS_OPENAI_KEY`)
|
||||
- Verify STT is available: install `faster-whisper` (no key needed) or set `GROQ_API_KEY` / `VOICE_TOOLS_OPENAI_KEY`
|
||||
- Check the LLM model is configured and accessible
|
||||
- Review gateway logs: `tail -f ~/.hermes/logs/gateway.log`
|
||||
|
||||
|
||||
Reference in New Issue
Block a user