Files
Timmy-time-dashboard/tests/infrastructure/test_router_cascade.py

1707 lines
59 KiB
Python
Raw Normal View History

"""Tests for Cascade LLM Router."""
import time
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import yaml
from infrastructure.router.cascade import (
CascadeRouter,
CircuitState,
ContentType,
Provider,
ProviderMetrics,
ProviderStatus,
RouterConfig,
get_router,
)
@pytest.mark.unit
class TestProviderMetrics:
"""Test provider metrics tracking."""
def test_empty_metrics(self):
"""Test metrics with no requests."""
metrics = ProviderMetrics()
assert metrics.total_requests == 0
assert metrics.avg_latency_ms == 0.0
assert metrics.error_rate == 0.0
def test_avg_latency_calculation(self):
"""Test average latency calculation."""
metrics = ProviderMetrics(
total_requests=4,
total_latency_ms=1000.0, # 4 requests, 1000ms total
)
assert metrics.avg_latency_ms == 250.0
def test_error_rate_calculation(self):
"""Test error rate calculation."""
metrics = ProviderMetrics(
total_requests=10,
successful_requests=7,
failed_requests=3,
)
assert metrics.error_rate == 0.3
@pytest.mark.unit
class TestProvider:
"""Test Provider dataclass."""
def test_get_default_model(self):
"""Test getting default model."""
provider = Provider(
name="test",
type="ollama",
enabled=True,
priority=1,
models=[
{"name": "llama3", "default": True},
{"name": "mistral"},
],
)
assert provider.get_default_model() == "llama3"
def test_get_default_model_no_default(self):
"""Test getting first model when no default set."""
provider = Provider(
name="test",
type="ollama",
enabled=True,
priority=1,
models=[
{"name": "llama3"},
{"name": "mistral"},
],
)
assert provider.get_default_model() == "llama3"
def test_get_default_model_empty(self):
"""Test with no models."""
provider = Provider(
name="test",
type="ollama",
enabled=True,
priority=1,
models=[],
)
assert provider.get_default_model() is None
@pytest.mark.unit
class TestRouterConfig:
"""Test router configuration."""
def test_default_config(self):
"""Test default configuration values."""
config = RouterConfig()
assert config.timeout_seconds == 30
assert config.max_retries_per_provider == 2
assert config.retry_delay_seconds == 1
assert config.circuit_breaker_failure_threshold == 5
@pytest.mark.unit
class TestCascadeRouterInit:
"""Test CascadeRouter initialization."""
def test_init_without_config(self, tmp_path):
"""Test initialization without config file."""
router = CascadeRouter(config_path=tmp_path / "nonexistent.yaml")
assert len(router.providers) == 0
assert router.config.timeout_seconds == 30
def test_init_with_config(self, tmp_path):
"""Test initialization with config file."""
config = {
"cascade": {
"timeout_seconds": 60,
"max_retries_per_provider": 3,
},
"providers": [
{
"name": "test-ollama",
"type": "ollama",
"enabled": False, # Disabled to avoid availability check
"priority": 1,
"url": "http://localhost:11434",
}
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
assert router.config.timeout_seconds == 60
assert router.config.max_retries_per_provider == 3
assert len(router.providers) == 0 # Provider is disabled
def test_env_var_expansion(self, tmp_path, monkeypatch):
"""Test environment variable expansion in config."""
monkeypatch.setenv("TEST_API_KEY", "secret123")
config = {
"cascade": {},
"providers": [
{
"name": "test-openai",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "${TEST_API_KEY}",
}
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
assert len(router.providers) == 1
assert router.providers[0].api_key == "secret123"
@pytest.mark.unit
class TestCascadeRouterMetrics:
"""Test metrics tracking."""
def test_record_success(self):
"""Test recording successful request."""
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
router = CascadeRouter(config_path=Path("/nonexistent"))
router._record_success(provider, 150.0)
assert provider.metrics.total_requests == 1
assert provider.metrics.successful_requests == 1
assert provider.metrics.total_latency_ms == 150.0
assert provider.metrics.consecutive_failures == 0
def test_record_failure(self):
"""Test recording failed request."""
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
router = CascadeRouter(config_path=Path("/nonexistent"))
router._record_failure(provider)
assert provider.metrics.total_requests == 1
assert provider.metrics.failed_requests == 1
assert provider.metrics.consecutive_failures == 1
def test_circuit_breaker_opens(self):
"""Test circuit breaker opens after failures."""
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.circuit_breaker_failure_threshold = 3
# Record 3 failures
for _ in range(3):
router._record_failure(provider)
assert provider.circuit_state == CircuitState.OPEN
assert provider.status == ProviderStatus.UNHEALTHY
assert provider.circuit_opened_at is not None
def test_circuit_breaker_can_close(self):
"""Test circuit breaker can transition to closed."""
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.circuit_breaker_failure_threshold = 3
router.config.circuit_breaker_recovery_timeout = 0.1
# Open the circuit
for _ in range(3):
router._record_failure(provider)
assert provider.circuit_state == CircuitState.OPEN
# Wait for recovery timeout (reduced for faster tests)
import time
time.sleep(0.2)
# Check if can close
assert router._can_close_circuit(provider) is True
def test_half_open_to_closed(self):
"""Test circuit breaker closes after successful test calls."""
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.circuit_breaker_half_open_max_calls = 2
# Manually set to half-open
provider.circuit_state = CircuitState.HALF_OPEN
provider.half_open_calls = 0
# Record successful calls
router._record_success(provider, 100.0)
assert provider.circuit_state == CircuitState.HALF_OPEN # Still half-open
router._record_success(provider, 100.0)
assert provider.circuit_state == CircuitState.CLOSED # Now closed
assert provider.status == ProviderStatus.HEALTHY
@pytest.mark.unit
class TestCascadeRouterGetMetrics:
"""Test get_metrics method."""
def test_get_metrics_empty(self):
"""Test getting metrics with no providers."""
router = CascadeRouter(config_path=Path("/nonexistent"))
metrics = router.get_metrics()
assert "providers" in metrics
assert len(metrics["providers"]) == 0
def test_get_metrics_with_providers(self):
"""Test getting metrics with providers."""
router = CascadeRouter(config_path=Path("/nonexistent"))
# Add a test provider
provider = Provider(
name="test",
type="ollama",
enabled=True,
priority=1,
)
provider.metrics.total_requests = 10
provider.metrics.successful_requests = 8
provider.metrics.failed_requests = 2
provider.metrics.total_latency_ms = 2000.0
router.providers = [provider]
metrics = router.get_metrics()
assert len(metrics["providers"]) == 1
p_metrics = metrics["providers"][0]
assert p_metrics["name"] == "test"
assert p_metrics["metrics"]["total_requests"] == 10
assert p_metrics["metrics"]["error_rate"] == 0.2
assert p_metrics["metrics"]["avg_latency_ms"] == 200.0
@pytest.mark.unit
class TestCascadeRouterGetStatus:
"""Test get_status method."""
def test_get_status(self):
"""Test getting router status."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="test",
type="ollama",
enabled=True,
priority=1,
models=[{"name": "llama3", "default": True}],
)
router.providers = [provider]
status = router.get_status()
assert status["total_providers"] == 1
assert status["healthy_providers"] == 1
assert status["degraded_providers"] == 0
assert status["unhealthy_providers"] == 0
assert len(status["providers"]) == 1
@pytest.mark.unit
@pytest.mark.asyncio
class TestCascadeRouterComplete:
"""Test complete method with failover."""
async def test_complete_with_ollama(self):
"""Test successful completion with Ollama."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="ollama-local",
type="ollama",
enabled=True,
priority=1,
url="http://localhost:11434",
models=[{"name": "llama3.2", "default": True}],
)
router.providers = [provider]
# Mock the Ollama call
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = AsyncMock()()
mock_call.return_value = {
"content": "Hello, world!",
"model": "llama3.2",
}
result = await router.complete(
messages=[{"role": "user", "content": "Hi"}],
)
assert result["content"] == "Hello, world!"
assert result["provider"] == "ollama-local"
assert result["model"] == "llama3.2"
async def test_failover_to_second_provider(self):
"""Test failover when first provider fails."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider1 = Provider(
name="ollama-failing",
type="ollama",
enabled=True,
priority=1,
url="http://localhost:11434",
models=[{"name": "llama3.2", "default": True}],
)
provider2 = Provider(
name="ollama-backup",
type="ollama",
enabled=True,
priority=2,
url="http://backup:11434",
models=[{"name": "llama3.2", "default": True}],
)
router.providers = [provider1, provider2]
# First provider fails, second succeeds
call_count = [0]
async def side_effect(*args, **kwargs):
call_count[0] += 1
# First 2 retries for provider1 fail, then provider2 succeeds
if call_count[0] <= router.config.max_retries_per_provider:
raise RuntimeError("Connection failed")
return {"content": "Backup response", "model": "llama3.2"}
with patch.object(router, "_call_ollama") as mock_call:
mock_call.side_effect = side_effect
result = await router.complete(
messages=[{"role": "user", "content": "Hi"}],
)
assert result["content"] == "Backup response"
assert result["provider"] == "ollama-backup"
async def test_all_providers_fail(self):
"""Test error when all providers fail."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="failing",
type="ollama",
enabled=True,
priority=1,
models=[{"name": "llama3.2", "default": True}],
)
router.providers = [provider]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.side_effect = RuntimeError("Always fails")
with pytest.raises(RuntimeError) as exc_info:
await router.complete(messages=[{"role": "user", "content": "Hi"}])
assert "All providers failed" in str(exc_info.value)
async def test_skips_unhealthy_provider(self):
"""Test that unhealthy providers are skipped."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider1 = Provider(
name="unhealthy",
type="ollama",
enabled=True,
priority=1,
status=ProviderStatus.UNHEALTHY,
circuit_state=CircuitState.OPEN,
circuit_opened_at=time.time(), # Just opened
models=[{"name": "llama3.2", "default": True}],
)
provider2 = Provider(
name="healthy",
type="ollama",
enabled=True,
priority=2,
models=[{"name": "llama3.2", "default": True}],
)
router.providers = [provider1, provider2]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "Success", "model": "llama3.2"}
result = await router.complete(
messages=[{"role": "user", "content": "Hi"}],
)
# Should use the healthy provider
assert result["provider"] == "healthy"
@pytest.mark.unit
class TestProviderAvailabilityCheck:
"""Test provider availability checking."""
def test_check_ollama_without_requests(self):
"""Test Ollama returns True when requests not available (fallback)."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="ollama",
type="ollama",
enabled=True,
priority=1,
url="http://localhost:11434",
)
# When requests is None, assume available
import infrastructure.router.cascade as cascade_module
old_requests = cascade_module.requests
cascade_module.requests = None
try:
assert router._check_provider_available(provider) is True
finally:
cascade_module.requests = old_requests
def test_check_openai_with_key(self):
"""Test OpenAI with API key."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="openai",
type="openai",
enabled=True,
priority=1,
api_key="sk-test123",
)
assert router._check_provider_available(provider) is True
def test_check_openai_without_key(self):
"""Test OpenAI without API key."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="openai",
type="openai",
enabled=True,
priority=1,
api_key=None,
)
assert router._check_provider_available(provider) is False
def test_check_vllm_mlx_without_requests(self):
"""Test vllm-mlx returns True when requests not available (fallback)."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="vllm-mlx-local",
type="vllm_mlx",
enabled=True,
priority=2,
base_url="http://localhost:8000/v1",
)
import infrastructure.router.cascade as cascade_module
old_requests = cascade_module.requests
cascade_module.requests = None
try:
assert router._check_provider_available(provider) is True
finally:
cascade_module.requests = old_requests
def test_check_vllm_mlx_server_healthy(self):
"""Test vllm-mlx when health check succeeds."""
from unittest.mock import patch
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="vllm-mlx-local",
type="vllm_mlx",
enabled=True,
priority=2,
base_url="http://localhost:8000/v1",
)
mock_response = MagicMock()
mock_response.status_code = 200
with patch("infrastructure.router.cascade.requests") as mock_requests:
mock_requests.get.return_value = mock_response
result = router._check_provider_available(provider)
assert result is True
mock_requests.get.assert_called_once_with("http://localhost:8000/health", timeout=5)
def test_check_vllm_mlx_server_down(self):
"""Test vllm-mlx when server is not running."""
from unittest.mock import patch
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="vllm-mlx-local",
type="vllm_mlx",
enabled=True,
priority=2,
base_url="http://localhost:8000/v1",
)
with patch("infrastructure.router.cascade.requests") as mock_requests:
mock_requests.get.side_effect = ConnectionRefusedError("Connection refused")
result = router._check_provider_available(provider)
assert result is False
def test_check_vllm_mlx_default_url(self):
"""Test vllm-mlx uses default localhost:8000 when no URL configured."""
from unittest.mock import patch
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="vllm-mlx-local",
type="vllm_mlx",
enabled=True,
priority=2,
)
mock_response = MagicMock()
mock_response.status_code = 200
with patch("infrastructure.router.cascade.requests") as mock_requests:
mock_requests.get.return_value = mock_response
router._check_provider_available(provider)
mock_requests.get.assert_called_once_with("http://localhost:8000/health", timeout=5)
@pytest.mark.unit
@pytest.mark.asyncio
class TestVllmMlxProvider:
"""Test vllm-mlx provider integration."""
async def test_complete_with_vllm_mlx(self):
"""Test successful completion via vllm-mlx."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="vllm-mlx-local",
type="vllm_mlx",
enabled=True,
priority=2,
base_url="http://localhost:8000/v1",
models=[{"name": "Qwen/Qwen2.5-14B-Instruct-MLX", "default": True}],
)
router.providers = [provider]
with patch.object(router, "_call_vllm_mlx") as mock_call:
mock_call.return_value = {
"content": "MLX response",
"model": "Qwen/Qwen2.5-14B-Instruct-MLX",
}
result = await router.complete(
messages=[{"role": "user", "content": "Hi"}],
)
assert result["content"] == "MLX response"
assert result["provider"] == "vllm-mlx-local"
assert result["model"] == "Qwen/Qwen2.5-14B-Instruct-MLX"
async def test_vllm_mlx_base_url_normalization(self):
"""Test _call_vllm_mlx appends /v1 when missing."""
from unittest.mock import AsyncMock, patch
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="vllm-mlx-local",
type="vllm_mlx",
enabled=True,
priority=2,
base_url="http://localhost:8000", # No /v1
models=[{"name": "qwen-mlx", "default": True}],
)
mock_choice = MagicMock()
mock_choice.message.content = "hello"
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_response.model = "qwen-mlx"
async def fake_create(**kwargs):
return mock_response
with patch("openai.AsyncOpenAI") as mock_openai_cls:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(side_effect=fake_create)
mock_openai_cls.return_value = mock_client
await router._call_vllm_mlx(
provider=provider,
messages=[{"role": "user", "content": "hi"}],
model="qwen-mlx",
temperature=0.7,
max_tokens=None,
)
call_kwargs = mock_openai_cls.call_args
base_url_used = call_kwargs.kwargs.get("base_url") or call_kwargs[1].get("base_url")
assert base_url_used.endswith("/v1")
async def test_vllm_mlx_is_local_not_cloud(self):
"""Confirm vllm_mlx is not subject to metabolic protocol cloud skip."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="vllm-mlx-local",
type="vllm_mlx",
enabled=True,
priority=2,
base_url="http://localhost:8000/v1",
models=[{"name": "qwen-mlx", "default": True}],
)
router.providers = [provider]
# Quota monitor downshifts to local (ACTIVE tier) — vllm_mlx should still be tried
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
mock_qm.select_model.return_value = "qwen3:14b"
mock_qm.check.return_value = None
with patch.object(router, "_call_vllm_mlx") as mock_call:
mock_call.return_value = {
"content": "Local MLX response",
"model": "qwen-mlx",
}
result = await router.complete(
messages=[{"role": "user", "content": "hi"}],
)
assert result["content"] == "Local MLX response"
@pytest.mark.unit
@pytest.mark.asyncio
class TestMetabolicProtocol:
"""Test metabolic protocol: cloud providers skip when quota is ACTIVE/RESTING."""
def _make_anthropic_provider(self) -> "Provider":
return Provider(
name="anthropic-primary",
type="anthropic",
enabled=True,
priority=1,
api_key="test-key",
models=[{"name": "claude-sonnet-4-6", "default": True}],
)
async def test_cloud_provider_allowed_in_burst_tier(self):
"""BURST tier (quota healthy): cloud provider is tried."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.providers = [self._make_anthropic_provider()]
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
# select_model returns cloud model → BURST tier
mock_qm.select_model.return_value = "claude-sonnet-4-6"
mock_qm.check.return_value = None
with patch.object(router, "_call_anthropic") as mock_call:
mock_call.return_value = {"content": "Cloud response", "model": "claude-sonnet-4-6"}
result = await router.complete(
messages=[{"role": "user", "content": "hard question"}],
)
mock_call.assert_called_once()
assert result["content"] == "Cloud response"
async def test_cloud_provider_skipped_in_active_tier(self):
"""ACTIVE tier (5-hour >= 50%): cloud provider is skipped."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.providers = [self._make_anthropic_provider()]
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
# select_model returns local 14B → ACTIVE tier
mock_qm.select_model.return_value = "qwen3:14b"
mock_qm.check.return_value = None
with patch.object(router, "_call_anthropic") as mock_call:
with pytest.raises(RuntimeError, match="All providers failed"):
await router.complete(
messages=[{"role": "user", "content": "question"}],
)
mock_call.assert_not_called()
async def test_cloud_provider_skipped_in_resting_tier(self):
"""RESTING tier (7-day >= 80%): cloud provider is skipped."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.providers = [self._make_anthropic_provider()]
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
# select_model returns local 8B → RESTING tier
mock_qm.select_model.return_value = "qwen3:8b"
mock_qm.check.return_value = None
with patch.object(router, "_call_anthropic") as mock_call:
with pytest.raises(RuntimeError, match="All providers failed"):
await router.complete(
messages=[{"role": "user", "content": "simple question"}],
)
mock_call.assert_not_called()
async def test_local_provider_always_tried_regardless_of_quota(self):
"""Local (ollama/vllm_mlx) providers bypass the metabolic protocol."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="ollama-local",
type="ollama",
enabled=True,
priority=1,
url="http://localhost:11434",
models=[{"name": "qwen3:14b", "default": True}],
)
router.providers = [provider]
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
mock_qm.select_model.return_value = "qwen3:8b" # RESTING tier
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "Local response", "model": "qwen3:14b"}
result = await router.complete(
messages=[{"role": "user", "content": "hi"}],
)
mock_call.assert_called_once()
assert result["content"] == "Local response"
async def test_no_quota_monitor_allows_cloud(self):
"""When quota monitor is None (unavailable), cloud providers are allowed."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.providers = [self._make_anthropic_provider()]
with patch("infrastructure.router.cascade._quota_monitor", None):
with patch.object(router, "_call_anthropic") as mock_call:
mock_call.return_value = {"content": "Cloud response", "model": "claude-sonnet-4-6"}
result = await router.complete(
messages=[{"role": "user", "content": "question"}],
)
mock_call.assert_called_once()
assert result["content"] == "Cloud response"
@pytest.mark.unit
class TestCascadeRouterReload:
"""Test hot-reload of providers.yaml."""
def test_reload_preserves_metrics(self, tmp_path):
"""Test that reload preserves metrics for existing providers."""
config = {
"providers": [
{
"name": "test-openai",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test",
}
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
assert len(router.providers) == 1
# Simulate some traffic
router._record_success(router.providers[0], 150.0)
router._record_success(router.providers[0], 250.0)
assert router.providers[0].metrics.total_requests == 2
# Reload
result = router.reload_config()
assert result["total_providers"] == 1
assert result["preserved"] == 1
assert result["added"] == []
assert result["removed"] == []
# Metrics survived
assert router.providers[0].metrics.total_requests == 2
assert router.providers[0].metrics.total_latency_ms == 400.0
def test_reload_preserves_circuit_breaker(self, tmp_path):
"""Test that reload preserves circuit breaker state."""
config = {
"cascade": {"circuit_breaker": {"failure_threshold": 2}},
"providers": [
{
"name": "test-openai",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test",
}
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
# Open circuit breaker
for _ in range(2):
router._record_failure(router.providers[0])
assert router.providers[0].circuit_state == CircuitState.OPEN
# Reload
router.reload_config()
# Circuit breaker state preserved
assert router.providers[0].circuit_state == CircuitState.OPEN
assert router.providers[0].status == ProviderStatus.UNHEALTHY
def test_reload_detects_added_provider(self, tmp_path):
"""Test that reload detects newly added providers."""
config = {
"providers": [
{
"name": "openai-1",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test",
}
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
assert len(router.providers) == 1
# Add a second provider to config
config["providers"].append(
{
"name": "anthropic-1",
"type": "anthropic",
"enabled": True,
"priority": 2,
"api_key": "sk-ant-test",
}
)
config_path.write_text(yaml.dump(config))
result = router.reload_config()
assert result["total_providers"] == 2
assert result["preserved"] == 1
assert result["added"] == ["anthropic-1"]
assert result["removed"] == []
def test_reload_detects_removed_provider(self, tmp_path):
"""Test that reload detects removed providers."""
config = {
"providers": [
{
"name": "openai-1",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test",
},
{
"name": "anthropic-1",
"type": "anthropic",
"enabled": True,
"priority": 2,
"api_key": "sk-ant-test",
},
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
assert len(router.providers) == 2
# Remove anthropic
config["providers"] = [config["providers"][0]]
config_path.write_text(yaml.dump(config))
result = router.reload_config()
assert result["total_providers"] == 1
assert result["preserved"] == 1
assert result["removed"] == ["anthropic-1"]
def test_reload_re_sorts_by_priority(self, tmp_path):
"""Test that providers are re-sorted by priority after reload."""
config = {
"providers": [
{
"name": "low-priority",
"type": "openai",
"enabled": True,
"priority": 10,
"api_key": "sk-test",
},
{
"name": "high-priority",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test2",
},
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
assert router.providers[0].name == "high-priority"
# Swap priorities
config["providers"][0]["priority"] = 1
config["providers"][1]["priority"] = 10
config_path.write_text(yaml.dump(config))
router.reload_config()
assert router.providers[0].name == "low-priority"
assert router.providers[1].name == "high-priority"
@pytest.mark.unit
class TestContentTypeDetection:
"""Test _detect_content_type logic."""
def _router(self) -> CascadeRouter:
return CascadeRouter(config_path=Path("/nonexistent"))
def test_text_only(self):
router = self._router()
msgs = [{"role": "user", "content": "Hello"}]
assert router._detect_content_type(msgs) == ContentType.TEXT
def test_images_key_triggers_vision(self):
router = self._router()
msgs = [{"role": "user", "content": "Describe this", "images": ["pic.jpg"]}]
assert router._detect_content_type(msgs) == ContentType.VISION
def test_image_extension_in_content_triggers_vision(self):
router = self._router()
msgs = [{"role": "user", "content": "Look at photo.png please"}]
assert router._detect_content_type(msgs) == ContentType.VISION
def test_base64_data_uri_triggers_vision(self):
router = self._router()
msgs = [{"role": "user", "content": "data:image/jpeg;base64,/9j/4AA..."}]
assert router._detect_content_type(msgs) == ContentType.VISION
def test_audio_key_triggers_audio(self):
router = self._router()
msgs = [{"role": "user", "content": "", "audio": b"bytes"}]
assert router._detect_content_type(msgs) == ContentType.AUDIO
def test_image_and_audio_triggers_multimodal(self):
router = self._router()
msgs = [
{"role": "user", "content": "check photo.jpg", "audio": b"bytes"},
]
assert router._detect_content_type(msgs) == ContentType.MULTIMODAL
def test_list_content_image_url_type(self):
router = self._router()
msgs = [
{
"role": "user",
"content": [
{"type": "text", "text": "What?"},
{"type": "image_url", "image_url": {"url": "http://example.com/a.jpg"}},
],
}
]
assert router._detect_content_type(msgs) == ContentType.VISION
def test_list_content_audio_type(self):
router = self._router()
msgs = [
{
"role": "user",
"content": [
{"type": "audio", "data": "base64..."},
],
}
]
assert router._detect_content_type(msgs) == ContentType.AUDIO
@pytest.mark.unit
class TestTransformMessagesForOllama:
"""Test _transform_messages_for_ollama."""
def _router(self) -> CascadeRouter:
return CascadeRouter(config_path=Path("/nonexistent"))
def test_plain_text_message(self):
router = self._router()
result = router._transform_messages_for_ollama([{"role": "user", "content": "Hello"}])
assert result == [{"role": "user", "content": "Hello"}]
def test_base64_image_stripped(self):
router = self._router()
msgs = [
{
"role": "user",
"content": "Describe",
"images": ["data:image/png;base64,abc123"],
}
]
result = router._transform_messages_for_ollama(msgs)
assert result[0]["images"] == ["abc123"]
def test_http_url_skipped(self):
router = self._router()
msgs = [
{
"role": "user",
"content": "Describe",
"images": ["http://example.com/img.jpg"],
}
]
result = router._transform_messages_for_ollama(msgs)
# URL is skipped — images list should be empty or absent
assert result[0].get("images", []) == []
def test_missing_local_file_skipped(self):
router = self._router()
msgs = [
{
"role": "user",
"content": "Describe",
"images": ["/nonexistent/path/image.png"],
}
]
result = router._transform_messages_for_ollama(msgs)
assert result[0].get("images", []) == []
@pytest.mark.unit
class TestProviderCapabilityMethods:
"""Test Provider.get_model_with_capability and model_has_capability."""
def _provider(self) -> Provider:
return Provider(
name="test",
type="ollama",
enabled=True,
priority=1,
models=[
{"name": "llava:7b", "capabilities": ["vision"]},
{"name": "llama3.2", "default": True},
],
)
def test_get_model_with_capability_found(self):
p = self._provider()
assert p.get_model_with_capability("vision") == "llava:7b"
def test_get_model_with_capability_falls_back_to_default(self):
p = self._provider()
assert p.get_model_with_capability("audio") == "llama3.2"
def test_model_has_capability_true(self):
p = self._provider()
assert p.model_has_capability("llava:7b", "vision") is True
def test_model_has_capability_false(self):
p = self._provider()
assert p.model_has_capability("llama3.2", "vision") is False
def test_model_has_capability_unknown_model(self):
p = self._provider()
assert p.model_has_capability("unknown-model", "vision") is False
@pytest.mark.unit
class TestGetFallbackModel:
"""Test _get_fallback_model."""
def _router_with_provider(self) -> tuple[CascadeRouter, Provider]:
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="test",
type="ollama",
enabled=True,
priority=1,
models=[
{"name": "llava:7b", "capabilities": ["vision"]},
{"name": "llama3.2", "default": True},
],
)
return router, provider
def test_returns_vision_model(self):
router, provider = self._router_with_provider()
result = router._get_fallback_model(provider, "llama3.2", ContentType.VISION)
assert result == "llava:7b"
def test_returns_none_if_no_capability(self):
router, provider = self._router_with_provider()
result = router._get_fallback_model(provider, "llama3.2", ContentType.AUDIO)
# No audio model; falls back to default which is same as original
assert result is None or result == "llama3.2"
def test_text_content_returns_none(self):
router, provider = self._router_with_provider()
result = router._get_fallback_model(provider, "llama3.2", ContentType.TEXT)
assert result is None
@pytest.mark.unit
@pytest.mark.asyncio
class TestCascadeTierFiltering:
"""Test cascade_tier parameter in complete()."""
def _make_router(self) -> CascadeRouter:
router = CascadeRouter(config_path=Path("/nonexistent"))
router.providers = [
Provider(
name="anthropic-primary",
type="anthropic",
enabled=True,
priority=1,
api_key="test-key",
models=[{"name": "claude-sonnet-4-6", "default": True}],
),
Provider(
name="ollama-local",
type="ollama",
enabled=True,
priority=2,
models=[{"name": "llama3.2", "default": True}],
),
]
return router
async def test_frontier_required_uses_anthropic(self):
router = self._make_router()
with patch("infrastructure.router.cascade._quota_monitor", None):
with patch.object(router, "_call_anthropic") as mock_call:
mock_call.return_value = {
"content": "frontier response",
"model": "claude-sonnet-4-6",
}
result = await router.complete(
messages=[{"role": "user", "content": "hi"}],
cascade_tier="frontier_required",
)
assert result["provider"] == "anthropic-primary"
mock_call.assert_called_once()
async def test_frontier_required_no_anthropic_raises(self):
router = CascadeRouter(config_path=Path("/nonexistent"))
router.providers = [
Provider(
name="ollama-local",
type="ollama",
enabled=True,
priority=1,
models=[{"name": "llama3.2", "default": True}],
)
]
with pytest.raises(RuntimeError, match="No Anthropic provider configured"):
await router.complete(
messages=[{"role": "user", "content": "hi"}],
cascade_tier="frontier_required",
)
async def test_unknown_tier_raises(self):
router = self._make_router()
with pytest.raises(RuntimeError, match="No providers found for tier"):
await router.complete(
messages=[{"role": "user", "content": "hi"}],
cascade_tier="nonexistent_tier",
)
async def test_tier_filter_only_matching_providers(self):
router = CascadeRouter(config_path=Path("/nonexistent"))
router.providers = [
Provider(
name="local-primary",
type="ollama",
enabled=True,
priority=1,
tier="local",
models=[{"name": "llama3.2", "default": True}],
),
Provider(
name="cloud-secondary",
type="anthropic",
enabled=True,
priority=2,
tier="cloud",
api_key="key",
models=[{"name": "claude-sonnet-4-6", "default": True}],
),
]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "local response", "model": "llama3.2"}
result = await router.complete(
messages=[{"role": "user", "content": "hi"}],
cascade_tier="local",
)
assert result["provider"] == "local-primary"
mock_call.assert_called_once()
@pytest.mark.unit
@pytest.mark.asyncio
class TestGenerateWithImage:
"""Test generate_with_image convenience method."""
async def test_delegates_to_complete(self):
router = CascadeRouter(config_path=Path("/nonexistent"))
router.providers = [
Provider(
name="ollama-vision",
type="ollama",
enabled=True,
priority=1,
models=[{"name": "llava:7b", "capabilities": ["vision"], "default": True}],
)
]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "A cat", "model": "llava:7b"}
result = await router.generate_with_image(
prompt="What is this?",
image_path="/tmp/cat.jpg",
model="llava:7b",
)
assert result["content"] == "A cat"
assert result["provider"] == "ollama-vision"
# complete() should have been called with images in messages
call_kwargs = mock_call.call_args
messages_passed = call_kwargs.kwargs.get("messages") or call_kwargs[1].get("messages")
assert messages_passed[0]["images"] == ["/tmp/cat.jpg"]
@pytest.mark.unit
class TestGetRouterSingleton:
"""Test get_router() returns a singleton and creates CascadeRouter."""
def test_get_router_returns_cascade_router(self):
import infrastructure.router.cascade as cascade_module
# Reset singleton to test creation
original = cascade_module.cascade_router
cascade_module.cascade_router = None
try:
router = get_router()
assert isinstance(router, CascadeRouter)
finally:
cascade_module.cascade_router = original
def test_get_router_returns_same_instance(self):
import infrastructure.router.cascade as cascade_module
original = cascade_module.cascade_router
cascade_module.cascade_router = None
try:
r1 = get_router()
r2 = get_router()
assert r1 is r2
finally:
cascade_module.cascade_router = original
@pytest.mark.unit
class TestIsProviderAvailable:
"""Test _is_provider_available with circuit breaker transitions."""
def _router(self) -> CascadeRouter:
return CascadeRouter(config_path=Path("/nonexistent"))
def test_disabled_provider_not_available(self):
router = self._router()
provider = Provider(name="p", type="ollama", enabled=False, priority=1)
assert router._is_provider_available(provider) is False
def test_healthy_provider_available(self):
router = self._router()
provider = Provider(name="p", type="ollama", enabled=True, priority=1)
assert router._is_provider_available(provider) is True
def test_unhealthy_open_circuit_not_available(self):
router = self._router()
provider = Provider(
name="p",
type="ollama",
enabled=True,
priority=1,
status=ProviderStatus.UNHEALTHY,
circuit_state=CircuitState.OPEN,
circuit_opened_at=time.time(), # Just opened — not yet recoverable
)
assert router._is_provider_available(provider) is False
def test_unhealthy_after_timeout_transitions_to_half_open(self):
router = self._router()
router.config.circuit_breaker_recovery_timeout = 0
provider = Provider(
name="p",
type="ollama",
enabled=True,
priority=1,
status=ProviderStatus.UNHEALTHY,
circuit_state=CircuitState.OPEN,
circuit_opened_at=time.time() - 10, # Long ago
)
result = router._is_provider_available(provider)
assert result is True
assert provider.circuit_state == CircuitState.HALF_OPEN
@pytest.mark.unit
class TestFilterProviders:
"""Test _filter_providers helper extracted from complete()."""
def _router(self) -> CascadeRouter:
router = CascadeRouter(config_path=Path("/nonexistent"))
router.providers = [
Provider(
name="anthropic-p",
type="anthropic",
enabled=True,
priority=1,
api_key="key",
tier="frontier",
),
Provider(
name="ollama-p",
type="ollama",
enabled=True,
priority=2,
tier="local",
),
]
return router
def test_no_tier_returns_all_providers(self):
router = self._router()
result = router._filter_providers(None)
assert result is router.providers
def test_frontier_required_returns_only_anthropic(self):
router = self._router()
result = router._filter_providers("frontier_required")
assert len(result) == 1
assert result[0].type == "anthropic"
def test_frontier_required_no_anthropic_raises(self):
router = CascadeRouter(config_path=Path("/nonexistent"))
router.providers = [Provider(name="ollama-p", type="ollama", enabled=True, priority=1)]
with pytest.raises(RuntimeError, match="No Anthropic provider configured"):
router._filter_providers("frontier_required")
def test_named_tier_filters_by_tier(self):
router = self._router()
result = router._filter_providers("local")
assert len(result) == 1
assert result[0].name == "ollama-p"
def test_named_tier_not_found_raises(self):
router = self._router()
with pytest.raises(RuntimeError, match="No providers found for tier"):
router._filter_providers("nonexistent")
@pytest.mark.unit
@pytest.mark.asyncio
class TestTrySingleProvider:
"""Test _try_single_provider helper extracted from complete()."""
def _router(self) -> CascadeRouter:
return CascadeRouter(config_path=Path("/nonexistent"))
def _provider(self, name: str = "test", ptype: str = "ollama") -> Provider:
return Provider(
name=name,
type=ptype,
enabled=True,
priority=1,
models=[{"name": "llama3.2", "default": True}],
)
async def test_unavailable_provider_returns_none(self):
router = self._router()
provider = self._provider()
provider.enabled = False
errors: list[str] = []
result = await router._try_single_provider(
provider, [], None, 0.7, None, ContentType.TEXT, errors
)
assert result is None
assert errors == []
async def test_quota_blocked_cloud_provider_returns_none(self):
router = self._router()
provider = self._provider(ptype="anthropic")
errors: list[str] = []
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
mock_qm.select_model.return_value = "qwen3:14b" # non-cloud → ACTIVE tier
mock_qm.check.return_value = None
result = await router._try_single_provider(
provider, [], None, 0.7, None, ContentType.TEXT, errors
)
assert result is None
assert errors == []
async def test_success_returns_result_dict(self):
router = self._router()
provider = self._provider()
errors: list[str] = []
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "hi", "model": "llama3.2"}
result = await router._try_single_provider(
provider,
[{"role": "user", "content": "hi"}],
None,
0.7,
None,
ContentType.TEXT,
errors,
)
assert result is not None
assert result["content"] == "hi"
assert result["provider"] == "test"
assert errors == []
async def test_failure_appends_error_and_returns_none(self):
router = self._router()
provider = self._provider()
errors: list[str] = []
with patch.object(router, "_call_ollama") as mock_call:
mock_call.side_effect = RuntimeError("boom")
result = await router._try_single_provider(
provider,
[{"role": "user", "content": "hi"}],
None,
0.7,
None,
ContentType.TEXT,
errors,
)
assert result is None
assert len(errors) == 1
assert "boom" in errors[0]
assert provider.metrics.failed_requests == 1
class TestComplexityRouting:
"""Tests for Qwen3-8B / Qwen3-14B dual-model routing (issue #1065)."""
def _make_dual_model_provider(self) -> Provider:
"""Build an Ollama provider with both Qwen3 models registered."""
return Provider(
name="ollama-local",
type="ollama",
enabled=True,
priority=1,
url="http://localhost:11434",
models=[
{
"name": "qwen3:8b",
"capabilities": ["text", "tools", "json", "streaming", "routine"],
},
{
"name": "qwen3:14b",
"default": True,
"capabilities": ["text", "tools", "json", "streaming", "complex", "reasoning"],
},
],
)
def test_get_model_for_complexity_simple_returns_8b(self):
"""Simple tasks should select the model with 'routine' capability."""
from infrastructure.router.classifier import TaskComplexity
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.fallback_chains = {
"routine": ["qwen3:8b"],
"complex": ["qwen3:14b"],
}
provider = self._make_dual_model_provider()
model = router._get_model_for_complexity(provider, TaskComplexity.SIMPLE)
assert model == "qwen3:8b"
def test_get_model_for_complexity_complex_returns_14b(self):
"""Complex tasks should select the model with 'complex' capability."""
from infrastructure.router.classifier import TaskComplexity
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.fallback_chains = {
"routine": ["qwen3:8b"],
"complex": ["qwen3:14b"],
}
provider = self._make_dual_model_provider()
model = router._get_model_for_complexity(provider, TaskComplexity.COMPLEX)
assert model == "qwen3:14b"
def test_get_model_for_complexity_returns_none_when_no_match(self):
"""Returns None when provider has no matching model in chain."""
from infrastructure.router.classifier import TaskComplexity
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.fallback_chains = {} # empty chains
provider = Provider(
name="test",
type="ollama",
enabled=True,
priority=1,
models=[{"name": "llama3.2:3b", "default": True, "capabilities": ["text"]}],
)
# No 'routine' or 'complex' model available
model = router._get_model_for_complexity(provider, TaskComplexity.SIMPLE)
assert model is None
@pytest.mark.asyncio
async def test_complete_with_simple_hint_routes_to_8b(self):
"""complexity_hint='simple' should use qwen3:8b."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.fallback_chains = {
"routine": ["qwen3:8b"],
"complex": ["qwen3:14b"],
}
router.providers = [self._make_dual_model_provider()]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "fast answer", "model": "qwen3:8b"}
result = await router.complete(
messages=[{"role": "user", "content": "list tasks"}],
complexity_hint="simple",
)
assert result["model"] == "qwen3:8b"
assert result["complexity"] == "simple"
@pytest.mark.asyncio
async def test_complete_with_complex_hint_routes_to_14b(self):
"""complexity_hint='complex' should use qwen3:14b."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.fallback_chains = {
"routine": ["qwen3:8b"],
"complex": ["qwen3:14b"],
}
router.providers = [self._make_dual_model_provider()]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "detailed answer", "model": "qwen3:14b"}
result = await router.complete(
messages=[{"role": "user", "content": "review this PR"}],
complexity_hint="complex",
)
assert result["model"] == "qwen3:14b"
assert result["complexity"] == "complex"
@pytest.mark.asyncio
async def test_explicit_model_bypasses_complexity_routing(self):
"""When model is explicitly provided, complexity routing is skipped."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.fallback_chains = {
"routine": ["qwen3:8b"],
"complex": ["qwen3:14b"],
}
router.providers = [self._make_dual_model_provider()]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "response", "model": "qwen3:14b"}
result = await router.complete(
messages=[{"role": "user", "content": "list tasks"}],
model="qwen3:14b", # explicit override
)
# Explicit model wins — complexity field is None
assert result["model"] == "qwen3:14b"
assert result["complexity"] is None
@pytest.mark.asyncio
async def test_auto_classification_routes_simple_message(self):
"""Short, simple messages should auto-classify as SIMPLE → 8B."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.fallback_chains = {
"routine": ["qwen3:8b"],
"complex": ["qwen3:14b"],
}
router.providers = [self._make_dual_model_provider()]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "ok", "model": "qwen3:8b"}
result = await router.complete(
messages=[{"role": "user", "content": "status"}],
# no complexity_hint — auto-classify
)
assert result["complexity"] == "simple"
assert result["model"] == "qwen3:8b"
@pytest.mark.asyncio
async def test_auto_classification_routes_complex_message(self):
"""Complex messages should auto-classify → 14B."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.fallback_chains = {
"routine": ["qwen3:8b"],
"complex": ["qwen3:14b"],
}
router.providers = [self._make_dual_model_provider()]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "deep analysis", "model": "qwen3:14b"}
result = await router.complete(
messages=[{"role": "user", "content": "analyze and prioritize the backlog"}],
)
assert result["complexity"] == "complex"
assert result["model"] == "qwen3:14b"
@pytest.mark.asyncio
async def test_invalid_complexity_hint_falls_back_to_auto(self):
"""Invalid complexity_hint should log a warning and auto-classify."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.fallback_chains = {
"routine": ["qwen3:8b"],
"complex": ["qwen3:14b"],
}
router.providers = [self._make_dual_model_provider()]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "ok", "model": "qwen3:8b"}
# Should not raise
result = await router.complete(
messages=[{"role": "user", "content": "status"}],
complexity_hint="INVALID_HINT",
)
assert result["complexity"] in ("simple", "complex") # auto-classified