forked from Rockachopa/Timmy-time-dashboard
Compare commits
1 Commits
claude/iss
...
kimi/issue
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e2d662257a |
@@ -528,18 +528,14 @@ class CascadeRouter:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _filter_providers(self, cascade_tier: str | None) -> list["Provider"]:
|
def _get_providers_for_tier(self, cascade_tier: str | None) -> list[Provider]:
|
||||||
"""Return the provider list filtered by tier.
|
"""Filter providers by tier, returning eligible providers."""
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If a tier is specified but no matching providers exist.
|
|
||||||
"""
|
|
||||||
if cascade_tier == "frontier_required":
|
if cascade_tier == "frontier_required":
|
||||||
providers = [p for p in self.providers if p.type == "anthropic"]
|
providers = [p for p in self.providers if p.type == "anthropic"]
|
||||||
if not providers:
|
if not providers:
|
||||||
raise RuntimeError("No Anthropic provider configured for 'frontier_required' tier.")
|
raise RuntimeError("No Anthropic provider configured for 'frontier_required' tier.")
|
||||||
return providers
|
return providers
|
||||||
if cascade_tier:
|
elif cascade_tier:
|
||||||
providers = [p for p in self.providers if p.tier == cascade_tier]
|
providers = [p for p in self.providers if p.tier == cascade_tier]
|
||||||
if not providers:
|
if not providers:
|
||||||
raise RuntimeError(f"No providers found for tier: {cascade_tier}")
|
raise RuntimeError(f"No providers found for tier: {cascade_tier}")
|
||||||
@@ -548,18 +544,19 @@ class CascadeRouter:
|
|||||||
|
|
||||||
async def _try_single_provider(
|
async def _try_single_provider(
|
||||||
self,
|
self,
|
||||||
provider: "Provider",
|
provider: Provider,
|
||||||
messages: list[dict],
|
messages: list[dict],
|
||||||
model: str | None,
|
model: str | None,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
max_tokens: int | None,
|
max_tokens: int | None,
|
||||||
content_type: ContentType,
|
content_type: ContentType,
|
||||||
errors: list[str],
|
|
||||||
) -> dict | None:
|
) -> dict | None:
|
||||||
"""Attempt one provider, returning a result dict on success or None on failure.
|
"""Attempt a single provider request.
|
||||||
|
|
||||||
On failure the error string is appended to *errors* and the provider's
|
Returns:
|
||||||
failure metrics are updated so the caller can move on to the next provider.
|
Response dict on success, None if provider should be skipped.
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the provider attempt fails.
|
||||||
"""
|
"""
|
||||||
if not self._is_provider_available(provider):
|
if not self._is_provider_available(provider):
|
||||||
return None
|
return None
|
||||||
@@ -575,14 +572,14 @@ class CascadeRouter:
|
|||||||
|
|
||||||
selected_model, is_fallback_model = self._select_model(provider, model, content_type)
|
selected_model, is_fallback_model = self._select_model(provider, model, content_type)
|
||||||
|
|
||||||
try:
|
result = await self._attempt_with_retry(
|
||||||
result = await self._attempt_with_retry(
|
provider,
|
||||||
provider, messages, selected_model, temperature, max_tokens, content_type
|
messages,
|
||||||
)
|
selected_model,
|
||||||
except RuntimeError as exc:
|
temperature,
|
||||||
errors.append(str(exc))
|
max_tokens,
|
||||||
self._record_failure(provider)
|
content_type,
|
||||||
return None
|
)
|
||||||
|
|
||||||
self._record_success(provider, result.get("latency_ms", 0))
|
self._record_success(provider, result.get("latency_ms", 0))
|
||||||
return {
|
return {
|
||||||
@@ -627,14 +624,23 @@ class CascadeRouter:
|
|||||||
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
|
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
|
||||||
|
|
||||||
errors: list[str] = []
|
errors: list[str] = []
|
||||||
providers = self._filter_providers(cascade_tier)
|
providers = self._get_providers_for_tier(cascade_tier)
|
||||||
|
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
result = await self._try_single_provider(
|
try:
|
||||||
provider, messages, model, temperature, max_tokens, content_type, errors
|
result = await self._try_single_provider(
|
||||||
)
|
provider,
|
||||||
if result is not None:
|
messages,
|
||||||
return result
|
model,
|
||||||
|
temperature,
|
||||||
|
max_tokens,
|
||||||
|
content_type,
|
||||||
|
)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
except RuntimeError as exc:
|
||||||
|
errors.append(str(exc))
|
||||||
|
self._record_failure(provider)
|
||||||
|
|
||||||
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
|
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,21 +1,10 @@
|
|||||||
"""Tests for the async event bus (infrastructure.events.bus)."""
|
"""Tests for the async event bus (infrastructure.events.bus)."""
|
||||||
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import infrastructure.events.bus as bus_module
|
from infrastructure.events.bus import Event, EventBus, emit, event_bus, on
|
||||||
from infrastructure.events.bus import (
|
|
||||||
Event,
|
|
||||||
EventBus,
|
|
||||||
emit,
|
|
||||||
event_bus,
|
|
||||||
get_event_bus,
|
|
||||||
init_event_bus_persistence,
|
|
||||||
on,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestEvent:
|
class TestEvent:
|
||||||
@@ -360,111 +349,3 @@ class TestEventBusPersistence:
|
|||||||
assert mode == "wal"
|
assert mode == "wal"
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
async def test_persist_event_exception_is_swallowed(self, tmp_path):
|
|
||||||
"""_persist_event must not propagate SQLite errors."""
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
bus = EventBus()
|
|
||||||
bus.enable_persistence(tmp_path / "events.db")
|
|
||||||
|
|
||||||
# Make the INSERT raise an OperationalError
|
|
||||||
mock_conn = MagicMock()
|
|
||||||
mock_conn.execute.side_effect = sqlite3.OperationalError("simulated failure")
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def fake_ctx():
|
|
||||||
yield mock_conn
|
|
||||||
|
|
||||||
with patch.object(bus, "_get_persistence_conn", fake_ctx):
|
|
||||||
# Should not raise
|
|
||||||
bus._persist_event(Event(type="x", source="s"))
|
|
||||||
|
|
||||||
async def test_replay_exception_returns_empty(self, tmp_path):
|
|
||||||
"""replay() must return [] when SQLite query fails."""
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
bus = EventBus()
|
|
||||||
bus.enable_persistence(tmp_path / "events.db")
|
|
||||||
|
|
||||||
mock_conn = MagicMock()
|
|
||||||
mock_conn.execute.side_effect = sqlite3.OperationalError("simulated failure")
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def fake_ctx():
|
|
||||||
yield mock_conn
|
|
||||||
|
|
||||||
with patch.object(bus, "_get_persistence_conn", fake_ctx):
|
|
||||||
result = bus.replay()
|
|
||||||
assert result == []
|
|
||||||
|
|
||||||
|
|
||||||
# ── Singleton helpers ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestSingletonHelpers:
|
|
||||||
"""Test get_event_bus(), init_event_bus_persistence(), and module __getattr__."""
|
|
||||||
|
|
||||||
def test_get_event_bus_returns_same_instance(self):
|
|
||||||
"""get_event_bus() is a true singleton."""
|
|
||||||
a = get_event_bus()
|
|
||||||
b = get_event_bus()
|
|
||||||
assert a is b
|
|
||||||
|
|
||||||
def test_module_event_bus_attr_is_singleton(self):
|
|
||||||
"""Accessing bus_module.event_bus via __getattr__ returns the singleton."""
|
|
||||||
assert bus_module.event_bus is get_event_bus()
|
|
||||||
|
|
||||||
def test_module_getattr_unknown_raises(self):
|
|
||||||
"""Accessing an unknown module attribute raises AttributeError."""
|
|
||||||
with pytest.raises(AttributeError):
|
|
||||||
_ = bus_module.no_such_attr # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
def test_init_event_bus_persistence_sets_path(self, tmp_path):
|
|
||||||
"""init_event_bus_persistence() enables persistence on the singleton."""
|
|
||||||
bus = get_event_bus()
|
|
||||||
original_path = bus._persistence_db_path
|
|
||||||
try:
|
|
||||||
bus._persistence_db_path = None # reset for the test
|
|
||||||
db_path = tmp_path / "test_init.db"
|
|
||||||
init_event_bus_persistence(db_path)
|
|
||||||
assert bus._persistence_db_path == db_path
|
|
||||||
finally:
|
|
||||||
bus._persistence_db_path = original_path
|
|
||||||
|
|
||||||
def test_init_event_bus_persistence_is_idempotent(self, tmp_path):
|
|
||||||
"""Calling init_event_bus_persistence() twice keeps the first path."""
|
|
||||||
bus = get_event_bus()
|
|
||||||
original_path = bus._persistence_db_path
|
|
||||||
try:
|
|
||||||
bus._persistence_db_path = None
|
|
||||||
first_path = tmp_path / "first.db"
|
|
||||||
second_path = tmp_path / "second.db"
|
|
||||||
init_event_bus_persistence(first_path)
|
|
||||||
init_event_bus_persistence(second_path) # should be ignored
|
|
||||||
assert bus._persistence_db_path == first_path
|
|
||||||
finally:
|
|
||||||
bus._persistence_db_path = original_path
|
|
||||||
|
|
||||||
def test_init_event_bus_persistence_default_path(self):
|
|
||||||
"""init_event_bus_persistence() uses 'data/events.db' when no path given."""
|
|
||||||
bus = get_event_bus()
|
|
||||||
original_path = bus._persistence_db_path
|
|
||||||
try:
|
|
||||||
bus._persistence_db_path = None
|
|
||||||
# Patch enable_persistence to capture what path it receives
|
|
||||||
captured = {}
|
|
||||||
|
|
||||||
def fake_enable(path: Path) -> None:
|
|
||||||
captured["path"] = path
|
|
||||||
|
|
||||||
with patch.object(bus, "enable_persistence", side_effect=fake_enable):
|
|
||||||
init_event_bus_persistence()
|
|
||||||
|
|
||||||
assert captured["path"] == Path("data/events.db")
|
|
||||||
finally:
|
|
||||||
bus._persistence_db_path = original_path
|
|
||||||
|
|||||||
@@ -1376,141 +1376,3 @@ class TestIsProviderAvailable:
|
|||||||
result = router._is_provider_available(provider)
|
result = router._is_provider_available(provider)
|
||||||
assert result is True
|
assert result is True
|
||||||
assert provider.circuit_state == CircuitState.HALF_OPEN
|
assert provider.circuit_state == CircuitState.HALF_OPEN
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
class TestFilterProviders:
|
|
||||||
"""Test _filter_providers helper extracted from complete()."""
|
|
||||||
|
|
||||||
def _router(self) -> CascadeRouter:
|
|
||||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
||||||
router.providers = [
|
|
||||||
Provider(
|
|
||||||
name="anthropic-p",
|
|
||||||
type="anthropic",
|
|
||||||
enabled=True,
|
|
||||||
priority=1,
|
|
||||||
api_key="key",
|
|
||||||
tier="frontier",
|
|
||||||
),
|
|
||||||
Provider(
|
|
||||||
name="ollama-p",
|
|
||||||
type="ollama",
|
|
||||||
enabled=True,
|
|
||||||
priority=2,
|
|
||||||
tier="local",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return router
|
|
||||||
|
|
||||||
def test_no_tier_returns_all_providers(self):
|
|
||||||
router = self._router()
|
|
||||||
result = router._filter_providers(None)
|
|
||||||
assert result is router.providers
|
|
||||||
|
|
||||||
def test_frontier_required_returns_only_anthropic(self):
|
|
||||||
router = self._router()
|
|
||||||
result = router._filter_providers("frontier_required")
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0].type == "anthropic"
|
|
||||||
|
|
||||||
def test_frontier_required_no_anthropic_raises(self):
|
|
||||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
|
||||||
router.providers = [
|
|
||||||
Provider(name="ollama-p", type="ollama", enabled=True, priority=1)
|
|
||||||
]
|
|
||||||
with pytest.raises(RuntimeError, match="No Anthropic provider configured"):
|
|
||||||
router._filter_providers("frontier_required")
|
|
||||||
|
|
||||||
def test_named_tier_filters_by_tier(self):
|
|
||||||
router = self._router()
|
|
||||||
result = router._filter_providers("local")
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0].name == "ollama-p"
|
|
||||||
|
|
||||||
def test_named_tier_not_found_raises(self):
|
|
||||||
router = self._router()
|
|
||||||
with pytest.raises(RuntimeError, match="No providers found for tier"):
|
|
||||||
router._filter_providers("nonexistent")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
class TestTrySingleProvider:
|
|
||||||
"""Test _try_single_provider helper extracted from complete()."""
|
|
||||||
|
|
||||||
def _router(self) -> CascadeRouter:
|
|
||||||
return CascadeRouter(config_path=Path("/nonexistent"))
|
|
||||||
|
|
||||||
def _provider(self, name: str = "test", ptype: str = "ollama") -> Provider:
|
|
||||||
return Provider(
|
|
||||||
name=name,
|
|
||||||
type=ptype,
|
|
||||||
enabled=True,
|
|
||||||
priority=1,
|
|
||||||
models=[{"name": "llama3.2", "default": True}],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def test_unavailable_provider_returns_none(self):
|
|
||||||
router = self._router()
|
|
||||||
provider = self._provider()
|
|
||||||
provider.enabled = False
|
|
||||||
errors: list[str] = []
|
|
||||||
result = await router._try_single_provider(
|
|
||||||
provider, [], None, 0.7, None, ContentType.TEXT, errors
|
|
||||||
)
|
|
||||||
assert result is None
|
|
||||||
assert errors == []
|
|
||||||
|
|
||||||
async def test_quota_blocked_cloud_provider_returns_none(self):
|
|
||||||
router = self._router()
|
|
||||||
provider = self._provider(ptype="anthropic")
|
|
||||||
errors: list[str] = []
|
|
||||||
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
|
||||||
mock_qm.select_model.return_value = "qwen3:14b" # non-cloud → ACTIVE tier
|
|
||||||
mock_qm.check.return_value = None
|
|
||||||
result = await router._try_single_provider(
|
|
||||||
provider, [], None, 0.7, None, ContentType.TEXT, errors
|
|
||||||
)
|
|
||||||
assert result is None
|
|
||||||
assert errors == []
|
|
||||||
|
|
||||||
async def test_success_returns_result_dict(self):
|
|
||||||
router = self._router()
|
|
||||||
provider = self._provider()
|
|
||||||
errors: list[str] = []
|
|
||||||
with patch.object(router, "_call_ollama") as mock_call:
|
|
||||||
mock_call.return_value = {"content": "hi", "model": "llama3.2"}
|
|
||||||
result = await router._try_single_provider(
|
|
||||||
provider,
|
|
||||||
[{"role": "user", "content": "hi"}],
|
|
||||||
None,
|
|
||||||
0.7,
|
|
||||||
None,
|
|
||||||
ContentType.TEXT,
|
|
||||||
errors,
|
|
||||||
)
|
|
||||||
assert result is not None
|
|
||||||
assert result["content"] == "hi"
|
|
||||||
assert result["provider"] == "test"
|
|
||||||
assert errors == []
|
|
||||||
|
|
||||||
async def test_failure_appends_error_and_returns_none(self):
|
|
||||||
router = self._router()
|
|
||||||
provider = self._provider()
|
|
||||||
errors: list[str] = []
|
|
||||||
with patch.object(router, "_call_ollama") as mock_call:
|
|
||||||
mock_call.side_effect = RuntimeError("boom")
|
|
||||||
result = await router._try_single_provider(
|
|
||||||
provider,
|
|
||||||
[{"role": "user", "content": "hi"}],
|
|
||||||
None,
|
|
||||||
0.7,
|
|
||||||
None,
|
|
||||||
ContentType.TEXT,
|
|
||||||
errors,
|
|
||||||
)
|
|
||||||
assert result is None
|
|
||||||
assert len(errors) == 1
|
|
||||||
assert "boom" in errors[0]
|
|
||||||
assert provider.metrics.failed_requests == 1
|
|
||||||
|
|||||||
Reference in New Issue
Block a user