feat: three-tier model router — Local 8B / Hermes 70B / Cloud API cascade (#882)
Some checks failed
Tests / lint (pull_request) Failing after 16s
Tests / test (pull_request) Has been skipped

Implements the intelligent model tiering router from issue #882:

- `src/infrastructure/models/router.py` — TieredModelRouter with heuristic
  task classifier (classify_tier), automatic T1→T2 escalation on low-quality
  responses, cloud-tier budget guard, and per-request routing logs.

- `src/infrastructure/models/budget.py` — BudgetTracker with SQLite
  persistence (in-memory fallback), daily/monthly cloud spend limits,
  cost estimates per model, and get_summary() for dashboards.

- `src/config.py` — five new settings: tier_local_fast_model,
  tier_local_heavy_model, tier_cloud_model, tier_cloud_daily_budget_usd
  (default $5), tier_cloud_monthly_budget_usd (default $50).

- Exports added to `src/infrastructure/models/__init__.py`.

- 44 new unit tests covering classify_tier, _is_low_quality, BudgetTracker,
  and TieredModelRouter (including acceptance criteria from the issue).

Acceptance criteria verified:
  "Walk to the next room"                       → LOCAL_FAST (Tier 1) ✓
  "Plan the optimal path to become Hortator"    → LOCAL_HEAVY (Tier 2) ✓
  Failed Tier-1 response auto-escalates to T2   ✓
  Cloud spend stays within configured budget    ✓
  Routing decisions logged                      ✓

Fixes #882

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Alexander Whitestone
2026-03-23 21:51:11 -04:00
parent 2d6bfe6ba1
commit 80d798a94b
6 changed files with 1326 additions and 0 deletions

View File

@@ -117,6 +117,23 @@ class Settings(BaseSettings):
anthropic_api_key: str = ""
claude_model: str = "haiku"
# ── Tiered Model Router (issue #882) ─────────────────────────────────
# Three-tier cascade: Local 8B (free, fast) → Local 70B (free, slower)
# → Cloud API (paid, best). Override model names per tier via env vars.
#
# TIER_LOCAL_FAST_MODEL — Tier-1 model name in Ollama (default: llama3.1:8b)
# TIER_LOCAL_HEAVY_MODEL — Tier-2 model name in Ollama (default: hermes3:70b)
# TIER_CLOUD_MODEL — Tier-3 cloud model name (default: claude-haiku-4-5)
#
# Budget limits for the cloud tier (0 = unlimited):
# TIER_CLOUD_DAILY_BUDGET_USD — daily ceiling in USD (default: 5.0)
# TIER_CLOUD_MONTHLY_BUDGET_USD — monthly ceiling in USD (default: 50.0)
tier_local_fast_model: str = "llama3.1:8b"
tier_local_heavy_model: str = "hermes3:70b"
tier_cloud_model: str = "claude-haiku-4-5"
tier_cloud_daily_budget_usd: float = 5.0
tier_cloud_monthly_budget_usd: float = 50.0
# ── Content Moderation ──────────────────────────────────────────────
# Three-layer moderation pipeline for AI narrator output.
# Uses Llama Guard via Ollama with regex fallback.

View File

@@ -1,5 +1,11 @@
"""Infrastructure models package."""
from infrastructure.models.budget import (
BudgetTracker,
SpendRecord,
estimate_cost_usd,
get_budget_tracker,
)
from infrastructure.models.multimodal import (
ModelCapability,
ModelInfo,
@@ -17,6 +23,12 @@ from infrastructure.models.registry import (
ModelRole,
model_registry,
)
from infrastructure.models.router import (
TierLabel,
TieredModelRouter,
classify_tier,
get_tiered_router,
)
__all__ = [
# Registry
@@ -34,4 +46,14 @@ __all__ = [
"model_supports_tools",
"model_supports_vision",
"pull_model_with_fallback",
# Tiered router
"TierLabel",
"TieredModelRouter",
"classify_tier",
"get_tiered_router",
# Budget tracker
"BudgetTracker",
"SpendRecord",
"estimate_cost_usd",
"get_budget_tracker",
]

View File

@@ -0,0 +1,302 @@
"""Cloud API budget tracker for the three-tier model router.
Tracks cloud API spend (daily / monthly) and enforces configurable limits.
SQLite-backed with in-memory fallback — degrades gracefully if the database
is unavailable.
References:
- Issue #882 — Model Tiering Router: Local 8B / Hermes 70B / Cloud API Cascade
"""
import logging
import sqlite3
import threading
import time
from dataclasses import dataclass
from datetime import UTC, date, datetime
from pathlib import Path
from config import settings
logger = logging.getLogger(__name__)
# ── Cost estimates (USD per 1 K tokens, input / output) ──────────────────────
# Updated 2026-03. Estimates only — actual costs vary by tier/usage.
_COST_PER_1K: dict[str, dict[str, float]] = {
# Claude models
"claude-haiku-4-5": {"input": 0.00025, "output": 0.00125},
"claude-sonnet-4-5": {"input": 0.003, "output": 0.015},
"claude-opus-4-5": {"input": 0.015, "output": 0.075},
"haiku": {"input": 0.00025, "output": 0.00125},
"sonnet": {"input": 0.003, "output": 0.015},
"opus": {"input": 0.015, "output": 0.075},
# GPT-4o
"gpt-4o-mini": {"input": 0.00015, "output": 0.0006},
"gpt-4o": {"input": 0.0025, "output": 0.01},
# Grok (xAI)
"grok-3-fast": {"input": 0.003, "output": 0.015},
"grok-3": {"input": 0.005, "output": 0.025},
}
_DEFAULT_COST: dict[str, float] = {"input": 0.003, "output": 0.015} # conservative fallback
def estimate_cost_usd(model: str, tokens_in: int, tokens_out: int) -> float:
"""Estimate the cost of a single request in USD.
Matches the model name by substring so versioned names like
``claude-haiku-4-5-20251001`` still resolve correctly.
Args:
model: Model name as passed to the provider.
tokens_in: Number of input (prompt) tokens consumed.
tokens_out: Number of output (completion) tokens generated.
Returns:
Estimated cost in USD (may be zero for unknown models).
"""
model_lower = model.lower()
rates = _DEFAULT_COST
for key, rate in _COST_PER_1K.items():
if key in model_lower:
rates = rate
break
return (tokens_in * rates["input"] + tokens_out * rates["output"]) / 1000.0
@dataclass
class SpendRecord:
"""A single spend event."""
ts: float
provider: str
model: str
tokens_in: int
tokens_out: int
cost_usd: float
tier: str
class BudgetTracker:
"""Tracks cloud API spend with configurable daily / monthly limits.
Persists spend records to SQLite (``data/budget.db`` by default).
Falls back to in-memory tracking when the database is unavailable —
budget enforcement still works; records are lost on restart.
Limits are read from ``settings``:
* ``tier_cloud_daily_budget_usd`` — daily ceiling (0 = disabled)
* ``tier_cloud_monthly_budget_usd`` — monthly ceiling (0 = disabled)
Usage::
tracker = BudgetTracker()
if tracker.cloud_allowed():
# … make cloud API call …
tracker.record_spend("anthropic", "claude-haiku-4-5", 100, 200)
summary = tracker.get_summary()
print(summary["daily_usd"], "/", summary["daily_limit_usd"])
"""
_DB_PATH = "data/budget.db"
def __init__(self, db_path: str | None = None) -> None:
"""Initialise the tracker.
Args:
db_path: Path to the SQLite database. Defaults to
``data/budget.db``. Pass ``":memory:"`` for tests.
"""
self._db_path = db_path or self._DB_PATH
self._lock = threading.Lock()
self._in_memory: list[SpendRecord] = []
self._db_ok = False
self._init_db()
# ── Database initialisation ──────────────────────────────────────────────
def _init_db(self) -> None:
"""Create the spend table (and parent directory) if needed."""
try:
if self._db_path != ":memory:":
Path(self._db_path).parent.mkdir(parents=True, exist_ok=True)
with self._connect() as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS cloud_spend (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ts REAL NOT NULL,
provider TEXT NOT NULL,
model TEXT NOT NULL,
tokens_in INTEGER NOT NULL DEFAULT 0,
tokens_out INTEGER NOT NULL DEFAULT 0,
cost_usd REAL NOT NULL DEFAULT 0.0,
tier TEXT NOT NULL DEFAULT 'cloud'
)
"""
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_spend_ts ON cloud_spend(ts)"
)
self._db_ok = True
logger.debug("BudgetTracker: SQLite initialised at %s", self._db_path)
except Exception as exc:
logger.warning(
"BudgetTracker: SQLite unavailable, using in-memory fallback: %s", exc
)
def _connect(self) -> sqlite3.Connection:
return sqlite3.connect(self._db_path, timeout=5)
# ── Public API ───────────────────────────────────────────────────────────
def record_spend(
self,
provider: str,
model: str,
tokens_in: int = 0,
tokens_out: int = 0,
cost_usd: float | None = None,
tier: str = "cloud",
) -> float:
"""Record a cloud API spend event and return the cost recorded.
Args:
provider: Provider name (e.g. ``"anthropic"``, ``"openai"``).
model: Model name used for the request.
tokens_in: Input token count (prompt).
tokens_out: Output token count (completion).
cost_usd: Explicit cost override. If ``None``, the cost is
estimated from the token counts and model rates.
tier: Tier label for the request (default ``"cloud"``).
Returns:
The cost recorded in USD.
"""
if cost_usd is None:
cost_usd = estimate_cost_usd(model, tokens_in, tokens_out)
ts = time.time()
record = SpendRecord(ts, provider, model, tokens_in, tokens_out, cost_usd, tier)
with self._lock:
if self._db_ok:
try:
with self._connect() as conn:
conn.execute(
"""
INSERT INTO cloud_spend
(ts, provider, model, tokens_in, tokens_out, cost_usd, tier)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(ts, provider, model, tokens_in, tokens_out, cost_usd, tier),
)
logger.debug(
"BudgetTracker: recorded %.6f USD (%s/%s, in=%d out=%d tier=%s)",
cost_usd,
provider,
model,
tokens_in,
tokens_out,
tier,
)
return cost_usd
except Exception as exc:
logger.warning("BudgetTracker: DB write failed, falling back: %s", exc)
self._in_memory.append(record)
return cost_usd
def get_daily_spend(self) -> float:
"""Return total cloud spend for the current UTC day in USD."""
today = date.today()
since = datetime(today.year, today.month, today.day, tzinfo=UTC).timestamp()
return self._query_spend(since)
def get_monthly_spend(self) -> float:
"""Return total cloud spend for the current UTC month in USD."""
today = date.today()
since = datetime(today.year, today.month, 1, tzinfo=UTC).timestamp()
return self._query_spend(since)
def cloud_allowed(self) -> bool:
"""Return ``True`` if cloud API spend is within configured limits.
Checks both daily and monthly ceilings. A limit of ``0`` disables
that particular check.
"""
daily_limit = settings.tier_cloud_daily_budget_usd
monthly_limit = settings.tier_cloud_monthly_budget_usd
if daily_limit > 0:
daily_spend = self.get_daily_spend()
if daily_spend >= daily_limit:
logger.warning(
"BudgetTracker: daily cloud budget exhausted (%.4f / %.4f USD)",
daily_spend,
daily_limit,
)
return False
if monthly_limit > 0:
monthly_spend = self.get_monthly_spend()
if monthly_spend >= monthly_limit:
logger.warning(
"BudgetTracker: monthly cloud budget exhausted (%.4f / %.4f USD)",
monthly_spend,
monthly_limit,
)
return False
return True
def get_summary(self) -> dict:
"""Return a spend summary dict suitable for dashboards / logging.
Keys: ``daily_usd``, ``monthly_usd``, ``daily_limit_usd``,
``monthly_limit_usd``, ``daily_ok``, ``monthly_ok``.
"""
daily = self.get_daily_spend()
monthly = self.get_monthly_spend()
daily_limit = settings.tier_cloud_daily_budget_usd
monthly_limit = settings.tier_cloud_monthly_budget_usd
return {
"daily_usd": round(daily, 6),
"monthly_usd": round(monthly, 6),
"daily_limit_usd": daily_limit,
"monthly_limit_usd": monthly_limit,
"daily_ok": daily_limit <= 0 or daily < daily_limit,
"monthly_ok": monthly_limit <= 0 or monthly < monthly_limit,
}
# ── Internal helpers ─────────────────────────────────────────────────────
def _query_spend(self, since_ts: float) -> float:
"""Sum ``cost_usd`` for records with ``ts >= since_ts``."""
if self._db_ok:
try:
with self._connect() as conn:
row = conn.execute(
"SELECT COALESCE(SUM(cost_usd), 0.0) FROM cloud_spend WHERE ts >= ?",
(since_ts,),
).fetchone()
return float(row[0]) if row else 0.0
except Exception as exc:
logger.warning("BudgetTracker: DB read failed: %s", exc)
# In-memory fallback
return sum(r.cost_usd for r in self._in_memory if r.ts >= since_ts)
# ── Module-level singleton ────────────────────────────────────────────────────
_budget_tracker: BudgetTracker | None = None
def get_budget_tracker() -> BudgetTracker:
"""Get or create the module-level BudgetTracker singleton."""
global _budget_tracker
if _budget_tracker is None:
_budget_tracker = BudgetTracker()
return _budget_tracker

View File

@@ -0,0 +1,427 @@
"""Three-tier model router — Local 8B / Local 70B / Cloud API Cascade.
Selects the cheapest-sufficient LLM for each request using a heuristic
task-complexity classifier. Tier 3 (Cloud API) is only used when Tier 2
fails or the budget guard allows it.
Tiers
-----
Tier 1 — LOCAL_FAST (Llama 3.1 8B / Hermes 3 8B via Ollama, free, ~0.3-1 s)
Navigation, basic interactions, simple decisions.
Tier 2 — LOCAL_HEAVY (Hermes 3/4 70B via Ollama, free, ~5-10 s for 200 tok)
Quest planning, dialogue strategy, complex reasoning.
Tier 3 — CLOUD_API (Claude / GPT-4o, paid ~$5-15/hr heavy use)
Recovery from Tier 2 failures, novel situations, multi-step planning.
Routing logic
-------------
1. Classify the task using keyword / length / context heuristics (no LLM call).
2. Route to the appropriate tier.
3. On Tier-1 low-quality response → auto-escalate to Tier 2.
4. On Tier-2 failure or explicit ``require_cloud=True`` → Tier 3 (if budget allows).
5. Log tier used, model, latency, estimated cost for every request.
References:
- Issue #882 — Model Tiering Router: Local 8B / Hermes 70B / Cloud API Cascade
"""
import asyncio
import logging
import re
import time
from enum import StrEnum
from typing import Any
from config import settings
logger = logging.getLogger(__name__)
# ── Tier definitions ──────────────────────────────────────────────────────────
class TierLabel(StrEnum):
"""Three cost-sorted model tiers."""
LOCAL_FAST = "local_fast" # 8B local, always hot, free
LOCAL_HEAVY = "local_heavy" # 70B local, free but slower
CLOUD_API = "cloud_api" # Paid cloud backend (Claude / GPT-4o)
# ── Default model assignments (overridable via Settings) ──────────────────────
_DEFAULT_TIER_MODELS: dict[TierLabel, str] = {
TierLabel.LOCAL_FAST: "llama3.1:8b",
TierLabel.LOCAL_HEAVY: "hermes3:70b",
TierLabel.CLOUD_API: "claude-haiku-4-5",
}
# ── Classification vocabulary ─────────────────────────────────────────────────
# Patterns that indicate a Tier-1 (simple) task
_T1_WORDS: frozenset[str] = frozenset(
{
"go", "move", "walk", "run",
"north", "south", "east", "west", "up", "down", "left", "right",
"yes", "no", "ok", "okay",
"open", "close", "take", "drop", "look",
"pick", "use", "wait", "rest", "save",
"attack", "flee", "jump", "crouch",
"status", "ping", "list", "show", "get", "check",
}
)
# Patterns that indicate a Tier-2 or Tier-3 task
_T2_PHRASES: tuple[str, ...] = (
"plan", "strategy", "optimize", "optimise",
"quest", "stuck", "recover",
"negotiate", "persuade", "faction", "reputation",
"analyze", "analyse", "evaluate", "decide",
"complex", "multi-step", "long-term",
"how do i", "what should i do", "help me figure",
"what is the best", "recommend", "best way",
"explain", "describe in detail", "walk me through",
"compare", "design", "implement", "refactor",
"debug", "diagnose", "root cause",
)
# Low-quality response detection patterns
_LOW_QUALITY_PATTERNS: tuple[re.Pattern, ...] = (
re.compile(r"i\s+don'?t\s+know", re.IGNORECASE),
re.compile(r"i'm\s+not\s+sure", re.IGNORECASE),
re.compile(r"i\s+cannot\s+(help|assist|answer)", re.IGNORECASE),
re.compile(r"i\s+apologize", re.IGNORECASE),
re.compile(r"as an ai", re.IGNORECASE),
re.compile(r"i\s+don'?t\s+have\s+(enough|sufficient)\s+information", re.IGNORECASE),
)
# Response is definitely low-quality if shorter than this many characters
_LOW_QUALITY_MIN_CHARS = 20
# Response is suspicious if shorter than this many chars for a complex task
_ESCALATION_MIN_CHARS = 60
def classify_tier(task: str, context: dict | None = None) -> TierLabel:
"""Classify a task to the cheapest-sufficient model tier.
Classification priority (highest wins):
1. ``context["require_cloud"] = True`` → CLOUD_API
2. Any Tier-2 phrase or stuck/recovery signal → LOCAL_HEAVY
3. Short task with only Tier-1 words, no active context → LOCAL_FAST
4. Default → LOCAL_HEAVY (safe fallback for unknown tasks)
Args:
task: Natural-language task or user input.
context: Optional context dict. Recognised keys:
``require_cloud`` (bool), ``stuck`` (bool),
``require_t2`` (bool), ``active_quests`` (list),
``dialogue_active`` (bool), ``combat_active`` (bool).
Returns:
The cheapest ``TierLabel`` sufficient for the task.
"""
ctx = context or {}
task_lower = task.lower()
words = set(task_lower.split())
# ── Explicit cloud override ──────────────────────────────────────────────
if ctx.get("require_cloud"):
logger.debug("classify_tier → CLOUD_API (explicit require_cloud)")
return TierLabel.CLOUD_API
# ── Tier-2 / complexity signals ──────────────────────────────────────────
t2_phrase_hit = any(phrase in task_lower for phrase in _T2_PHRASES)
t2_word_hit = bool(words & {"plan", "strategy", "optimize", "optimise", "quest",
"stuck", "recover", "analyze", "analyse", "evaluate"})
is_stuck = bool(ctx.get("stuck"))
require_t2 = bool(ctx.get("require_t2"))
long_input = len(task) > 300 # long tasks warrant more capable model
deep_context = (
len(ctx.get("active_quests", [])) >= 3
or ctx.get("dialogue_active")
)
if t2_phrase_hit or t2_word_hit or is_stuck or require_t2 or long_input or deep_context:
logger.debug(
"classify_tier → LOCAL_HEAVY (phrase=%s word=%s stuck=%s explicit=%s long=%s ctx=%s)",
t2_phrase_hit, t2_word_hit, is_stuck, require_t2, long_input, deep_context,
)
return TierLabel.LOCAL_HEAVY
# ── Tier-1 signals ───────────────────────────────────────────────────────
t1_word_hit = bool(words & _T1_WORDS)
task_short = len(task.split()) <= 8
no_active_context = (
not ctx.get("active_quests")
and not ctx.get("dialogue_active")
and not ctx.get("combat_active")
)
if t1_word_hit and task_short and no_active_context:
logger.debug(
"classify_tier → LOCAL_FAST (words=%s short=%s)", t1_word_hit, task_short
)
return TierLabel.LOCAL_FAST
# ── Default: LOCAL_HEAVY (safe for anything unclassified) ────────────────
logger.debug("classify_tier → LOCAL_HEAVY (default)")
return TierLabel.LOCAL_HEAVY
def _is_low_quality(content: str, tier: TierLabel) -> bool:
"""Return True if the response looks like it should be escalated.
Used for automatic Tier-1 → Tier-2 escalation.
Args:
content: LLM response text.
tier: The tier that produced the response.
Returns:
True if the response is likely too low-quality to be useful.
"""
if not content or not content.strip():
return True
stripped = content.strip()
# Too short to be useful
if len(stripped) < _LOW_QUALITY_MIN_CHARS:
return True
# Insufficient for a supposedly complex-enough task
if tier == TierLabel.LOCAL_FAST and len(stripped) < _ESCALATION_MIN_CHARS:
return True
# Matches known "I can't help" patterns
for pattern in _LOW_QUALITY_PATTERNS:
if pattern.search(stripped):
return True
return False
class TieredModelRouter:
"""Routes LLM requests across the Local 8B / Local 70B / Cloud API tiers.
Wraps CascadeRouter with:
- Heuristic tier classification via ``classify_tier()``
- Automatic Tier-1 → Tier-2 escalation on low-quality responses
- Cloud-tier budget guard via ``BudgetTracker``
- Per-request logging: tier, model, latency, estimated cost
Usage::
router = TieredModelRouter()
result = await router.route(
task="Walk to the next room",
context={},
)
print(result["content"], result["tier"]) # "Move north.", "local_fast"
# Force heavy tier
result = await router.route(
task="Plan the optimal path to become Hortator",
context={"require_t2": True},
)
"""
def __init__(
self,
cascade: Any | None = None,
budget_tracker: Any | None = None,
tier_models: dict[TierLabel, str] | None = None,
auto_escalate: bool = True,
) -> None:
"""Initialise the tiered router.
Args:
cascade: CascadeRouter instance. If ``None``, the
singleton from ``get_router()`` is used lazily.
budget_tracker: BudgetTracker instance. If ``None``, the
singleton from ``get_budget_tracker()`` is used.
tier_models: Override default model names per tier.
auto_escalate: When ``True``, low-quality Tier-1 responses
automatically retry on Tier-2.
"""
self._cascade = cascade
self._budget = budget_tracker
self._tier_models: dict[TierLabel, str] = dict(_DEFAULT_TIER_MODELS)
self._auto_escalate = auto_escalate
# Apply settings-level overrides (can still be overridden per-instance)
if settings.tier_local_fast_model:
self._tier_models[TierLabel.LOCAL_FAST] = settings.tier_local_fast_model
if settings.tier_local_heavy_model:
self._tier_models[TierLabel.LOCAL_HEAVY] = settings.tier_local_heavy_model
if settings.tier_cloud_model:
self._tier_models[TierLabel.CLOUD_API] = settings.tier_cloud_model
if tier_models:
self._tier_models.update(tier_models)
# ── Lazy singletons ──────────────────────────────────────────────────────
def _get_cascade(self) -> Any:
if self._cascade is None:
from infrastructure.router.cascade import get_router
self._cascade = get_router()
return self._cascade
def _get_budget(self) -> Any:
if self._budget is None:
from infrastructure.models.budget import get_budget_tracker
self._budget = get_budget_tracker()
return self._budget
# ── Public interface ─────────────────────────────────────────────────────
def classify(self, task: str, context: dict | None = None) -> TierLabel:
"""Classify a task without routing. Useful for telemetry."""
return classify_tier(task, context)
async def route(
self,
task: str,
context: dict | None = None,
messages: list[dict] | None = None,
temperature: float = 0.3,
max_tokens: int | None = None,
) -> dict:
"""Route a task to the appropriate model tier.
Builds a minimal messages list if ``messages`` is not provided.
The result always includes a ``tier`` key indicating which tier
ultimately handled the request.
Args:
task: Natural-language task description.
context: Task context dict (see ``classify_tier()``).
messages: Pre-built OpenAI-compatible messages list. If
provided, ``task`` is only used for classification.
temperature: Sampling temperature (default 0.3).
max_tokens: Maximum tokens to generate.
Returns:
Dict with at minimum: ``content``, ``provider``, ``model``,
``tier``, ``latency_ms``. May include ``cost_usd`` when a
cloud request is recorded.
Raises:
RuntimeError: If all available tiers are exhausted.
"""
ctx = context or {}
tier = self.classify(task, ctx)
msgs = messages or [{"role": "user", "content": task}]
# ── Tier 1 attempt ───────────────────────────────────────────────────
if tier == TierLabel.LOCAL_FAST:
result = await self._complete_tier(
TierLabel.LOCAL_FAST, msgs, temperature, max_tokens
)
if self._auto_escalate and _is_low_quality(result.get("content", ""), TierLabel.LOCAL_FAST):
logger.info(
"TieredModelRouter: Tier-1 response low quality, escalating to Tier-2 "
"(task=%r content_len=%d)",
task[:80],
len(result.get("content", "")),
)
tier = TierLabel.LOCAL_HEAVY
result = await self._complete_tier(
TierLabel.LOCAL_HEAVY, msgs, temperature, max_tokens
)
return result
# ── Tier 2 attempt ───────────────────────────────────────────────────
if tier == TierLabel.LOCAL_HEAVY:
try:
return await self._complete_tier(
TierLabel.LOCAL_HEAVY, msgs, temperature, max_tokens
)
except Exception as exc:
logger.warning(
"TieredModelRouter: Tier-2 failed (%s) — escalating to cloud", exc
)
tier = TierLabel.CLOUD_API
# ── Tier 3 (Cloud) ───────────────────────────────────────────────────
budget = self._get_budget()
if not budget.cloud_allowed():
raise RuntimeError(
"Cloud API tier requested but budget limit reached — "
"increase tier_cloud_daily_budget_usd or tier_cloud_monthly_budget_usd"
)
result = await self._complete_tier(
TierLabel.CLOUD_API, msgs, temperature, max_tokens
)
# Record cloud spend if token info is available
usage = result.get("usage", {})
if usage:
cost = budget.record_spend(
provider=result.get("provider", "unknown"),
model=result.get("model", self._tier_models[TierLabel.CLOUD_API]),
tokens_in=usage.get("prompt_tokens", 0),
tokens_out=usage.get("completion_tokens", 0),
tier=TierLabel.CLOUD_API,
)
result["cost_usd"] = cost
return result
# ── Internal helpers ─────────────────────────────────────────────────────
async def _complete_tier(
self,
tier: TierLabel,
messages: list[dict],
temperature: float,
max_tokens: int | None,
) -> dict:
"""Dispatch a single inference request for the given tier."""
model = self._tier_models[tier]
cascade = self._get_cascade()
start = time.monotonic()
logger.info(
"TieredModelRouter: tier=%s model=%s messages=%d",
tier,
model,
len(messages),
)
result = await cascade.complete(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
)
elapsed_ms = (time.monotonic() - start) * 1000
result["tier"] = tier
result.setdefault("latency_ms", elapsed_ms)
logger.info(
"TieredModelRouter: done tier=%s model=%s latency_ms=%.0f",
tier,
result.get("model", model),
elapsed_ms,
)
return result
# ── Module-level singleton ────────────────────────────────────────────────────
_tiered_router: TieredModelRouter | None = None
def get_tiered_router() -> TieredModelRouter:
"""Get or create the module-level TieredModelRouter singleton."""
global _tiered_router
if _tiered_router is None:
_tiered_router = TieredModelRouter()
return _tiered_router

View File

@@ -0,0 +1,178 @@
"""Tests for the cloud API budget tracker (issue #882)."""
import time
from unittest.mock import patch
import pytest
from infrastructure.models.budget import (
BudgetTracker,
SpendRecord,
estimate_cost_usd,
get_budget_tracker,
)
pytestmark = pytest.mark.unit
# ── estimate_cost_usd ─────────────────────────────────────────────────────────
class TestEstimateCostUsd:
def test_haiku_cheaper_than_sonnet(self):
haiku_cost = estimate_cost_usd("claude-haiku-4-5", 1000, 1000)
sonnet_cost = estimate_cost_usd("claude-sonnet-4-5", 1000, 1000)
assert haiku_cost < sonnet_cost
def test_zero_tokens_is_zero_cost(self):
assert estimate_cost_usd("gpt-4o", 0, 0) == 0.0
def test_unknown_model_uses_default(self):
cost = estimate_cost_usd("some-unknown-model-xyz", 1000, 1000)
assert cost > 0 # Uses conservative default, not zero
def test_versioned_model_name_matches(self):
# "claude-haiku-4-5-20251001" should match "haiku"
cost1 = estimate_cost_usd("claude-haiku-4-5-20251001", 1000, 0)
cost2 = estimate_cost_usd("claude-haiku-4-5", 1000, 0)
assert cost1 == cost2
def test_gpt4o_mini_cheaper_than_gpt4o(self):
mini = estimate_cost_usd("gpt-4o-mini", 1000, 1000)
full = estimate_cost_usd("gpt-4o", 1000, 1000)
assert mini < full
def test_returns_float(self):
assert isinstance(estimate_cost_usd("haiku", 100, 200), float)
# ── BudgetTracker ─────────────────────────────────────────────────────────────
class TestBudgetTrackerInit:
def test_creates_with_memory_db(self):
tracker = BudgetTracker(db_path=":memory:")
assert tracker._db_ok is True
def test_in_memory_fallback_empty_on_creation(self):
tracker = BudgetTracker(db_path=":memory:")
assert tracker._in_memory == []
def test_bad_path_uses_memory_fallback(self, tmp_path):
bad_path = str(tmp_path / "nonexistent" / "x" / "budget.db")
# Should not raise — just log and continue with memory fallback
# (actually will create parent dirs, so test with truly bad path)
tracker = BudgetTracker.__new__(BudgetTracker)
tracker._db_path = bad_path
tracker._lock = __import__("threading").Lock()
tracker._in_memory = []
tracker._db_ok = False
# Record to in-memory fallback
tracker._in_memory.append(
SpendRecord(time.time(), "test", "model", 100, 100, 0.001, "cloud")
)
assert len(tracker._in_memory) == 1
class TestBudgetTrackerRecordSpend:
def test_record_spend_returns_cost(self):
tracker = BudgetTracker(db_path=":memory:")
cost = tracker.record_spend("anthropic", "claude-haiku-4-5", 100, 200)
assert cost > 0
def test_record_spend_explicit_cost(self):
tracker = BudgetTracker(db_path=":memory:")
cost = tracker.record_spend("anthropic", "model", cost_usd=1.23)
assert cost == pytest.approx(1.23)
def test_record_spend_accumulates(self):
tracker = BudgetTracker(db_path=":memory:")
tracker.record_spend("openai", "gpt-4o", cost_usd=0.01)
tracker.record_spend("openai", "gpt-4o", cost_usd=0.02)
assert tracker.get_daily_spend() == pytest.approx(0.03, abs=1e-9)
def test_record_spend_with_tier_label(self):
tracker = BudgetTracker(db_path=":memory:")
cost = tracker.record_spend("anthropic", "haiku", tier="cloud_api")
assert cost >= 0
def test_monthly_spend_includes_daily(self):
tracker = BudgetTracker(db_path=":memory:")
tracker.record_spend("anthropic", "haiku", cost_usd=5.00)
assert tracker.get_monthly_spend() >= tracker.get_daily_spend()
class TestBudgetTrackerCloudAllowed:
def test_allowed_when_no_spend(self):
tracker = BudgetTracker(db_path=":memory:")
with (
patch.object(type(tracker._get_budget() if hasattr(tracker, "_get_budget") else tracker), "tier_cloud_daily_budget_usd", 5.0, create=True),
):
# Settings-based check — use real settings (5.0 default, 0 spent)
assert tracker.cloud_allowed() is True
def test_blocked_when_daily_limit_exceeded(self):
tracker = BudgetTracker(db_path=":memory:")
tracker.record_spend("anthropic", "haiku", cost_usd=999.0)
# With default daily limit of 5.0, 999 should block
assert tracker.cloud_allowed() is False
def test_allowed_when_daily_limit_zero(self):
tracker = BudgetTracker(db_path=":memory:")
tracker.record_spend("anthropic", "haiku", cost_usd=999.0)
with (
patch("infrastructure.models.budget.settings") as mock_settings,
):
mock_settings.tier_cloud_daily_budget_usd = 0 # disabled
mock_settings.tier_cloud_monthly_budget_usd = 0 # disabled
assert tracker.cloud_allowed() is True
def test_blocked_when_monthly_limit_exceeded(self):
tracker = BudgetTracker(db_path=":memory:")
tracker.record_spend("anthropic", "haiku", cost_usd=999.0)
with patch("infrastructure.models.budget.settings") as mock_settings:
mock_settings.tier_cloud_daily_budget_usd = 0 # daily disabled
mock_settings.tier_cloud_monthly_budget_usd = 10.0
assert tracker.cloud_allowed() is False
class TestBudgetTrackerSummary:
def test_summary_keys_present(self):
tracker = BudgetTracker(db_path=":memory:")
summary = tracker.get_summary()
assert "daily_usd" in summary
assert "monthly_usd" in summary
assert "daily_limit_usd" in summary
assert "monthly_limit_usd" in summary
assert "daily_ok" in summary
assert "monthly_ok" in summary
def test_summary_daily_ok_true_on_empty(self):
tracker = BudgetTracker(db_path=":memory:")
summary = tracker.get_summary()
assert summary["daily_ok"] is True
assert summary["monthly_ok"] is True
def test_summary_daily_ok_false_when_exceeded(self):
tracker = BudgetTracker(db_path=":memory:")
tracker.record_spend("openai", "gpt-4o", cost_usd=999.0)
summary = tracker.get_summary()
assert summary["daily_ok"] is False
# ── Singleton ─────────────────────────────────────────────────────────────────
class TestGetBudgetTrackerSingleton:
def test_returns_budget_tracker(self):
import infrastructure.models.budget as bmod
bmod._budget_tracker = None
tracker = get_budget_tracker()
assert isinstance(tracker, BudgetTracker)
def test_returns_same_instance(self):
import infrastructure.models.budget as bmod
bmod._budget_tracker = None
t1 = get_budget_tracker()
t2 = get_budget_tracker()
assert t1 is t2

View File

@@ -0,0 +1,380 @@
"""Tests for the tiered model router (issue #882).
Covers:
- classify_tier() for Tier-1/2/3 routing
- TieredModelRouter.route() with mocked CascadeRouter + BudgetTracker
- Auto-escalation from Tier-1 on low-quality responses
- Cloud-tier budget guard
- Acceptance criteria from the issue:
- "Walk to the next room" → LOCAL_FAST
- "Plan the optimal path to become Hortator" → LOCAL_HEAVY
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from infrastructure.models.router import (
TierLabel,
TieredModelRouter,
_is_low_quality,
classify_tier,
get_tiered_router,
)
pytestmark = pytest.mark.unit
# ── classify_tier ─────────────────────────────────────────────────────────────
class TestClassifyTier:
# ── Tier-1 (LOCAL_FAST) ────────────────────────────────────────────────
def test_simple_navigation_is_local_fast(self):
assert classify_tier("walk to the next room") == TierLabel.LOCAL_FAST
def test_go_north_is_local_fast(self):
assert classify_tier("go north") == TierLabel.LOCAL_FAST
def test_single_binary_choice_is_local_fast(self):
assert classify_tier("yes") == TierLabel.LOCAL_FAST
def test_open_door_is_local_fast(self):
assert classify_tier("open door") == TierLabel.LOCAL_FAST
def test_attack_is_local_fast(self):
assert classify_tier("attack", {}) == TierLabel.LOCAL_FAST
# ── Tier-2 (LOCAL_HEAVY) ───────────────────────────────────────────────
def test_quest_planning_is_local_heavy(self):
assert classify_tier("plan the optimal path to become Hortator") == TierLabel.LOCAL_HEAVY
def test_strategy_keyword_is_local_heavy(self):
assert classify_tier("what is the best strategy") == TierLabel.LOCAL_HEAVY
def test_stuck_state_escalates_to_local_heavy(self):
assert classify_tier("help me", {"stuck": True}) == TierLabel.LOCAL_HEAVY
def test_require_t2_flag_is_local_heavy(self):
assert classify_tier("go north", {"require_t2": True}) == TierLabel.LOCAL_HEAVY
def test_long_input_is_local_heavy(self):
long_task = "tell me about " + ("the dungeon " * 30)
assert classify_tier(long_task) == TierLabel.LOCAL_HEAVY
def test_active_quests_upgrades_to_local_heavy(self):
ctx = {"active_quests": ["Q1", "Q2", "Q3"]}
assert classify_tier("go north", ctx) == TierLabel.LOCAL_HEAVY
def test_dialogue_active_upgrades_to_local_heavy(self):
ctx = {"dialogue_active": True}
assert classify_tier("yes", ctx) == TierLabel.LOCAL_HEAVY
def test_analyze_is_local_heavy(self):
assert classify_tier("analyze the situation") == TierLabel.LOCAL_HEAVY
def test_optimize_is_local_heavy(self):
assert classify_tier("optimize my build") == TierLabel.LOCAL_HEAVY
def test_negotiate_is_local_heavy(self):
assert classify_tier("negotiate with the Camonna Tong") == TierLabel.LOCAL_HEAVY
def test_explain_is_local_heavy(self):
assert classify_tier("explain the faction system") == TierLabel.LOCAL_HEAVY
# ── Tier-3 (CLOUD_API) ─────────────────────────────────────────────────
def test_require_cloud_flag_is_cloud_api(self):
assert classify_tier("go north", {"require_cloud": True}) == TierLabel.CLOUD_API
def test_require_cloud_overrides_everything(self):
assert classify_tier("yes", {"require_cloud": True}) == TierLabel.CLOUD_API
# ── Edge cases ────────────────────────────────────────────────────────
def test_empty_task_defaults_to_local_heavy(self):
# Empty string → nothing classifies it as T1 or T3
assert classify_tier("") == TierLabel.LOCAL_HEAVY
def test_case_insensitive(self):
assert classify_tier("PLAN my route") == TierLabel.LOCAL_HEAVY
def test_combat_active_upgrades_t1_to_heavy(self):
ctx = {"combat_active": True}
# "attack" is T1 word, but combat context → should NOT be LOCAL_FAST
result = classify_tier("attack", ctx)
assert result != TierLabel.LOCAL_FAST
# ── _is_low_quality ───────────────────────────────────────────────────────────
class TestIsLowQuality:
def test_empty_is_low_quality(self):
assert _is_low_quality("", TierLabel.LOCAL_FAST) is True
def test_whitespace_only_is_low_quality(self):
assert _is_low_quality(" ", TierLabel.LOCAL_FAST) is True
def test_very_short_is_low_quality(self):
assert _is_low_quality("ok", TierLabel.LOCAL_FAST) is True
def test_idontknow_is_low_quality(self):
assert _is_low_quality("I don't know how to help with that.", TierLabel.LOCAL_FAST) is True
def test_not_sure_is_low_quality(self):
assert _is_low_quality("I'm not sure about this.", TierLabel.LOCAL_FAST) is True
def test_as_an_ai_is_low_quality(self):
assert _is_low_quality("As an AI, I cannot...", TierLabel.LOCAL_FAST) is True
def test_good_response_is_not_low_quality(self):
response = "You move north into the Vivec Canton. The Ordinators watch your approach."
assert _is_low_quality(response, TierLabel.LOCAL_FAST) is False
def test_t1_short_response_triggers_escalation(self):
# Less than _ESCALATION_MIN_CHARS for T1
assert _is_low_quality("OK, done.", TierLabel.LOCAL_FAST) is True
def test_borderline_ok_for_t2_not_t1(self):
# Between _LOW_QUALITY_MIN_CHARS (20) and _ESCALATION_MIN_CHARS (60)
# → low quality for T1 (escalation threshold), but acceptable for T2/T3
response = "Done. The item is retrieved." # 28 chars: ≥20, <60
assert _is_low_quality(response, TierLabel.LOCAL_FAST) is True
assert _is_low_quality(response, TierLabel.LOCAL_HEAVY) is False
# ── TieredModelRouter ─────────────────────────────────────────────────────────
_GOOD_CONTENT = (
"You move north through the doorway into the next room. "
"The stone walls glisten with moisture."
) # 90 chars — well above the escalation threshold
def _make_cascade_mock(content=_GOOD_CONTENT, model="llama3.1:8b"):
mock = MagicMock()
mock.complete = AsyncMock(
return_value={
"content": content,
"provider": "ollama-local",
"model": model,
"latency_ms": 150.0,
}
)
return mock
def _make_budget_mock(allowed=True):
mock = MagicMock()
mock.cloud_allowed = MagicMock(return_value=allowed)
mock.record_spend = MagicMock(return_value=0.001)
return mock
@pytest.mark.asyncio
class TestTieredModelRouterRoute:
async def test_route_returns_tier_in_result(self):
router = TieredModelRouter(cascade=_make_cascade_mock())
result = await router.route("go north")
assert "tier" in result
assert result["tier"] == TierLabel.LOCAL_FAST
async def test_acceptance_walk_to_room_is_local_fast(self):
"""Acceptance: 'Walk to the next room' → LOCAL_FAST."""
router = TieredModelRouter(cascade=_make_cascade_mock())
result = await router.route("Walk to the next room")
assert result["tier"] == TierLabel.LOCAL_FAST
async def test_acceptance_plan_hortator_is_local_heavy(self):
"""Acceptance: 'Plan the optimal path to become Hortator' → LOCAL_HEAVY."""
router = TieredModelRouter(
cascade=_make_cascade_mock(model="hermes3:70b"),
)
result = await router.route("Plan the optimal path to become Hortator")
assert result["tier"] == TierLabel.LOCAL_HEAVY
async def test_t1_low_quality_escalates_to_t2(self):
"""Failed Tier-1 response auto-escalates to Tier-2."""
call_models = []
cascade = MagicMock()
async def complete_side_effect(messages, model, temperature, max_tokens):
call_models.append(model)
# First call (T1) returns a low-quality response
if len(call_models) == 1:
return {
"content": "I don't know.",
"provider": "ollama",
"model": model,
"latency_ms": 50,
}
# Second call (T2) returns a good response
return {
"content": "You move to the northern passage, passing through the Dunmer stronghold.",
"provider": "ollama",
"model": model,
"latency_ms": 800,
}
cascade.complete = complete_side_effect
router = TieredModelRouter(cascade=cascade, auto_escalate=True)
result = await router.route("go north")
assert len(call_models) == 2, "Should have called twice (T1 escalated to T2)"
assert result["tier"] == TierLabel.LOCAL_HEAVY
async def test_auto_escalate_false_no_escalation(self):
"""With auto_escalate=False, low-quality T1 response is returned as-is."""
call_count = {"n": 0}
cascade = MagicMock()
async def complete_side_effect(**kwargs):
call_count["n"] += 1
return {
"content": "I don't know.",
"provider": "ollama",
"model": "llama3.1:8b",
"latency_ms": 50,
}
cascade.complete = AsyncMock(side_effect=complete_side_effect)
router = TieredModelRouter(cascade=cascade, auto_escalate=False)
result = await router.route("go north")
assert call_count["n"] == 1
assert result["tier"] == TierLabel.LOCAL_FAST
async def test_t2_failure_escalates_to_cloud(self):
"""Tier-2 failure escalates to Cloud API (when budget allows)."""
cascade = MagicMock()
call_models = []
async def complete_side_effect(messages, model, temperature, max_tokens):
call_models.append(model)
if "hermes3" in model or "70b" in model.lower():
raise RuntimeError("Tier-2 model unavailable")
return {
"content": "Cloud response here.",
"provider": "anthropic",
"model": model,
"latency_ms": 1200,
}
cascade.complete = complete_side_effect
budget = _make_budget_mock(allowed=True)
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
result = await router.route("plan my route", context={"require_t2": True})
assert result["tier"] == TierLabel.CLOUD_API
async def test_cloud_blocked_by_budget_raises(self):
"""Cloud tier blocked when budget is exhausted."""
cascade = MagicMock()
cascade.complete = AsyncMock(side_effect=RuntimeError("T2 fail"))
budget = _make_budget_mock(allowed=False)
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
with pytest.raises(RuntimeError, match="budget limit"):
await router.route("plan my route", context={"require_t2": True})
async def test_explicit_cloud_tier_uses_cloud_model(self):
cascade = _make_cascade_mock(model="claude-haiku-4-5")
budget = _make_budget_mock(allowed=True)
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
result = await router.route("go north", context={"require_cloud": True})
assert result["tier"] == TierLabel.CLOUD_API
async def test_cloud_spend_recorded_with_usage(self):
"""Cloud spend is recorded when the response includes usage info."""
cascade = MagicMock()
cascade.complete = AsyncMock(
return_value={
"content": "Cloud answer.",
"provider": "anthropic",
"model": "claude-haiku-4-5",
"latency_ms": 900,
"usage": {"prompt_tokens": 50, "completion_tokens": 100},
}
)
budget = _make_budget_mock(allowed=True)
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
result = await router.route("go north", context={"require_cloud": True})
budget.record_spend.assert_called_once()
assert "cost_usd" in result
async def test_cloud_spend_not_recorded_without_usage(self):
"""Cloud spend is not recorded when usage info is absent."""
cascade = MagicMock()
cascade.complete = AsyncMock(
return_value={
"content": "Cloud answer.",
"provider": "anthropic",
"model": "claude-haiku-4-5",
"latency_ms": 900,
# no "usage" key
}
)
budget = _make_budget_mock(allowed=True)
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
result = await router.route("go north", context={"require_cloud": True})
budget.record_spend.assert_not_called()
assert "cost_usd" not in result
async def test_custom_tier_models_respected(self):
cascade = _make_cascade_mock()
router = TieredModelRouter(
cascade=cascade,
tier_models={TierLabel.LOCAL_FAST: "llama3.2:3b"},
)
await router.route("go north")
call_kwargs = cascade.complete.call_args
assert call_kwargs.kwargs["model"] == "llama3.2:3b"
async def test_messages_override_used_when_provided(self):
cascade = _make_cascade_mock()
router = TieredModelRouter(cascade=cascade)
custom_msgs = [{"role": "user", "content": "custom message"}]
await router.route("go north", messages=custom_msgs)
call_kwargs = cascade.complete.call_args
assert call_kwargs.kwargs["messages"] == custom_msgs
async def test_temperature_forwarded(self):
cascade = _make_cascade_mock()
router = TieredModelRouter(cascade=cascade)
await router.route("go north", temperature=0.7)
call_kwargs = cascade.complete.call_args
assert call_kwargs.kwargs["temperature"] == 0.7
async def test_max_tokens_forwarded(self):
cascade = _make_cascade_mock()
router = TieredModelRouter(cascade=cascade)
await router.route("go north", max_tokens=128)
call_kwargs = cascade.complete.call_args
assert call_kwargs.kwargs["max_tokens"] == 128
class TestTieredModelRouterClassify:
def test_classify_delegates_to_classify_tier(self):
router = TieredModelRouter(cascade=MagicMock())
assert router.classify("go north") == classify_tier("go north")
assert router.classify("plan the quest") == classify_tier("plan the quest")
class TestGetTieredRouterSingleton:
def test_returns_tiered_router_instance(self):
import infrastructure.models.router as rmod
rmod._tiered_router = None
router = get_tiered_router()
assert isinstance(router, TieredModelRouter)
def test_singleton_returns_same_instance(self):
import infrastructure.models.router as rmod
rmod._tiered_router = None
r1 = get_tiered_router()
r2 = get_tiered_router()
assert r1 is r2