diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 8ca1a2cfb..5180e9b31 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -175,29 +175,51 @@ def cache_audio_from_bytes(data: bytes, ext: str = ".ogg") -> str: return str(filepath) -async def cache_audio_from_url(url: str, ext: str = ".ogg") -> str: +async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) -> str: """ Download an audio file from a URL and save it to the local cache. + Retries on transient failures (timeouts, 429, 5xx) with exponential + backoff so a single slow CDN response doesn't lose the media. + Args: url: The HTTP/HTTPS URL to download from. ext: File extension including the dot (e.g. ".ogg", ".mp3"). + retries: Number of retry attempts on transient failures. Returns: Absolute path to the cached audio file as a string. """ + import asyncio import httpx + import logging as _logging + _log = _logging.getLogger(__name__) + last_exc = None async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: - response = await client.get( - url, - headers={ - "User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)", - "Accept": "audio/*,*/*;q=0.8", - }, - ) - response.raise_for_status() - return cache_audio_from_bytes(response.content, ext) + for attempt in range(retries + 1): + try: + response = await client.get( + url, + headers={ + "User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)", + "Accept": "audio/*,*/*;q=0.8", + }, + ) + response.raise_for_status() + return cache_audio_from_bytes(response.content, ext) + except (httpx.TimeoutException, httpx.HTTPStatusError) as exc: + last_exc = exc + if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429: + raise + if attempt < retries: + wait = 1.5 * (attempt + 1) + _log.debug("Audio cache retry %d/%d for %s (%.1fs): %s", + attempt + 1, retries, url[:80], wait, exc) + await asyncio.sleep(wait) + continue + raise + raise last_exc # --------------------------------------------------------------------------- diff --git a/tests/gateway/test_media_download_retry.py b/tests/gateway/test_media_download_retry.py index 6a6995212..ad00da246 100644 --- a/tests/gateway/test_media_download_retry.py +++ b/tests/gateway/test_media_download_retry.py @@ -171,6 +171,170 @@ class TestCacheImageFromUrl: mock_sleep.assert_not_called() +# --------------------------------------------------------------------------- +# cache_audio_from_url (base.py) +# --------------------------------------------------------------------------- + +class TestCacheAudioFromUrl: + """Tests for gateway.platforms.base.cache_audio_from_url""" + + def test_success_on_first_attempt(self, tmp_path, monkeypatch): + """A clean 200 response caches the audio and returns a path.""" + monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio") + + fake_response = MagicMock() + fake_response.content = b"\x00\x01 fake audio" + fake_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=fake_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client): + from gateway.platforms.base import cache_audio_from_url + return await cache_audio_from_url( + "http://example.com/voice.ogg", ext=".ogg" + ) + + path = asyncio.run(run()) + assert path.endswith(".ogg") + mock_client.get.assert_called_once() + + def test_retries_on_timeout_then_succeeds(self, tmp_path, monkeypatch): + """A timeout on the first attempt is retried; second attempt succeeds.""" + monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio") + + fake_response = MagicMock() + fake_response.content = b"audio data" + fake_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get = AsyncMock( + side_effect=[_make_timeout_error(), fake_response] + ) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + mock_sleep = AsyncMock() + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client), \ + patch("asyncio.sleep", mock_sleep): + from gateway.platforms.base import cache_audio_from_url + return await cache_audio_from_url( + "http://example.com/voice.ogg", ext=".ogg", retries=2 + ) + + path = asyncio.run(run()) + assert path.endswith(".ogg") + assert mock_client.get.call_count == 2 + mock_sleep.assert_called_once() + + def test_retries_on_429_then_succeeds(self, tmp_path, monkeypatch): + """A 429 response on the first attempt is retried; second attempt succeeds.""" + monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio") + + ok_response = MagicMock() + ok_response.content = b"audio data" + ok_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get = AsyncMock( + side_effect=[_make_http_status_error(429), ok_response] + ) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client), \ + patch("asyncio.sleep", new_callable=AsyncMock): + from gateway.platforms.base import cache_audio_from_url + return await cache_audio_from_url( + "http://example.com/voice.ogg", ext=".ogg", retries=2 + ) + + path = asyncio.run(run()) + assert path.endswith(".ogg") + assert mock_client.get.call_count == 2 + + def test_retries_on_500_then_succeeds(self, tmp_path, monkeypatch): + """A 500 response on the first attempt is retried; second attempt succeeds.""" + monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio") + + ok_response = MagicMock() + ok_response.content = b"audio data" + ok_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get = AsyncMock( + side_effect=[_make_http_status_error(500), ok_response] + ) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client), \ + patch("asyncio.sleep", new_callable=AsyncMock): + from gateway.platforms.base import cache_audio_from_url + return await cache_audio_from_url( + "http://example.com/voice.ogg", ext=".ogg", retries=2 + ) + + path = asyncio.run(run()) + assert path.endswith(".ogg") + assert mock_client.get.call_count == 2 + + def test_raises_after_max_retries_exhausted(self, tmp_path, monkeypatch): + """Timeout on every attempt raises after all retries are consumed.""" + monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio") + + mock_client = AsyncMock() + mock_client.get = AsyncMock(side_effect=_make_timeout_error()) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client), \ + patch("asyncio.sleep", new_callable=AsyncMock): + from gateway.platforms.base import cache_audio_from_url + await cache_audio_from_url( + "http://example.com/voice.ogg", ext=".ogg", retries=2 + ) + + with pytest.raises(httpx.TimeoutException): + asyncio.run(run()) + + # 3 total calls: initial + 2 retries + assert mock_client.get.call_count == 3 + + def test_non_retryable_4xx_raises_immediately(self, tmp_path, monkeypatch): + """A 404 (non-retryable) is raised immediately without any retry.""" + monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio") + + mock_sleep = AsyncMock() + mock_client = AsyncMock() + mock_client.get = AsyncMock(side_effect=_make_http_status_error(404)) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client), \ + patch("asyncio.sleep", mock_sleep): + from gateway.platforms.base import cache_audio_from_url + await cache_audio_from_url( + "http://example.com/voice.ogg", ext=".ogg", retries=2 + ) + + with pytest.raises(httpx.HTTPStatusError): + asyncio.run(run()) + + # Only 1 attempt, no sleep + assert mock_client.get.call_count == 1 + mock_sleep.assert_not_called() + + # --------------------------------------------------------------------------- # Slack mock setup (mirrors existing test_slack.py approach) # ---------------------------------------------------------------------------