forked from Rockachopa/Timmy-time-dashboard
feat: code quality audit + autoresearch integration + infra hardening (#150)
This commit is contained in:
committed by
GitHub
parent
fd0ede0d51
commit
ae3bb1cc21
@@ -1,15 +1,16 @@
|
||||
"""Tests for infrastructure.error_capture module."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.error_capture import (
|
||||
_stack_hash,
|
||||
_is_duplicate,
|
||||
_get_git_context,
|
||||
capture_error,
|
||||
_dedup_cache,
|
||||
_get_git_context,
|
||||
_is_duplicate,
|
||||
_stack_hash,
|
||||
capture_error,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
"""Tests for the event broadcaster (infrastructure.events.broadcaster)."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.events.broadcaster import (
|
||||
EventBroadcaster,
|
||||
event_broadcaster,
|
||||
get_event_icon,
|
||||
get_event_label,
|
||||
format_event_for_display,
|
||||
EVENT_ICONS,
|
||||
EVENT_LABELS,
|
||||
EventBroadcaster,
|
||||
event_broadcaster,
|
||||
format_event_for_display,
|
||||
get_event_icon,
|
||||
get_event_label,
|
||||
)
|
||||
|
||||
|
||||
# ── Fake EventLogEntry for testing ──────────────────────────────────────────
|
||||
|
||||
|
||||
class FakeEventType(Enum):
|
||||
TASK_CREATED = "task.created"
|
||||
TASK_ASSIGNED = "task.assigned"
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""Tests for the async event bus (infrastructure.events.bus)."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from infrastructure.events.bus import EventBus, Event, emit, on, event_bus
|
||||
|
||||
from infrastructure.events.bus import Event, EventBus, emit, event_bus, on
|
||||
|
||||
|
||||
class TestEvent:
|
||||
|
||||
@@ -10,17 +10,17 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.router.cascade import CascadeRouter, Provider, ProviderStatus, CircuitState
|
||||
from infrastructure.router.cascade import CascadeRouter, CircuitState, Provider, ProviderStatus
|
||||
|
||||
|
||||
class TestCascadeRouterFunctional:
|
||||
"""Functional tests for Cascade Router with mocked providers."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def router(self):
|
||||
"""Create a router with no config file."""
|
||||
return CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_healthy_provider(self):
|
||||
"""Create a mock healthy provider."""
|
||||
@@ -32,7 +32,7 @@ class TestCascadeRouterFunctional:
|
||||
models=[{"name": "test-model", "default": True}],
|
||||
)
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_failing_provider(self):
|
||||
"""Create a mock failing provider."""
|
||||
@@ -44,12 +44,12 @@ class TestCascadeRouterFunctional:
|
||||
models=[{"name": "test-model", "default": True}],
|
||||
)
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_completion_single_provider(self, router, mock_healthy_provider):
|
||||
"""Test successful completion with a single working provider."""
|
||||
router.providers = [mock_healthy_provider]
|
||||
|
||||
|
||||
# Mock the provider's call method
|
||||
with patch.object(router, "_try_provider") as mock_try:
|
||||
mock_try.return_value = {
|
||||
@@ -57,16 +57,16 @@ class TestCascadeRouterFunctional:
|
||||
"model": "test-model",
|
||||
"latency_ms": 100.0,
|
||||
}
|
||||
|
||||
|
||||
result = await router.complete(
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
)
|
||||
|
||||
|
||||
assert result["content"] == "Hello, world!"
|
||||
assert result["provider"] == "test-healthy"
|
||||
assert result["model"] == "test-model"
|
||||
assert result["latency_ms"] == 100.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failover_to_second_provider(self, router):
|
||||
"""Test failover when first provider fails."""
|
||||
@@ -85,23 +85,23 @@ class TestCascadeRouterFunctional:
|
||||
models=[{"name": "model", "default": True}],
|
||||
)
|
||||
router.providers = [provider1, provider2]
|
||||
|
||||
|
||||
call_count = [0]
|
||||
|
||||
|
||||
async def side_effect(*args, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] <= router.config.max_retries_per_provider:
|
||||
raise RuntimeError("Connection failed")
|
||||
return {"content": "Backup works!", "model": "model"}
|
||||
|
||||
|
||||
with patch.object(router, "_try_provider", side_effect=side_effect):
|
||||
result = await router.complete(
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
)
|
||||
|
||||
|
||||
assert result["content"] == "Backup works!"
|
||||
assert result["provider"] == "backup"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_providers_fail_raises_error(self, router):
|
||||
"""Test that RuntimeError is raised when all providers fail."""
|
||||
@@ -113,15 +113,15 @@ class TestCascadeRouterFunctional:
|
||||
models=[{"name": "model", "default": True}],
|
||||
)
|
||||
router.providers = [provider]
|
||||
|
||||
|
||||
with patch.object(router, "_try_provider") as mock_try:
|
||||
mock_try.side_effect = RuntimeError("Always fails")
|
||||
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await router.complete(messages=[{"role": "user", "content": "Hi"}])
|
||||
|
||||
|
||||
assert "All providers failed" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_opens_after_failures(self, router):
|
||||
"""Test circuit breaker opens after threshold failures."""
|
||||
@@ -134,14 +134,14 @@ class TestCascadeRouterFunctional:
|
||||
)
|
||||
router.providers = [provider]
|
||||
router.config.circuit_breaker_failure_threshold = 3
|
||||
|
||||
|
||||
# Record 3 failures
|
||||
for _ in range(3):
|
||||
router._record_failure(provider)
|
||||
|
||||
|
||||
assert provider.circuit_state == CircuitState.OPEN
|
||||
assert provider.status == ProviderStatus.UNHEALTHY
|
||||
|
||||
|
||||
def test_metrics_tracking(self, router):
|
||||
"""Test that metrics are tracked correctly."""
|
||||
provider = Provider(
|
||||
@@ -151,14 +151,14 @@ class TestCascadeRouterFunctional:
|
||||
priority=1,
|
||||
)
|
||||
router.providers = [provider]
|
||||
|
||||
|
||||
# Record some successes and failures
|
||||
router._record_success(provider, 100.0)
|
||||
router._record_success(provider, 200.0)
|
||||
router._record_failure(provider)
|
||||
|
||||
|
||||
metrics = router.get_metrics()
|
||||
|
||||
|
||||
assert len(metrics["providers"]) == 1
|
||||
p_metrics = metrics["providers"][0]
|
||||
assert p_metrics["metrics"]["total_requests"] == 3
|
||||
@@ -166,7 +166,7 @@ class TestCascadeRouterFunctional:
|
||||
assert p_metrics["metrics"]["failed"] == 1
|
||||
# Average latency is over ALL requests (including failures with 0 latency)
|
||||
assert p_metrics["metrics"]["avg_latency_ms"] == 100.0 # (100+200+0)/3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_disabled_providers(self, router):
|
||||
"""Test that disabled providers are skipped."""
|
||||
@@ -185,23 +185,23 @@ class TestCascadeRouterFunctional:
|
||||
models=[{"name": "model", "default": True}],
|
||||
)
|
||||
router.providers = [disabled, enabled]
|
||||
|
||||
|
||||
# The router should try enabled provider
|
||||
with patch.object(router, "_try_provider") as mock_try:
|
||||
mock_try.return_value = {"content": "Success", "model": "model"}
|
||||
|
||||
|
||||
result = await router.complete(messages=[{"role": "user", "content": "Hi"}])
|
||||
|
||||
|
||||
assert result["provider"] == "enabled"
|
||||
|
||||
|
||||
class TestProviderAvailability:
|
||||
"""Test provider availability checking."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def router(self):
|
||||
return CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
|
||||
def test_openai_available_with_key(self, router):
|
||||
"""Test OpenAI provider is available when API key is set."""
|
||||
provider = Provider(
|
||||
@@ -211,9 +211,9 @@ class TestProviderAvailability:
|
||||
priority=1,
|
||||
api_key="sk-test123",
|
||||
)
|
||||
|
||||
|
||||
assert router._check_provider_available(provider) is True
|
||||
|
||||
|
||||
def test_openai_unavailable_without_key(self, router):
|
||||
"""Test OpenAI provider is unavailable without API key."""
|
||||
provider = Provider(
|
||||
@@ -223,9 +223,9 @@ class TestProviderAvailability:
|
||||
priority=1,
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
|
||||
assert router._check_provider_available(provider) is False
|
||||
|
||||
|
||||
def test_anthropic_available_with_key(self, router):
|
||||
"""Test Anthropic provider is available when API key is set."""
|
||||
provider = Provider(
|
||||
@@ -235,17 +235,17 @@ class TestProviderAvailability:
|
||||
priority=1,
|
||||
api_key="sk-test123",
|
||||
)
|
||||
|
||||
|
||||
assert router._check_provider_available(provider) is True
|
||||
|
||||
|
||||
class TestRouterConfigLoading:
|
||||
"""Test router configuration loading."""
|
||||
|
||||
|
||||
def test_loads_timeout_from_config(self, tmp_path):
|
||||
"""Test that timeout is loaded from config."""
|
||||
import yaml
|
||||
|
||||
|
||||
config = {
|
||||
"cascade": {
|
||||
"timeout_seconds": 60,
|
||||
@@ -253,18 +253,18 @@ class TestRouterConfigLoading:
|
||||
},
|
||||
"providers": [],
|
||||
}
|
||||
|
||||
|
||||
config_path = tmp_path / "providers.yaml"
|
||||
config_path.write_text(yaml.dump(config))
|
||||
|
||||
|
||||
router = CascadeRouter(config_path=config_path)
|
||||
|
||||
|
||||
assert router.config.timeout_seconds == 60
|
||||
assert router.config.max_retries_per_provider == 3
|
||||
|
||||
|
||||
def test_uses_defaults_without_config(self):
|
||||
"""Test that defaults are used when config file doesn't exist."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
|
||||
assert router.config.timeout_seconds == 30
|
||||
assert router.config.max_retries_per_provider == 2
|
||||
|
||||
@@ -6,12 +6,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.models.registry import (
|
||||
CustomModel,
|
||||
ModelFormat,
|
||||
ModelRegistry,
|
||||
ModelRole,
|
||||
)
|
||||
from infrastructure.models.registry import CustomModel, ModelFormat, ModelRegistry, ModelRole
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -199,9 +194,7 @@ class TestCustomModelDataclass:
|
||||
"""Test CustomModel construction."""
|
||||
|
||||
def test_default_registered_at(self):
|
||||
model = CustomModel(
|
||||
name="test", format=ModelFormat.OLLAMA, path="test"
|
||||
)
|
||||
model = CustomModel(name="test", format=ModelFormat.OLLAMA, path="test")
|
||||
assert model.registered_at != ""
|
||||
|
||||
def test_model_roles(self):
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
"""Tests for the custom models API routes."""
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.models.registry import (
|
||||
CustomModel,
|
||||
ModelFormat,
|
||||
ModelRegistry,
|
||||
ModelRole,
|
||||
)
|
||||
from infrastructure.models.registry import CustomModel, ModelFormat, ModelRegistry, ModelRole
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -27,9 +22,7 @@ class TestModelsAPIList:
|
||||
def test_list_models_empty(self, client, tmp_path):
|
||||
db = tmp_path / "api.db"
|
||||
with patch("infrastructure.models.registry.DB_PATH", db):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
with patch("dashboard.routes.models.model_registry") as mock_reg:
|
||||
mock_reg.list_models.return_value = []
|
||||
resp = client.get("/api/v1/models")
|
||||
assert resp.status_code == 200
|
||||
@@ -44,9 +37,7 @@ class TestModelsAPIList:
|
||||
path="llama3.2",
|
||||
role=ModelRole.GENERAL,
|
||||
)
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
with patch("dashboard.routes.models.model_registry") as mock_reg:
|
||||
mock_reg.list_models.return_value = [model]
|
||||
resp = client.get("/api/v1/models")
|
||||
assert resp.status_code == 200
|
||||
@@ -59,9 +50,7 @@ class TestModelsAPIRegister:
|
||||
"""Test model registration via the API."""
|
||||
|
||||
def test_register_ollama_model(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
with patch("dashboard.routes.models.model_registry") as mock_reg:
|
||||
mock_reg.register.return_value = CustomModel(
|
||||
name="my-model",
|
||||
format=ModelFormat.OLLAMA,
|
||||
@@ -111,17 +100,13 @@ class TestModelsAPIDelete:
|
||||
"""Test model deletion via the API."""
|
||||
|
||||
def test_delete_model(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
with patch("dashboard.routes.models.model_registry") as mock_reg:
|
||||
mock_reg.unregister.return_value = True
|
||||
resp = client.delete("/api/v1/models/my-model")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_delete_nonexistent(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
with patch("dashboard.routes.models.model_registry") as mock_reg:
|
||||
mock_reg.unregister.return_value = False
|
||||
resp = client.delete("/api/v1/models/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
@@ -137,18 +122,14 @@ class TestModelsAPIGet:
|
||||
path="llama3.2",
|
||||
role=ModelRole.GENERAL,
|
||||
)
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
with patch("dashboard.routes.models.model_registry") as mock_reg:
|
||||
mock_reg.get.return_value = model
|
||||
resp = client.get("/api/v1/models/my-model")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "my-model"
|
||||
|
||||
def test_get_nonexistent(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
with patch("dashboard.routes.models.model_registry") as mock_reg:
|
||||
mock_reg.get.return_value = None
|
||||
resp = client.get("/api/v1/models/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
@@ -158,9 +139,7 @@ class TestModelsAPIAssignments:
|
||||
"""Test agent model assignment endpoints."""
|
||||
|
||||
def test_assign_model(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
with patch("dashboard.routes.models.model_registry") as mock_reg:
|
||||
mock_reg.assign_model.return_value = True
|
||||
resp = client.post(
|
||||
"/api/v1/models/assignments",
|
||||
@@ -169,9 +148,7 @@ class TestModelsAPIAssignments:
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_assign_nonexistent_model(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
with patch("dashboard.routes.models.model_registry") as mock_reg:
|
||||
mock_reg.assign_model.return_value = False
|
||||
resp = client.post(
|
||||
"/api/v1/models/assignments",
|
||||
@@ -180,25 +157,19 @@ class TestModelsAPIAssignments:
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_unassign_model(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
with patch("dashboard.routes.models.model_registry") as mock_reg:
|
||||
mock_reg.unassign_model.return_value = True
|
||||
resp = client.delete("/api/v1/models/assignments/agent-1")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_unassign_nonexistent(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
with patch("dashboard.routes.models.model_registry") as mock_reg:
|
||||
mock_reg.unassign_model.return_value = False
|
||||
resp = client.delete("/api/v1/models/assignments/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_list_assignments(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
with patch("dashboard.routes.models.model_registry") as mock_reg:
|
||||
mock_reg.get_agent_assignments.return_value = {
|
||||
"agent-1": "model-a",
|
||||
"agent-2": "model-b",
|
||||
@@ -219,9 +190,7 @@ class TestModelsAPIRoles:
|
||||
path="deepseek-r1:1.5b",
|
||||
role=ModelRole.REWARD,
|
||||
)
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
with patch("dashboard.routes.models.model_registry") as mock_reg:
|
||||
mock_reg.get_reward_model.return_value = model
|
||||
resp = client.get("/api/v1/models/roles/reward")
|
||||
assert resp.status_code == 200
|
||||
@@ -229,18 +198,14 @@ class TestModelsAPIRoles:
|
||||
assert data["reward_model"]["name"] == "reward-m"
|
||||
|
||||
def test_get_reward_model_none(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
with patch("dashboard.routes.models.model_registry") as mock_reg:
|
||||
mock_reg.get_reward_model.return_value = None
|
||||
resp = client.get("/api/v1/models/roles/reward")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["reward_model"] is None
|
||||
|
||||
def test_get_teacher_model(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
with patch("dashboard.routes.models.model_registry") as mock_reg:
|
||||
mock_reg.get_teacher_model.return_value = None
|
||||
resp = client.get("/api/v1/models/roles/teacher")
|
||||
assert resp.status_code == 200
|
||||
@@ -251,9 +216,7 @@ class TestModelsAPISetActive:
|
||||
"""Test enable/disable model endpoint."""
|
||||
|
||||
def test_enable_model(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
with patch("dashboard.routes.models.model_registry") as mock_reg:
|
||||
mock_reg.set_active.return_value = True
|
||||
resp = client.patch(
|
||||
"/api/v1/models/my-model/active",
|
||||
@@ -262,9 +225,7 @@ class TestModelsAPISetActive:
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_disable_nonexistent(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
with patch("dashboard.routes.models.model_registry") as mock_reg:
|
||||
mock_reg.set_active.return_value = False
|
||||
resp = client.patch(
|
||||
"/api/v1/models/nonexistent/active",
|
||||
|
||||
@@ -5,14 +5,14 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from infrastructure.router.api import get_cascade_router, router
|
||||
from infrastructure.router.cascade import CircuitState, Provider, ProviderStatus
|
||||
from infrastructure.router.api import router, get_cascade_router
|
||||
|
||||
|
||||
def make_mock_router():
|
||||
"""Create a mock CascadeRouter."""
|
||||
router = MagicMock()
|
||||
|
||||
|
||||
# Create test providers
|
||||
provider1 = Provider(
|
||||
name="ollama-local",
|
||||
@@ -24,7 +24,7 @@ def make_mock_router():
|
||||
)
|
||||
provider1.status = ProviderStatus.HEALTHY
|
||||
provider1.circuit_state = CircuitState.CLOSED
|
||||
|
||||
|
||||
provider2 = Provider(
|
||||
name="openai-backup",
|
||||
type="openai",
|
||||
@@ -35,12 +35,12 @@ def make_mock_router():
|
||||
)
|
||||
provider2.status = ProviderStatus.DEGRADED
|
||||
provider2.circuit_state = CircuitState.CLOSED
|
||||
|
||||
|
||||
router.providers = [provider1, provider2]
|
||||
router.config.timeout_seconds = 30
|
||||
router.config.max_retries_per_provider = 2
|
||||
router.config.circuit_breaker_failure_threshold = 5
|
||||
|
||||
|
||||
return router
|
||||
|
||||
|
||||
@@ -48,74 +48,87 @@ def make_mock_router():
|
||||
def mock_router():
|
||||
"""Create test client with mocked router."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
# Create mock router
|
||||
mock = make_mock_router()
|
||||
|
||||
|
||||
# Override dependency
|
||||
async def mock_get_router():
|
||||
return mock
|
||||
|
||||
|
||||
app.dependency_overrides[get_cascade_router] = mock_get_router
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
return client, mock
|
||||
|
||||
|
||||
class TestCompleteEndpoint:
|
||||
"""Test /complete endpoint."""
|
||||
|
||||
|
||||
def test_complete_success(self, mock_router):
|
||||
"""Test successful completion."""
|
||||
client, mock = mock_router
|
||||
mock.complete = AsyncMock(return_value={
|
||||
"content": "Hello! How can I help?",
|
||||
"provider": "ollama-local",
|
||||
"model": "llama3.2",
|
||||
"latency_ms": 250.5,
|
||||
})
|
||||
|
||||
response = client.post("/api/v1/router/complete", json={
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"model": "llama3.2",
|
||||
"temperature": 0.7,
|
||||
})
|
||||
|
||||
mock.complete = AsyncMock(
|
||||
return_value={
|
||||
"content": "Hello! How can I help?",
|
||||
"provider": "ollama-local",
|
||||
"model": "llama3.2",
|
||||
"latency_ms": 250.5,
|
||||
}
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/router/complete",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"model": "llama3.2",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["content"] == "Hello! How can I help?"
|
||||
assert data["provider"] == "ollama-local"
|
||||
assert data["latency_ms"] == 250.5
|
||||
|
||||
|
||||
def test_complete_all_providers_fail(self, mock_router):
|
||||
"""Test 503 when all providers fail."""
|
||||
client, mock = mock_router
|
||||
mock.complete = AsyncMock(side_effect=RuntimeError("All providers failed"))
|
||||
|
||||
response = client.post("/api/v1/router/complete", json={
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
})
|
||||
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/router/complete",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 503
|
||||
assert "All providers failed" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_complete_default_temperature(self, mock_router):
|
||||
"""Test completion with default temperature."""
|
||||
client, mock = mock_router
|
||||
mock.complete = AsyncMock(return_value={
|
||||
"content": "Response",
|
||||
"provider": "ollama-local",
|
||||
"model": "llama3.2",
|
||||
"latency_ms": 100.0,
|
||||
})
|
||||
|
||||
response = client.post("/api/v1/router/complete", json={
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
})
|
||||
|
||||
mock.complete = AsyncMock(
|
||||
return_value={
|
||||
"content": "Response",
|
||||
"provider": "ollama-local",
|
||||
"model": "llama3.2",
|
||||
"latency_ms": 100.0,
|
||||
}
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/router/complete",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
# Check that complete was called with correct temperature
|
||||
call_args = mock.complete.call_args
|
||||
@@ -124,35 +137,37 @@ class TestCompleteEndpoint:
|
||||
|
||||
class TestStatusEndpoint:
|
||||
"""Test /status endpoint."""
|
||||
|
||||
|
||||
def test_get_status(self, mock_router):
|
||||
"""Test getting router status."""
|
||||
client, mock = mock_router
|
||||
mock.get_status = MagicMock(return_value={
|
||||
"total_providers": 2,
|
||||
"healthy_providers": 1,
|
||||
"degraded_providers": 1,
|
||||
"unhealthy_providers": 0,
|
||||
"providers": [
|
||||
{
|
||||
"name": "ollama-local",
|
||||
"type": "ollama",
|
||||
"status": "healthy",
|
||||
"priority": 1,
|
||||
"default_model": "llama3.2",
|
||||
},
|
||||
{
|
||||
"name": "openai-backup",
|
||||
"type": "openai",
|
||||
"status": "degraded",
|
||||
"priority": 2,
|
||||
"default_model": "gpt-4o-mini",
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
mock.get_status = MagicMock(
|
||||
return_value={
|
||||
"total_providers": 2,
|
||||
"healthy_providers": 1,
|
||||
"degraded_providers": 1,
|
||||
"unhealthy_providers": 0,
|
||||
"providers": [
|
||||
{
|
||||
"name": "ollama-local",
|
||||
"type": "ollama",
|
||||
"status": "healthy",
|
||||
"priority": 1,
|
||||
"default_model": "llama3.2",
|
||||
},
|
||||
{
|
||||
"name": "openai-backup",
|
||||
"type": "openai",
|
||||
"status": "degraded",
|
||||
"priority": 2,
|
||||
"default_model": "gpt-4o-mini",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/router/status")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_providers"] == 2
|
||||
@@ -163,31 +178,33 @@ class TestStatusEndpoint:
|
||||
|
||||
class TestMetricsEndpoint:
|
||||
"""Test /metrics endpoint."""
|
||||
|
||||
|
||||
def test_get_metrics(self, mock_router):
|
||||
"""Test getting detailed metrics."""
|
||||
client, mock = mock_router
|
||||
# Setup the mock return value on the mock_router object
|
||||
mock.get_metrics = MagicMock(return_value={
|
||||
"providers": [
|
||||
{
|
||||
"name": "ollama-local",
|
||||
"type": "ollama",
|
||||
"status": "healthy",
|
||||
"circuit_state": "closed",
|
||||
"metrics": {
|
||||
"total_requests": 100,
|
||||
"successful": 98,
|
||||
"failed": 2,
|
||||
"error_rate": 0.02,
|
||||
"avg_latency_ms": 150.5,
|
||||
mock.get_metrics = MagicMock(
|
||||
return_value={
|
||||
"providers": [
|
||||
{
|
||||
"name": "ollama-local",
|
||||
"type": "ollama",
|
||||
"status": "healthy",
|
||||
"circuit_state": "closed",
|
||||
"metrics": {
|
||||
"total_requests": 100,
|
||||
"successful": 98,
|
||||
"failed": 2,
|
||||
"error_rate": 0.02,
|
||||
"avg_latency_ms": 150.5,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/router/metrics")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["providers"]) == 1
|
||||
@@ -199,17 +216,17 @@ class TestMetricsEndpoint:
|
||||
|
||||
class TestListProvidersEndpoint:
|
||||
"""Test /providers endpoint."""
|
||||
|
||||
|
||||
def test_list_providers(self, mock_router):
|
||||
"""Test listing all providers."""
|
||||
client, mock = mock_router
|
||||
|
||||
|
||||
response = client.get("/api/v1/router/providers")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 2
|
||||
|
||||
|
||||
# Check first provider
|
||||
assert data[0]["name"] == "ollama-local"
|
||||
assert data[0]["type"] == "ollama"
|
||||
@@ -221,40 +238,38 @@ class TestListProvidersEndpoint:
|
||||
|
||||
class TestControlProviderEndpoint:
|
||||
"""Test /providers/{name}/control endpoint."""
|
||||
|
||||
|
||||
def test_disable_provider(self, mock_router):
|
||||
"""Test disabling a provider."""
|
||||
client, mock = mock_router
|
||||
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/router/providers/ollama-local/control",
|
||||
json={"action": "disable"}
|
||||
"/api/v1/router/providers/ollama-local/control", json={"action": "disable"}
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "disabled" in response.json()["message"]
|
||||
|
||||
|
||||
# Check that the provider was disabled
|
||||
provider = mock.providers[0]
|
||||
assert provider.enabled is False
|
||||
assert provider.status == ProviderStatus.DISABLED
|
||||
|
||||
|
||||
def test_enable_provider(self, mock_router):
|
||||
"""Test enabling a provider."""
|
||||
client, mock = mock_router
|
||||
# First disable it
|
||||
mock.providers[0].enabled = False
|
||||
mock.providers[0].status = ProviderStatus.DISABLED
|
||||
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/router/providers/ollama-local/control",
|
||||
json={"action": "enable"}
|
||||
"/api/v1/router/providers/ollama-local/control", json={"action": "enable"}
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "enabled" in response.json()["message"]
|
||||
assert mock.providers[0].enabled is True
|
||||
|
||||
|
||||
def test_reset_circuit(self, mock_router):
|
||||
"""Test resetting circuit breaker."""
|
||||
client, mock = mock_router
|
||||
@@ -262,73 +277,70 @@ class TestControlProviderEndpoint:
|
||||
mock.providers[0].circuit_state = CircuitState.OPEN
|
||||
mock.providers[0].status = ProviderStatus.UNHEALTHY
|
||||
mock.providers[0].metrics.consecutive_failures = 10
|
||||
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/router/providers/ollama-local/control",
|
||||
json={"action": "reset_circuit"}
|
||||
"/api/v1/router/providers/ollama-local/control", json={"action": "reset_circuit"}
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "reset" in response.json()["message"]
|
||||
|
||||
|
||||
provider = mock.providers[0]
|
||||
assert provider.circuit_state == CircuitState.CLOSED
|
||||
assert provider.status == ProviderStatus.HEALTHY
|
||||
assert provider.metrics.consecutive_failures == 0
|
||||
|
||||
|
||||
def test_control_unknown_provider(self, mock_router):
|
||||
"""Test controlling unknown provider returns 404."""
|
||||
client, mock = mock_router
|
||||
response = client.post(
|
||||
"/api/v1/router/providers/unknown/control",
|
||||
json={"action": "disable"}
|
||||
"/api/v1/router/providers/unknown/control", json={"action": "disable"}
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_control_unknown_action(self, mock_router):
|
||||
"""Test unknown action returns 400."""
|
||||
client, mock = mock_router
|
||||
response = client.post(
|
||||
"/api/v1/router/providers/ollama-local/control",
|
||||
json={"action": "invalid_action"}
|
||||
"/api/v1/router/providers/ollama-local/control", json={"action": "invalid_action"}
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Unknown action" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestHealthCheckEndpoint:
|
||||
"""Test /health-check endpoint."""
|
||||
|
||||
|
||||
def test_health_check_all_healthy(self, mock_router):
|
||||
"""Test health check when all providers are healthy."""
|
||||
client, mock = mock_router
|
||||
|
||||
|
||||
with patch.object(mock, "_check_provider_available") as mock_check:
|
||||
mock_check.return_value = True
|
||||
|
||||
|
||||
response = client.post("/api/v1/router/health-check")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["healthy_count"] == 2
|
||||
assert len(data["providers"]) == 2
|
||||
|
||||
|
||||
for p in data["providers"]:
|
||||
assert p["healthy"] is True
|
||||
|
||||
|
||||
def test_health_check_with_failure(self, mock_router):
|
||||
"""Test health check when some providers fail."""
|
||||
client, mock = mock_router
|
||||
|
||||
|
||||
with patch.object(mock, "_check_provider_available") as mock_check:
|
||||
# First provider fails, second succeeds
|
||||
mock_check.side_effect = [False, True]
|
||||
|
||||
|
||||
response = client.post("/api/v1/router/health-check")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["healthy_count"] == 1
|
||||
@@ -338,21 +350,21 @@ class TestHealthCheckEndpoint:
|
||||
|
||||
class TestGetConfigEndpoint:
|
||||
"""Test /config endpoint."""
|
||||
|
||||
|
||||
def test_get_config(self, mock_router):
|
||||
"""Test getting router configuration."""
|
||||
client, mock = mock_router
|
||||
|
||||
|
||||
response = client.get("/api/v1/router/config")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
|
||||
assert data["timeout_seconds"] == 30
|
||||
assert data["max_retries_per_provider"] == 2
|
||||
assert "circuit_breaker" in data
|
||||
assert data["circuit_breaker"]["failure_threshold"] == 5
|
||||
|
||||
|
||||
# Check providers list (without secrets)
|
||||
assert len(data["providers"]) == 2
|
||||
assert "api_key" not in data["providers"][0]
|
||||
|
||||
Reference in New Issue
Block a user