Compare commits
4 Commits
claude/iss
...
claude/iss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e69228b793 | ||
| 8f8061e224 | |||
| c78922ccbc | |||
| f3093e9dea |
@@ -528,6 +528,71 @@ class CascadeRouter:
|
||||
|
||||
return True
|
||||
|
||||
def _filter_providers(self, cascade_tier: str | None) -> list["Provider"]:
|
||||
"""Return the provider list filtered by tier.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If a tier is specified but no matching providers exist.
|
||||
"""
|
||||
if cascade_tier == "frontier_required":
|
||||
providers = [p for p in self.providers if p.type == "anthropic"]
|
||||
if not providers:
|
||||
raise RuntimeError("No Anthropic provider configured for 'frontier_required' tier.")
|
||||
return providers
|
||||
if cascade_tier:
|
||||
providers = [p for p in self.providers if p.tier == cascade_tier]
|
||||
if not providers:
|
||||
raise RuntimeError(f"No providers found for tier: {cascade_tier}")
|
||||
return providers
|
||||
return self.providers
|
||||
|
||||
async def _try_single_provider(
|
||||
self,
|
||||
provider: "Provider",
|
||||
messages: list[dict],
|
||||
model: str | None,
|
||||
temperature: float,
|
||||
max_tokens: int | None,
|
||||
content_type: ContentType,
|
||||
errors: list[str],
|
||||
) -> dict | None:
|
||||
"""Attempt one provider, returning a result dict on success or None on failure.
|
||||
|
||||
On failure the error string is appended to *errors* and the provider's
|
||||
failure metrics are updated so the caller can move on to the next provider.
|
||||
"""
|
||||
if not self._is_provider_available(provider):
|
||||
return None
|
||||
|
||||
# Metabolic protocol: skip cloud providers when quota is low
|
||||
if provider.type in ("anthropic", "openai", "grok"):
|
||||
if not self._quota_allows_cloud(provider):
|
||||
logger.info(
|
||||
"Metabolic protocol: skipping cloud provider %s (quota too low)",
|
||||
provider.name,
|
||||
)
|
||||
return None
|
||||
|
||||
selected_model, is_fallback_model = self._select_model(provider, model, content_type)
|
||||
|
||||
try:
|
||||
result = await self._attempt_with_retry(
|
||||
provider, messages, selected_model, temperature, max_tokens, content_type
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
errors.append(str(exc))
|
||||
self._record_failure(provider)
|
||||
return None
|
||||
|
||||
self._record_success(provider, result.get("latency_ms", 0))
|
||||
return {
|
||||
"content": result["content"],
|
||||
"provider": provider.name,
|
||||
"model": result.get("model", selected_model or provider.get_default_model()),
|
||||
"latency_ms": result.get("latency_ms", 0),
|
||||
"is_fallback_model": is_fallback_model,
|
||||
}
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict],
|
||||
@@ -561,55 +626,15 @@ class CascadeRouter:
|
||||
if content_type != ContentType.TEXT:
|
||||
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
|
||||
|
||||
errors = []
|
||||
|
||||
providers = self.providers
|
||||
if cascade_tier == "frontier_required":
|
||||
providers = [p for p in self.providers if p.type == "anthropic"]
|
||||
if not providers:
|
||||
raise RuntimeError("No Anthropic provider configured for 'frontier_required' tier.")
|
||||
elif cascade_tier:
|
||||
providers = [p for p in self.providers if p.tier == cascade_tier]
|
||||
if not providers:
|
||||
raise RuntimeError(f"No providers found for tier: {cascade_tier}")
|
||||
errors: list[str] = []
|
||||
providers = self._filter_providers(cascade_tier)
|
||||
|
||||
for provider in providers:
|
||||
if not self._is_provider_available(provider):
|
||||
continue
|
||||
|
||||
# Metabolic protocol: skip cloud providers when quota is low
|
||||
if provider.type in ("anthropic", "openai", "grok"):
|
||||
if not self._quota_allows_cloud(provider):
|
||||
logger.info(
|
||||
"Metabolic protocol: skipping cloud provider %s (quota too low)",
|
||||
provider.name,
|
||||
)
|
||||
continue
|
||||
|
||||
selected_model, is_fallback_model = self._select_model(provider, model, content_type)
|
||||
|
||||
try:
|
||||
result = await self._attempt_with_retry(
|
||||
provider,
|
||||
messages,
|
||||
selected_model,
|
||||
temperature,
|
||||
max_tokens,
|
||||
content_type,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
errors.append(str(exc))
|
||||
self._record_failure(provider)
|
||||
continue
|
||||
|
||||
self._record_success(provider, result.get("latency_ms", 0))
|
||||
return {
|
||||
"content": result["content"],
|
||||
"provider": provider.name,
|
||||
"model": result.get("model", selected_model or provider.get_default_model()),
|
||||
"latency_ms": result.get("latency_ms", 0),
|
||||
"is_fallback_model": is_fallback_model,
|
||||
}
|
||||
result = await self._try_single_provider(
|
||||
provider, messages, model, temperature, max_tokens, content_type, errors
|
||||
)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
|
||||
|
||||
|
||||
@@ -110,6 +110,92 @@ async def _get_or_create_label(
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dispatch action helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _apply_label_to_issue(
|
||||
client: Any,
|
||||
base_url: str,
|
||||
headers: dict,
|
||||
repo: str,
|
||||
issue_number: int,
|
||||
label_name: str,
|
||||
) -> bool:
|
||||
"""Get-or-create the label then apply it to the issue. Returns True on success."""
|
||||
label_id = await _get_or_create_label(client, base_url, headers, repo, label_name)
|
||||
if label_id is None:
|
||||
return False
|
||||
resp = await client.post(
|
||||
f"{base_url}/repos/{repo}/issues/{issue_number}/labels",
|
||||
headers=headers,
|
||||
json={"labels": [label_id]},
|
||||
)
|
||||
return resp.status_code in (200, 201)
|
||||
|
||||
|
||||
async def _post_dispatch_comment(
|
||||
client: Any,
|
||||
base_url: str,
|
||||
headers: dict,
|
||||
repo: str,
|
||||
issue: TriagedIssue,
|
||||
label_name: str,
|
||||
) -> bool:
|
||||
"""Post the vassal routing comment. Returns True on success."""
|
||||
agent_name = issue.agent_target.value.capitalize()
|
||||
comment_body = (
|
||||
f"🤖 **Vassal dispatch** → routed to **{agent_name}**\n\n"
|
||||
f"Priority score: {issue.priority_score} \n"
|
||||
f"Rationale: {issue.rationale} \n"
|
||||
f"Label: `{label_name}`"
|
||||
)
|
||||
resp = await client.post(
|
||||
f"{base_url}/repos/{repo}/issues/{issue.number}/comments",
|
||||
headers=headers,
|
||||
json={"body": comment_body},
|
||||
)
|
||||
return resp.status_code in (200, 201)
|
||||
|
||||
|
||||
async def _perform_gitea_dispatch(
|
||||
issue: TriagedIssue,
|
||||
record: DispatchRecord,
|
||||
) -> None:
|
||||
"""Apply label and post comment via Gitea. Mutates *record* in-place."""
|
||||
try:
|
||||
import httpx
|
||||
|
||||
from config import settings
|
||||
except ImportError as exc:
|
||||
logger.warning("dispatch_issue: missing dependency — %s", exc)
|
||||
return
|
||||
|
||||
if not settings.gitea_enabled or not settings.gitea_token:
|
||||
logger.info("dispatch_issue: Gitea disabled — skipping label/comment")
|
||||
return
|
||||
|
||||
base_url = f"{settings.gitea_url}/api/v1"
|
||||
repo = settings.gitea_repo
|
||||
headers = {
|
||||
"Authorization": f"token {settings.gitea_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
label_name = _LABEL_MAP[issue.agent_target]
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
record.label_applied = await _apply_label_to_issue(
|
||||
client, base_url, headers, repo, issue.number, label_name
|
||||
)
|
||||
record.comment_posted = await _post_dispatch_comment(
|
||||
client, base_url, headers, repo, issue, label_name
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("dispatch_issue: Gitea action failed — %s", exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dispatch action
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -144,58 +230,7 @@ async def dispatch_issue(issue: TriagedIssue) -> DispatchRecord:
|
||||
_registry[issue.number] = record
|
||||
return record
|
||||
|
||||
try:
|
||||
import httpx
|
||||
|
||||
from config import settings
|
||||
except ImportError as exc:
|
||||
logger.warning("dispatch_issue: missing dependency — %s", exc)
|
||||
_registry[issue.number] = record
|
||||
return record
|
||||
|
||||
if not settings.gitea_enabled or not settings.gitea_token:
|
||||
logger.info("dispatch_issue: Gitea disabled — skipping label/comment")
|
||||
_registry[issue.number] = record
|
||||
return record
|
||||
|
||||
base_url = f"{settings.gitea_url}/api/v1"
|
||||
repo = settings.gitea_repo
|
||||
headers = {
|
||||
"Authorization": f"token {settings.gitea_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
label_name = _LABEL_MAP[issue.agent_target]
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
label_id = await _get_or_create_label(client, base_url, headers, repo, label_name)
|
||||
|
||||
# Apply label
|
||||
if label_id is not None:
|
||||
resp = await client.post(
|
||||
f"{base_url}/repos/{repo}/issues/{issue.number}/labels",
|
||||
headers=headers,
|
||||
json={"labels": [label_id]},
|
||||
)
|
||||
record.label_applied = resp.status_code in (200, 201)
|
||||
|
||||
# Post routing comment
|
||||
agent_name = issue.agent_target.value.capitalize()
|
||||
comment_body = (
|
||||
f"🤖 **Vassal dispatch** → routed to **{agent_name}**\n\n"
|
||||
f"Priority score: {issue.priority_score} \n"
|
||||
f"Rationale: {issue.rationale} \n"
|
||||
f"Label: `{label_name}`"
|
||||
)
|
||||
resp = await client.post(
|
||||
f"{base_url}/repos/{repo}/issues/{issue.number}/comments",
|
||||
headers=headers,
|
||||
json={"body": comment_body},
|
||||
)
|
||||
record.comment_posted = resp.status_code in (200, 201)
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("dispatch_issue: Gitea action failed — %s", exc)
|
||||
await _perform_gitea_dispatch(issue, record)
|
||||
|
||||
_registry[issue.number] = record
|
||||
logger.info(
|
||||
|
||||
@@ -95,6 +95,106 @@ def _get_config_dir() -> Path:
|
||||
return DEFAULT_CONFIG_DIR
|
||||
|
||||
|
||||
def _load_daily_run_config() -> dict[str, Any]:
|
||||
"""Load and validate the daily run configuration."""
|
||||
config_path = _get_config_dir() / "daily_run.json"
|
||||
config = _load_json_config(config_path)
|
||||
|
||||
if not config:
|
||||
console.print("[yellow]No daily run configuration found.[/yellow]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _display_schedules_table(schedules: dict[str, Any]) -> None:
|
||||
"""Display the daily run schedules in a table."""
|
||||
table = Table(title="Daily Run Schedules")
|
||||
table.add_column("Schedule", style="cyan")
|
||||
table.add_column("Description", style="green")
|
||||
table.add_column("Automations", style="yellow")
|
||||
|
||||
for schedule_name, schedule_data in schedules.items():
|
||||
automations = schedule_data.get("automations", [])
|
||||
table.add_row(
|
||||
schedule_name,
|
||||
schedule_data.get("description", ""),
|
||||
", ".join(automations) if automations else "—",
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
||||
|
||||
def _display_triggers_table(triggers: dict[str, Any]) -> None:
|
||||
"""Display the triggers in a table."""
|
||||
trigger_table = Table(title="Triggers")
|
||||
trigger_table.add_column("Trigger", style="cyan")
|
||||
trigger_table.add_column("Description", style="green")
|
||||
trigger_table.add_column("Automations", style="yellow")
|
||||
|
||||
for trigger_name, trigger_data in triggers.items():
|
||||
automations = trigger_data.get("automations", [])
|
||||
trigger_table.add_row(
|
||||
trigger_name,
|
||||
trigger_data.get("description", ""),
|
||||
", ".join(automations) if automations else "—",
|
||||
)
|
||||
|
||||
console.print(trigger_table)
|
||||
console.print()
|
||||
|
||||
|
||||
def _execute_automation(auto: dict[str, Any], verbose: bool) -> None:
|
||||
"""Execute a single automation and display results."""
|
||||
cmd = auto.get("command")
|
||||
name = auto.get("name", auto.get("id", "unnamed"))
|
||||
if not cmd:
|
||||
console.print(f"[yellow]Skipping {name} — no command defined.[/yellow]")
|
||||
return
|
||||
|
||||
console.print(f"[cyan]▶ Running: {name}[/cyan]")
|
||||
if verbose:
|
||||
console.print(f"[dim] $ {cmd}[/dim]")
|
||||
|
||||
try:
|
||||
result = subprocess.run( # noqa: S602
|
||||
cmd,
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
if result.stdout.strip():
|
||||
console.print(result.stdout.strip())
|
||||
if result.returncode != 0:
|
||||
console.print(f"[red] ✗ {name} exited with code {result.returncode}[/red]")
|
||||
if result.stderr.strip():
|
||||
console.print(f"[red]{result.stderr.strip()}[/red]")
|
||||
else:
|
||||
console.print(f"[green] ✓ {name} completed successfully[/green]")
|
||||
except subprocess.TimeoutExpired:
|
||||
console.print(f"[red] ✗ {name} timed out after 120s[/red]")
|
||||
except Exception as exc:
|
||||
console.print(f"[red] ✗ {name} failed: {exc}[/red]")
|
||||
|
||||
|
||||
def _execute_all_automations(verbose: bool) -> None:
|
||||
"""Execute all enabled automations."""
|
||||
console.print("[green]Executing daily run automations...[/green]")
|
||||
auto_config_path = _get_config_dir() / "automations.json"
|
||||
auto_config = _load_json_config(auto_config_path)
|
||||
all_automations = auto_config.get("automations", [])
|
||||
enabled = [a for a in all_automations if a.get("enabled", False)]
|
||||
|
||||
if not enabled:
|
||||
console.print("[yellow]No enabled automations found.[/yellow]")
|
||||
return
|
||||
|
||||
for auto in enabled:
|
||||
_execute_automation(auto, verbose)
|
||||
|
||||
|
||||
@app.command()
|
||||
def daily_run(
|
||||
dry_run: bool = typer.Option(
|
||||
@@ -113,93 +213,22 @@ def daily_run(
|
||||
console.print("[bold green]Timmy Daily Run[/bold green]")
|
||||
console.print()
|
||||
|
||||
config_path = _get_config_dir() / "daily_run.json"
|
||||
config = _load_json_config(config_path)
|
||||
|
||||
if not config:
|
||||
console.print("[yellow]No daily run configuration found.[/yellow]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
config = _load_daily_run_config()
|
||||
schedules = config.get("schedules", {})
|
||||
triggers = config.get("triggers", {})
|
||||
|
||||
if verbose:
|
||||
config_path = _get_config_dir() / "daily_run.json"
|
||||
console.print(f"[dim]Config loaded from: {config_path}[/dim]")
|
||||
console.print()
|
||||
|
||||
# Show the daily run schedule
|
||||
table = Table(title="Daily Run Schedules")
|
||||
table.add_column("Schedule", style="cyan")
|
||||
table.add_column("Description", style="green")
|
||||
table.add_column("Automations", style="yellow")
|
||||
|
||||
for schedule_name, schedule_data in schedules.items():
|
||||
automations = schedule_data.get("automations", [])
|
||||
table.add_row(
|
||||
schedule_name,
|
||||
schedule_data.get("description", ""),
|
||||
", ".join(automations) if automations else "—",
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
||||
# Show triggers
|
||||
trigger_table = Table(title="Triggers")
|
||||
trigger_table.add_column("Trigger", style="cyan")
|
||||
trigger_table.add_column("Description", style="green")
|
||||
trigger_table.add_column("Automations", style="yellow")
|
||||
|
||||
for trigger_name, trigger_data in triggers.items():
|
||||
automations = trigger_data.get("automations", [])
|
||||
trigger_table.add_row(
|
||||
trigger_name,
|
||||
trigger_data.get("description", ""),
|
||||
", ".join(automations) if automations else "—",
|
||||
)
|
||||
|
||||
console.print(trigger_table)
|
||||
console.print()
|
||||
_display_schedules_table(schedules)
|
||||
_display_triggers_table(triggers)
|
||||
|
||||
if dry_run:
|
||||
console.print("[yellow]Dry run mode — no actions executed.[/yellow]")
|
||||
else:
|
||||
console.print("[green]Executing daily run automations...[/green]")
|
||||
auto_config_path = _get_config_dir() / "automations.json"
|
||||
auto_config = _load_json_config(auto_config_path)
|
||||
all_automations = auto_config.get("automations", [])
|
||||
enabled = [a for a in all_automations if a.get("enabled", False)]
|
||||
if not enabled:
|
||||
console.print("[yellow]No enabled automations found.[/yellow]")
|
||||
for auto in enabled:
|
||||
cmd = auto.get("command")
|
||||
name = auto.get("name", auto.get("id", "unnamed"))
|
||||
if not cmd:
|
||||
console.print(f"[yellow]Skipping {name} — no command defined.[/yellow]")
|
||||
continue
|
||||
console.print(f"[cyan]▶ Running: {name}[/cyan]")
|
||||
if verbose:
|
||||
console.print(f"[dim] $ {cmd}[/dim]")
|
||||
try:
|
||||
result = subprocess.run( # noqa: S602
|
||||
cmd,
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
if result.stdout.strip():
|
||||
console.print(result.stdout.strip())
|
||||
if result.returncode != 0:
|
||||
console.print(f"[red] ✗ {name} exited with code {result.returncode}[/red]")
|
||||
if result.stderr.strip():
|
||||
console.print(f"[red]{result.stderr.strip()}[/red]")
|
||||
else:
|
||||
console.print(f"[green] ✓ {name} completed successfully[/green]")
|
||||
except subprocess.TimeoutExpired:
|
||||
console.print(f"[red] ✗ {name} timed out after 120s[/red]")
|
||||
except Exception as exc:
|
||||
console.print(f"[red] ✗ {name} failed: {exc}[/red]")
|
||||
_execute_all_automations(verbose)
|
||||
|
||||
|
||||
@app.command()
|
||||
|
||||
@@ -1,10 +1,21 @@
|
||||
"""Tests for the async event bus (infrastructure.events.bus)."""
|
||||
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.events.bus import Event, EventBus, emit, event_bus, on
|
||||
import infrastructure.events.bus as bus_module
|
||||
from infrastructure.events.bus import (
|
||||
Event,
|
||||
EventBus,
|
||||
emit,
|
||||
event_bus,
|
||||
get_event_bus,
|
||||
init_event_bus_persistence,
|
||||
on,
|
||||
)
|
||||
|
||||
|
||||
class TestEvent:
|
||||
@@ -349,3 +360,111 @@ class TestEventBusPersistence:
|
||||
assert mode == "wal"
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
async def test_persist_event_exception_is_swallowed(self, tmp_path):
|
||||
"""_persist_event must not propagate SQLite errors."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
bus = EventBus()
|
||||
bus.enable_persistence(tmp_path / "events.db")
|
||||
|
||||
# Make the INSERT raise an OperationalError
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.execute.side_effect = sqlite3.OperationalError("simulated failure")
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def fake_ctx():
|
||||
yield mock_conn
|
||||
|
||||
with patch.object(bus, "_get_persistence_conn", fake_ctx):
|
||||
# Should not raise
|
||||
bus._persist_event(Event(type="x", source="s"))
|
||||
|
||||
async def test_replay_exception_returns_empty(self, tmp_path):
|
||||
"""replay() must return [] when SQLite query fails."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
bus = EventBus()
|
||||
bus.enable_persistence(tmp_path / "events.db")
|
||||
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.execute.side_effect = sqlite3.OperationalError("simulated failure")
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def fake_ctx():
|
||||
yield mock_conn
|
||||
|
||||
with patch.object(bus, "_get_persistence_conn", fake_ctx):
|
||||
result = bus.replay()
|
||||
assert result == []
|
||||
|
||||
|
||||
# ── Singleton helpers ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSingletonHelpers:
|
||||
"""Test get_event_bus(), init_event_bus_persistence(), and module __getattr__."""
|
||||
|
||||
def test_get_event_bus_returns_same_instance(self):
|
||||
"""get_event_bus() is a true singleton."""
|
||||
a = get_event_bus()
|
||||
b = get_event_bus()
|
||||
assert a is b
|
||||
|
||||
def test_module_event_bus_attr_is_singleton(self):
|
||||
"""Accessing bus_module.event_bus via __getattr__ returns the singleton."""
|
||||
assert bus_module.event_bus is get_event_bus()
|
||||
|
||||
def test_module_getattr_unknown_raises(self):
|
||||
"""Accessing an unknown module attribute raises AttributeError."""
|
||||
with pytest.raises(AttributeError):
|
||||
_ = bus_module.no_such_attr # type: ignore[attr-defined]
|
||||
|
||||
def test_init_event_bus_persistence_sets_path(self, tmp_path):
|
||||
"""init_event_bus_persistence() enables persistence on the singleton."""
|
||||
bus = get_event_bus()
|
||||
original_path = bus._persistence_db_path
|
||||
try:
|
||||
bus._persistence_db_path = None # reset for the test
|
||||
db_path = tmp_path / "test_init.db"
|
||||
init_event_bus_persistence(db_path)
|
||||
assert bus._persistence_db_path == db_path
|
||||
finally:
|
||||
bus._persistence_db_path = original_path
|
||||
|
||||
def test_init_event_bus_persistence_is_idempotent(self, tmp_path):
|
||||
"""Calling init_event_bus_persistence() twice keeps the first path."""
|
||||
bus = get_event_bus()
|
||||
original_path = bus._persistence_db_path
|
||||
try:
|
||||
bus._persistence_db_path = None
|
||||
first_path = tmp_path / "first.db"
|
||||
second_path = tmp_path / "second.db"
|
||||
init_event_bus_persistence(first_path)
|
||||
init_event_bus_persistence(second_path) # should be ignored
|
||||
assert bus._persistence_db_path == first_path
|
||||
finally:
|
||||
bus._persistence_db_path = original_path
|
||||
|
||||
def test_init_event_bus_persistence_default_path(self):
|
||||
"""init_event_bus_persistence() uses 'data/events.db' when no path given."""
|
||||
bus = get_event_bus()
|
||||
original_path = bus._persistence_db_path
|
||||
try:
|
||||
bus._persistence_db_path = None
|
||||
# Patch enable_persistence to capture what path it receives
|
||||
captured = {}
|
||||
|
||||
def fake_enable(path: Path) -> None:
|
||||
captured["path"] = path
|
||||
|
||||
with patch.object(bus, "enable_persistence", side_effect=fake_enable):
|
||||
init_event_bus_persistence()
|
||||
|
||||
assert captured["path"] == Path("data/events.db")
|
||||
finally:
|
||||
bus._persistence_db_path = original_path
|
||||
|
||||
@@ -1376,3 +1376,141 @@ class TestIsProviderAvailable:
|
||||
result = router._is_provider_available(provider)
|
||||
assert result is True
|
||||
assert provider.circuit_state == CircuitState.HALF_OPEN
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFilterProviders:
|
||||
"""Test _filter_providers helper extracted from complete()."""
|
||||
|
||||
def _router(self) -> CascadeRouter:
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router.providers = [
|
||||
Provider(
|
||||
name="anthropic-p",
|
||||
type="anthropic",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
api_key="key",
|
||||
tier="frontier",
|
||||
),
|
||||
Provider(
|
||||
name="ollama-p",
|
||||
type="ollama",
|
||||
enabled=True,
|
||||
priority=2,
|
||||
tier="local",
|
||||
),
|
||||
]
|
||||
return router
|
||||
|
||||
def test_no_tier_returns_all_providers(self):
|
||||
router = self._router()
|
||||
result = router._filter_providers(None)
|
||||
assert result is router.providers
|
||||
|
||||
def test_frontier_required_returns_only_anthropic(self):
|
||||
router = self._router()
|
||||
result = router._filter_providers("frontier_required")
|
||||
assert len(result) == 1
|
||||
assert result[0].type == "anthropic"
|
||||
|
||||
def test_frontier_required_no_anthropic_raises(self):
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router.providers = [
|
||||
Provider(name="ollama-p", type="ollama", enabled=True, priority=1)
|
||||
]
|
||||
with pytest.raises(RuntimeError, match="No Anthropic provider configured"):
|
||||
router._filter_providers("frontier_required")
|
||||
|
||||
def test_named_tier_filters_by_tier(self):
|
||||
router = self._router()
|
||||
result = router._filter_providers("local")
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "ollama-p"
|
||||
|
||||
def test_named_tier_not_found_raises(self):
|
||||
router = self._router()
|
||||
with pytest.raises(RuntimeError, match="No providers found for tier"):
|
||||
router._filter_providers("nonexistent")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
class TestTrySingleProvider:
|
||||
"""Test _try_single_provider helper extracted from complete()."""
|
||||
|
||||
def _router(self) -> CascadeRouter:
|
||||
return CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
def _provider(self, name: str = "test", ptype: str = "ollama") -> Provider:
|
||||
return Provider(
|
||||
name=name,
|
||||
type=ptype,
|
||||
enabled=True,
|
||||
priority=1,
|
||||
models=[{"name": "llama3.2", "default": True}],
|
||||
)
|
||||
|
||||
async def test_unavailable_provider_returns_none(self):
|
||||
router = self._router()
|
||||
provider = self._provider()
|
||||
provider.enabled = False
|
||||
errors: list[str] = []
|
||||
result = await router._try_single_provider(
|
||||
provider, [], None, 0.7, None, ContentType.TEXT, errors
|
||||
)
|
||||
assert result is None
|
||||
assert errors == []
|
||||
|
||||
async def test_quota_blocked_cloud_provider_returns_none(self):
|
||||
router = self._router()
|
||||
provider = self._provider(ptype="anthropic")
|
||||
errors: list[str] = []
|
||||
with patch("infrastructure.router.cascade._quota_monitor") as mock_qm:
|
||||
mock_qm.select_model.return_value = "qwen3:14b" # non-cloud → ACTIVE tier
|
||||
mock_qm.check.return_value = None
|
||||
result = await router._try_single_provider(
|
||||
provider, [], None, 0.7, None, ContentType.TEXT, errors
|
||||
)
|
||||
assert result is None
|
||||
assert errors == []
|
||||
|
||||
async def test_success_returns_result_dict(self):
|
||||
router = self._router()
|
||||
provider = self._provider()
|
||||
errors: list[str] = []
|
||||
with patch.object(router, "_call_ollama") as mock_call:
|
||||
mock_call.return_value = {"content": "hi", "model": "llama3.2"}
|
||||
result = await router._try_single_provider(
|
||||
provider,
|
||||
[{"role": "user", "content": "hi"}],
|
||||
None,
|
||||
0.7,
|
||||
None,
|
||||
ContentType.TEXT,
|
||||
errors,
|
||||
)
|
||||
assert result is not None
|
||||
assert result["content"] == "hi"
|
||||
assert result["provider"] == "test"
|
||||
assert errors == []
|
||||
|
||||
async def test_failure_appends_error_and_returns_none(self):
|
||||
router = self._router()
|
||||
provider = self._provider()
|
||||
errors: list[str] = []
|
||||
with patch.object(router, "_call_ollama") as mock_call:
|
||||
mock_call.side_effect = RuntimeError("boom")
|
||||
result = await router._try_single_provider(
|
||||
provider,
|
||||
[{"role": "user", "content": "hi"}],
|
||||
None,
|
||||
0.7,
|
||||
None,
|
||||
ContentType.TEXT,
|
||||
errors,
|
||||
)
|
||||
assert result is None
|
||||
assert len(errors) == 1
|
||||
assert "boom" in errors[0]
|
||||
assert provider.metrics.failed_requests == 1
|
||||
|
||||
Reference in New Issue
Block a user