forked from Rockachopa/Timmy-time-dashboard
Merge pull request #44 from AlexanderWhitestone/feature/memory-layers-and-conversational-ai
Phase 3-4: Cascade LLM Router + Tool Registry Auto-Discovery
This commit is contained in:
265
tests/test_mcp_bootstrap.py
Normal file
265
tests/test_mcp_bootstrap.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""Tests for MCP Auto-Bootstrap.
|
||||
|
||||
Tests follow pytest best practices:
|
||||
- No module-level state
|
||||
- Proper fixture cleanup
|
||||
- Isolated tests
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mcp.bootstrap import (
|
||||
auto_bootstrap,
|
||||
bootstrap_from_directory,
|
||||
get_bootstrap_status,
|
||||
DEFAULT_TOOL_PACKAGES,
|
||||
AUTO_BOOTSTRAP_ENV_VAR,
|
||||
)
|
||||
from mcp.discovery import mcp_tool, ToolDiscovery
|
||||
from mcp.registry import ToolRegistry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fresh_registry():
|
||||
"""Create a fresh registry for each test."""
|
||||
return ToolRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fresh_discovery(fresh_registry):
|
||||
"""Create a fresh discovery instance for each test."""
|
||||
return ToolDiscovery(registry=fresh_registry)
|
||||
|
||||
|
||||
class TestAutoBootstrap:
|
||||
"""Test auto_bootstrap function."""
|
||||
|
||||
def test_auto_bootstrap_disabled_by_env(self, fresh_registry):
|
||||
"""Test that auto-bootstrap can be disabled via env var."""
|
||||
with patch.dict(os.environ, {AUTO_BOOTSTRAP_ENV_VAR: "0"}):
|
||||
registered = auto_bootstrap(registry=fresh_registry)
|
||||
|
||||
assert len(registered) == 0
|
||||
|
||||
def test_auto_bootstrap_forced_overrides_env(self, fresh_registry):
|
||||
"""Test that force=True overrides env var."""
|
||||
with patch.dict(os.environ, {AUTO_BOOTSTRAP_ENV_VAR: "0"}):
|
||||
# Empty packages list - just test that it runs
|
||||
registered = auto_bootstrap(
|
||||
packages=[],
|
||||
registry=fresh_registry,
|
||||
force=True,
|
||||
)
|
||||
|
||||
assert len(registered) == 0 # No packages, but didn't abort
|
||||
|
||||
def test_auto_bootstrap_nonexistent_package(self, fresh_registry):
|
||||
"""Test bootstrap from non-existent package."""
|
||||
registered = auto_bootstrap(
|
||||
packages=["nonexistent_package_xyz_12345"],
|
||||
registry=fresh_registry,
|
||||
force=True,
|
||||
)
|
||||
|
||||
assert len(registered) == 0
|
||||
|
||||
def test_auto_bootstrap_empty_packages(self, fresh_registry):
|
||||
"""Test bootstrap with empty packages list."""
|
||||
registered = auto_bootstrap(
|
||||
packages=[],
|
||||
registry=fresh_registry,
|
||||
force=True,
|
||||
)
|
||||
|
||||
assert len(registered) == 0
|
||||
|
||||
def test_auto_bootstrap_registers_tools(self, fresh_registry, fresh_discovery):
|
||||
"""Test that auto-bootstrap registers discovered tools."""
|
||||
@mcp_tool(name="bootstrap_tool", category="bootstrap")
|
||||
def bootstrap_func(value: str) -> str:
|
||||
"""A bootstrap test tool."""
|
||||
return value
|
||||
|
||||
# Manually register it
|
||||
fresh_registry.register_tool(
|
||||
name="bootstrap_tool",
|
||||
function=bootstrap_func,
|
||||
category="bootstrap",
|
||||
)
|
||||
|
||||
# Verify it's in the registry
|
||||
record = fresh_registry.get("bootstrap_tool")
|
||||
assert record is not None
|
||||
assert record.auto_discovered is True
|
||||
|
||||
|
||||
class TestBootstrapFromDirectory:
|
||||
"""Test bootstrap_from_directory function."""
|
||||
|
||||
def test_bootstrap_from_directory(self, fresh_registry, tmp_path):
|
||||
"""Test bootstrapping from a directory."""
|
||||
tools_dir = tmp_path / "tools"
|
||||
tools_dir.mkdir()
|
||||
|
||||
tool_file = tools_dir / "my_tools.py"
|
||||
tool_file.write_text('''
|
||||
from mcp.discovery import mcp_tool
|
||||
|
||||
@mcp_tool(name="dir_tool", category="directory")
|
||||
def dir_tool(value: str) -> str:
|
||||
"""A tool from directory."""
|
||||
return value
|
||||
''')
|
||||
|
||||
registered = bootstrap_from_directory(tools_dir, registry=fresh_registry)
|
||||
|
||||
# Function won't be resolved (AST only), so not registered
|
||||
assert len(registered) == 0
|
||||
|
||||
def test_bootstrap_from_nonexistent_directory(self, fresh_registry):
|
||||
"""Test bootstrapping from non-existent directory."""
|
||||
registered = bootstrap_from_directory(
|
||||
Path("/nonexistent/tools"),
|
||||
registry=fresh_registry
|
||||
)
|
||||
|
||||
assert len(registered) == 0
|
||||
|
||||
def test_bootstrap_skips_private_files(self, fresh_registry, tmp_path):
|
||||
"""Test that private files are skipped."""
|
||||
tools_dir = tmp_path / "tools"
|
||||
tools_dir.mkdir()
|
||||
|
||||
private_file = tools_dir / "_private.py"
|
||||
private_file.write_text('''
|
||||
from mcp.discovery import mcp_tool
|
||||
|
||||
@mcp_tool(name="private_tool")
|
||||
def private_tool():
|
||||
pass
|
||||
''')
|
||||
|
||||
registered = bootstrap_from_directory(tools_dir, registry=fresh_registry)
|
||||
assert len(registered) == 0
|
||||
|
||||
|
||||
class TestGetBootstrapStatus:
|
||||
"""Test get_bootstrap_status function."""
|
||||
|
||||
def test_status_default_enabled(self):
|
||||
"""Test status when auto-bootstrap is enabled by default."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
status = get_bootstrap_status()
|
||||
|
||||
assert status["auto_bootstrap_enabled"] is True
|
||||
assert "discovered_tools_count" in status
|
||||
assert "registered_tools_count" in status
|
||||
assert status["default_packages"] == DEFAULT_TOOL_PACKAGES
|
||||
|
||||
def test_status_disabled(self):
|
||||
"""Test status when auto-bootstrap is disabled."""
|
||||
with patch.dict(os.environ, {AUTO_BOOTSTRAP_ENV_VAR: "0"}):
|
||||
status = get_bootstrap_status()
|
||||
|
||||
assert status["auto_bootstrap_enabled"] is False
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for bootstrap + discovery + registry."""
|
||||
|
||||
def test_full_workflow(self, fresh_registry):
|
||||
"""Test the full auto-discovery and registration workflow."""
|
||||
@mcp_tool(name="integration_tool", category="integration")
|
||||
def integration_func(data: str) -> str:
|
||||
"""Integration test tool."""
|
||||
return f"processed: {data}"
|
||||
|
||||
fresh_registry.register_tool(
|
||||
name="integration_tool",
|
||||
function=integration_func,
|
||||
category="integration",
|
||||
source_module="test_module",
|
||||
)
|
||||
|
||||
record = fresh_registry.get("integration_tool")
|
||||
assert record is not None
|
||||
assert record.auto_discovered is True
|
||||
assert record.source_module == "test_module"
|
||||
|
||||
export = fresh_registry.to_dict()
|
||||
assert export["total_tools"] == 1
|
||||
assert export["auto_discovered_count"] == 1
|
||||
|
||||
def test_tool_execution_after_registration(self, fresh_registry):
|
||||
"""Test that registered tools can be executed."""
|
||||
@mcp_tool(name="exec_tool", category="execution")
|
||||
def exec_func(input: str) -> str:
|
||||
"""Executable test tool."""
|
||||
return input.upper()
|
||||
|
||||
fresh_registry.register_tool(
|
||||
name="exec_tool",
|
||||
function=exec_func,
|
||||
category="execution",
|
||||
)
|
||||
|
||||
import asyncio
|
||||
result = asyncio.run(fresh_registry.execute("exec_tool", {"input": "hello"}))
|
||||
|
||||
assert result == "HELLO"
|
||||
|
||||
metrics = fresh_registry.get_metrics("exec_tool")
|
||||
assert metrics["executions"] == 1
|
||||
assert metrics["health"] == "healthy"
|
||||
|
||||
def test_discover_filtering(self, fresh_registry):
|
||||
"""Test filtering registered tools."""
|
||||
@mcp_tool(name="cat1_tool", category="category1")
|
||||
def cat1_func():
|
||||
pass
|
||||
|
||||
@mcp_tool(name="cat2_tool", category="category2")
|
||||
def cat2_func():
|
||||
pass
|
||||
|
||||
fresh_registry.register_tool(
|
||||
name="cat1_tool",
|
||||
function=cat1_func,
|
||||
category="category1"
|
||||
)
|
||||
fresh_registry.register_tool(
|
||||
name="cat2_tool",
|
||||
function=cat2_func,
|
||||
category="category2"
|
||||
)
|
||||
|
||||
cat1_tools = fresh_registry.discover(category="category1")
|
||||
assert len(cat1_tools) == 1
|
||||
assert cat1_tools[0].name == "cat1_tool"
|
||||
|
||||
auto_tools = fresh_registry.discover(auto_discovered_only=True)
|
||||
assert len(auto_tools) == 2
|
||||
|
||||
def test_registry_export_includes_metadata(self, fresh_registry):
|
||||
"""Test that registry export includes all metadata."""
|
||||
@mcp_tool(name="meta_tool", category="meta", tags=["tag1", "tag2"])
|
||||
def meta_func():
|
||||
pass
|
||||
|
||||
fresh_registry.register_tool(
|
||||
name="meta_tool",
|
||||
function=meta_func,
|
||||
category="meta",
|
||||
tags=["tag1", "tag2"],
|
||||
)
|
||||
|
||||
export = fresh_registry.to_dict()
|
||||
|
||||
for tool_dict in export["tools"]:
|
||||
assert "tags" in tool_dict
|
||||
assert "source_module" in tool_dict
|
||||
assert "auto_discovered" in tool_dict
|
||||
329
tests/test_mcp_discovery.py
Normal file
329
tests/test_mcp_discovery.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""Tests for MCP Tool Auto-Discovery.
|
||||
|
||||
Tests follow pytest best practices:
|
||||
- No module-level state
|
||||
- Proper fixture cleanup
|
||||
- Isolated tests
|
||||
"""
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mcp.discovery import DiscoveredTool, ToolDiscovery, mcp_tool
|
||||
from mcp.registry import ToolRegistry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fresh_registry():
|
||||
"""Create a fresh registry for each test."""
|
||||
return ToolRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def discovery(fresh_registry):
|
||||
"""Create a fresh discovery instance for each test."""
|
||||
return ToolDiscovery(registry=fresh_registry)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_module_with_tools():
|
||||
"""Create a mock module with MCP tools for testing."""
|
||||
# Create a fresh module
|
||||
mock_module = types.ModuleType("mock_test_module")
|
||||
mock_module.__file__ = "mock_test_module.py"
|
||||
|
||||
# Add decorated functions
|
||||
@mcp_tool(name="echo", category="test", tags=["utility"])
|
||||
def echo_func(message: str) -> str:
|
||||
"""Echo a message back."""
|
||||
return message
|
||||
|
||||
@mcp_tool(category="math")
|
||||
def add_func(a: int, b: int) -> int:
|
||||
"""Add two numbers."""
|
||||
return a + b
|
||||
|
||||
def not_decorated():
|
||||
"""Not a tool."""
|
||||
pass
|
||||
|
||||
mock_module.echo_func = echo_func
|
||||
mock_module.add_func = add_func
|
||||
mock_module.not_decorated = not_decorated
|
||||
|
||||
# Inject into sys.modules
|
||||
sys.modules["mock_test_module"] = mock_module
|
||||
|
||||
yield mock_module
|
||||
|
||||
# Cleanup
|
||||
del sys.modules["mock_test_module"]
|
||||
|
||||
|
||||
class TestMCPToolDecorator:
|
||||
"""Test the @mcp_tool decorator."""
|
||||
|
||||
def test_decorator_sets_explicit_name(self):
|
||||
"""Test that decorator uses explicit name."""
|
||||
@mcp_tool(name="custom_name", category="test")
|
||||
def my_func():
|
||||
pass
|
||||
|
||||
assert my_func._mcp_name == "custom_name"
|
||||
assert my_func._mcp_category == "test"
|
||||
|
||||
def test_decorator_uses_function_name(self):
|
||||
"""Test that decorator uses function name when not specified."""
|
||||
@mcp_tool(category="math")
|
||||
def my_add_func():
|
||||
pass
|
||||
|
||||
assert my_add_func._mcp_name == "my_add_func"
|
||||
|
||||
def test_decorator_captures_docstring(self):
|
||||
"""Test that decorator captures docstring as description."""
|
||||
@mcp_tool(name="test")
|
||||
def with_doc():
|
||||
"""This is the description."""
|
||||
pass
|
||||
|
||||
assert "This is the description" in with_doc._mcp_description
|
||||
|
||||
def test_decorator_sets_tags(self):
|
||||
"""Test that decorator sets tags."""
|
||||
@mcp_tool(name="test", tags=["tag1", "tag2"])
|
||||
def tagged_func():
|
||||
pass
|
||||
|
||||
assert tagged_func._mcp_tags == ["tag1", "tag2"]
|
||||
|
||||
def test_undecorated_function(self):
|
||||
"""Test that undecorated functions don't have MCP attributes."""
|
||||
def plain_func():
|
||||
pass
|
||||
|
||||
assert not hasattr(plain_func, "_mcp_tool")
|
||||
|
||||
|
||||
class TestDiscoveredTool:
|
||||
"""Test DiscoveredTool dataclass."""
|
||||
|
||||
def test_tool_creation(self):
|
||||
"""Test creating a DiscoveredTool."""
|
||||
def dummy_func():
|
||||
pass
|
||||
|
||||
tool = DiscoveredTool(
|
||||
name="test",
|
||||
description="A test tool",
|
||||
function=dummy_func,
|
||||
module="test_module",
|
||||
category="test",
|
||||
tags=["utility"],
|
||||
parameters_schema={"type": "object"},
|
||||
returns_schema={"type": "string"},
|
||||
)
|
||||
|
||||
assert tool.name == "test"
|
||||
assert tool.function == dummy_func
|
||||
assert tool.category == "test"
|
||||
|
||||
|
||||
class TestToolDiscoveryInit:
|
||||
"""Test ToolDiscovery initialization."""
|
||||
|
||||
def test_uses_provided_registry(self, fresh_registry):
|
||||
"""Test initialization with provided registry."""
|
||||
discovery = ToolDiscovery(registry=fresh_registry)
|
||||
assert discovery.registry is fresh_registry
|
||||
|
||||
|
||||
class TestDiscoverModule:
|
||||
"""Test discovering tools from modules."""
|
||||
|
||||
def test_discover_finds_decorated_tools(self, discovery, mock_module_with_tools):
|
||||
"""Test discovering tools from a module."""
|
||||
tools = discovery.discover_module("mock_test_module")
|
||||
|
||||
tool_names = [t.name for t in tools]
|
||||
assert "echo" in tool_names
|
||||
assert "add_func" in tool_names
|
||||
assert "not_decorated" not in tool_names
|
||||
|
||||
def test_discover_nonexistent_module(self, discovery):
|
||||
"""Test discovering from non-existent module."""
|
||||
tools = discovery.discover_module("nonexistent.module.xyz")
|
||||
assert len(tools) == 0
|
||||
|
||||
def test_discovered_tool_has_correct_metadata(self, discovery, mock_module_with_tools):
|
||||
"""Test that discovered tools have correct metadata."""
|
||||
tools = discovery.discover_module("mock_test_module")
|
||||
|
||||
echo_tool = next(t for t in tools if t.name == "echo")
|
||||
assert echo_tool.category == "test"
|
||||
assert "utility" in echo_tool.tags
|
||||
|
||||
def test_discovered_tool_has_schema(self, discovery, mock_module_with_tools):
|
||||
"""Test that discovered tools have parameter schemas."""
|
||||
tools = discovery.discover_module("mock_test_module")
|
||||
|
||||
add_tool = next(t for t in tools if t.name == "add_func")
|
||||
assert "properties" in add_tool.parameters_schema
|
||||
assert "a" in add_tool.parameters_schema["properties"]
|
||||
|
||||
|
||||
class TestDiscoverFile:
|
||||
"""Test discovering tools from Python files."""
|
||||
|
||||
def test_discover_from_file(self, discovery, tmp_path):
|
||||
"""Test discovering tools from a Python file."""
|
||||
test_file = tmp_path / "test_tools.py"
|
||||
test_file.write_text('''
|
||||
from mcp.discovery import mcp_tool
|
||||
|
||||
@mcp_tool(name="file_tool", category="file_ops", tags=["io"])
|
||||
def file_tool(path: str) -> dict:
|
||||
"""Process a file."""
|
||||
return {"path": path}
|
||||
''')
|
||||
|
||||
tools = discovery.discover_file(test_file)
|
||||
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "file_tool"
|
||||
assert tools[0].category == "file_ops"
|
||||
|
||||
def test_discover_from_nonexistent_file(self, discovery, tmp_path):
|
||||
"""Test discovering from non-existent file."""
|
||||
tools = discovery.discover_file(tmp_path / "nonexistent.py")
|
||||
assert len(tools) == 0
|
||||
|
||||
def test_discover_from_invalid_python(self, discovery, tmp_path):
|
||||
"""Test discovering from invalid Python file."""
|
||||
test_file = tmp_path / "invalid.py"
|
||||
test_file.write_text("not valid python @#$%")
|
||||
|
||||
tools = discovery.discover_file(test_file)
|
||||
assert len(tools) == 0
|
||||
|
||||
|
||||
class TestSchemaBuilding:
|
||||
"""Test JSON schema building from type hints."""
|
||||
|
||||
def test_string_parameter(self, discovery):
|
||||
"""Test string parameter schema."""
|
||||
def func(name: str) -> str:
|
||||
return name
|
||||
|
||||
sig = inspect.signature(func)
|
||||
schema = discovery._build_parameters_schema(sig)
|
||||
|
||||
assert schema["properties"]["name"]["type"] == "string"
|
||||
|
||||
def test_int_parameter(self, discovery):
|
||||
"""Test int parameter schema."""
|
||||
def func(count: int) -> int:
|
||||
return count
|
||||
|
||||
sig = inspect.signature(func)
|
||||
schema = discovery._build_parameters_schema(sig)
|
||||
|
||||
assert schema["properties"]["count"]["type"] == "number"
|
||||
|
||||
def test_bool_parameter(self, discovery):
|
||||
"""Test bool parameter schema."""
|
||||
def func(enabled: bool) -> bool:
|
||||
return enabled
|
||||
|
||||
sig = inspect.signature(func)
|
||||
schema = discovery._build_parameters_schema(sig)
|
||||
|
||||
assert schema["properties"]["enabled"]["type"] == "boolean"
|
||||
|
||||
def test_required_parameters(self, discovery):
|
||||
"""Test that required parameters are marked."""
|
||||
def func(required: str, optional: str = "default") -> str:
|
||||
return required
|
||||
|
||||
sig = inspect.signature(func)
|
||||
schema = discovery._build_parameters_schema(sig)
|
||||
|
||||
assert "required" in schema["required"]
|
||||
assert "optional" not in schema["required"]
|
||||
|
||||
def test_default_values(self, discovery):
|
||||
"""Test that default values are captured."""
|
||||
def func(name: str = "default") -> str:
|
||||
return name
|
||||
|
||||
sig = inspect.signature(func)
|
||||
schema = discovery._build_parameters_schema(sig)
|
||||
|
||||
assert schema["properties"]["name"]["default"] == "default"
|
||||
|
||||
|
||||
class TestTypeToSchema:
|
||||
"""Test type annotation to JSON schema conversion."""
|
||||
|
||||
def test_str_annotation(self, discovery):
|
||||
"""Test string annotation."""
|
||||
schema = discovery._type_to_schema(str)
|
||||
assert schema["type"] == "string"
|
||||
|
||||
def test_int_annotation(self, discovery):
|
||||
"""Test int annotation."""
|
||||
schema = discovery._type_to_schema(int)
|
||||
assert schema["type"] == "number"
|
||||
|
||||
def test_optional_annotation(self, discovery):
|
||||
"""Test Optional[T] annotation."""
|
||||
from typing import Optional
|
||||
schema = discovery._type_to_schema(Optional[str])
|
||||
assert schema["type"] == "string"
|
||||
|
||||
|
||||
class TestAutoRegister:
|
||||
"""Test auto-registration of discovered tools."""
|
||||
|
||||
def test_auto_register_module(self, discovery, mock_module_with_tools, fresh_registry):
|
||||
"""Test auto-registering tools from a module."""
|
||||
registered = discovery.auto_register("mock_test_module")
|
||||
|
||||
assert "echo" in registered
|
||||
assert "add_func" in registered
|
||||
assert fresh_registry.get("echo") is not None
|
||||
|
||||
def test_auto_register_skips_unresolved_functions(self, discovery, fresh_registry):
|
||||
"""Test that tools without resolved functions are skipped."""
|
||||
# Add a discovered tool with no function
|
||||
discovery._discovered.append(DiscoveredTool(
|
||||
name="no_func",
|
||||
description="No function",
|
||||
function=None, # type: ignore
|
||||
module="test",
|
||||
category="test",
|
||||
tags=[],
|
||||
parameters_schema={},
|
||||
returns_schema={},
|
||||
))
|
||||
|
||||
registered = discovery.auto_register("mock_test_module")
|
||||
assert "no_func" not in registered
|
||||
|
||||
|
||||
class TestClearDiscovered:
|
||||
"""Test clearing discovered tools cache."""
|
||||
|
||||
def test_clear_discovered(self, discovery, mock_module_with_tools):
|
||||
"""Test clearing discovered tools."""
|
||||
discovery.discover_module("mock_test_module")
|
||||
assert len(discovery.get_discovered()) > 0
|
||||
|
||||
discovery.clear()
|
||||
assert len(discovery.get_discovered()) == 0
|
||||
358
tests/test_router_api.py
Normal file
358
tests/test_router_api.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""Tests for Cascade Router API endpoints."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from router.cascade import CircuitState, Provider, ProviderStatus
|
||||
from router.api import router, get_cascade_router
|
||||
|
||||
|
||||
def make_mock_router():
|
||||
"""Create a mock CascadeRouter."""
|
||||
router = MagicMock()
|
||||
|
||||
# Create test providers
|
||||
provider1 = Provider(
|
||||
name="ollama-local",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
url="http://localhost:11434",
|
||||
models=[{"name": "llama3.2", "default": True, "context_window": 128000}],
|
||||
)
|
||||
provider1.status = ProviderStatus.HEALTHY
|
||||
provider1.circuit_state = CircuitState.CLOSED
|
||||
|
||||
provider2 = Provider(
|
||||
name="openai-backup",
|
||||
type="openai",
|
||||
enabled=True,
|
||||
priority=2,
|
||||
api_key="sk-test",
|
||||
models=[{"name": "gpt-4o-mini", "default": True, "context_window": 128000}],
|
||||
)
|
||||
provider2.status = ProviderStatus.DEGRADED
|
||||
provider2.circuit_state = CircuitState.CLOSED
|
||||
|
||||
router.providers = [provider1, provider2]
|
||||
router.config.timeout_seconds = 30
|
||||
router.config.max_retries_per_provider = 2
|
||||
router.config.circuit_breaker_failure_threshold = 5
|
||||
|
||||
return router
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_router():
|
||||
"""Create test client with mocked router."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
# Create mock router
|
||||
mock = make_mock_router()
|
||||
|
||||
# Override dependency
|
||||
async def mock_get_router():
|
||||
return mock
|
||||
|
||||
app.dependency_overrides[get_cascade_router] = mock_get_router
|
||||
|
||||
client = TestClient(app)
|
||||
return client, mock
|
||||
|
||||
|
||||
class TestCompleteEndpoint:
|
||||
"""Test /complete endpoint."""
|
||||
|
||||
def test_complete_success(self, mock_router):
|
||||
"""Test successful completion."""
|
||||
client, mock = mock_router
|
||||
mock.complete = AsyncMock(return_value={
|
||||
"content": "Hello! How can I help?",
|
||||
"provider": "ollama-local",
|
||||
"model": "llama3.2",
|
||||
"latency_ms": 250.5,
|
||||
})
|
||||
|
||||
response = client.post("/api/v1/router/complete", json={
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"model": "llama3.2",
|
||||
"temperature": 0.7,
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["content"] == "Hello! How can I help?"
|
||||
assert data["provider"] == "ollama-local"
|
||||
assert data["latency_ms"] == 250.5
|
||||
|
||||
def test_complete_all_providers_fail(self, mock_router):
|
||||
"""Test 503 when all providers fail."""
|
||||
client, mock = mock_router
|
||||
mock.complete = AsyncMock(side_effect=RuntimeError("All providers failed"))
|
||||
|
||||
response = client.post("/api/v1/router/complete", json={
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
})
|
||||
|
||||
assert response.status_code == 503
|
||||
assert "All providers failed" in response.json()["detail"]
|
||||
|
||||
def test_complete_default_temperature(self, mock_router):
|
||||
"""Test completion with default temperature."""
|
||||
client, mock = mock_router
|
||||
mock.complete = AsyncMock(return_value={
|
||||
"content": "Response",
|
||||
"provider": "ollama-local",
|
||||
"model": "llama3.2",
|
||||
"latency_ms": 100.0,
|
||||
})
|
||||
|
||||
response = client.post("/api/v1/router/complete", json={
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
# Check that complete was called with correct temperature
|
||||
call_args = mock.complete.call_args
|
||||
assert call_args.kwargs["temperature"] == 0.7
|
||||
|
||||
|
||||
class TestStatusEndpoint:
|
||||
"""Test /status endpoint."""
|
||||
|
||||
def test_get_status(self, mock_router):
|
||||
"""Test getting router status."""
|
||||
client, mock = mock_router
|
||||
mock.get_status = MagicMock(return_value={
|
||||
"total_providers": 2,
|
||||
"healthy_providers": 1,
|
||||
"degraded_providers": 1,
|
||||
"unhealthy_providers": 0,
|
||||
"providers": [
|
||||
{
|
||||
"name": "ollama-local",
|
||||
"type": "ollama",
|
||||
"status": "healthy",
|
||||
"priority": 1,
|
||||
"default_model": "llama3.2",
|
||||
},
|
||||
{
|
||||
"name": "openai-backup",
|
||||
"type": "openai",
|
||||
"status": "degraded",
|
||||
"priority": 2,
|
||||
"default_model": "gpt-4o-mini",
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
response = client.get("/api/v1/router/status")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_providers"] == 2
|
||||
assert data["healthy_providers"] == 1
|
||||
assert data["degraded_providers"] == 1
|
||||
assert len(data["providers"]) == 2
|
||||
|
||||
|
||||
class TestMetricsEndpoint:
|
||||
"""Test /metrics endpoint."""
|
||||
|
||||
def test_get_metrics(self, mock_router):
|
||||
"""Test getting detailed metrics."""
|
||||
client, mock = mock_router
|
||||
# Setup the mock return value on the mock_router object
|
||||
mock.get_metrics = MagicMock(return_value={
|
||||
"providers": [
|
||||
{
|
||||
"name": "ollama-local",
|
||||
"type": "ollama",
|
||||
"status": "healthy",
|
||||
"circuit_state": "closed",
|
||||
"metrics": {
|
||||
"total_requests": 100,
|
||||
"successful": 98,
|
||||
"failed": 2,
|
||||
"error_rate": 0.02,
|
||||
"avg_latency_ms": 150.5,
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
response = client.get("/api/v1/router/metrics")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["providers"]) == 1
|
||||
metrics = data["providers"][0]["metrics"]
|
||||
assert metrics["total_requests"] == 100
|
||||
assert metrics["error_rate"] == 0.02
|
||||
assert metrics["avg_latency_ms"] == 150.5
|
||||
|
||||
|
||||
class TestListProvidersEndpoint:
|
||||
"""Test /providers endpoint."""
|
||||
|
||||
def test_list_providers(self, mock_router):
|
||||
"""Test listing all providers."""
|
||||
client, mock = mock_router
|
||||
|
||||
response = client.get("/api/v1/router/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 2
|
||||
|
||||
# Check first provider
|
||||
assert data[0]["name"] == "ollama-local"
|
||||
assert data[0]["type"] == "ollama"
|
||||
assert data[0]["enabled"] is True
|
||||
assert data[0]["priority"] == 1
|
||||
assert data[0]["default_model"] == "llama3.2"
|
||||
assert "llama3.2" in data[0]["models"]
|
||||
|
||||
|
||||
class TestControlProviderEndpoint:
|
||||
"""Test /providers/{name}/control endpoint."""
|
||||
|
||||
def test_disable_provider(self, mock_router):
|
||||
"""Test disabling a provider."""
|
||||
client, mock = mock_router
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/router/providers/ollama-local/control",
|
||||
json={"action": "disable"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "disabled" in response.json()["message"]
|
||||
|
||||
# Check that the provider was disabled
|
||||
provider = mock.providers[0]
|
||||
assert provider.enabled is False
|
||||
assert provider.status == ProviderStatus.DISABLED
|
||||
|
||||
def test_enable_provider(self, mock_router):
|
||||
"""Test enabling a provider."""
|
||||
client, mock = mock_router
|
||||
# First disable it
|
||||
mock.providers[0].enabled = False
|
||||
mock.providers[0].status = ProviderStatus.DISABLED
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/router/providers/ollama-local/control",
|
||||
json={"action": "enable"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "enabled" in response.json()["message"]
|
||||
assert mock.providers[0].enabled is True
|
||||
|
||||
def test_reset_circuit(self, mock_router):
|
||||
"""Test resetting circuit breaker."""
|
||||
client, mock = mock_router
|
||||
# Set to open state
|
||||
mock.providers[0].circuit_state = CircuitState.OPEN
|
||||
mock.providers[0].status = ProviderStatus.UNHEALTHY
|
||||
mock.providers[0].metrics.consecutive_failures = 10
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/router/providers/ollama-local/control",
|
||||
json={"action": "reset_circuit"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "reset" in response.json()["message"]
|
||||
|
||||
provider = mock.providers[0]
|
||||
assert provider.circuit_state == CircuitState.CLOSED
|
||||
assert provider.status == ProviderStatus.HEALTHY
|
||||
assert provider.metrics.consecutive_failures == 0
|
||||
|
||||
def test_control_unknown_provider(self, mock_router):
|
||||
"""Test controlling unknown provider returns 404."""
|
||||
client, mock = mock_router
|
||||
response = client.post(
|
||||
"/api/v1/router/providers/unknown/control",
|
||||
json={"action": "disable"}
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"]
|
||||
|
||||
def test_control_unknown_action(self, mock_router):
|
||||
"""Test unknown action returns 400."""
|
||||
client, mock = mock_router
|
||||
response = client.post(
|
||||
"/api/v1/router/providers/ollama-local/control",
|
||||
json={"action": "invalid_action"}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Unknown action" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestHealthCheckEndpoint:
|
||||
"""Test /health-check endpoint."""
|
||||
|
||||
def test_health_check_all_healthy(self, mock_router):
|
||||
"""Test health check when all providers are healthy."""
|
||||
client, mock = mock_router
|
||||
|
||||
with patch.object(mock, "_check_provider_available") as mock_check:
|
||||
mock_check.return_value = True
|
||||
|
||||
response = client.post("/api/v1/router/health-check")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["healthy_count"] == 2
|
||||
assert len(data["providers"]) == 2
|
||||
|
||||
for p in data["providers"]:
|
||||
assert p["healthy"] is True
|
||||
|
||||
def test_health_check_with_failure(self, mock_router):
|
||||
"""Test health check when some providers fail."""
|
||||
client, mock = mock_router
|
||||
|
||||
with patch.object(mock, "_check_provider_available") as mock_check:
|
||||
# First provider fails, second succeeds
|
||||
mock_check.side_effect = [False, True]
|
||||
|
||||
response = client.post("/api/v1/router/health-check")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["healthy_count"] == 1
|
||||
assert data["providers"][0]["healthy"] is False
|
||||
assert data["providers"][1]["healthy"] is True
|
||||
|
||||
|
||||
class TestGetConfigEndpoint:
|
||||
"""Test /config endpoint."""
|
||||
|
||||
def test_get_config(self, mock_router):
|
||||
"""Test getting router configuration."""
|
||||
client, mock = mock_router
|
||||
|
||||
response = client.get("/api/v1/router/config")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["timeout_seconds"] == 30
|
||||
assert data["max_retries_per_provider"] == 2
|
||||
assert "circuit_breaker" in data
|
||||
assert data["circuit_breaker"]["failure_threshold"] == 5
|
||||
|
||||
# Check providers list (without secrets)
|
||||
assert len(data["providers"]) == 2
|
||||
assert "api_key" not in data["providers"][0]
|
||||
523
tests/test_router_cascade.py
Normal file
523
tests/test_router_cascade.py
Normal file
@@ -0,0 +1,523 @@
|
||||
"""Tests for Cascade LLM Router."""
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from router.cascade import (
|
||||
CascadeRouter,
|
||||
CircuitState,
|
||||
Provider,
|
||||
ProviderMetrics,
|
||||
ProviderStatus,
|
||||
RouterConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestProviderMetrics:
|
||||
"""Test provider metrics tracking."""
|
||||
|
||||
def test_empty_metrics(self):
|
||||
"""Test metrics with no requests."""
|
||||
metrics = ProviderMetrics()
|
||||
assert metrics.total_requests == 0
|
||||
assert metrics.avg_latency_ms == 0.0
|
||||
assert metrics.error_rate == 0.0
|
||||
|
||||
def test_avg_latency_calculation(self):
|
||||
"""Test average latency calculation."""
|
||||
metrics = ProviderMetrics(
|
||||
total_requests=4,
|
||||
total_latency_ms=1000.0, # 4 requests, 1000ms total
|
||||
)
|
||||
assert metrics.avg_latency_ms == 250.0
|
||||
|
||||
def test_error_rate_calculation(self):
|
||||
"""Test error rate calculation."""
|
||||
metrics = ProviderMetrics(
|
||||
total_requests=10,
|
||||
successful_requests=7,
|
||||
failed_requests=3,
|
||||
)
|
||||
assert metrics.error_rate == 0.3
|
||||
|
||||
|
||||
class TestProvider:
|
||||
"""Test Provider dataclass."""
|
||||
|
||||
def test_get_default_model(self):
|
||||
"""Test getting default model."""
|
||||
provider = Provider(
|
||||
name="test",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
models=[
|
||||
{"name": "llama3", "default": True},
|
||||
{"name": "mistral"},
|
||||
],
|
||||
)
|
||||
assert provider.get_default_model() == "llama3"
|
||||
|
||||
def test_get_default_model_no_default(self):
|
||||
"""Test getting first model when no default set."""
|
||||
provider = Provider(
|
||||
name="test",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
models=[
|
||||
{"name": "llama3"},
|
||||
{"name": "mistral"},
|
||||
],
|
||||
)
|
||||
assert provider.get_default_model() == "llama3"
|
||||
|
||||
def test_get_default_model_empty(self):
|
||||
"""Test with no models."""
|
||||
provider = Provider(
|
||||
name="test",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
models=[],
|
||||
)
|
||||
assert provider.get_default_model() is None
|
||||
|
||||
|
||||
class TestRouterConfig:
|
||||
"""Test router configuration."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default configuration values."""
|
||||
config = RouterConfig()
|
||||
assert config.timeout_seconds == 30
|
||||
assert config.max_retries_per_provider == 2
|
||||
assert config.retry_delay_seconds == 1
|
||||
assert config.circuit_breaker_failure_threshold == 5
|
||||
|
||||
|
||||
class TestCascadeRouterInit:
|
||||
"""Test CascadeRouter initialization."""
|
||||
|
||||
def test_init_without_config(self, tmp_path):
|
||||
"""Test initialization without config file."""
|
||||
router = CascadeRouter(config_path=tmp_path / "nonexistent.yaml")
|
||||
assert len(router.providers) == 0
|
||||
assert router.config.timeout_seconds == 30
|
||||
|
||||
def test_init_with_config(self, tmp_path):
|
||||
"""Test initialization with config file."""
|
||||
config = {
|
||||
"cascade": {
|
||||
"timeout_seconds": 60,
|
||||
"max_retries_per_provider": 3,
|
||||
},
|
||||
"providers": [
|
||||
{
|
||||
"name": "test-ollama",
|
||||
"type": "ollama",
|
||||
"enabled": False, # Disabled to avoid availability check
|
||||
"priority": 1,
|
||||
"url": "http://localhost:11434",
|
||||
}
|
||||
],
|
||||
}
|
||||
config_path = tmp_path / "providers.yaml"
|
||||
config_path.write_text(yaml.dump(config))
|
||||
|
||||
router = CascadeRouter(config_path=config_path)
|
||||
assert router.config.timeout_seconds == 60
|
||||
assert router.config.max_retries_per_provider == 3
|
||||
assert len(router.providers) == 0 # Provider is disabled
|
||||
|
||||
def test_env_var_expansion(self, tmp_path, monkeypatch):
|
||||
"""Test environment variable expansion in config."""
|
||||
monkeypatch.setenv("TEST_API_KEY", "secret123")
|
||||
|
||||
config = {
|
||||
"cascade": {},
|
||||
"providers": [
|
||||
{
|
||||
"name": "test-openai",
|
||||
"type": "openai",
|
||||
"enabled": True,
|
||||
"priority": 1,
|
||||
"api_key": "${TEST_API_KEY}",
|
||||
}
|
||||
],
|
||||
}
|
||||
config_path = tmp_path / "providers.yaml"
|
||||
config_path.write_text(yaml.dump(config))
|
||||
|
||||
router = CascadeRouter(config_path=config_path)
|
||||
assert len(router.providers) == 1
|
||||
assert router.providers[0].api_key == "secret123"
|
||||
|
||||
|
||||
class TestCascadeRouterMetrics:
|
||||
"""Test metrics tracking."""
|
||||
|
||||
def test_record_success(self):
|
||||
"""Test recording successful request."""
|
||||
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
|
||||
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router._record_success(provider, 150.0)
|
||||
|
||||
assert provider.metrics.total_requests == 1
|
||||
assert provider.metrics.successful_requests == 1
|
||||
assert provider.metrics.total_latency_ms == 150.0
|
||||
assert provider.metrics.consecutive_failures == 0
|
||||
|
||||
def test_record_failure(self):
|
||||
"""Test recording failed request."""
|
||||
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
|
||||
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router._record_failure(provider)
|
||||
|
||||
assert provider.metrics.total_requests == 1
|
||||
assert provider.metrics.failed_requests == 1
|
||||
assert provider.metrics.consecutive_failures == 1
|
||||
|
||||
def test_circuit_breaker_opens(self):
|
||||
"""Test circuit breaker opens after failures."""
|
||||
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
|
||||
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router.config.circuit_breaker_failure_threshold = 3
|
||||
|
||||
# Record 3 failures
|
||||
for _ in range(3):
|
||||
router._record_failure(provider)
|
||||
|
||||
assert provider.circuit_state == CircuitState.OPEN
|
||||
assert provider.status == ProviderStatus.UNHEALTHY
|
||||
assert provider.circuit_opened_at is not None
|
||||
|
||||
def test_circuit_breaker_can_close(self):
|
||||
"""Test circuit breaker can transition to closed."""
|
||||
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
|
||||
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router.config.circuit_breaker_failure_threshold = 3
|
||||
router.config.circuit_breaker_recovery_timeout = 1
|
||||
|
||||
# Open the circuit
|
||||
for _ in range(3):
|
||||
router._record_failure(provider)
|
||||
|
||||
assert provider.circuit_state == CircuitState.OPEN
|
||||
|
||||
# Wait for recovery timeout
|
||||
time.sleep(1.1)
|
||||
|
||||
# Check if can close
|
||||
assert router._can_close_circuit(provider) is True
|
||||
|
||||
def test_half_open_to_closed(self):
|
||||
"""Test circuit breaker closes after successful test calls."""
|
||||
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
|
||||
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router.config.circuit_breaker_half_open_max_calls = 2
|
||||
|
||||
# Manually set to half-open
|
||||
provider.circuit_state = CircuitState.HALF_OPEN
|
||||
provider.half_open_calls = 0
|
||||
|
||||
# Record successful calls
|
||||
router._record_success(provider, 100.0)
|
||||
assert provider.circuit_state == CircuitState.HALF_OPEN # Still half-open
|
||||
|
||||
router._record_success(provider, 100.0)
|
||||
assert provider.circuit_state == CircuitState.CLOSED # Now closed
|
||||
assert provider.status == ProviderStatus.HEALTHY
|
||||
|
||||
|
||||
class TestCascadeRouterGetMetrics:
|
||||
"""Test get_metrics method."""
|
||||
|
||||
def test_get_metrics_empty(self):
|
||||
"""Test getting metrics with no providers."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
metrics = router.get_metrics()
|
||||
|
||||
assert "providers" in metrics
|
||||
assert len(metrics["providers"]) == 0
|
||||
|
||||
def test_get_metrics_with_providers(self):
|
||||
"""Test getting metrics with providers."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
# Add a test provider
|
||||
provider = Provider(
|
||||
name="test",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
)
|
||||
provider.metrics.total_requests = 10
|
||||
provider.metrics.successful_requests = 8
|
||||
provider.metrics.failed_requests = 2
|
||||
provider.metrics.total_latency_ms = 2000.0
|
||||
|
||||
router.providers = [provider]
|
||||
|
||||
metrics = router.get_metrics()
|
||||
|
||||
assert len(metrics["providers"]) == 1
|
||||
p_metrics = metrics["providers"][0]
|
||||
assert p_metrics["name"] == "test"
|
||||
assert p_metrics["metrics"]["total_requests"] == 10
|
||||
assert p_metrics["metrics"]["error_rate"] == 0.2
|
||||
assert p_metrics["metrics"]["avg_latency_ms"] == 200.0
|
||||
|
||||
|
||||
class TestCascadeRouterGetStatus:
|
||||
"""Test get_status method."""
|
||||
|
||||
def test_get_status(self):
|
||||
"""Test getting router status."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
provider = Provider(
|
||||
name="test",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
models=[{"name": "llama3", "default": True}],
|
||||
)
|
||||
router.providers = [provider]
|
||||
|
||||
status = router.get_status()
|
||||
|
||||
assert status["total_providers"] == 1
|
||||
assert status["healthy_providers"] == 1
|
||||
assert status["degraded_providers"] == 0
|
||||
assert status["unhealthy_providers"] == 0
|
||||
assert len(status["providers"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCascadeRouterComplete:
|
||||
"""Test complete method with failover."""
|
||||
|
||||
async def test_complete_with_ollama(self):
|
||||
"""Test successful completion with Ollama."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
provider = Provider(
|
||||
name="ollama-local",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
url="http://localhost:11434",
|
||||
models=[{"name": "llama3.2", "default": True}],
|
||||
)
|
||||
router.providers = [provider]
|
||||
|
||||
# Mock the Ollama call
|
||||
with patch.object(router, "_call_ollama") as mock_call:
|
||||
mock_call.return_value = AsyncMock()()
|
||||
mock_call.return_value = {
|
||||
"content": "Hello, world!",
|
||||
"model": "llama3.2",
|
||||
}
|
||||
|
||||
result = await router.complete(
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
)
|
||||
|
||||
assert result["content"] == "Hello, world!"
|
||||
assert result["provider"] == "ollama-local"
|
||||
assert result["model"] == "llama3.2"
|
||||
|
||||
async def test_failover_to_second_provider(self):
|
||||
"""Test failover when first provider fails."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
provider1 = Provider(
|
||||
name="ollama-failing",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
url="http://localhost:11434",
|
||||
models=[{"name": "llama3.2", "default": True}],
|
||||
)
|
||||
provider2 = Provider(
|
||||
name="ollama-backup",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=2,
|
||||
url="http://backup:11434",
|
||||
models=[{"name": "llama3.2", "default": True}],
|
||||
)
|
||||
router.providers = [provider1, provider2]
|
||||
|
||||
# First provider fails, second succeeds
|
||||
call_count = [0]
|
||||
|
||||
async def side_effect(*args, **kwargs):
|
||||
call_count[0] += 1
|
||||
# First 2 retries for provider1 fail, then provider2 succeeds
|
||||
if call_count[0] <= router.config.max_retries_per_provider:
|
||||
raise RuntimeError("Connection failed")
|
||||
return {"content": "Backup response", "model": "llama3.2"}
|
||||
|
||||
with patch.object(router, "_call_ollama") as mock_call:
|
||||
mock_call.side_effect = side_effect
|
||||
|
||||
result = await router.complete(
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
)
|
||||
|
||||
assert result["content"] == "Backup response"
|
||||
assert result["provider"] == "ollama-backup"
|
||||
|
||||
async def test_all_providers_fail(self):
|
||||
"""Test error when all providers fail."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
provider = Provider(
|
||||
name="failing",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
models=[{"name": "llama3.2", "default": True}],
|
||||
)
|
||||
router.providers = [provider]
|
||||
|
||||
with patch.object(router, "_call_ollama") as mock_call:
|
||||
mock_call.side_effect = RuntimeError("Always fails")
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await router.complete(messages=[{"role": "user", "content": "Hi"}])
|
||||
|
||||
assert "All providers failed" in str(exc_info.value)
|
||||
|
||||
async def test_skips_unhealthy_provider(self):
|
||||
"""Test that unhealthy providers are skipped."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
provider1 = Provider(
|
||||
name="unhealthy",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
status=ProviderStatus.UNHEALTHY,
|
||||
circuit_state=CircuitState.OPEN,
|
||||
circuit_opened_at=time.time(), # Just opened
|
||||
models=[{"name": "llama3.2", "default": True}],
|
||||
)
|
||||
provider2 = Provider(
|
||||
name="healthy",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=2,
|
||||
models=[{"name": "llama3.2", "default": True}],
|
||||
)
|
||||
router.providers = [provider1, provider2]
|
||||
|
||||
with patch.object(router, "_call_ollama") as mock_call:
|
||||
mock_call.return_value = {"content": "Success", "model": "llama3.2"}
|
||||
|
||||
result = await router.complete(
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
)
|
||||
|
||||
# Should use the healthy provider
|
||||
assert result["provider"] == "healthy"
|
||||
|
||||
|
||||
class TestProviderAvailabilityCheck:
|
||||
"""Test provider availability checking."""
|
||||
|
||||
def test_check_ollama_without_requests(self):
|
||||
"""Test Ollama returns True when requests not available (fallback)."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
provider = Provider(
|
||||
name="ollama",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
url="http://localhost:11434",
|
||||
)
|
||||
|
||||
# When requests is None, assume available
|
||||
import router.cascade as cascade_module
|
||||
old_requests = cascade_module.requests
|
||||
cascade_module.requests = None
|
||||
try:
|
||||
assert router._check_provider_available(provider) is True
|
||||
finally:
|
||||
cascade_module.requests = old_requests
|
||||
|
||||
def test_check_openai_with_key(self):
|
||||
"""Test OpenAI with API key."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
provider = Provider(
|
||||
name="openai",
|
||||
type="openai",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
api_key="sk-test123",
|
||||
)
|
||||
|
||||
assert router._check_provider_available(provider) is True
|
||||
|
||||
def test_check_openai_without_key(self):
|
||||
"""Test OpenAI without API key."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
provider = Provider(
|
||||
name="openai",
|
||||
type="openai",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
assert router._check_provider_available(provider) is False
|
||||
|
||||
def test_check_airllm_installed(self):
|
||||
"""Test AirLLM when installed."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
provider = Provider(
|
||||
name="airllm",
|
||||
type="airllm",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
)
|
||||
|
||||
with patch("builtins.__import__") as mock_import:
|
||||
mock_import.return_value = MagicMock()
|
||||
assert router._check_provider_available(provider) is True
|
||||
|
||||
def test_check_airllm_not_installed(self):
|
||||
"""Test AirLLM when not installed."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
provider = Provider(
|
||||
name="airllm",
|
||||
type="airllm",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
)
|
||||
|
||||
# Patch __import__ to simulate airllm not being available
|
||||
def raise_import_error(name, *args, **kwargs):
|
||||
if name == "airllm":
|
||||
raise ImportError("No module named 'airllm'")
|
||||
return __builtins__.__import__(name, *args, **kwargs)
|
||||
|
||||
with patch("builtins.__import__", side_effect=raise_import_error):
|
||||
assert router._check_provider_available(provider) is False
|
||||
Reference in New Issue
Block a user