Files
Timmy-time-dashboard/tests/infrastructure/test_router_cascade.py
Alexander Whitestone 6c5f55230b WIP: Claude Code progress on #1065
Automated salvage commit — agent session ended (exit 124).
Work in progress, may need continuation.
2026-03-23 14:41:42 -04:00

1163 lines
40 KiB
Python

"""Tests for Cascade LLM Router."""
import time
from pathlib import Path
from unittest.mock import AsyncMock, patch
import pytest
import yaml
from infrastructure.router.cascade import (
CascadeRouter,
CircuitState,
Provider,
ProviderMetrics,
ProviderStatus,
RouterConfig,
)
class TestProviderMetrics:
"""Test provider metrics tracking."""
def test_empty_metrics(self):
"""Test metrics with no requests."""
metrics = ProviderMetrics()
assert metrics.total_requests == 0
assert metrics.avg_latency_ms == 0.0
assert metrics.error_rate == 0.0
def test_avg_latency_calculation(self):
"""Test average latency calculation."""
metrics = ProviderMetrics(
total_requests=4,
total_latency_ms=1000.0, # 4 requests, 1000ms total
)
assert metrics.avg_latency_ms == 250.0
def test_error_rate_calculation(self):
"""Test error rate calculation."""
metrics = ProviderMetrics(
total_requests=10,
successful_requests=7,
failed_requests=3,
)
assert metrics.error_rate == 0.3
class TestProvider:
"""Test Provider dataclass."""
def test_get_default_model(self):
"""Test getting default model."""
provider = Provider(
name="test",
type="ollama",
enabled=True,
priority=1,
models=[
{"name": "llama3", "default": True},
{"name": "mistral"},
],
)
assert provider.get_default_model() == "llama3"
def test_get_default_model_no_default(self):
"""Test getting first model when no default set."""
provider = Provider(
name="test",
type="ollama",
enabled=True,
priority=1,
models=[
{"name": "llama3"},
{"name": "mistral"},
],
)
assert provider.get_default_model() == "llama3"
def test_get_default_model_empty(self):
"""Test with no models."""
provider = Provider(
name="test",
type="ollama",
enabled=True,
priority=1,
models=[],
)
assert provider.get_default_model() is None
class TestRouterConfig:
"""Test router configuration."""
def test_default_config(self):
"""Test default configuration values."""
config = RouterConfig()
assert config.timeout_seconds == 30
assert config.max_retries_per_provider == 2
assert config.retry_delay_seconds == 1
assert config.circuit_breaker_failure_threshold == 5
class TestCascadeRouterInit:
"""Test CascadeRouter initialization."""
def test_init_without_config(self, tmp_path):
"""Test initialization without config file."""
router = CascadeRouter(config_path=tmp_path / "nonexistent.yaml")
assert len(router.providers) == 0
assert router.config.timeout_seconds == 30
def test_init_with_config(self, tmp_path):
"""Test initialization with config file."""
config = {
"cascade": {
"timeout_seconds": 60,
"max_retries_per_provider": 3,
},
"providers": [
{
"name": "test-ollama",
"type": "ollama",
"enabled": False, # Disabled to avoid availability check
"priority": 1,
"url": "http://localhost:11434",
}
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
assert router.config.timeout_seconds == 60
assert router.config.max_retries_per_provider == 3
assert len(router.providers) == 0 # Provider is disabled
def test_env_var_expansion(self, tmp_path, monkeypatch):
"""Test environment variable expansion in config."""
monkeypatch.setenv("TEST_API_KEY", "secret123")
config = {
"cascade": {},
"providers": [
{
"name": "test-openai",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "${TEST_API_KEY}",
}
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
assert len(router.providers) == 1
assert router.providers[0].api_key == "secret123"
class TestCascadeRouterMetrics:
"""Test metrics tracking."""
def test_record_success(self):
"""Test recording successful request."""
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
router = CascadeRouter(config_path=Path("/nonexistent"))
router._record_success(provider, 150.0)
assert provider.metrics.total_requests == 1
assert provider.metrics.successful_requests == 1
assert provider.metrics.total_latency_ms == 150.0
assert provider.metrics.consecutive_failures == 0
def test_record_failure(self):
"""Test recording failed request."""
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
router = CascadeRouter(config_path=Path("/nonexistent"))
router._record_failure(provider)
assert provider.metrics.total_requests == 1
assert provider.metrics.failed_requests == 1
assert provider.metrics.consecutive_failures == 1
def test_circuit_breaker_opens(self):
"""Test circuit breaker opens after failures."""
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.circuit_breaker_failure_threshold = 3
# Record 3 failures
for _ in range(3):
router._record_failure(provider)
assert provider.circuit_state == CircuitState.OPEN
assert provider.status == ProviderStatus.UNHEALTHY
assert provider.circuit_opened_at is not None
def test_circuit_breaker_can_close(self):
"""Test circuit breaker can transition to closed."""
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.circuit_breaker_failure_threshold = 3
router.config.circuit_breaker_recovery_timeout = 0.1
# Open the circuit
for _ in range(3):
router._record_failure(provider)
assert provider.circuit_state == CircuitState.OPEN
# Wait for recovery timeout (reduced for faster tests)
import time
time.sleep(0.2)
# Check if can close
assert router._can_close_circuit(provider) is True
def test_half_open_to_closed(self):
"""Test circuit breaker closes after successful test calls."""
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.circuit_breaker_half_open_max_calls = 2
# Manually set to half-open
provider.circuit_state = CircuitState.HALF_OPEN
provider.half_open_calls = 0
# Record successful calls
router._record_success(provider, 100.0)
assert provider.circuit_state == CircuitState.HALF_OPEN # Still half-open
router._record_success(provider, 100.0)
assert provider.circuit_state == CircuitState.CLOSED # Now closed
assert provider.status == ProviderStatus.HEALTHY
class TestCascadeRouterGetMetrics:
"""Test get_metrics method."""
def test_get_metrics_empty(self):
"""Test getting metrics with no providers."""
router = CascadeRouter(config_path=Path("/nonexistent"))
metrics = router.get_metrics()
assert "providers" in metrics
assert len(metrics["providers"]) == 0
def test_get_metrics_with_providers(self):
"""Test getting metrics with providers."""
router = CascadeRouter(config_path=Path("/nonexistent"))
# Add a test provider
provider = Provider(
name="test",
type="ollama",
enabled=True,
priority=1,
)
provider.metrics.total_requests = 10
provider.metrics.successful_requests = 8
provider.metrics.failed_requests = 2
provider.metrics.total_latency_ms = 2000.0
router.providers = [provider]
metrics = router.get_metrics()
assert len(metrics["providers"]) == 1
p_metrics = metrics["providers"][0]
assert p_metrics["name"] == "test"
assert p_metrics["metrics"]["total_requests"] == 10
assert p_metrics["metrics"]["error_rate"] == 0.2
assert p_metrics["metrics"]["avg_latency_ms"] == 200.0
class TestCascadeRouterGetStatus:
"""Test get_status method."""
def test_get_status(self):
"""Test getting router status."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="test",
type="ollama",
enabled=True,
priority=1,
models=[{"name": "llama3", "default": True}],
)
router.providers = [provider]
status = router.get_status()
assert status["total_providers"] == 1
assert status["healthy_providers"] == 1
assert status["degraded_providers"] == 0
assert status["unhealthy_providers"] == 0
assert len(status["providers"]) == 1
@pytest.mark.asyncio
class TestCascadeRouterComplete:
"""Test complete method with failover."""
async def test_complete_with_ollama(self):
"""Test successful completion with Ollama."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="ollama-local",
type="ollama",
enabled=True,
priority=1,
url="http://localhost:11434",
models=[{"name": "llama3.2", "default": True}],
)
router.providers = [provider]
# Mock the Ollama call
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = AsyncMock()()
mock_call.return_value = {
"content": "Hello, world!",
"model": "llama3.2",
}
result = await router.complete(
messages=[{"role": "user", "content": "Hi"}],
)
assert result["content"] == "Hello, world!"
assert result["provider"] == "ollama-local"
assert result["model"] == "llama3.2"
async def test_failover_to_second_provider(self):
"""Test failover when first provider fails."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider1 = Provider(
name="ollama-failing",
type="ollama",
enabled=True,
priority=1,
url="http://localhost:11434",
models=[{"name": "llama3.2", "default": True}],
)
provider2 = Provider(
name="ollama-backup",
type="ollama",
enabled=True,
priority=2,
url="http://backup:11434",
models=[{"name": "llama3.2", "default": True}],
)
router.providers = [provider1, provider2]
# First provider fails, second succeeds
call_count = [0]
async def side_effect(*args, **kwargs):
call_count[0] += 1
# First 2 retries for provider1 fail, then provider2 succeeds
if call_count[0] <= router.config.max_retries_per_provider:
raise RuntimeError("Connection failed")
return {"content": "Backup response", "model": "llama3.2"}
with patch.object(router, "_call_ollama") as mock_call:
mock_call.side_effect = side_effect
result = await router.complete(
messages=[{"role": "user", "content": "Hi"}],
)
assert result["content"] == "Backup response"
assert result["provider"] == "ollama-backup"
async def test_all_providers_fail(self):
"""Test error when all providers fail."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="failing",
type="ollama",
enabled=True,
priority=1,
models=[{"name": "llama3.2", "default": True}],
)
router.providers = [provider]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.side_effect = RuntimeError("Always fails")
with pytest.raises(RuntimeError) as exc_info:
await router.complete(messages=[{"role": "user", "content": "Hi"}])
assert "All providers failed" in str(exc_info.value)
async def test_skips_unhealthy_provider(self):
"""Test that unhealthy providers are skipped."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider1 = Provider(
name="unhealthy",
type="ollama",
enabled=True,
priority=1,
status=ProviderStatus.UNHEALTHY,
circuit_state=CircuitState.OPEN,
circuit_opened_at=time.time(), # Just opened
models=[{"name": "llama3.2", "default": True}],
)
provider2 = Provider(
name="healthy",
type="ollama",
enabled=True,
priority=2,
models=[{"name": "llama3.2", "default": True}],
)
router.providers = [provider1, provider2]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "Success", "model": "llama3.2"}
result = await router.complete(
messages=[{"role": "user", "content": "Hi"}],
)
# Should use the healthy provider
assert result["provider"] == "healthy"
class TestProviderAvailabilityCheck:
"""Test provider availability checking."""
def test_check_ollama_without_requests(self):
"""Test Ollama returns True when requests not available (fallback)."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="ollama",
type="ollama",
enabled=True,
priority=1,
url="http://localhost:11434",
)
# When requests is None, assume available
import infrastructure.router.cascade as cascade_module
old_requests = cascade_module.requests
cascade_module.requests = None
try:
assert router._check_provider_available(provider) is True
finally:
cascade_module.requests = old_requests
def test_check_openai_with_key(self):
"""Test OpenAI with API key."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="openai",
type="openai",
enabled=True,
priority=1,
api_key="sk-test123",
)
assert router._check_provider_available(provider) is True
def test_check_openai_without_key(self):
"""Test OpenAI without API key."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="openai",
type="openai",
enabled=True,
priority=1,
api_key=None,
)
assert router._check_provider_available(provider) is False
def test_check_vllm_mlx_without_requests(self):
"""Test vllm-mlx returns True when requests not available (fallback)."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="vllm-mlx-local",
type="vllm_mlx",
enabled=True,
priority=2,
base_url="http://localhost:8000/v1",
)
import infrastructure.router.cascade as cascade_module
old_requests = cascade_module.requests
cascade_module.requests = None
try:
assert router._check_provider_available(provider) is True
finally:
cascade_module.requests = old_requests
def test_check_vllm_mlx_server_healthy(self):
"""Test vllm-mlx when health check succeeds."""
from unittest.mock import MagicMock, patch
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="vllm-mlx-local",
type="vllm_mlx",
enabled=True,
priority=2,
base_url="http://localhost:8000/v1",
)
mock_response = MagicMock()
mock_response.status_code = 200
with patch("infrastructure.router.cascade.requests") as mock_requests:
mock_requests.get.return_value = mock_response
result = router._check_provider_available(provider)
assert result is True
mock_requests.get.assert_called_once_with("http://localhost:8000/health", timeout=5)
def test_check_vllm_mlx_server_down(self):
"""Test vllm-mlx when server is not running."""
from unittest.mock import patch
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="vllm-mlx-local",
type="vllm_mlx",
enabled=True,
priority=2,
base_url="http://localhost:8000/v1",
)
with patch("infrastructure.router.cascade.requests") as mock_requests:
mock_requests.get.side_effect = ConnectionRefusedError("Connection refused")
result = router._check_provider_available(provider)
assert result is False
def test_check_vllm_mlx_default_url(self):
"""Test vllm-mlx uses default localhost:8000 when no URL configured."""
from unittest.mock import MagicMock, patch
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="vllm-mlx-local",
type="vllm_mlx",
enabled=True,
priority=2,
)
mock_response = MagicMock()
mock_response.status_code = 200
with patch("infrastructure.router.cascade.requests") as mock_requests:
mock_requests.get.return_value = mock_response
router._check_provider_available(provider)
mock_requests.get.assert_called_once_with("http://localhost:8000/health", timeout=5)
@pytest.mark.asyncio
class TestVllmMlxProvider:
"""Test vllm-mlx provider integration."""
async def test_complete_with_vllm_mlx(self):
"""Test successful completion via vllm-mlx."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="vllm-mlx-local",
type="vllm_mlx",
enabled=True,
priority=2,
base_url="http://localhost:8000/v1",
models=[{"name": "Qwen/Qwen2.5-14B-Instruct-MLX", "default": True}],
)
router.providers = [provider]
with patch.object(router, "_call_vllm_mlx") as mock_call:
mock_call.return_value = {
"content": "MLX response",
"model": "Qwen/Qwen2.5-14B-Instruct-MLX",
}
result = await router.complete(
messages=[{"role": "user", "content": "Hi"}],
)
assert result["content"] == "MLX response"
assert result["provider"] == "vllm-mlx-local"
assert result["model"] == "Qwen/Qwen2.5-14B-Instruct-MLX"
async def test_vllm_mlx_base_url_normalization(self):
"""Test _call_vllm_mlx appends /v1 when missing."""
from unittest.mock import AsyncMock, MagicMock, patch
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="vllm-mlx-local",
type="vllm_mlx",
enabled=True,
priority=2,
base_url="http://localhost:8000", # No /v1
models=[{"name": "qwen-mlx", "default": True}],
)
mock_choice = MagicMock()
mock_choice.message.content = "hello"
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_response.model = "qwen-mlx"
async def fake_create(**kwargs):
return mock_response
with patch("openai.AsyncOpenAI") as mock_openai_cls:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(side_effect=fake_create)
mock_openai_cls.return_value = mock_client
await router._call_vllm_mlx(
provider=provider,
messages=[{"role": "user", "content": "hi"}],
model="qwen-mlx",
temperature=0.7,
max_tokens=None,
)
call_kwargs = mock_openai_cls.call_args
base_url_used = call_kwargs.kwargs.get("base_url") or call_kwargs[1].get("base_url")
assert base_url_used.endswith("/v1")
async def test_vllm_mlx_is_local_not_cloud(self):
"""Confirm vllm_mlx is not subject to metabolic protocol cloud skip."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="vllm-mlx-local",
type="vllm_mlx",
enabled=True,
priority=2,
base_url="http://localhost:8000/v1",
models=[{"name": "qwen-mlx", "default": True}],
)
router.providers = [provider]
# Quota monitor downshifts to local (ACTIVE tier) — vllm_mlx should still be tried
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
mock_qm.select_model.return_value = "qwen3:14b"
mock_qm.check.return_value = None
with patch.object(router, "_call_vllm_mlx") as mock_call:
mock_call.return_value = {
"content": "Local MLX response",
"model": "qwen-mlx",
}
result = await router.complete(
messages=[{"role": "user", "content": "hi"}],
)
assert result["content"] == "Local MLX response"
class TestMetabolicProtocol:
"""Test metabolic protocol: cloud providers skip when quota is ACTIVE/RESTING."""
def _make_anthropic_provider(self) -> "Provider":
return Provider(
name="anthropic-primary",
type="anthropic",
enabled=True,
priority=1,
api_key="test-key",
models=[{"name": "claude-sonnet-4-6", "default": True}],
)
async def test_cloud_provider_allowed_in_burst_tier(self):
"""BURST tier (quota healthy): cloud provider is tried."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.providers = [self._make_anthropic_provider()]
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
# select_model returns cloud model → BURST tier
mock_qm.select_model.return_value = "claude-sonnet-4-6"
mock_qm.check.return_value = None
with patch.object(router, "_call_anthropic") as mock_call:
mock_call.return_value = {"content": "Cloud response", "model": "claude-sonnet-4-6"}
result = await router.complete(
messages=[{"role": "user", "content": "hard question"}],
)
mock_call.assert_called_once()
assert result["content"] == "Cloud response"
async def test_cloud_provider_skipped_in_active_tier(self):
"""ACTIVE tier (5-hour >= 50%): cloud provider is skipped."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.providers = [self._make_anthropic_provider()]
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
# select_model returns local 14B → ACTIVE tier
mock_qm.select_model.return_value = "qwen3:14b"
mock_qm.check.return_value = None
with patch.object(router, "_call_anthropic") as mock_call:
with pytest.raises(RuntimeError, match="All providers failed"):
await router.complete(
messages=[{"role": "user", "content": "question"}],
)
mock_call.assert_not_called()
async def test_cloud_provider_skipped_in_resting_tier(self):
"""RESTING tier (7-day >= 80%): cloud provider is skipped."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.providers = [self._make_anthropic_provider()]
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
# select_model returns local 8B → RESTING tier
mock_qm.select_model.return_value = "qwen3:8b"
mock_qm.check.return_value = None
with patch.object(router, "_call_anthropic") as mock_call:
with pytest.raises(RuntimeError, match="All providers failed"):
await router.complete(
messages=[{"role": "user", "content": "simple question"}],
)
mock_call.assert_not_called()
async def test_local_provider_always_tried_regardless_of_quota(self):
"""Local (ollama/vllm_mlx) providers bypass the metabolic protocol."""
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="ollama-local",
type="ollama",
enabled=True,
priority=1,
url="http://localhost:11434",
models=[{"name": "qwen3:14b", "default": True}],
)
router.providers = [provider]
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
mock_qm.select_model.return_value = "qwen3:8b" # RESTING tier
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "Local response", "model": "qwen3:14b"}
result = await router.complete(
messages=[{"role": "user", "content": "hi"}],
)
mock_call.assert_called_once()
assert result["content"] == "Local response"
async def test_no_quota_monitor_allows_cloud(self):
"""When quota monitor is None (unavailable), cloud providers are allowed."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.providers = [self._make_anthropic_provider()]
with patch("infrastructure.router.cascade._quota_monitor", None):
with patch.object(router, "_call_anthropic") as mock_call:
mock_call.return_value = {"content": "Cloud response", "model": "claude-sonnet-4-6"}
result = await router.complete(
messages=[{"role": "user", "content": "question"}],
)
mock_call.assert_called_once()
assert result["content"] == "Cloud response"
class TestCascadeRouterReload:
"""Test hot-reload of providers.yaml."""
def test_reload_preserves_metrics(self, tmp_path):
"""Test that reload preserves metrics for existing providers."""
config = {
"providers": [
{
"name": "test-openai",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test",
}
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
assert len(router.providers) == 1
# Simulate some traffic
router._record_success(router.providers[0], 150.0)
router._record_success(router.providers[0], 250.0)
assert router.providers[0].metrics.total_requests == 2
# Reload
result = router.reload_config()
assert result["total_providers"] == 1
assert result["preserved"] == 1
assert result["added"] == []
assert result["removed"] == []
# Metrics survived
assert router.providers[0].metrics.total_requests == 2
assert router.providers[0].metrics.total_latency_ms == 400.0
def test_reload_preserves_circuit_breaker(self, tmp_path):
"""Test that reload preserves circuit breaker state."""
config = {
"cascade": {"circuit_breaker": {"failure_threshold": 2}},
"providers": [
{
"name": "test-openai",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test",
}
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
# Open circuit breaker
for _ in range(2):
router._record_failure(router.providers[0])
assert router.providers[0].circuit_state == CircuitState.OPEN
# Reload
router.reload_config()
# Circuit breaker state preserved
assert router.providers[0].circuit_state == CircuitState.OPEN
assert router.providers[0].status == ProviderStatus.UNHEALTHY
def test_reload_detects_added_provider(self, tmp_path):
"""Test that reload detects newly added providers."""
config = {
"providers": [
{
"name": "openai-1",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test",
}
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
assert len(router.providers) == 1
# Add a second provider to config
config["providers"].append(
{
"name": "anthropic-1",
"type": "anthropic",
"enabled": True,
"priority": 2,
"api_key": "sk-ant-test",
}
)
config_path.write_text(yaml.dump(config))
result = router.reload_config()
assert result["total_providers"] == 2
assert result["preserved"] == 1
assert result["added"] == ["anthropic-1"]
assert result["removed"] == []
def test_reload_detects_removed_provider(self, tmp_path):
"""Test that reload detects removed providers."""
config = {
"providers": [
{
"name": "openai-1",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test",
},
{
"name": "anthropic-1",
"type": "anthropic",
"enabled": True,
"priority": 2,
"api_key": "sk-ant-test",
},
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
assert len(router.providers) == 2
# Remove anthropic
config["providers"] = [config["providers"][0]]
config_path.write_text(yaml.dump(config))
result = router.reload_config()
assert result["total_providers"] == 1
assert result["preserved"] == 1
assert result["removed"] == ["anthropic-1"]
def test_reload_re_sorts_by_priority(self, tmp_path):
"""Test that providers are re-sorted by priority after reload."""
config = {
"providers": [
{
"name": "low-priority",
"type": "openai",
"enabled": True,
"priority": 10,
"api_key": "sk-test",
},
{
"name": "high-priority",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test2",
},
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
assert router.providers[0].name == "high-priority"
# Swap priorities
config["providers"][0]["priority"] = 1
config["providers"][1]["priority"] = 10
config_path.write_text(yaml.dump(config))
router.reload_config()
assert router.providers[0].name == "low-priority"
assert router.providers[1].name == "high-priority"
class TestComplexityRouting:
"""Tests for Qwen3-8B / Qwen3-14B dual-model routing (issue #1065)."""
def _make_dual_model_provider(self) -> Provider:
"""Build an Ollama provider with both Qwen3 models registered."""
return Provider(
name="ollama-local",
type="ollama",
enabled=True,
priority=1,
url="http://localhost:11434",
models=[
{
"name": "qwen3:8b",
"capabilities": ["text", "tools", "json", "streaming", "routine"],
},
{
"name": "qwen3:14b",
"default": True,
"capabilities": ["text", "tools", "json", "streaming", "complex", "reasoning"],
},
],
)
def test_get_model_for_complexity_simple_returns_8b(self):
"""Simple tasks should select the model with 'routine' capability."""
from infrastructure.router.classifier import TaskComplexity
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.fallback_chains = {
"routine": ["qwen3:8b"],
"complex": ["qwen3:14b"],
}
provider = self._make_dual_model_provider()
model = router._get_model_for_complexity(provider, TaskComplexity.SIMPLE)
assert model == "qwen3:8b"
def test_get_model_for_complexity_complex_returns_14b(self):
"""Complex tasks should select the model with 'complex' capability."""
from infrastructure.router.classifier import TaskComplexity
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.fallback_chains = {
"routine": ["qwen3:8b"],
"complex": ["qwen3:14b"],
}
provider = self._make_dual_model_provider()
model = router._get_model_for_complexity(provider, TaskComplexity.COMPLEX)
assert model == "qwen3:14b"
def test_get_model_for_complexity_returns_none_when_no_match(self):
"""Returns None when provider has no matching model in chain."""
from infrastructure.router.classifier import TaskComplexity
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.fallback_chains = {} # empty chains
provider = Provider(
name="test",
type="ollama",
enabled=True,
priority=1,
models=[{"name": "llama3.2:3b", "default": True, "capabilities": ["text"]}],
)
# No 'routine' or 'complex' model available
model = router._get_model_for_complexity(provider, TaskComplexity.SIMPLE)
assert model is None
@pytest.mark.asyncio
async def test_complete_with_simple_hint_routes_to_8b(self):
"""complexity_hint='simple' should use qwen3:8b."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.fallback_chains = {
"routine": ["qwen3:8b"],
"complex": ["qwen3:14b"],
}
router.providers = [self._make_dual_model_provider()]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "fast answer", "model": "qwen3:8b"}
result = await router.complete(
messages=[{"role": "user", "content": "list tasks"}],
complexity_hint="simple",
)
assert result["model"] == "qwen3:8b"
assert result["complexity"] == "simple"
@pytest.mark.asyncio
async def test_complete_with_complex_hint_routes_to_14b(self):
"""complexity_hint='complex' should use qwen3:14b."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.fallback_chains = {
"routine": ["qwen3:8b"],
"complex": ["qwen3:14b"],
}
router.providers = [self._make_dual_model_provider()]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "detailed answer", "model": "qwen3:14b"}
result = await router.complete(
messages=[{"role": "user", "content": "review this PR"}],
complexity_hint="complex",
)
assert result["model"] == "qwen3:14b"
assert result["complexity"] == "complex"
@pytest.mark.asyncio
async def test_explicit_model_bypasses_complexity_routing(self):
"""When model is explicitly provided, complexity routing is skipped."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.fallback_chains = {
"routine": ["qwen3:8b"],
"complex": ["qwen3:14b"],
}
router.providers = [self._make_dual_model_provider()]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "response", "model": "qwen3:14b"}
result = await router.complete(
messages=[{"role": "user", "content": "list tasks"}],
model="qwen3:14b", # explicit override
)
# Explicit model wins — complexity field is None
assert result["model"] == "qwen3:14b"
assert result["complexity"] is None
@pytest.mark.asyncio
async def test_auto_classification_routes_simple_message(self):
"""Short, simple messages should auto-classify as SIMPLE → 8B."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.fallback_chains = {
"routine": ["qwen3:8b"],
"complex": ["qwen3:14b"],
}
router.providers = [self._make_dual_model_provider()]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "ok", "model": "qwen3:8b"}
result = await router.complete(
messages=[{"role": "user", "content": "status"}],
# no complexity_hint — auto-classify
)
assert result["complexity"] == "simple"
assert result["model"] == "qwen3:8b"
@pytest.mark.asyncio
async def test_auto_classification_routes_complex_message(self):
"""Complex messages should auto-classify → 14B."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.fallback_chains = {
"routine": ["qwen3:8b"],
"complex": ["qwen3:14b"],
}
router.providers = [self._make_dual_model_provider()]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "deep analysis", "model": "qwen3:14b"}
result = await router.complete(
messages=[{"role": "user", "content": "analyze and prioritize the backlog"}],
)
assert result["complexity"] == "complex"
assert result["model"] == "qwen3:14b"
@pytest.mark.asyncio
async def test_invalid_complexity_hint_falls_back_to_auto(self):
"""Invalid complexity_hint should log a warning and auto-classify."""
router = CascadeRouter(config_path=Path("/nonexistent"))
router.config.fallback_chains = {
"routine": ["qwen3:8b"],
"complex": ["qwen3:14b"],
}
router.providers = [self._make_dual_model_provider()]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "ok", "model": "qwen3:8b"}
# Should not raise
result = await router.complete(
messages=[{"role": "user", "content": "status"}],
complexity_hint="INVALID_HINT",
)
assert result["complexity"] in ("simple", "complex") # auto-classified