This commit was merged in pull request #1150.
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
@@ -10,13 +10,16 @@ import yaml
|
||||
from infrastructure.router.cascade import (
|
||||
CascadeRouter,
|
||||
CircuitState,
|
||||
ContentType,
|
||||
Provider,
|
||||
ProviderMetrics,
|
||||
ProviderStatus,
|
||||
RouterConfig,
|
||||
get_router,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestProviderMetrics:
|
||||
"""Test provider metrics tracking."""
|
||||
|
||||
@@ -45,6 +48,7 @@ class TestProviderMetrics:
|
||||
assert metrics.error_rate == 0.3
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestProvider:
|
||||
"""Test Provider dataclass."""
|
||||
|
||||
@@ -88,6 +92,7 @@ class TestProvider:
|
||||
assert provider.get_default_model() is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRouterConfig:
|
||||
"""Test router configuration."""
|
||||
|
||||
@@ -100,6 +105,7 @@ class TestRouterConfig:
|
||||
assert config.circuit_breaker_failure_threshold == 5
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCascadeRouterInit:
|
||||
"""Test CascadeRouter initialization."""
|
||||
|
||||
@@ -158,6 +164,7 @@ class TestCascadeRouterInit:
|
||||
assert router.providers[0].api_key == "secret123"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCascadeRouterMetrics:
|
||||
"""Test metrics tracking."""
|
||||
|
||||
@@ -241,6 +248,7 @@ class TestCascadeRouterMetrics:
|
||||
assert provider.status == ProviderStatus.HEALTHY
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCascadeRouterGetMetrics:
|
||||
"""Test get_metrics method."""
|
||||
|
||||
@@ -280,6 +288,7 @@ class TestCascadeRouterGetMetrics:
|
||||
assert p_metrics["metrics"]["avg_latency_ms"] == 200.0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCascadeRouterGetStatus:
|
||||
"""Test get_status method."""
|
||||
|
||||
@@ -305,6 +314,7 @@ class TestCascadeRouterGetStatus:
|
||||
assert len(status["providers"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
class TestCascadeRouterComplete:
|
||||
"""Test complete method with failover."""
|
||||
@@ -436,6 +446,7 @@ class TestCascadeRouterComplete:
|
||||
assert result["provider"] == "healthy"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestProviderAvailabilityCheck:
|
||||
"""Test provider availability checking."""
|
||||
|
||||
@@ -577,6 +588,7 @@ class TestProviderAvailabilityCheck:
|
||||
mock_requests.get.assert_called_once_with("http://localhost:8000/health", timeout=5)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
class TestVllmMlxProvider:
|
||||
"""Test vllm-mlx provider integration."""
|
||||
@@ -681,6 +693,8 @@ class TestVllmMlxProvider:
|
||||
assert result["content"] == "Local MLX response"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
class TestMetabolicProtocol:
|
||||
"""Test metabolic protocol: cloud providers skip when quota is ACTIVE/RESTING."""
|
||||
|
||||
@@ -790,6 +804,7 @@ class TestMetabolicProtocol:
|
||||
assert result["content"] == "Cloud response"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCascadeRouterReload:
|
||||
"""Test hot-reload of providers.yaml."""
|
||||
|
||||
@@ -968,3 +983,395 @@ class TestCascadeRouterReload:
|
||||
|
||||
assert router.providers[0].name == "low-priority"
|
||||
assert router.providers[1].name == "high-priority"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestContentTypeDetection:
|
||||
"""Test _detect_content_type logic."""
|
||||
|
||||
def _router(self) -> CascadeRouter:
|
||||
return CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
def test_text_only(self):
|
||||
router = self._router()
|
||||
msgs = [{"role": "user", "content": "Hello"}]
|
||||
assert router._detect_content_type(msgs) == ContentType.TEXT
|
||||
|
||||
def test_images_key_triggers_vision(self):
|
||||
router = self._router()
|
||||
msgs = [{"role": "user", "content": "Describe this", "images": ["pic.jpg"]}]
|
||||
assert router._detect_content_type(msgs) == ContentType.VISION
|
||||
|
||||
def test_image_extension_in_content_triggers_vision(self):
|
||||
router = self._router()
|
||||
msgs = [{"role": "user", "content": "Look at photo.png please"}]
|
||||
assert router._detect_content_type(msgs) == ContentType.VISION
|
||||
|
||||
def test_base64_data_uri_triggers_vision(self):
|
||||
router = self._router()
|
||||
msgs = [{"role": "user", "content": "data:image/jpeg;base64,/9j/4AA..."}]
|
||||
assert router._detect_content_type(msgs) == ContentType.VISION
|
||||
|
||||
def test_audio_key_triggers_audio(self):
|
||||
router = self._router()
|
||||
msgs = [{"role": "user", "content": "", "audio": b"bytes"}]
|
||||
assert router._detect_content_type(msgs) == ContentType.AUDIO
|
||||
|
||||
def test_image_and_audio_triggers_multimodal(self):
|
||||
router = self._router()
|
||||
msgs = [
|
||||
{"role": "user", "content": "check photo.jpg", "audio": b"bytes"},
|
||||
]
|
||||
assert router._detect_content_type(msgs) == ContentType.MULTIMODAL
|
||||
|
||||
def test_list_content_image_url_type(self):
|
||||
router = self._router()
|
||||
msgs = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What?"},
|
||||
{"type": "image_url", "image_url": {"url": "http://example.com/a.jpg"}},
|
||||
],
|
||||
}
|
||||
]
|
||||
assert router._detect_content_type(msgs) == ContentType.VISION
|
||||
|
||||
def test_list_content_audio_type(self):
|
||||
router = self._router()
|
||||
msgs = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "data": "base64..."},
|
||||
],
|
||||
}
|
||||
]
|
||||
assert router._detect_content_type(msgs) == ContentType.AUDIO
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTransformMessagesForOllama:
|
||||
"""Test _transform_messages_for_ollama."""
|
||||
|
||||
def _router(self) -> CascadeRouter:
|
||||
return CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
def test_plain_text_message(self):
|
||||
router = self._router()
|
||||
result = router._transform_messages_for_ollama(
|
||||
[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
assert result == [{"role": "user", "content": "Hello"}]
|
||||
|
||||
def test_base64_image_stripped(self):
|
||||
router = self._router()
|
||||
msgs = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Describe",
|
||||
"images": ["data:image/png;base64,abc123"],
|
||||
}
|
||||
]
|
||||
result = router._transform_messages_for_ollama(msgs)
|
||||
assert result[0]["images"] == ["abc123"]
|
||||
|
||||
def test_http_url_skipped(self):
|
||||
router = self._router()
|
||||
msgs = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Describe",
|
||||
"images": ["http://example.com/img.jpg"],
|
||||
}
|
||||
]
|
||||
result = router._transform_messages_for_ollama(msgs)
|
||||
# URL is skipped — images list should be empty or absent
|
||||
assert result[0].get("images", []) == []
|
||||
|
||||
def test_missing_local_file_skipped(self):
|
||||
router = self._router()
|
||||
msgs = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Describe",
|
||||
"images": ["/nonexistent/path/image.png"],
|
||||
}
|
||||
]
|
||||
result = router._transform_messages_for_ollama(msgs)
|
||||
assert result[0].get("images", []) == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestProviderCapabilityMethods:
|
||||
"""Test Provider.get_model_with_capability and model_has_capability."""
|
||||
|
||||
def _provider(self) -> Provider:
|
||||
return Provider(
|
||||
name="test",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
models=[
|
||||
{"name": "llava:7b", "capabilities": ["vision"]},
|
||||
{"name": "llama3.2", "default": True},
|
||||
],
|
||||
)
|
||||
|
||||
def test_get_model_with_capability_found(self):
|
||||
p = self._provider()
|
||||
assert p.get_model_with_capability("vision") == "llava:7b"
|
||||
|
||||
def test_get_model_with_capability_falls_back_to_default(self):
|
||||
p = self._provider()
|
||||
assert p.get_model_with_capability("audio") == "llama3.2"
|
||||
|
||||
def test_model_has_capability_true(self):
|
||||
p = self._provider()
|
||||
assert p.model_has_capability("llava:7b", "vision") is True
|
||||
|
||||
def test_model_has_capability_false(self):
|
||||
p = self._provider()
|
||||
assert p.model_has_capability("llama3.2", "vision") is False
|
||||
|
||||
def test_model_has_capability_unknown_model(self):
|
||||
p = self._provider()
|
||||
assert p.model_has_capability("unknown-model", "vision") is False
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetFallbackModel:
|
||||
"""Test _get_fallback_model."""
|
||||
|
||||
def _router_with_provider(self) -> tuple[CascadeRouter, Provider]:
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
provider = Provider(
|
||||
name="test",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
models=[
|
||||
{"name": "llava:7b", "capabilities": ["vision"]},
|
||||
{"name": "llama3.2", "default": True},
|
||||
],
|
||||
)
|
||||
return router, provider
|
||||
|
||||
def test_returns_vision_model(self):
|
||||
router, provider = self._router_with_provider()
|
||||
result = router._get_fallback_model(provider, "llama3.2", ContentType.VISION)
|
||||
assert result == "llava:7b"
|
||||
|
||||
def test_returns_none_if_no_capability(self):
|
||||
router, provider = self._router_with_provider()
|
||||
result = router._get_fallback_model(provider, "llama3.2", ContentType.AUDIO)
|
||||
# No audio model; falls back to default which is same as original
|
||||
assert result is None or result == "llama3.2"
|
||||
|
||||
def test_text_content_returns_none(self):
|
||||
router, provider = self._router_with_provider()
|
||||
result = router._get_fallback_model(provider, "llama3.2", ContentType.TEXT)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
class TestCascadeTierFiltering:
|
||||
"""Test cascade_tier parameter in complete()."""
|
||||
|
||||
def _make_router(self) -> CascadeRouter:
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router.providers = [
|
||||
Provider(
|
||||
name="anthropic-primary",
|
||||
type="anthropic",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
api_key="test-key",
|
||||
models=[{"name": "claude-sonnet-4-6", "default": True}],
|
||||
),
|
||||
Provider(
|
||||
name="ollama-local",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=2,
|
||||
models=[{"name": "llama3.2", "default": True}],
|
||||
),
|
||||
]
|
||||
return router
|
||||
|
||||
async def test_frontier_required_uses_anthropic(self):
|
||||
router = self._make_router()
|
||||
with patch("infrastructure.router.cascade._quota_monitor", None):
|
||||
with patch.object(router, "_call_anthropic") as mock_call:
|
||||
mock_call.return_value = {"content": "frontier response", "model": "claude-sonnet-4-6"}
|
||||
result = await router.complete(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
cascade_tier="frontier_required",
|
||||
)
|
||||
assert result["provider"] == "anthropic-primary"
|
||||
mock_call.assert_called_once()
|
||||
|
||||
async def test_frontier_required_no_anthropic_raises(self):
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router.providers = [
|
||||
Provider(
|
||||
name="ollama-local",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
models=[{"name": "llama3.2", "default": True}],
|
||||
)
|
||||
]
|
||||
with pytest.raises(RuntimeError, match="No Anthropic provider configured"):
|
||||
await router.complete(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
cascade_tier="frontier_required",
|
||||
)
|
||||
|
||||
async def test_unknown_tier_raises(self):
|
||||
router = self._make_router()
|
||||
with pytest.raises(RuntimeError, match="No providers found for tier"):
|
||||
await router.complete(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
cascade_tier="nonexistent_tier",
|
||||
)
|
||||
|
||||
async def test_tier_filter_only_matching_providers(self):
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router.providers = [
|
||||
Provider(
|
||||
name="local-primary",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
tier="local",
|
||||
models=[{"name": "llama3.2", "default": True}],
|
||||
),
|
||||
Provider(
|
||||
name="cloud-secondary",
|
||||
type="anthropic",
|
||||
enabled=True,
|
||||
priority=2,
|
||||
tier="cloud",
|
||||
api_key="key",
|
||||
models=[{"name": "claude-sonnet-4-6", "default": True}],
|
||||
),
|
||||
]
|
||||
with patch.object(router, "_call_ollama") as mock_call:
|
||||
mock_call.return_value = {"content": "local response", "model": "llama3.2"}
|
||||
result = await router.complete(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
cascade_tier="local",
|
||||
)
|
||||
assert result["provider"] == "local-primary"
|
||||
mock_call.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
class TestGenerateWithImage:
|
||||
"""Test generate_with_image convenience method."""
|
||||
|
||||
async def test_delegates_to_complete(self):
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router.providers = [
|
||||
Provider(
|
||||
name="ollama-vision",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
models=[{"name": "llava:7b", "capabilities": ["vision"], "default": True}],
|
||||
)
|
||||
]
|
||||
|
||||
with patch.object(router, "_call_ollama") as mock_call:
|
||||
mock_call.return_value = {"content": "A cat", "model": "llava:7b"}
|
||||
result = await router.generate_with_image(
|
||||
prompt="What is this?",
|
||||
image_path="/tmp/cat.jpg",
|
||||
model="llava:7b",
|
||||
)
|
||||
|
||||
assert result["content"] == "A cat"
|
||||
assert result["provider"] == "ollama-vision"
|
||||
# complete() should have been called with images in messages
|
||||
call_kwargs = mock_call.call_args
|
||||
messages_passed = call_kwargs.kwargs.get("messages") or call_kwargs[1].get("messages")
|
||||
assert messages_passed[0]["images"] == ["/tmp/cat.jpg"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetRouterSingleton:
|
||||
"""Test get_router() returns a singleton and creates CascadeRouter."""
|
||||
|
||||
def test_get_router_returns_cascade_router(self):
|
||||
import infrastructure.router.cascade as cascade_module
|
||||
|
||||
# Reset singleton to test creation
|
||||
original = cascade_module.cascade_router
|
||||
cascade_module.cascade_router = None
|
||||
try:
|
||||
router = get_router()
|
||||
assert isinstance(router, CascadeRouter)
|
||||
finally:
|
||||
cascade_module.cascade_router = original
|
||||
|
||||
def test_get_router_returns_same_instance(self):
|
||||
import infrastructure.router.cascade as cascade_module
|
||||
|
||||
original = cascade_module.cascade_router
|
||||
cascade_module.cascade_router = None
|
||||
try:
|
||||
r1 = get_router()
|
||||
r2 = get_router()
|
||||
assert r1 is r2
|
||||
finally:
|
||||
cascade_module.cascade_router = original
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestIsProviderAvailable:
|
||||
"""Test _is_provider_available with circuit breaker transitions."""
|
||||
|
||||
def _router(self) -> CascadeRouter:
|
||||
return CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
def test_disabled_provider_not_available(self):
|
||||
router = self._router()
|
||||
provider = Provider(name="p", type="ollama", enabled=False, priority=1)
|
||||
assert router._is_provider_available(provider) is False
|
||||
|
||||
def test_healthy_provider_available(self):
|
||||
router = self._router()
|
||||
provider = Provider(name="p", type="ollama", enabled=True, priority=1)
|
||||
assert router._is_provider_available(provider) is True
|
||||
|
||||
def test_unhealthy_open_circuit_not_available(self):
|
||||
router = self._router()
|
||||
provider = Provider(
|
||||
name="p",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
status=ProviderStatus.UNHEALTHY,
|
||||
circuit_state=CircuitState.OPEN,
|
||||
circuit_opened_at=time.time(), # Just opened — not yet recoverable
|
||||
)
|
||||
assert router._is_provider_available(provider) is False
|
||||
|
||||
def test_unhealthy_after_timeout_transitions_to_half_open(self):
|
||||
router = self._router()
|
||||
router.config.circuit_breaker_recovery_timeout = 0
|
||||
provider = Provider(
|
||||
name="p",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
status=ProviderStatus.UNHEALTHY,
|
||||
circuit_state=CircuitState.OPEN,
|
||||
circuit_opened_at=time.time() - 10, # Long ago
|
||||
)
|
||||
result = router._is_provider_available(provider)
|
||||
assert result is True
|
||||
assert provider.circuit_state == CircuitState.HALF_OPEN
|
||||
|
||||
Reference in New Issue
Block a user