Automated salvage commit — agent session ended (exit 124). Work in progress, may need continuation.
1163 lines
40 KiB
Python
1163 lines
40 KiB
Python
"""Tests for Cascade LLM Router."""
|
|
|
|
import time
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
import yaml
|
|
|
|
from infrastructure.router.cascade import (
|
|
CascadeRouter,
|
|
CircuitState,
|
|
Provider,
|
|
ProviderMetrics,
|
|
ProviderStatus,
|
|
RouterConfig,
|
|
)
|
|
|
|
|
|
class TestProviderMetrics:
|
|
"""Test provider metrics tracking."""
|
|
|
|
def test_empty_metrics(self):
|
|
"""Test metrics with no requests."""
|
|
metrics = ProviderMetrics()
|
|
assert metrics.total_requests == 0
|
|
assert metrics.avg_latency_ms == 0.0
|
|
assert metrics.error_rate == 0.0
|
|
|
|
def test_avg_latency_calculation(self):
|
|
"""Test average latency calculation."""
|
|
metrics = ProviderMetrics(
|
|
total_requests=4,
|
|
total_latency_ms=1000.0, # 4 requests, 1000ms total
|
|
)
|
|
assert metrics.avg_latency_ms == 250.0
|
|
|
|
def test_error_rate_calculation(self):
|
|
"""Test error rate calculation."""
|
|
metrics = ProviderMetrics(
|
|
total_requests=10,
|
|
successful_requests=7,
|
|
failed_requests=3,
|
|
)
|
|
assert metrics.error_rate == 0.3
|
|
|
|
|
|
class TestProvider:
|
|
"""Test Provider dataclass."""
|
|
|
|
def test_get_default_model(self):
|
|
"""Test getting default model."""
|
|
provider = Provider(
|
|
name="test",
|
|
type="ollama",
|
|
enabled=True,
|
|
priority=1,
|
|
models=[
|
|
{"name": "llama3", "default": True},
|
|
{"name": "mistral"},
|
|
],
|
|
)
|
|
assert provider.get_default_model() == "llama3"
|
|
|
|
def test_get_default_model_no_default(self):
|
|
"""Test getting first model when no default set."""
|
|
provider = Provider(
|
|
name="test",
|
|
type="ollama",
|
|
enabled=True,
|
|
priority=1,
|
|
models=[
|
|
{"name": "llama3"},
|
|
{"name": "mistral"},
|
|
],
|
|
)
|
|
assert provider.get_default_model() == "llama3"
|
|
|
|
def test_get_default_model_empty(self):
|
|
"""Test with no models."""
|
|
provider = Provider(
|
|
name="test",
|
|
type="ollama",
|
|
enabled=True,
|
|
priority=1,
|
|
models=[],
|
|
)
|
|
assert provider.get_default_model() is None
|
|
|
|
|
|
class TestRouterConfig:
|
|
"""Test router configuration."""
|
|
|
|
def test_default_config(self):
|
|
"""Test default configuration values."""
|
|
config = RouterConfig()
|
|
assert config.timeout_seconds == 30
|
|
assert config.max_retries_per_provider == 2
|
|
assert config.retry_delay_seconds == 1
|
|
assert config.circuit_breaker_failure_threshold == 5
|
|
|
|
|
|
class TestCascadeRouterInit:
|
|
"""Test CascadeRouter initialization."""
|
|
|
|
def test_init_without_config(self, tmp_path):
|
|
"""Test initialization without config file."""
|
|
router = CascadeRouter(config_path=tmp_path / "nonexistent.yaml")
|
|
assert len(router.providers) == 0
|
|
assert router.config.timeout_seconds == 30
|
|
|
|
def test_init_with_config(self, tmp_path):
|
|
"""Test initialization with config file."""
|
|
config = {
|
|
"cascade": {
|
|
"timeout_seconds": 60,
|
|
"max_retries_per_provider": 3,
|
|
},
|
|
"providers": [
|
|
{
|
|
"name": "test-ollama",
|
|
"type": "ollama",
|
|
"enabled": False, # Disabled to avoid availability check
|
|
"priority": 1,
|
|
"url": "http://localhost:11434",
|
|
}
|
|
],
|
|
}
|
|
config_path = tmp_path / "providers.yaml"
|
|
config_path.write_text(yaml.dump(config))
|
|
|
|
router = CascadeRouter(config_path=config_path)
|
|
assert router.config.timeout_seconds == 60
|
|
assert router.config.max_retries_per_provider == 3
|
|
assert len(router.providers) == 0 # Provider is disabled
|
|
|
|
def test_env_var_expansion(self, tmp_path, monkeypatch):
|
|
"""Test environment variable expansion in config."""
|
|
monkeypatch.setenv("TEST_API_KEY", "secret123")
|
|
|
|
config = {
|
|
"cascade": {},
|
|
"providers": [
|
|
{
|
|
"name": "test-openai",
|
|
"type": "openai",
|
|
"enabled": True,
|
|
"priority": 1,
|
|
"api_key": "${TEST_API_KEY}",
|
|
}
|
|
],
|
|
}
|
|
config_path = tmp_path / "providers.yaml"
|
|
config_path.write_text(yaml.dump(config))
|
|
|
|
router = CascadeRouter(config_path=config_path)
|
|
assert len(router.providers) == 1
|
|
assert router.providers[0].api_key == "secret123"
|
|
|
|
|
|
class TestCascadeRouterMetrics:
|
|
"""Test metrics tracking."""
|
|
|
|
def test_record_success(self):
|
|
"""Test recording successful request."""
|
|
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
|
|
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
router._record_success(provider, 150.0)
|
|
|
|
assert provider.metrics.total_requests == 1
|
|
assert provider.metrics.successful_requests == 1
|
|
assert provider.metrics.total_latency_ms == 150.0
|
|
assert provider.metrics.consecutive_failures == 0
|
|
|
|
def test_record_failure(self):
|
|
"""Test recording failed request."""
|
|
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
|
|
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
router._record_failure(provider)
|
|
|
|
assert provider.metrics.total_requests == 1
|
|
assert provider.metrics.failed_requests == 1
|
|
assert provider.metrics.consecutive_failures == 1
|
|
|
|
def test_circuit_breaker_opens(self):
|
|
"""Test circuit breaker opens after failures."""
|
|
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
|
|
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
router.config.circuit_breaker_failure_threshold = 3
|
|
|
|
# Record 3 failures
|
|
for _ in range(3):
|
|
router._record_failure(provider)
|
|
|
|
assert provider.circuit_state == CircuitState.OPEN
|
|
assert provider.status == ProviderStatus.UNHEALTHY
|
|
assert provider.circuit_opened_at is not None
|
|
|
|
def test_circuit_breaker_can_close(self):
|
|
"""Test circuit breaker can transition to closed."""
|
|
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
|
|
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
router.config.circuit_breaker_failure_threshold = 3
|
|
router.config.circuit_breaker_recovery_timeout = 0.1
|
|
|
|
# Open the circuit
|
|
for _ in range(3):
|
|
router._record_failure(provider)
|
|
|
|
assert provider.circuit_state == CircuitState.OPEN
|
|
|
|
# Wait for recovery timeout (reduced for faster tests)
|
|
import time
|
|
|
|
time.sleep(0.2)
|
|
|
|
# Check if can close
|
|
assert router._can_close_circuit(provider) is True
|
|
|
|
def test_half_open_to_closed(self):
|
|
"""Test circuit breaker closes after successful test calls."""
|
|
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
|
|
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
router.config.circuit_breaker_half_open_max_calls = 2
|
|
|
|
# Manually set to half-open
|
|
provider.circuit_state = CircuitState.HALF_OPEN
|
|
provider.half_open_calls = 0
|
|
|
|
# Record successful calls
|
|
router._record_success(provider, 100.0)
|
|
assert provider.circuit_state == CircuitState.HALF_OPEN # Still half-open
|
|
|
|
router._record_success(provider, 100.0)
|
|
assert provider.circuit_state == CircuitState.CLOSED # Now closed
|
|
assert provider.status == ProviderStatus.HEALTHY
|
|
|
|
|
|
class TestCascadeRouterGetMetrics:
|
|
"""Test get_metrics method."""
|
|
|
|
def test_get_metrics_empty(self):
|
|
"""Test getting metrics with no providers."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
metrics = router.get_metrics()
|
|
|
|
assert "providers" in metrics
|
|
assert len(metrics["providers"]) == 0
|
|
|
|
def test_get_metrics_with_providers(self):
|
|
"""Test getting metrics with providers."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
|
|
# Add a test provider
|
|
provider = Provider(
|
|
name="test",
|
|
type="ollama",
|
|
enabled=True,
|
|
priority=1,
|
|
)
|
|
provider.metrics.total_requests = 10
|
|
provider.metrics.successful_requests = 8
|
|
provider.metrics.failed_requests = 2
|
|
provider.metrics.total_latency_ms = 2000.0
|
|
|
|
router.providers = [provider]
|
|
|
|
metrics = router.get_metrics()
|
|
|
|
assert len(metrics["providers"]) == 1
|
|
p_metrics = metrics["providers"][0]
|
|
assert p_metrics["name"] == "test"
|
|
assert p_metrics["metrics"]["total_requests"] == 10
|
|
assert p_metrics["metrics"]["error_rate"] == 0.2
|
|
assert p_metrics["metrics"]["avg_latency_ms"] == 200.0
|
|
|
|
|
|
class TestCascadeRouterGetStatus:
|
|
"""Test get_status method."""
|
|
|
|
def test_get_status(self):
|
|
"""Test getting router status."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
|
|
provider = Provider(
|
|
name="test",
|
|
type="ollama",
|
|
enabled=True,
|
|
priority=1,
|
|
models=[{"name": "llama3", "default": True}],
|
|
)
|
|
router.providers = [provider]
|
|
|
|
status = router.get_status()
|
|
|
|
assert status["total_providers"] == 1
|
|
assert status["healthy_providers"] == 1
|
|
assert status["degraded_providers"] == 0
|
|
assert status["unhealthy_providers"] == 0
|
|
assert len(status["providers"]) == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestCascadeRouterComplete:
|
|
"""Test complete method with failover."""
|
|
|
|
async def test_complete_with_ollama(self):
|
|
"""Test successful completion with Ollama."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
|
|
provider = Provider(
|
|
name="ollama-local",
|
|
type="ollama",
|
|
enabled=True,
|
|
priority=1,
|
|
url="http://localhost:11434",
|
|
models=[{"name": "llama3.2", "default": True}],
|
|
)
|
|
router.providers = [provider]
|
|
|
|
# Mock the Ollama call
|
|
with patch.object(router, "_call_ollama") as mock_call:
|
|
mock_call.return_value = AsyncMock()()
|
|
mock_call.return_value = {
|
|
"content": "Hello, world!",
|
|
"model": "llama3.2",
|
|
}
|
|
|
|
result = await router.complete(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
)
|
|
|
|
assert result["content"] == "Hello, world!"
|
|
assert result["provider"] == "ollama-local"
|
|
assert result["model"] == "llama3.2"
|
|
|
|
async def test_failover_to_second_provider(self):
|
|
"""Test failover when first provider fails."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
|
|
provider1 = Provider(
|
|
name="ollama-failing",
|
|
type="ollama",
|
|
enabled=True,
|
|
priority=1,
|
|
url="http://localhost:11434",
|
|
models=[{"name": "llama3.2", "default": True}],
|
|
)
|
|
provider2 = Provider(
|
|
name="ollama-backup",
|
|
type="ollama",
|
|
enabled=True,
|
|
priority=2,
|
|
url="http://backup:11434",
|
|
models=[{"name": "llama3.2", "default": True}],
|
|
)
|
|
router.providers = [provider1, provider2]
|
|
|
|
# First provider fails, second succeeds
|
|
call_count = [0]
|
|
|
|
async def side_effect(*args, **kwargs):
|
|
call_count[0] += 1
|
|
# First 2 retries for provider1 fail, then provider2 succeeds
|
|
if call_count[0] <= router.config.max_retries_per_provider:
|
|
raise RuntimeError("Connection failed")
|
|
return {"content": "Backup response", "model": "llama3.2"}
|
|
|
|
with patch.object(router, "_call_ollama") as mock_call:
|
|
mock_call.side_effect = side_effect
|
|
|
|
result = await router.complete(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
)
|
|
|
|
assert result["content"] == "Backup response"
|
|
assert result["provider"] == "ollama-backup"
|
|
|
|
async def test_all_providers_fail(self):
|
|
"""Test error when all providers fail."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
|
|
provider = Provider(
|
|
name="failing",
|
|
type="ollama",
|
|
enabled=True,
|
|
priority=1,
|
|
models=[{"name": "llama3.2", "default": True}],
|
|
)
|
|
router.providers = [provider]
|
|
|
|
with patch.object(router, "_call_ollama") as mock_call:
|
|
mock_call.side_effect = RuntimeError("Always fails")
|
|
|
|
with pytest.raises(RuntimeError) as exc_info:
|
|
await router.complete(messages=[{"role": "user", "content": "Hi"}])
|
|
|
|
assert "All providers failed" in str(exc_info.value)
|
|
|
|
async def test_skips_unhealthy_provider(self):
|
|
"""Test that unhealthy providers are skipped."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
|
|
provider1 = Provider(
|
|
name="unhealthy",
|
|
type="ollama",
|
|
enabled=True,
|
|
priority=1,
|
|
status=ProviderStatus.UNHEALTHY,
|
|
circuit_state=CircuitState.OPEN,
|
|
circuit_opened_at=time.time(), # Just opened
|
|
models=[{"name": "llama3.2", "default": True}],
|
|
)
|
|
provider2 = Provider(
|
|
name="healthy",
|
|
type="ollama",
|
|
enabled=True,
|
|
priority=2,
|
|
models=[{"name": "llama3.2", "default": True}],
|
|
)
|
|
router.providers = [provider1, provider2]
|
|
|
|
with patch.object(router, "_call_ollama") as mock_call:
|
|
mock_call.return_value = {"content": "Success", "model": "llama3.2"}
|
|
|
|
result = await router.complete(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
)
|
|
|
|
# Should use the healthy provider
|
|
assert result["provider"] == "healthy"
|
|
|
|
|
|
class TestProviderAvailabilityCheck:
|
|
"""Test provider availability checking."""
|
|
|
|
def test_check_ollama_without_requests(self):
|
|
"""Test Ollama returns True when requests not available (fallback)."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
|
|
provider = Provider(
|
|
name="ollama",
|
|
type="ollama",
|
|
enabled=True,
|
|
priority=1,
|
|
url="http://localhost:11434",
|
|
)
|
|
|
|
# When requests is None, assume available
|
|
import infrastructure.router.cascade as cascade_module
|
|
|
|
old_requests = cascade_module.requests
|
|
cascade_module.requests = None
|
|
try:
|
|
assert router._check_provider_available(provider) is True
|
|
finally:
|
|
cascade_module.requests = old_requests
|
|
|
|
def test_check_openai_with_key(self):
|
|
"""Test OpenAI with API key."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
|
|
provider = Provider(
|
|
name="openai",
|
|
type="openai",
|
|
enabled=True,
|
|
priority=1,
|
|
api_key="sk-test123",
|
|
)
|
|
|
|
assert router._check_provider_available(provider) is True
|
|
|
|
def test_check_openai_without_key(self):
|
|
"""Test OpenAI without API key."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
|
|
provider = Provider(
|
|
name="openai",
|
|
type="openai",
|
|
enabled=True,
|
|
priority=1,
|
|
api_key=None,
|
|
)
|
|
|
|
assert router._check_provider_available(provider) is False
|
|
|
|
def test_check_vllm_mlx_without_requests(self):
|
|
"""Test vllm-mlx returns True when requests not available (fallback)."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
|
|
provider = Provider(
|
|
name="vllm-mlx-local",
|
|
type="vllm_mlx",
|
|
enabled=True,
|
|
priority=2,
|
|
base_url="http://localhost:8000/v1",
|
|
)
|
|
|
|
import infrastructure.router.cascade as cascade_module
|
|
|
|
old_requests = cascade_module.requests
|
|
cascade_module.requests = None
|
|
try:
|
|
assert router._check_provider_available(provider) is True
|
|
finally:
|
|
cascade_module.requests = old_requests
|
|
|
|
def test_check_vllm_mlx_server_healthy(self):
|
|
"""Test vllm-mlx when health check succeeds."""
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
|
|
provider = Provider(
|
|
name="vllm-mlx-local",
|
|
type="vllm_mlx",
|
|
enabled=True,
|
|
priority=2,
|
|
base_url="http://localhost:8000/v1",
|
|
)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
|
|
with patch("infrastructure.router.cascade.requests") as mock_requests:
|
|
mock_requests.get.return_value = mock_response
|
|
result = router._check_provider_available(provider)
|
|
|
|
assert result is True
|
|
mock_requests.get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
|
|
|
def test_check_vllm_mlx_server_down(self):
|
|
"""Test vllm-mlx when server is not running."""
|
|
from unittest.mock import patch
|
|
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
|
|
provider = Provider(
|
|
name="vllm-mlx-local",
|
|
type="vllm_mlx",
|
|
enabled=True,
|
|
priority=2,
|
|
base_url="http://localhost:8000/v1",
|
|
)
|
|
|
|
with patch("infrastructure.router.cascade.requests") as mock_requests:
|
|
mock_requests.get.side_effect = ConnectionRefusedError("Connection refused")
|
|
result = router._check_provider_available(provider)
|
|
|
|
assert result is False
|
|
|
|
def test_check_vllm_mlx_default_url(self):
|
|
"""Test vllm-mlx uses default localhost:8000 when no URL configured."""
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
|
|
provider = Provider(
|
|
name="vllm-mlx-local",
|
|
type="vllm_mlx",
|
|
enabled=True,
|
|
priority=2,
|
|
)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
|
|
with patch("infrastructure.router.cascade.requests") as mock_requests:
|
|
mock_requests.get.return_value = mock_response
|
|
router._check_provider_available(provider)
|
|
|
|
mock_requests.get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestVllmMlxProvider:
|
|
"""Test vllm-mlx provider integration."""
|
|
|
|
async def test_complete_with_vllm_mlx(self):
|
|
"""Test successful completion via vllm-mlx."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
|
|
provider = Provider(
|
|
name="vllm-mlx-local",
|
|
type="vllm_mlx",
|
|
enabled=True,
|
|
priority=2,
|
|
base_url="http://localhost:8000/v1",
|
|
models=[{"name": "Qwen/Qwen2.5-14B-Instruct-MLX", "default": True}],
|
|
)
|
|
router.providers = [provider]
|
|
|
|
with patch.object(router, "_call_vllm_mlx") as mock_call:
|
|
mock_call.return_value = {
|
|
"content": "MLX response",
|
|
"model": "Qwen/Qwen2.5-14B-Instruct-MLX",
|
|
}
|
|
|
|
result = await router.complete(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
)
|
|
|
|
assert result["content"] == "MLX response"
|
|
assert result["provider"] == "vllm-mlx-local"
|
|
assert result["model"] == "Qwen/Qwen2.5-14B-Instruct-MLX"
|
|
|
|
async def test_vllm_mlx_base_url_normalization(self):
|
|
"""Test _call_vllm_mlx appends /v1 when missing."""
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
|
|
provider = Provider(
|
|
name="vllm-mlx-local",
|
|
type="vllm_mlx",
|
|
enabled=True,
|
|
priority=2,
|
|
base_url="http://localhost:8000", # No /v1
|
|
models=[{"name": "qwen-mlx", "default": True}],
|
|
)
|
|
|
|
mock_choice = MagicMock()
|
|
mock_choice.message.content = "hello"
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [mock_choice]
|
|
mock_response.model = "qwen-mlx"
|
|
|
|
async def fake_create(**kwargs):
|
|
return mock_response
|
|
|
|
with patch("openai.AsyncOpenAI") as mock_openai_cls:
|
|
mock_client = MagicMock()
|
|
mock_client.chat.completions.create = AsyncMock(side_effect=fake_create)
|
|
mock_openai_cls.return_value = mock_client
|
|
|
|
await router._call_vllm_mlx(
|
|
provider=provider,
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
model="qwen-mlx",
|
|
temperature=0.7,
|
|
max_tokens=None,
|
|
)
|
|
|
|
call_kwargs = mock_openai_cls.call_args
|
|
base_url_used = call_kwargs.kwargs.get("base_url") or call_kwargs[1].get("base_url")
|
|
assert base_url_used.endswith("/v1")
|
|
|
|
async def test_vllm_mlx_is_local_not_cloud(self):
|
|
"""Confirm vllm_mlx is not subject to metabolic protocol cloud skip."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
|
|
provider = Provider(
|
|
name="vllm-mlx-local",
|
|
type="vllm_mlx",
|
|
enabled=True,
|
|
priority=2,
|
|
base_url="http://localhost:8000/v1",
|
|
models=[{"name": "qwen-mlx", "default": True}],
|
|
)
|
|
router.providers = [provider]
|
|
|
|
# Quota monitor downshifts to local (ACTIVE tier) — vllm_mlx should still be tried
|
|
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
|
mock_qm.select_model.return_value = "qwen3:14b"
|
|
mock_qm.check.return_value = None
|
|
|
|
with patch.object(router, "_call_vllm_mlx") as mock_call:
|
|
mock_call.return_value = {
|
|
"content": "Local MLX response",
|
|
"model": "qwen-mlx",
|
|
}
|
|
result = await router.complete(
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
)
|
|
|
|
assert result["content"] == "Local MLX response"
|
|
|
|
|
|
class TestMetabolicProtocol:
|
|
"""Test metabolic protocol: cloud providers skip when quota is ACTIVE/RESTING."""
|
|
|
|
def _make_anthropic_provider(self) -> "Provider":
|
|
return Provider(
|
|
name="anthropic-primary",
|
|
type="anthropic",
|
|
enabled=True,
|
|
priority=1,
|
|
api_key="test-key",
|
|
models=[{"name": "claude-sonnet-4-6", "default": True}],
|
|
)
|
|
|
|
async def test_cloud_provider_allowed_in_burst_tier(self):
|
|
"""BURST tier (quota healthy): cloud provider is tried."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
router.providers = [self._make_anthropic_provider()]
|
|
|
|
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
|
# select_model returns cloud model → BURST tier
|
|
mock_qm.select_model.return_value = "claude-sonnet-4-6"
|
|
mock_qm.check.return_value = None
|
|
|
|
with patch.object(router, "_call_anthropic") as mock_call:
|
|
mock_call.return_value = {"content": "Cloud response", "model": "claude-sonnet-4-6"}
|
|
result = await router.complete(
|
|
messages=[{"role": "user", "content": "hard question"}],
|
|
)
|
|
|
|
mock_call.assert_called_once()
|
|
assert result["content"] == "Cloud response"
|
|
|
|
async def test_cloud_provider_skipped_in_active_tier(self):
|
|
"""ACTIVE tier (5-hour >= 50%): cloud provider is skipped."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
router.providers = [self._make_anthropic_provider()]
|
|
|
|
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
|
# select_model returns local 14B → ACTIVE tier
|
|
mock_qm.select_model.return_value = "qwen3:14b"
|
|
mock_qm.check.return_value = None
|
|
|
|
with patch.object(router, "_call_anthropic") as mock_call:
|
|
with pytest.raises(RuntimeError, match="All providers failed"):
|
|
await router.complete(
|
|
messages=[{"role": "user", "content": "question"}],
|
|
)
|
|
|
|
mock_call.assert_not_called()
|
|
|
|
async def test_cloud_provider_skipped_in_resting_tier(self):
|
|
"""RESTING tier (7-day >= 80%): cloud provider is skipped."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
router.providers = [self._make_anthropic_provider()]
|
|
|
|
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
|
# select_model returns local 8B → RESTING tier
|
|
mock_qm.select_model.return_value = "qwen3:8b"
|
|
mock_qm.check.return_value = None
|
|
|
|
with patch.object(router, "_call_anthropic") as mock_call:
|
|
with pytest.raises(RuntimeError, match="All providers failed"):
|
|
await router.complete(
|
|
messages=[{"role": "user", "content": "simple question"}],
|
|
)
|
|
|
|
mock_call.assert_not_called()
|
|
|
|
async def test_local_provider_always_tried_regardless_of_quota(self):
|
|
"""Local (ollama/vllm_mlx) providers bypass the metabolic protocol."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
provider = Provider(
|
|
name="ollama-local",
|
|
type="ollama",
|
|
enabled=True,
|
|
priority=1,
|
|
url="http://localhost:11434",
|
|
models=[{"name": "qwen3:14b", "default": True}],
|
|
)
|
|
router.providers = [provider]
|
|
|
|
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
|
mock_qm.select_model.return_value = "qwen3:8b" # RESTING tier
|
|
|
|
with patch.object(router, "_call_ollama") as mock_call:
|
|
mock_call.return_value = {"content": "Local response", "model": "qwen3:14b"}
|
|
result = await router.complete(
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
)
|
|
|
|
mock_call.assert_called_once()
|
|
assert result["content"] == "Local response"
|
|
|
|
async def test_no_quota_monitor_allows_cloud(self):
|
|
"""When quota monitor is None (unavailable), cloud providers are allowed."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
router.providers = [self._make_anthropic_provider()]
|
|
|
|
with patch("infrastructure.router.cascade._quota_monitor", None):
|
|
with patch.object(router, "_call_anthropic") as mock_call:
|
|
mock_call.return_value = {"content": "Cloud response", "model": "claude-sonnet-4-6"}
|
|
result = await router.complete(
|
|
messages=[{"role": "user", "content": "question"}],
|
|
)
|
|
|
|
mock_call.assert_called_once()
|
|
assert result["content"] == "Cloud response"
|
|
|
|
|
|
class TestCascadeRouterReload:
|
|
"""Test hot-reload of providers.yaml."""
|
|
|
|
def test_reload_preserves_metrics(self, tmp_path):
|
|
"""Test that reload preserves metrics for existing providers."""
|
|
config = {
|
|
"providers": [
|
|
{
|
|
"name": "test-openai",
|
|
"type": "openai",
|
|
"enabled": True,
|
|
"priority": 1,
|
|
"api_key": "sk-test",
|
|
}
|
|
],
|
|
}
|
|
config_path = tmp_path / "providers.yaml"
|
|
config_path.write_text(yaml.dump(config))
|
|
|
|
router = CascadeRouter(config_path=config_path)
|
|
assert len(router.providers) == 1
|
|
|
|
# Simulate some traffic
|
|
router._record_success(router.providers[0], 150.0)
|
|
router._record_success(router.providers[0], 250.0)
|
|
assert router.providers[0].metrics.total_requests == 2
|
|
|
|
# Reload
|
|
result = router.reload_config()
|
|
|
|
assert result["total_providers"] == 1
|
|
assert result["preserved"] == 1
|
|
assert result["added"] == []
|
|
assert result["removed"] == []
|
|
# Metrics survived
|
|
assert router.providers[0].metrics.total_requests == 2
|
|
assert router.providers[0].metrics.total_latency_ms == 400.0
|
|
|
|
def test_reload_preserves_circuit_breaker(self, tmp_path):
|
|
"""Test that reload preserves circuit breaker state."""
|
|
config = {
|
|
"cascade": {"circuit_breaker": {"failure_threshold": 2}},
|
|
"providers": [
|
|
{
|
|
"name": "test-openai",
|
|
"type": "openai",
|
|
"enabled": True,
|
|
"priority": 1,
|
|
"api_key": "sk-test",
|
|
}
|
|
],
|
|
}
|
|
config_path = tmp_path / "providers.yaml"
|
|
config_path.write_text(yaml.dump(config))
|
|
|
|
router = CascadeRouter(config_path=config_path)
|
|
|
|
# Open circuit breaker
|
|
for _ in range(2):
|
|
router._record_failure(router.providers[0])
|
|
assert router.providers[0].circuit_state == CircuitState.OPEN
|
|
|
|
# Reload
|
|
router.reload_config()
|
|
|
|
# Circuit breaker state preserved
|
|
assert router.providers[0].circuit_state == CircuitState.OPEN
|
|
assert router.providers[0].status == ProviderStatus.UNHEALTHY
|
|
|
|
def test_reload_detects_added_provider(self, tmp_path):
|
|
"""Test that reload detects newly added providers."""
|
|
config = {
|
|
"providers": [
|
|
{
|
|
"name": "openai-1",
|
|
"type": "openai",
|
|
"enabled": True,
|
|
"priority": 1,
|
|
"api_key": "sk-test",
|
|
}
|
|
],
|
|
}
|
|
config_path = tmp_path / "providers.yaml"
|
|
config_path.write_text(yaml.dump(config))
|
|
|
|
router = CascadeRouter(config_path=config_path)
|
|
assert len(router.providers) == 1
|
|
|
|
# Add a second provider to config
|
|
config["providers"].append(
|
|
{
|
|
"name": "anthropic-1",
|
|
"type": "anthropic",
|
|
"enabled": True,
|
|
"priority": 2,
|
|
"api_key": "sk-ant-test",
|
|
}
|
|
)
|
|
config_path.write_text(yaml.dump(config))
|
|
|
|
result = router.reload_config()
|
|
|
|
assert result["total_providers"] == 2
|
|
assert result["preserved"] == 1
|
|
assert result["added"] == ["anthropic-1"]
|
|
assert result["removed"] == []
|
|
|
|
def test_reload_detects_removed_provider(self, tmp_path):
|
|
"""Test that reload detects removed providers."""
|
|
config = {
|
|
"providers": [
|
|
{
|
|
"name": "openai-1",
|
|
"type": "openai",
|
|
"enabled": True,
|
|
"priority": 1,
|
|
"api_key": "sk-test",
|
|
},
|
|
{
|
|
"name": "anthropic-1",
|
|
"type": "anthropic",
|
|
"enabled": True,
|
|
"priority": 2,
|
|
"api_key": "sk-ant-test",
|
|
},
|
|
],
|
|
}
|
|
config_path = tmp_path / "providers.yaml"
|
|
config_path.write_text(yaml.dump(config))
|
|
|
|
router = CascadeRouter(config_path=config_path)
|
|
assert len(router.providers) == 2
|
|
|
|
# Remove anthropic
|
|
config["providers"] = [config["providers"][0]]
|
|
config_path.write_text(yaml.dump(config))
|
|
|
|
result = router.reload_config()
|
|
|
|
assert result["total_providers"] == 1
|
|
assert result["preserved"] == 1
|
|
assert result["removed"] == ["anthropic-1"]
|
|
|
|
def test_reload_re_sorts_by_priority(self, tmp_path):
|
|
"""Test that providers are re-sorted by priority after reload."""
|
|
config = {
|
|
"providers": [
|
|
{
|
|
"name": "low-priority",
|
|
"type": "openai",
|
|
"enabled": True,
|
|
"priority": 10,
|
|
"api_key": "sk-test",
|
|
},
|
|
{
|
|
"name": "high-priority",
|
|
"type": "openai",
|
|
"enabled": True,
|
|
"priority": 1,
|
|
"api_key": "sk-test2",
|
|
},
|
|
],
|
|
}
|
|
config_path = tmp_path / "providers.yaml"
|
|
config_path.write_text(yaml.dump(config))
|
|
|
|
router = CascadeRouter(config_path=config_path)
|
|
assert router.providers[0].name == "high-priority"
|
|
|
|
# Swap priorities
|
|
config["providers"][0]["priority"] = 1
|
|
config["providers"][1]["priority"] = 10
|
|
config_path.write_text(yaml.dump(config))
|
|
|
|
router.reload_config()
|
|
|
|
assert router.providers[0].name == "low-priority"
|
|
assert router.providers[1].name == "high-priority"
|
|
|
|
|
|
class TestComplexityRouting:
|
|
"""Tests for Qwen3-8B / Qwen3-14B dual-model routing (issue #1065)."""
|
|
|
|
def _make_dual_model_provider(self) -> Provider:
|
|
"""Build an Ollama provider with both Qwen3 models registered."""
|
|
return Provider(
|
|
name="ollama-local",
|
|
type="ollama",
|
|
enabled=True,
|
|
priority=1,
|
|
url="http://localhost:11434",
|
|
models=[
|
|
{
|
|
"name": "qwen3:8b",
|
|
"capabilities": ["text", "tools", "json", "streaming", "routine"],
|
|
},
|
|
{
|
|
"name": "qwen3:14b",
|
|
"default": True,
|
|
"capabilities": ["text", "tools", "json", "streaming", "complex", "reasoning"],
|
|
},
|
|
],
|
|
)
|
|
|
|
def test_get_model_for_complexity_simple_returns_8b(self):
|
|
"""Simple tasks should select the model with 'routine' capability."""
|
|
from infrastructure.router.classifier import TaskComplexity
|
|
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
router.config.fallback_chains = {
|
|
"routine": ["qwen3:8b"],
|
|
"complex": ["qwen3:14b"],
|
|
}
|
|
provider = self._make_dual_model_provider()
|
|
|
|
model = router._get_model_for_complexity(provider, TaskComplexity.SIMPLE)
|
|
assert model == "qwen3:8b"
|
|
|
|
def test_get_model_for_complexity_complex_returns_14b(self):
|
|
"""Complex tasks should select the model with 'complex' capability."""
|
|
from infrastructure.router.classifier import TaskComplexity
|
|
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
router.config.fallback_chains = {
|
|
"routine": ["qwen3:8b"],
|
|
"complex": ["qwen3:14b"],
|
|
}
|
|
provider = self._make_dual_model_provider()
|
|
|
|
model = router._get_model_for_complexity(provider, TaskComplexity.COMPLEX)
|
|
assert model == "qwen3:14b"
|
|
|
|
def test_get_model_for_complexity_returns_none_when_no_match(self):
|
|
"""Returns None when provider has no matching model in chain."""
|
|
from infrastructure.router.classifier import TaskComplexity
|
|
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
router.config.fallback_chains = {} # empty chains
|
|
|
|
provider = Provider(
|
|
name="test",
|
|
type="ollama",
|
|
enabled=True,
|
|
priority=1,
|
|
models=[{"name": "llama3.2:3b", "default": True, "capabilities": ["text"]}],
|
|
)
|
|
|
|
# No 'routine' or 'complex' model available
|
|
model = router._get_model_for_complexity(provider, TaskComplexity.SIMPLE)
|
|
assert model is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_with_simple_hint_routes_to_8b(self):
|
|
"""complexity_hint='simple' should use qwen3:8b."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
router.config.fallback_chains = {
|
|
"routine": ["qwen3:8b"],
|
|
"complex": ["qwen3:14b"],
|
|
}
|
|
router.providers = [self._make_dual_model_provider()]
|
|
|
|
with patch.object(router, "_call_ollama") as mock_call:
|
|
mock_call.return_value = {"content": "fast answer", "model": "qwen3:8b"}
|
|
result = await router.complete(
|
|
messages=[{"role": "user", "content": "list tasks"}],
|
|
complexity_hint="simple",
|
|
)
|
|
|
|
assert result["model"] == "qwen3:8b"
|
|
assert result["complexity"] == "simple"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_with_complex_hint_routes_to_14b(self):
|
|
"""complexity_hint='complex' should use qwen3:14b."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
router.config.fallback_chains = {
|
|
"routine": ["qwen3:8b"],
|
|
"complex": ["qwen3:14b"],
|
|
}
|
|
router.providers = [self._make_dual_model_provider()]
|
|
|
|
with patch.object(router, "_call_ollama") as mock_call:
|
|
mock_call.return_value = {"content": "detailed answer", "model": "qwen3:14b"}
|
|
result = await router.complete(
|
|
messages=[{"role": "user", "content": "review this PR"}],
|
|
complexity_hint="complex",
|
|
)
|
|
|
|
assert result["model"] == "qwen3:14b"
|
|
assert result["complexity"] == "complex"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_explicit_model_bypasses_complexity_routing(self):
|
|
"""When model is explicitly provided, complexity routing is skipped."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
router.config.fallback_chains = {
|
|
"routine": ["qwen3:8b"],
|
|
"complex": ["qwen3:14b"],
|
|
}
|
|
router.providers = [self._make_dual_model_provider()]
|
|
|
|
with patch.object(router, "_call_ollama") as mock_call:
|
|
mock_call.return_value = {"content": "response", "model": "qwen3:14b"}
|
|
result = await router.complete(
|
|
messages=[{"role": "user", "content": "list tasks"}],
|
|
model="qwen3:14b", # explicit override
|
|
)
|
|
|
|
# Explicit model wins — complexity field is None
|
|
assert result["model"] == "qwen3:14b"
|
|
assert result["complexity"] is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_auto_classification_routes_simple_message(self):
|
|
"""Short, simple messages should auto-classify as SIMPLE → 8B."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
router.config.fallback_chains = {
|
|
"routine": ["qwen3:8b"],
|
|
"complex": ["qwen3:14b"],
|
|
}
|
|
router.providers = [self._make_dual_model_provider()]
|
|
|
|
with patch.object(router, "_call_ollama") as mock_call:
|
|
mock_call.return_value = {"content": "ok", "model": "qwen3:8b"}
|
|
result = await router.complete(
|
|
messages=[{"role": "user", "content": "status"}],
|
|
# no complexity_hint — auto-classify
|
|
)
|
|
|
|
assert result["complexity"] == "simple"
|
|
assert result["model"] == "qwen3:8b"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_auto_classification_routes_complex_message(self):
|
|
"""Complex messages should auto-classify → 14B."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
router.config.fallback_chains = {
|
|
"routine": ["qwen3:8b"],
|
|
"complex": ["qwen3:14b"],
|
|
}
|
|
router.providers = [self._make_dual_model_provider()]
|
|
|
|
with patch.object(router, "_call_ollama") as mock_call:
|
|
mock_call.return_value = {"content": "deep analysis", "model": "qwen3:14b"}
|
|
result = await router.complete(
|
|
messages=[{"role": "user", "content": "analyze and prioritize the backlog"}],
|
|
)
|
|
|
|
assert result["complexity"] == "complex"
|
|
assert result["model"] == "qwen3:14b"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_complexity_hint_falls_back_to_auto(self):
|
|
"""Invalid complexity_hint should log a warning and auto-classify."""
|
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
router.config.fallback_chains = {
|
|
"routine": ["qwen3:8b"],
|
|
"complex": ["qwen3:14b"],
|
|
}
|
|
router.providers = [self._make_dual_model_provider()]
|
|
|
|
with patch.object(router, "_call_ollama") as mock_call:
|
|
mock_call.return_value = {"content": "ok", "model": "qwen3:8b"}
|
|
# Should not raise
|
|
result = await router.complete(
|
|
messages=[{"role": "user", "content": "status"}],
|
|
complexity_hint="INVALID_HINT",
|
|
)
|
|
|
|
assert result["complexity"] in ("simple", "complex") # auto-classified
|