Adds vLLM (high-throughput OpenAI-compatible inference server) as a selectable backend alongside the existing Ollama and vllm-mlx backends. vLLM's continuous batching gives 3-10x throughput for agentic workloads. Changes: - config.py: add `vllm` to timmy_model_backend Literal; add vllm_url / vllm_model settings (VLLM_URL / VLLM_MODEL env vars) - cascade.py: add vllm provider type with _check_provider_available (hits /health) and _call_vllm (OpenAI-compatible completions) - providers.yaml: add disabled-by-default vllm-local provider (priority 3, port 8001); bump OpenAI/Anthropic backup priorities to 4/5 - health.py: add _check_vllm/_check_vllm_sync with 30-second TTL cache; /health and /health/sovereignty reflect vLLM status when it is the active backend - docker-compose.yml: add vllm service behind 'vllm' profile (GPU passthrough commented-out template included); add vllm-cache volume - CLAUDE.md: add vLLM row to Service Fallback Matrix - tests: 26 new unit tests covering availability checks, _call_vllm, providers.yaml validation, config options, and health helpers Graceful fallback: if vLLM is unavailable the cascade router automatically falls back to Ollama. The app never crashes. Fixes #1281 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
412 lines
16 KiB
Python
412 lines
16 KiB
Python
"""Unit tests for the vLLM inference backend (issue #1281).
|
|
|
|
Covers:
|
|
- vllm provider type in CascadeRouter availability check
|
|
- _call_vllm method (mocked OpenAI client)
|
|
- providers.yaml loads vllm-local entry
|
|
- vLLM health check helpers in dashboard routes
|
|
- config.py has vllm backend option
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
import yaml
|
|
|
|
from infrastructure.router.cascade import CascadeRouter, Provider, ProviderStatus
|
|
|
|
|
|
# ── Provider availability checks ────────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestVllmProviderAvailability:
|
|
"""Test _check_provider_available for vllm provider type."""
|
|
|
|
def _make_vllm_provider(self, url: str = "http://localhost:8001/v1") -> Provider:
|
|
return Provider(
|
|
name="vllm-local",
|
|
type="vllm",
|
|
enabled=True,
|
|
priority=3,
|
|
base_url=url,
|
|
models=[{"name": "Qwen/Qwen2.5-14B-Instruct", "default": True}],
|
|
)
|
|
|
|
def test_available_when_health_200(self, tmp_path):
|
|
"""Provider is available when /health returns 200."""
|
|
provider = self._make_vllm_provider()
|
|
router = CascadeRouter(config_path=tmp_path / "none.yaml")
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
|
|
with patch("infrastructure.router.cascade.requests") as mock_requests:
|
|
mock_requests.get.return_value = mock_response
|
|
available = router._check_provider_available(provider)
|
|
|
|
assert available is True
|
|
# Verify the health endpoint was called (root, not /v1)
|
|
call_args = mock_requests.get.call_args[0][0]
|
|
assert call_args.endswith("/health")
|
|
assert "/v1" not in call_args
|
|
|
|
def test_unavailable_when_health_non_200(self, tmp_path):
|
|
"""Provider is unavailable when /health returns non-200."""
|
|
provider = self._make_vllm_provider()
|
|
router = CascadeRouter(config_path=tmp_path / "none.yaml")
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 503
|
|
|
|
with patch("infrastructure.router.cascade.requests") as mock_requests:
|
|
mock_requests.get.return_value = mock_response
|
|
available = router._check_provider_available(provider)
|
|
|
|
assert available is False
|
|
|
|
def test_unavailable_on_connection_error(self, tmp_path):
|
|
"""Provider is unavailable when connection fails."""
|
|
provider = self._make_vllm_provider()
|
|
router = CascadeRouter(config_path=tmp_path / "none.yaml")
|
|
|
|
with patch("infrastructure.router.cascade.requests") as mock_requests:
|
|
mock_requests.get.side_effect = ConnectionError("refused")
|
|
available = router._check_provider_available(provider)
|
|
|
|
assert available is False
|
|
|
|
def test_strips_v1_suffix_for_health_check(self, tmp_path):
|
|
"""Health check URL strips /v1 before appending /health."""
|
|
provider = self._make_vllm_provider(url="http://localhost:8001/v1")
|
|
router = CascadeRouter(config_path=tmp_path / "none.yaml")
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
|
|
with patch("infrastructure.router.cascade.requests") as mock_requests:
|
|
mock_requests.get.return_value = mock_response
|
|
router._check_provider_available(provider)
|
|
|
|
called_url = mock_requests.get.call_args[0][0]
|
|
assert called_url == "http://localhost:8001/health"
|
|
|
|
def test_assumes_available_when_requests_none(self, tmp_path):
|
|
"""Gracefully assumes available when requests library is absent."""
|
|
provider = self._make_vllm_provider()
|
|
router = CascadeRouter(config_path=tmp_path / "none.yaml")
|
|
|
|
with patch("infrastructure.router.cascade.requests", None):
|
|
available = router._check_provider_available(provider)
|
|
|
|
assert available is True
|
|
|
|
|
|
# ── _call_vllm method ────────────────────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestCallVllm:
|
|
"""Test CascadeRouter._call_vllm."""
|
|
|
|
def _make_router(self, tmp_path: Path) -> CascadeRouter:
|
|
return CascadeRouter(config_path=tmp_path / "none.yaml")
|
|
|
|
def _make_provider(self, base_url: str = "http://localhost:8001") -> Provider:
|
|
return Provider(
|
|
name="vllm-local",
|
|
type="vllm",
|
|
enabled=True,
|
|
priority=3,
|
|
base_url=base_url,
|
|
models=[{"name": "Qwen/Qwen2.5-14B-Instruct", "default": True}],
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_content_and_model(self, tmp_path):
|
|
"""_call_vllm returns content and model name from API response."""
|
|
router = self._make_router(tmp_path)
|
|
provider = self._make_provider()
|
|
|
|
mock_choice = MagicMock()
|
|
mock_choice.message.content = "Hello from vLLM!"
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [mock_choice]
|
|
mock_response.model = "Qwen/Qwen2.5-14B-Instruct"
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
|
|
with patch("openai.AsyncOpenAI", return_value=mock_client):
|
|
result = await router._call_vllm(
|
|
provider=provider,
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
model="Qwen/Qwen2.5-14B-Instruct",
|
|
temperature=0.7,
|
|
max_tokens=None,
|
|
)
|
|
|
|
assert result["content"] == "Hello from vLLM!"
|
|
assert result["model"] == "Qwen/Qwen2.5-14B-Instruct"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_appends_v1_to_base_url(self, tmp_path):
|
|
"""_call_vllm always points the OpenAI client at base_url/v1."""
|
|
router = self._make_router(tmp_path)
|
|
provider = self._make_provider(base_url="http://localhost:8001")
|
|
|
|
mock_choice = MagicMock()
|
|
mock_choice.message.content = "ok"
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [mock_choice]
|
|
mock_response.model = "model"
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
|
|
with patch("openai.AsyncOpenAI", return_value=mock_client) as mock_openai:
|
|
await router._call_vllm(
|
|
provider=provider,
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
model="model",
|
|
temperature=0.0,
|
|
max_tokens=None,
|
|
)
|
|
_, kwargs = mock_openai.call_args
|
|
assert kwargs["base_url"].endswith("/v1")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_does_not_double_v1(self, tmp_path):
|
|
"""_call_vllm does not append /v1 if base_url already ends with it."""
|
|
router = self._make_router(tmp_path)
|
|
provider = self._make_provider(base_url="http://localhost:8001/v1")
|
|
|
|
mock_choice = MagicMock()
|
|
mock_choice.message.content = "ok"
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [mock_choice]
|
|
mock_response.model = "model"
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
|
|
with patch("openai.AsyncOpenAI", return_value=mock_client) as mock_openai:
|
|
await router._call_vllm(
|
|
provider=provider,
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
model="model",
|
|
temperature=0.0,
|
|
max_tokens=None,
|
|
)
|
|
_, kwargs = mock_openai.call_args
|
|
assert kwargs["base_url"] == "http://localhost:8001/v1"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_max_tokens_passed_when_set(self, tmp_path):
|
|
"""max_tokens is forwarded to the API when provided."""
|
|
router = self._make_router(tmp_path)
|
|
provider = self._make_provider()
|
|
|
|
mock_choice = MagicMock()
|
|
mock_choice.message.content = "ok"
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [mock_choice]
|
|
mock_response.model = "model"
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
|
|
with patch("openai.AsyncOpenAI", return_value=mock_client):
|
|
await router._call_vllm(
|
|
provider=provider,
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
model="model",
|
|
temperature=0.0,
|
|
max_tokens=256,
|
|
)
|
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
|
assert call_kwargs.get("max_tokens") == 256
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_max_tokens_omitted_when_none(self, tmp_path):
|
|
"""max_tokens key is absent when not provided."""
|
|
router = self._make_router(tmp_path)
|
|
provider = self._make_provider()
|
|
|
|
mock_choice = MagicMock()
|
|
mock_choice.message.content = "ok"
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [mock_choice]
|
|
mock_response.model = "model"
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
|
|
with patch("openai.AsyncOpenAI", return_value=mock_client):
|
|
await router._call_vllm(
|
|
provider=provider,
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
model="model",
|
|
temperature=0.0,
|
|
max_tokens=None,
|
|
)
|
|
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
|
assert "max_tokens" not in call_kwargs
|
|
|
|
|
|
# ── providers.yaml loads vllm-local ─────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestProvidersYamlVllm:
|
|
"""Verify providers.yaml contains a valid vllm-local entry."""
|
|
|
|
def test_vllm_local_entry_exists(self):
|
|
"""providers.yaml has a vllm-local provider of type vllm."""
|
|
config_path = Path(__file__).parents[2] / "config" / "providers.yaml"
|
|
assert config_path.exists(), "config/providers.yaml not found"
|
|
|
|
with config_path.open() as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
providers = config.get("providers", [])
|
|
vllm_providers = [p for p in providers if p.get("type") == "vllm"]
|
|
assert vllm_providers, "No provider with type=vllm found in providers.yaml"
|
|
|
|
vllm_local = next((p for p in vllm_providers if p["name"] == "vllm-local"), None)
|
|
assert vllm_local is not None, "vllm-local provider not found in providers.yaml"
|
|
|
|
def test_vllm_local_disabled_by_default(self):
|
|
"""vllm-local is disabled by default so the router stays on Ollama."""
|
|
config_path = Path(__file__).parents[2] / "config" / "providers.yaml"
|
|
with config_path.open() as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
providers = config.get("providers", [])
|
|
vllm_local = next((p for p in providers if p.get("name") == "vllm-local"), None)
|
|
assert vllm_local is not None
|
|
assert vllm_local.get("enabled") is False, "vllm-local should be disabled by default"
|
|
|
|
def test_vllm_local_has_default_model(self):
|
|
"""vllm-local has at least one model with a context window."""
|
|
config_path = Path(__file__).parents[2] / "config" / "providers.yaml"
|
|
with config_path.open() as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
providers = config.get("providers", [])
|
|
vllm_local = next((p for p in providers if p.get("name") == "vllm-local"), None)
|
|
assert vllm_local is not None
|
|
|
|
models = vllm_local.get("models", [])
|
|
assert models, "vllm-local must declare at least one model"
|
|
default_models = [m for m in models if m.get("default")]
|
|
assert default_models, "vllm-local must have a model marked default: true"
|
|
|
|
|
|
# ── config.py backend option ─────────────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestConfigVllmBackend:
|
|
"""Verify config.py exposes the vllm backend option."""
|
|
|
|
def test_vllm_is_valid_backend(self):
|
|
"""timmy_model_backend accepts 'vllm' without validation errors."""
|
|
from config import Settings
|
|
|
|
s = Settings(timmy_model_backend="vllm")
|
|
assert s.timmy_model_backend == "vllm"
|
|
|
|
def test_vllm_url_default(self):
|
|
"""vllm_url has a sensible default."""
|
|
from config import Settings
|
|
|
|
s = Settings()
|
|
assert s.vllm_url.startswith("http://")
|
|
|
|
def test_vllm_model_default(self):
|
|
"""vllm_model has a sensible default."""
|
|
from config import Settings
|
|
|
|
s = Settings()
|
|
assert s.vllm_model # non-empty string
|
|
|
|
|
|
# ── Health check helpers ─────────────────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestVllmHealthCheck:
|
|
"""Test _check_vllm_sync and _check_vllm."""
|
|
|
|
def test_sync_returns_healthy_on_200(self):
|
|
"""_check_vllm_sync returns 'healthy' when server responds 200."""
|
|
import urllib.request
|
|
|
|
from dashboard.routes.health import _check_vllm_sync
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status = 200
|
|
mock_response.__enter__ = lambda s: s
|
|
mock_response.__exit__ = MagicMock(return_value=False)
|
|
|
|
with patch.object(urllib.request, "urlopen", return_value=mock_response):
|
|
result = _check_vllm_sync()
|
|
|
|
assert result.status == "healthy"
|
|
assert result.name == "vLLM"
|
|
|
|
def test_sync_returns_unavailable_on_connection_error(self):
|
|
"""_check_vllm_sync returns 'unavailable' when server is unreachable."""
|
|
import urllib.error
|
|
import urllib.request
|
|
|
|
from dashboard.routes.health import _check_vllm_sync
|
|
|
|
with patch.object(urllib.request, "urlopen", side_effect=urllib.error.URLError("refused")):
|
|
result = _check_vllm_sync()
|
|
|
|
assert result.status == "unavailable"
|
|
assert result.name == "vLLM"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_caches_result(self):
|
|
"""_check_vllm caches the result for _VLLM_CACHE_TTL seconds."""
|
|
import dashboard.routes.health as health_module
|
|
from dashboard.routes.health import _check_vllm
|
|
|
|
# Reset cache
|
|
health_module._vllm_cache = None
|
|
health_module._vllm_cache_ts = 0.0
|
|
|
|
mock_dep = MagicMock()
|
|
mock_dep.status = "healthy"
|
|
|
|
with patch("dashboard.routes.health._check_vllm_sync", return_value=mock_dep):
|
|
result1 = await _check_vllm()
|
|
result2 = await _check_vllm() # should hit cache
|
|
|
|
assert result1 is result2 # same object returned from cache
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_refreshes_after_ttl(self):
|
|
"""_check_vllm refreshes the cache after the TTL expires."""
|
|
import dashboard.routes.health as health_module
|
|
from dashboard.routes.health import _VLLM_CACHE_TTL, _check_vllm
|
|
|
|
# Expire the cache
|
|
health_module._vllm_cache = None
|
|
health_module._vllm_cache_ts = time.monotonic() - _VLLM_CACHE_TTL - 1
|
|
|
|
mock_dep = MagicMock()
|
|
mock_dep.status = "unavailable"
|
|
|
|
with patch("dashboard.routes.health._check_vllm_sync", return_value=mock_dep) as mock_fn:
|
|
await _check_vllm()
|
|
|
|
mock_fn.assert_called_once()
|