[claude] Add unit tests for cascade.py (#1138) (#1150)

This commit is contained in:
2026-03-23 18:47:28 +00:00
parent 3a8d9ee380
commit 0e5948632d

View File

@@ -2,7 +2,7 @@
import time
from pathlib import Path
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import yaml
@@ -10,13 +10,16 @@ import yaml
from infrastructure.router.cascade import (
CascadeRouter,
CircuitState,
ContentType,
Provider,
ProviderMetrics,
ProviderStatus,
RouterConfig,
get_router,
)
@pytest.mark.unit
class TestProviderMetrics:
"""Test provider metrics tracking."""
@@ -45,6 +48,7 @@ class TestProviderMetrics:
assert metrics.error_rate == 0.3
@pytest.mark.unit
class TestProvider:
"""Test Provider dataclass."""
@@ -88,6 +92,7 @@ class TestProvider:
assert provider.get_default_model() is None
@pytest.mark.unit
class TestRouterConfig:
"""Test router configuration."""
@@ -100,6 +105,7 @@ class TestRouterConfig:
assert config.circuit_breaker_failure_threshold == 5
@pytest.mark.unit
class TestCascadeRouterInit:
"""Test CascadeRouter initialization."""
@@ -158,6 +164,7 @@ class TestCascadeRouterInit:
assert router.providers[0].api_key == "secret123"
@pytest.mark.unit
class TestCascadeRouterMetrics:
"""Test metrics tracking."""
@@ -241,6 +248,7 @@ class TestCascadeRouterMetrics:
assert provider.status == ProviderStatus.HEALTHY
@pytest.mark.unit
class TestCascadeRouterGetMetrics:
"""Test get_metrics method."""
@@ -280,6 +288,7 @@ class TestCascadeRouterGetMetrics:
assert p_metrics["metrics"]["avg_latency_ms"] == 200.0
@pytest.mark.unit
class TestCascadeRouterGetStatus:
"""Test get_status method."""
@@ -305,6 +314,7 @@ class TestCascadeRouterGetStatus:
assert len(status["providers"]) == 1
@pytest.mark.unit
@pytest.mark.asyncio
class TestCascadeRouterComplete:
"""Test complete method with failover."""
@@ -436,6 +446,7 @@ class TestCascadeRouterComplete:
assert result["provider"] == "healthy"
@pytest.mark.unit
class TestProviderAvailabilityCheck:
"""Test provider availability checking."""
@@ -577,6 +588,7 @@ class TestProviderAvailabilityCheck:
mock_requests.get.assert_called_once_with("http://localhost:8000/health", timeout=5)
@pytest.mark.unit
@pytest.mark.asyncio
class TestVllmMlxProvider:
"""Test vllm-mlx provider integration."""
@@ -681,6 +693,8 @@ class TestVllmMlxProvider:
assert result["content"] == "Local MLX response"
@pytest.mark.unit
@pytest.mark.asyncio
class TestMetabolicProtocol:
"""Test metabolic protocol: cloud providers skip when quota is ACTIVE/RESTING."""
@@ -790,6 +804,7 @@ class TestMetabolicProtocol:
assert result["content"] == "Cloud response"
@pytest.mark.unit
class TestCascadeRouterReload:
"""Test hot-reload of providers.yaml."""
@@ -968,3 +983,395 @@ class TestCascadeRouterReload:
assert router.providers[0].name == "low-priority"
assert router.providers[1].name == "high-priority"
@pytest.mark.unit
class TestContentTypeDetection:
"""Test _detect_content_type logic."""
def _router(self) -> CascadeRouter:
return CascadeRouter(config_path=Path("/nonexistent"))
def test_text_only(self):
router = self._router()
msgs = [{"role": "user", "content": "Hello"}]
assert router._detect_content_type(msgs) == ContentType.TEXT
def test_images_key_triggers_vision(self):
router = self._router()
msgs = [{"role": "user", "content": "Describe this", "images": ["pic.jpg"]}]
assert router._detect_content_type(msgs) == ContentType.VISION
def test_image_extension_in_content_triggers_vision(self):
router = self._router()
msgs = [{"role": "user", "content": "Look at photo.png please"}]
assert router._detect_content_type(msgs) == ContentType.VISION
def test_base64_data_uri_triggers_vision(self):
router = self._router()
msgs = [{"role": "user", "content": "data:image/jpeg;base64,/9j/4AA..."}]
assert router._detect_content_type(msgs) == ContentType.VISION
def test_audio_key_triggers_audio(self):
router = self._router()
msgs = [{"role": "user", "content": "", "audio": b"bytes"}]
assert router._detect_content_type(msgs) == ContentType.AUDIO
def test_image_and_audio_triggers_multimodal(self):
router = self._router()
msgs = [
{"role": "user", "content": "check photo.jpg", "audio": b"bytes"},
]
assert router._detect_content_type(msgs) == ContentType.MULTIMODAL
def test_list_content_image_url_type(self):
router = self._router()
msgs = [
{
"role": "user",
"content": [
{"type": "text", "text": "What?"},
{"type": "image_url", "image_url": {"url": "http://example.com/a.jpg"}},
],
}
]
assert router._detect_content_type(msgs) == ContentType.VISION
def test_list_content_audio_type(self):
router = self._router()
msgs = [
{
"role": "user",
"content": [
{"type": "audio", "data": "base64..."},
],
}
]
assert router._detect_content_type(msgs) == ContentType.AUDIO
@pytest.mark.unit
class TestTransformMessagesForOllama:
"""Test _transform_messages_for_ollama."""
def _router(self) -> CascadeRouter:
return CascadeRouter(config_path=Path("/nonexistent"))
def test_plain_text_message(self):
router = self._router()
result = router._transform_messages_for_ollama(
[{"role": "user", "content": "Hello"}]
)
assert result == [{"role": "user", "content": "Hello"}]
def test_base64_image_stripped(self):
router = self._router()
msgs = [
{
"role": "user",
"content": "Describe",
"images": ["data:image/png;base64,abc123"],
}
]
result = router._transform_messages_for_ollama(msgs)
assert result[0]["images"] == ["abc123"]
def test_http_url_skipped(self):
router = self._router()
msgs = [
{
"role": "user",
"content": "Describe",
"images": ["http://example.com/img.jpg"],
}
]
result = router._transform_messages_for_ollama(msgs)
# URL is skipped — images list should be empty or absent
assert result[0].get("images", []) == []
def test_missing_local_file_skipped(self):
router = self._router()
msgs = [
{
"role": "user",
"content": "Describe",
"images": ["/nonexistent/path/image.png"],
}
]
result = router._transform_messages_for_ollama(msgs)
assert result[0].get("images", []) == []
@pytest.mark.unit
class TestProviderCapabilityMethods:
"""Test Provider.get_model_with_capability and model_has_capability."""
def _provider(self) -> Provider:
return Provider(
name="test",
type="ollama",
enabled=True,
priority=1,
models=[
{"name": "llava:7b", "capabilities": ["vision"]},
{"name": "llama3.2", "default": True},
],
)
def test_get_model_with_capability_found(self):
p = self._provider()
assert p.get_model_with_capability("vision") == "llava:7b"
def test_get_model_with_capability_falls_back_to_default(self):
p = self._provider()
assert p.get_model_with_capability("audio") == "llama3.2"
def test_model_has_capability_true(self):
p = self._provider()
assert p.model_has_capability("llava:7b", "vision") is True
def test_model_has_capability_false(self):
p = self._provider()
assert p.model_has_capability("llama3.2", "vision") is False
def test_model_has_capability_unknown_model(self):
p = self._provider()
assert p.model_has_capability("unknown-model", "vision") is False
@pytest.mark.unit
class TestGetFallbackModel:
"""Test _get_fallback_model."""
def _router_with_provider(self) -> tuple[CascadeRouter, Provider]:
router = CascadeRouter(config_path=Path("/nonexistent"))
provider = Provider(
name="test",
type="ollama",
enabled=True,
priority=1,
models=[
{"name": "llava:7b", "capabilities": ["vision"]},
{"name": "llama3.2", "default": True},
],
)
return router, provider
def test_returns_vision_model(self):
router, provider = self._router_with_provider()
result = router._get_fallback_model(provider, "llama3.2", ContentType.VISION)
assert result == "llava:7b"
def test_returns_none_if_no_capability(self):
router, provider = self._router_with_provider()
result = router._get_fallback_model(provider, "llama3.2", ContentType.AUDIO)
# No audio model; falls back to default which is same as original
assert result is None or result == "llama3.2"
def test_text_content_returns_none(self):
router, provider = self._router_with_provider()
result = router._get_fallback_model(provider, "llama3.2", ContentType.TEXT)
assert result is None
@pytest.mark.unit
@pytest.mark.asyncio
class TestCascadeTierFiltering:
"""Test cascade_tier parameter in complete()."""
def _make_router(self) -> CascadeRouter:
router = CascadeRouter(config_path=Path("/nonexistent"))
router.providers = [
Provider(
name="anthropic-primary",
type="anthropic",
enabled=True,
priority=1,
api_key="test-key",
models=[{"name": "claude-sonnet-4-6", "default": True}],
),
Provider(
name="ollama-local",
type="ollama",
enabled=True,
priority=2,
models=[{"name": "llama3.2", "default": True}],
),
]
return router
async def test_frontier_required_uses_anthropic(self):
router = self._make_router()
with patch("infrastructure.router.cascade._quota_monitor", None):
with patch.object(router, "_call_anthropic") as mock_call:
mock_call.return_value = {"content": "frontier response", "model": "claude-sonnet-4-6"}
result = await router.complete(
messages=[{"role": "user", "content": "hi"}],
cascade_tier="frontier_required",
)
assert result["provider"] == "anthropic-primary"
mock_call.assert_called_once()
async def test_frontier_required_no_anthropic_raises(self):
router = CascadeRouter(config_path=Path("/nonexistent"))
router.providers = [
Provider(
name="ollama-local",
type="ollama",
enabled=True,
priority=1,
models=[{"name": "llama3.2", "default": True}],
)
]
with pytest.raises(RuntimeError, match="No Anthropic provider configured"):
await router.complete(
messages=[{"role": "user", "content": "hi"}],
cascade_tier="frontier_required",
)
async def test_unknown_tier_raises(self):
router = self._make_router()
with pytest.raises(RuntimeError, match="No providers found for tier"):
await router.complete(
messages=[{"role": "user", "content": "hi"}],
cascade_tier="nonexistent_tier",
)
async def test_tier_filter_only_matching_providers(self):
router = CascadeRouter(config_path=Path("/nonexistent"))
router.providers = [
Provider(
name="local-primary",
type="ollama",
enabled=True,
priority=1,
tier="local",
models=[{"name": "llama3.2", "default": True}],
),
Provider(
name="cloud-secondary",
type="anthropic",
enabled=True,
priority=2,
tier="cloud",
api_key="key",
models=[{"name": "claude-sonnet-4-6", "default": True}],
),
]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "local response", "model": "llama3.2"}
result = await router.complete(
messages=[{"role": "user", "content": "hi"}],
cascade_tier="local",
)
assert result["provider"] == "local-primary"
mock_call.assert_called_once()
@pytest.mark.unit
@pytest.mark.asyncio
class TestGenerateWithImage:
"""Test generate_with_image convenience method."""
async def test_delegates_to_complete(self):
router = CascadeRouter(config_path=Path("/nonexistent"))
router.providers = [
Provider(
name="ollama-vision",
type="ollama",
enabled=True,
priority=1,
models=[{"name": "llava:7b", "capabilities": ["vision"], "default": True}],
)
]
with patch.object(router, "_call_ollama") as mock_call:
mock_call.return_value = {"content": "A cat", "model": "llava:7b"}
result = await router.generate_with_image(
prompt="What is this?",
image_path="/tmp/cat.jpg",
model="llava:7b",
)
assert result["content"] == "A cat"
assert result["provider"] == "ollama-vision"
# complete() should have been called with images in messages
call_kwargs = mock_call.call_args
messages_passed = call_kwargs.kwargs.get("messages") or call_kwargs[1].get("messages")
assert messages_passed[0]["images"] == ["/tmp/cat.jpg"]
@pytest.mark.unit
class TestGetRouterSingleton:
"""Test get_router() returns a singleton and creates CascadeRouter."""
def test_get_router_returns_cascade_router(self):
import infrastructure.router.cascade as cascade_module
# Reset singleton to test creation
original = cascade_module.cascade_router
cascade_module.cascade_router = None
try:
router = get_router()
assert isinstance(router, CascadeRouter)
finally:
cascade_module.cascade_router = original
def test_get_router_returns_same_instance(self):
import infrastructure.router.cascade as cascade_module
original = cascade_module.cascade_router
cascade_module.cascade_router = None
try:
r1 = get_router()
r2 = get_router()
assert r1 is r2
finally:
cascade_module.cascade_router = original
@pytest.mark.unit
class TestIsProviderAvailable:
"""Test _is_provider_available with circuit breaker transitions."""
def _router(self) -> CascadeRouter:
return CascadeRouter(config_path=Path("/nonexistent"))
def test_disabled_provider_not_available(self):
router = self._router()
provider = Provider(name="p", type="ollama", enabled=False, priority=1)
assert router._is_provider_available(provider) is False
def test_healthy_provider_available(self):
router = self._router()
provider = Provider(name="p", type="ollama", enabled=True, priority=1)
assert router._is_provider_available(provider) is True
def test_unhealthy_open_circuit_not_available(self):
router = self._router()
provider = Provider(
name="p",
type="ollama",
enabled=True,
priority=1,
status=ProviderStatus.UNHEALTHY,
circuit_state=CircuitState.OPEN,
circuit_opened_at=time.time(), # Just opened — not yet recoverable
)
assert router._is_provider_available(provider) is False
def test_unhealthy_after_timeout_transitions_to_half_open(self):
router = self._router()
router.config.circuit_breaker_recovery_timeout = 0
provider = Provider(
name="p",
type="ollama",
enabled=True,
priority=1,
status=ProviderStatus.UNHEALTHY,
circuit_state=CircuitState.OPEN,
circuit_opened_at=time.time() - 10, # Long ago
)
result = router._is_provider_available(provider)
assert result is True
assert provider.circuit_state == CircuitState.HALF_OPEN