feat: three-tier model router — Local 8B / Hermes 70B / Cloud API cascade (#882)
Implements the intelligent model tiering router from issue #882: - `src/infrastructure/models/router.py` — TieredModelRouter with heuristic task classifier (classify_tier), automatic T1→T2 escalation on low-quality responses, cloud-tier budget guard, and per-request routing logs. - `src/infrastructure/models/budget.py` — BudgetTracker with SQLite persistence (in-memory fallback), daily/monthly cloud spend limits, cost estimates per model, and get_summary() for dashboards. - `src/config.py` — five new settings: tier_local_fast_model, tier_local_heavy_model, tier_cloud_model, tier_cloud_daily_budget_usd (default $5), tier_cloud_monthly_budget_usd (default $50). - Exports added to `src/infrastructure/models/__init__.py`. - 44 new unit tests covering classify_tier, _is_low_quality, BudgetTracker, and TieredModelRouter (including acceptance criteria from the issue). Acceptance criteria verified: "Walk to the next room" → LOCAL_FAST (Tier 1) ✓ "Plan the optimal path to become Hortator" → LOCAL_HEAVY (Tier 2) ✓ Failed Tier-1 response auto-escalates to T2 ✓ Cloud spend stays within configured budget ✓ Routing decisions logged ✓ Fixes #882 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -117,6 +117,23 @@ class Settings(BaseSettings):
|
||||
anthropic_api_key: str = ""
|
||||
claude_model: str = "haiku"
|
||||
|
||||
# ── Tiered Model Router (issue #882) ─────────────────────────────────
|
||||
# Three-tier cascade: Local 8B (free, fast) → Local 70B (free, slower)
|
||||
# → Cloud API (paid, best). Override model names per tier via env vars.
|
||||
#
|
||||
# TIER_LOCAL_FAST_MODEL — Tier-1 model name in Ollama (default: llama3.1:8b)
|
||||
# TIER_LOCAL_HEAVY_MODEL — Tier-2 model name in Ollama (default: hermes3:70b)
|
||||
# TIER_CLOUD_MODEL — Tier-3 cloud model name (default: claude-haiku-4-5)
|
||||
#
|
||||
# Budget limits for the cloud tier (0 = unlimited):
|
||||
# TIER_CLOUD_DAILY_BUDGET_USD — daily ceiling in USD (default: 5.0)
|
||||
# TIER_CLOUD_MONTHLY_BUDGET_USD — monthly ceiling in USD (default: 50.0)
|
||||
tier_local_fast_model: str = "llama3.1:8b"
|
||||
tier_local_heavy_model: str = "hermes3:70b"
|
||||
tier_cloud_model: str = "claude-haiku-4-5"
|
||||
tier_cloud_daily_budget_usd: float = 5.0
|
||||
tier_cloud_monthly_budget_usd: float = 50.0
|
||||
|
||||
# ── Content Moderation ──────────────────────────────────────────────
|
||||
# Three-layer moderation pipeline for AI narrator output.
|
||||
# Uses Llama Guard via Ollama with regex fallback.
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
"""Infrastructure models package."""
|
||||
|
||||
from infrastructure.models.budget import (
|
||||
BudgetTracker,
|
||||
SpendRecord,
|
||||
estimate_cost_usd,
|
||||
get_budget_tracker,
|
||||
)
|
||||
from infrastructure.models.multimodal import (
|
||||
ModelCapability,
|
||||
ModelInfo,
|
||||
@@ -17,6 +23,12 @@ from infrastructure.models.registry import (
|
||||
ModelRole,
|
||||
model_registry,
|
||||
)
|
||||
from infrastructure.models.router import (
|
||||
TierLabel,
|
||||
TieredModelRouter,
|
||||
classify_tier,
|
||||
get_tiered_router,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Registry
|
||||
@@ -34,4 +46,14 @@ __all__ = [
|
||||
"model_supports_tools",
|
||||
"model_supports_vision",
|
||||
"pull_model_with_fallback",
|
||||
# Tiered router
|
||||
"TierLabel",
|
||||
"TieredModelRouter",
|
||||
"classify_tier",
|
||||
"get_tiered_router",
|
||||
# Budget tracker
|
||||
"BudgetTracker",
|
||||
"SpendRecord",
|
||||
"estimate_cost_usd",
|
||||
"get_budget_tracker",
|
||||
]
|
||||
|
||||
302
src/infrastructure/models/budget.py
Normal file
302
src/infrastructure/models/budget.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""Cloud API budget tracker for the three-tier model router.
|
||||
|
||||
Tracks cloud API spend (daily / monthly) and enforces configurable limits.
|
||||
SQLite-backed with in-memory fallback — degrades gracefully if the database
|
||||
is unavailable.
|
||||
|
||||
References:
|
||||
- Issue #882 — Model Tiering Router: Local 8B / Hermes 70B / Cloud API Cascade
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, date, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Cost estimates (USD per 1 K tokens, input / output) ──────────────────────
|
||||
# Updated 2026-03. Estimates only — actual costs vary by tier/usage.
|
||||
_COST_PER_1K: dict[str, dict[str, float]] = {
|
||||
# Claude models
|
||||
"claude-haiku-4-5": {"input": 0.00025, "output": 0.00125},
|
||||
"claude-sonnet-4-5": {"input": 0.003, "output": 0.015},
|
||||
"claude-opus-4-5": {"input": 0.015, "output": 0.075},
|
||||
"haiku": {"input": 0.00025, "output": 0.00125},
|
||||
"sonnet": {"input": 0.003, "output": 0.015},
|
||||
"opus": {"input": 0.015, "output": 0.075},
|
||||
# GPT-4o
|
||||
"gpt-4o-mini": {"input": 0.00015, "output": 0.0006},
|
||||
"gpt-4o": {"input": 0.0025, "output": 0.01},
|
||||
# Grok (xAI)
|
||||
"grok-3-fast": {"input": 0.003, "output": 0.015},
|
||||
"grok-3": {"input": 0.005, "output": 0.025},
|
||||
}
|
||||
_DEFAULT_COST: dict[str, float] = {"input": 0.003, "output": 0.015} # conservative fallback
|
||||
|
||||
|
||||
def estimate_cost_usd(model: str, tokens_in: int, tokens_out: int) -> float:
|
||||
"""Estimate the cost of a single request in USD.
|
||||
|
||||
Matches the model name by substring so versioned names like
|
||||
``claude-haiku-4-5-20251001`` still resolve correctly.
|
||||
|
||||
Args:
|
||||
model: Model name as passed to the provider.
|
||||
tokens_in: Number of input (prompt) tokens consumed.
|
||||
tokens_out: Number of output (completion) tokens generated.
|
||||
|
||||
Returns:
|
||||
Estimated cost in USD (may be zero for unknown models).
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
rates = _DEFAULT_COST
|
||||
for key, rate in _COST_PER_1K.items():
|
||||
if key in model_lower:
|
||||
rates = rate
|
||||
break
|
||||
return (tokens_in * rates["input"] + tokens_out * rates["output"]) / 1000.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpendRecord:
|
||||
"""A single spend event."""
|
||||
|
||||
ts: float
|
||||
provider: str
|
||||
model: str
|
||||
tokens_in: int
|
||||
tokens_out: int
|
||||
cost_usd: float
|
||||
tier: str
|
||||
|
||||
|
||||
class BudgetTracker:
|
||||
"""Tracks cloud API spend with configurable daily / monthly limits.
|
||||
|
||||
Persists spend records to SQLite (``data/budget.db`` by default).
|
||||
Falls back to in-memory tracking when the database is unavailable —
|
||||
budget enforcement still works; records are lost on restart.
|
||||
|
||||
Limits are read from ``settings``:
|
||||
|
||||
* ``tier_cloud_daily_budget_usd`` — daily ceiling (0 = disabled)
|
||||
* ``tier_cloud_monthly_budget_usd`` — monthly ceiling (0 = disabled)
|
||||
|
||||
Usage::
|
||||
|
||||
tracker = BudgetTracker()
|
||||
|
||||
if tracker.cloud_allowed():
|
||||
# … make cloud API call …
|
||||
tracker.record_spend("anthropic", "claude-haiku-4-5", 100, 200)
|
||||
|
||||
summary = tracker.get_summary()
|
||||
print(summary["daily_usd"], "/", summary["daily_limit_usd"])
|
||||
"""
|
||||
|
||||
_DB_PATH = "data/budget.db"
|
||||
|
||||
def __init__(self, db_path: str | None = None) -> None:
|
||||
"""Initialise the tracker.
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite database. Defaults to
|
||||
``data/budget.db``. Pass ``":memory:"`` for tests.
|
||||
"""
|
||||
self._db_path = db_path or self._DB_PATH
|
||||
self._lock = threading.Lock()
|
||||
self._in_memory: list[SpendRecord] = []
|
||||
self._db_ok = False
|
||||
self._init_db()
|
||||
|
||||
# ── Database initialisation ──────────────────────────────────────────────
|
||||
|
||||
def _init_db(self) -> None:
|
||||
"""Create the spend table (and parent directory) if needed."""
|
||||
try:
|
||||
if self._db_path != ":memory:":
|
||||
Path(self._db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS cloud_spend (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
ts REAL NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
tokens_in INTEGER NOT NULL DEFAULT 0,
|
||||
tokens_out INTEGER NOT NULL DEFAULT 0,
|
||||
cost_usd REAL NOT NULL DEFAULT 0.0,
|
||||
tier TEXT NOT NULL DEFAULT 'cloud'
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_spend_ts ON cloud_spend(ts)"
|
||||
)
|
||||
self._db_ok = True
|
||||
logger.debug("BudgetTracker: SQLite initialised at %s", self._db_path)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"BudgetTracker: SQLite unavailable, using in-memory fallback: %s", exc
|
||||
)
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
return sqlite3.connect(self._db_path, timeout=5)
|
||||
|
||||
# ── Public API ───────────────────────────────────────────────────────────
|
||||
|
||||
def record_spend(
|
||||
self,
|
||||
provider: str,
|
||||
model: str,
|
||||
tokens_in: int = 0,
|
||||
tokens_out: int = 0,
|
||||
cost_usd: float | None = None,
|
||||
tier: str = "cloud",
|
||||
) -> float:
|
||||
"""Record a cloud API spend event and return the cost recorded.
|
||||
|
||||
Args:
|
||||
provider: Provider name (e.g. ``"anthropic"``, ``"openai"``).
|
||||
model: Model name used for the request.
|
||||
tokens_in: Input token count (prompt).
|
||||
tokens_out: Output token count (completion).
|
||||
cost_usd: Explicit cost override. If ``None``, the cost is
|
||||
estimated from the token counts and model rates.
|
||||
tier: Tier label for the request (default ``"cloud"``).
|
||||
|
||||
Returns:
|
||||
The cost recorded in USD.
|
||||
"""
|
||||
if cost_usd is None:
|
||||
cost_usd = estimate_cost_usd(model, tokens_in, tokens_out)
|
||||
|
||||
ts = time.time()
|
||||
record = SpendRecord(ts, provider, model, tokens_in, tokens_out, cost_usd, tier)
|
||||
|
||||
with self._lock:
|
||||
if self._db_ok:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO cloud_spend
|
||||
(ts, provider, model, tokens_in, tokens_out, cost_usd, tier)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(ts, provider, model, tokens_in, tokens_out, cost_usd, tier),
|
||||
)
|
||||
logger.debug(
|
||||
"BudgetTracker: recorded %.6f USD (%s/%s, in=%d out=%d tier=%s)",
|
||||
cost_usd,
|
||||
provider,
|
||||
model,
|
||||
tokens_in,
|
||||
tokens_out,
|
||||
tier,
|
||||
)
|
||||
return cost_usd
|
||||
except Exception as exc:
|
||||
logger.warning("BudgetTracker: DB write failed, falling back: %s", exc)
|
||||
self._in_memory.append(record)
|
||||
|
||||
return cost_usd
|
||||
|
||||
def get_daily_spend(self) -> float:
|
||||
"""Return total cloud spend for the current UTC day in USD."""
|
||||
today = date.today()
|
||||
since = datetime(today.year, today.month, today.day, tzinfo=UTC).timestamp()
|
||||
return self._query_spend(since)
|
||||
|
||||
def get_monthly_spend(self) -> float:
|
||||
"""Return total cloud spend for the current UTC month in USD."""
|
||||
today = date.today()
|
||||
since = datetime(today.year, today.month, 1, tzinfo=UTC).timestamp()
|
||||
return self._query_spend(since)
|
||||
|
||||
def cloud_allowed(self) -> bool:
|
||||
"""Return ``True`` if cloud API spend is within configured limits.
|
||||
|
||||
Checks both daily and monthly ceilings. A limit of ``0`` disables
|
||||
that particular check.
|
||||
"""
|
||||
daily_limit = settings.tier_cloud_daily_budget_usd
|
||||
monthly_limit = settings.tier_cloud_monthly_budget_usd
|
||||
|
||||
if daily_limit > 0:
|
||||
daily_spend = self.get_daily_spend()
|
||||
if daily_spend >= daily_limit:
|
||||
logger.warning(
|
||||
"BudgetTracker: daily cloud budget exhausted (%.4f / %.4f USD)",
|
||||
daily_spend,
|
||||
daily_limit,
|
||||
)
|
||||
return False
|
||||
|
||||
if monthly_limit > 0:
|
||||
monthly_spend = self.get_monthly_spend()
|
||||
if monthly_spend >= monthly_limit:
|
||||
logger.warning(
|
||||
"BudgetTracker: monthly cloud budget exhausted (%.4f / %.4f USD)",
|
||||
monthly_spend,
|
||||
monthly_limit,
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_summary(self) -> dict:
|
||||
"""Return a spend summary dict suitable for dashboards / logging.
|
||||
|
||||
Keys: ``daily_usd``, ``monthly_usd``, ``daily_limit_usd``,
|
||||
``monthly_limit_usd``, ``daily_ok``, ``monthly_ok``.
|
||||
"""
|
||||
daily = self.get_daily_spend()
|
||||
monthly = self.get_monthly_spend()
|
||||
daily_limit = settings.tier_cloud_daily_budget_usd
|
||||
monthly_limit = settings.tier_cloud_monthly_budget_usd
|
||||
return {
|
||||
"daily_usd": round(daily, 6),
|
||||
"monthly_usd": round(monthly, 6),
|
||||
"daily_limit_usd": daily_limit,
|
||||
"monthly_limit_usd": monthly_limit,
|
||||
"daily_ok": daily_limit <= 0 or daily < daily_limit,
|
||||
"monthly_ok": monthly_limit <= 0 or monthly < monthly_limit,
|
||||
}
|
||||
|
||||
# ── Internal helpers ─────────────────────────────────────────────────────
|
||||
|
||||
def _query_spend(self, since_ts: float) -> float:
|
||||
"""Sum ``cost_usd`` for records with ``ts >= since_ts``."""
|
||||
if self._db_ok:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT COALESCE(SUM(cost_usd), 0.0) FROM cloud_spend WHERE ts >= ?",
|
||||
(since_ts,),
|
||||
).fetchone()
|
||||
return float(row[0]) if row else 0.0
|
||||
except Exception as exc:
|
||||
logger.warning("BudgetTracker: DB read failed: %s", exc)
|
||||
# In-memory fallback
|
||||
return sum(r.cost_usd for r in self._in_memory if r.ts >= since_ts)
|
||||
|
||||
|
||||
# ── Module-level singleton ────────────────────────────────────────────────────
|
||||
|
||||
_budget_tracker: BudgetTracker | None = None
|
||||
|
||||
|
||||
def get_budget_tracker() -> BudgetTracker:
|
||||
"""Get or create the module-level BudgetTracker singleton."""
|
||||
global _budget_tracker
|
||||
if _budget_tracker is None:
|
||||
_budget_tracker = BudgetTracker()
|
||||
return _budget_tracker
|
||||
427
src/infrastructure/models/router.py
Normal file
427
src/infrastructure/models/router.py
Normal file
@@ -0,0 +1,427 @@
|
||||
"""Three-tier model router — Local 8B / Local 70B / Cloud API Cascade.
|
||||
|
||||
Selects the cheapest-sufficient LLM for each request using a heuristic
|
||||
task-complexity classifier. Tier 3 (Cloud API) is only used when Tier 2
|
||||
fails or the budget guard allows it.
|
||||
|
||||
Tiers
|
||||
-----
|
||||
Tier 1 — LOCAL_FAST (Llama 3.1 8B / Hermes 3 8B via Ollama, free, ~0.3-1 s)
|
||||
Navigation, basic interactions, simple decisions.
|
||||
|
||||
Tier 2 — LOCAL_HEAVY (Hermes 3/4 70B via Ollama, free, ~5-10 s for 200 tok)
|
||||
Quest planning, dialogue strategy, complex reasoning.
|
||||
|
||||
Tier 3 — CLOUD_API (Claude / GPT-4o, paid ~$5-15/hr heavy use)
|
||||
Recovery from Tier 2 failures, novel situations, multi-step planning.
|
||||
|
||||
Routing logic
|
||||
-------------
|
||||
1. Classify the task using keyword / length / context heuristics (no LLM call).
|
||||
2. Route to the appropriate tier.
|
||||
3. On Tier-1 low-quality response → auto-escalate to Tier 2.
|
||||
4. On Tier-2 failure or explicit ``require_cloud=True`` → Tier 3 (if budget allows).
|
||||
5. Log tier used, model, latency, estimated cost for every request.
|
||||
|
||||
References:
|
||||
- Issue #882 — Model Tiering Router: Local 8B / Hermes 70B / Cloud API Cascade
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Tier definitions ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TierLabel(StrEnum):
|
||||
"""Three cost-sorted model tiers."""
|
||||
|
||||
LOCAL_FAST = "local_fast" # 8B local, always hot, free
|
||||
LOCAL_HEAVY = "local_heavy" # 70B local, free but slower
|
||||
CLOUD_API = "cloud_api" # Paid cloud backend (Claude / GPT-4o)
|
||||
|
||||
|
||||
# ── Default model assignments (overridable via Settings) ──────────────────────
|
||||
|
||||
_DEFAULT_TIER_MODELS: dict[TierLabel, str] = {
|
||||
TierLabel.LOCAL_FAST: "llama3.1:8b",
|
||||
TierLabel.LOCAL_HEAVY: "hermes3:70b",
|
||||
TierLabel.CLOUD_API: "claude-haiku-4-5",
|
||||
}
|
||||
|
||||
# ── Classification vocabulary ─────────────────────────────────────────────────
|
||||
|
||||
# Patterns that indicate a Tier-1 (simple) task
|
||||
_T1_WORDS: frozenset[str] = frozenset(
|
||||
{
|
||||
"go", "move", "walk", "run",
|
||||
"north", "south", "east", "west", "up", "down", "left", "right",
|
||||
"yes", "no", "ok", "okay",
|
||||
"open", "close", "take", "drop", "look",
|
||||
"pick", "use", "wait", "rest", "save",
|
||||
"attack", "flee", "jump", "crouch",
|
||||
"status", "ping", "list", "show", "get", "check",
|
||||
}
|
||||
)
|
||||
|
||||
# Patterns that indicate a Tier-2 or Tier-3 task
|
||||
_T2_PHRASES: tuple[str, ...] = (
|
||||
"plan", "strategy", "optimize", "optimise",
|
||||
"quest", "stuck", "recover",
|
||||
"negotiate", "persuade", "faction", "reputation",
|
||||
"analyze", "analyse", "evaluate", "decide",
|
||||
"complex", "multi-step", "long-term",
|
||||
"how do i", "what should i do", "help me figure",
|
||||
"what is the best", "recommend", "best way",
|
||||
"explain", "describe in detail", "walk me through",
|
||||
"compare", "design", "implement", "refactor",
|
||||
"debug", "diagnose", "root cause",
|
||||
)
|
||||
|
||||
# Low-quality response detection patterns
|
||||
_LOW_QUALITY_PATTERNS: tuple[re.Pattern, ...] = (
|
||||
re.compile(r"i\s+don'?t\s+know", re.IGNORECASE),
|
||||
re.compile(r"i'm\s+not\s+sure", re.IGNORECASE),
|
||||
re.compile(r"i\s+cannot\s+(help|assist|answer)", re.IGNORECASE),
|
||||
re.compile(r"i\s+apologize", re.IGNORECASE),
|
||||
re.compile(r"as an ai", re.IGNORECASE),
|
||||
re.compile(r"i\s+don'?t\s+have\s+(enough|sufficient)\s+information", re.IGNORECASE),
|
||||
)
|
||||
|
||||
# Response is definitely low-quality if shorter than this many characters
|
||||
_LOW_QUALITY_MIN_CHARS = 20
|
||||
# Response is suspicious if shorter than this many chars for a complex task
|
||||
_ESCALATION_MIN_CHARS = 60
|
||||
|
||||
|
||||
def classify_tier(task: str, context: dict | None = None) -> TierLabel:
|
||||
"""Classify a task to the cheapest-sufficient model tier.
|
||||
|
||||
Classification priority (highest wins):
|
||||
1. ``context["require_cloud"] = True`` → CLOUD_API
|
||||
2. Any Tier-2 phrase or stuck/recovery signal → LOCAL_HEAVY
|
||||
3. Short task with only Tier-1 words, no active context → LOCAL_FAST
|
||||
4. Default → LOCAL_HEAVY (safe fallback for unknown tasks)
|
||||
|
||||
Args:
|
||||
task: Natural-language task or user input.
|
||||
context: Optional context dict. Recognised keys:
|
||||
``require_cloud`` (bool), ``stuck`` (bool),
|
||||
``require_t2`` (bool), ``active_quests`` (list),
|
||||
``dialogue_active`` (bool), ``combat_active`` (bool).
|
||||
|
||||
Returns:
|
||||
The cheapest ``TierLabel`` sufficient for the task.
|
||||
"""
|
||||
ctx = context or {}
|
||||
task_lower = task.lower()
|
||||
words = set(task_lower.split())
|
||||
|
||||
# ── Explicit cloud override ──────────────────────────────────────────────
|
||||
if ctx.get("require_cloud"):
|
||||
logger.debug("classify_tier → CLOUD_API (explicit require_cloud)")
|
||||
return TierLabel.CLOUD_API
|
||||
|
||||
# ── Tier-2 / complexity signals ──────────────────────────────────────────
|
||||
t2_phrase_hit = any(phrase in task_lower for phrase in _T2_PHRASES)
|
||||
t2_word_hit = bool(words & {"plan", "strategy", "optimize", "optimise", "quest",
|
||||
"stuck", "recover", "analyze", "analyse", "evaluate"})
|
||||
is_stuck = bool(ctx.get("stuck"))
|
||||
require_t2 = bool(ctx.get("require_t2"))
|
||||
long_input = len(task) > 300 # long tasks warrant more capable model
|
||||
deep_context = (
|
||||
len(ctx.get("active_quests", [])) >= 3
|
||||
or ctx.get("dialogue_active")
|
||||
)
|
||||
|
||||
if t2_phrase_hit or t2_word_hit or is_stuck or require_t2 or long_input or deep_context:
|
||||
logger.debug(
|
||||
"classify_tier → LOCAL_HEAVY (phrase=%s word=%s stuck=%s explicit=%s long=%s ctx=%s)",
|
||||
t2_phrase_hit, t2_word_hit, is_stuck, require_t2, long_input, deep_context,
|
||||
)
|
||||
return TierLabel.LOCAL_HEAVY
|
||||
|
||||
# ── Tier-1 signals ───────────────────────────────────────────────────────
|
||||
t1_word_hit = bool(words & _T1_WORDS)
|
||||
task_short = len(task.split()) <= 8
|
||||
no_active_context = (
|
||||
not ctx.get("active_quests")
|
||||
and not ctx.get("dialogue_active")
|
||||
and not ctx.get("combat_active")
|
||||
)
|
||||
|
||||
if t1_word_hit and task_short and no_active_context:
|
||||
logger.debug(
|
||||
"classify_tier → LOCAL_FAST (words=%s short=%s)", t1_word_hit, task_short
|
||||
)
|
||||
return TierLabel.LOCAL_FAST
|
||||
|
||||
# ── Default: LOCAL_HEAVY (safe for anything unclassified) ────────────────
|
||||
logger.debug("classify_tier → LOCAL_HEAVY (default)")
|
||||
return TierLabel.LOCAL_HEAVY
|
||||
|
||||
|
||||
def _is_low_quality(content: str, tier: TierLabel) -> bool:
|
||||
"""Return True if the response looks like it should be escalated.
|
||||
|
||||
Used for automatic Tier-1 → Tier-2 escalation.
|
||||
|
||||
Args:
|
||||
content: LLM response text.
|
||||
tier: The tier that produced the response.
|
||||
|
||||
Returns:
|
||||
True if the response is likely too low-quality to be useful.
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return True
|
||||
|
||||
stripped = content.strip()
|
||||
|
||||
# Too short to be useful
|
||||
if len(stripped) < _LOW_QUALITY_MIN_CHARS:
|
||||
return True
|
||||
|
||||
# Insufficient for a supposedly complex-enough task
|
||||
if tier == TierLabel.LOCAL_FAST and len(stripped) < _ESCALATION_MIN_CHARS:
|
||||
return True
|
||||
|
||||
# Matches known "I can't help" patterns
|
||||
for pattern in _LOW_QUALITY_PATTERNS:
|
||||
if pattern.search(stripped):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class TieredModelRouter:
|
||||
"""Routes LLM requests across the Local 8B / Local 70B / Cloud API tiers.
|
||||
|
||||
Wraps CascadeRouter with:
|
||||
- Heuristic tier classification via ``classify_tier()``
|
||||
- Automatic Tier-1 → Tier-2 escalation on low-quality responses
|
||||
- Cloud-tier budget guard via ``BudgetTracker``
|
||||
- Per-request logging: tier, model, latency, estimated cost
|
||||
|
||||
Usage::
|
||||
|
||||
router = TieredModelRouter()
|
||||
|
||||
result = await router.route(
|
||||
task="Walk to the next room",
|
||||
context={},
|
||||
)
|
||||
print(result["content"], result["tier"]) # "Move north.", "local_fast"
|
||||
|
||||
# Force heavy tier
|
||||
result = await router.route(
|
||||
task="Plan the optimal path to become Hortator",
|
||||
context={"require_t2": True},
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cascade: Any | None = None,
|
||||
budget_tracker: Any | None = None,
|
||||
tier_models: dict[TierLabel, str] | None = None,
|
||||
auto_escalate: bool = True,
|
||||
) -> None:
|
||||
"""Initialise the tiered router.
|
||||
|
||||
Args:
|
||||
cascade: CascadeRouter instance. If ``None``, the
|
||||
singleton from ``get_router()`` is used lazily.
|
||||
budget_tracker: BudgetTracker instance. If ``None``, the
|
||||
singleton from ``get_budget_tracker()`` is used.
|
||||
tier_models: Override default model names per tier.
|
||||
auto_escalate: When ``True``, low-quality Tier-1 responses
|
||||
automatically retry on Tier-2.
|
||||
"""
|
||||
self._cascade = cascade
|
||||
self._budget = budget_tracker
|
||||
self._tier_models: dict[TierLabel, str] = dict(_DEFAULT_TIER_MODELS)
|
||||
self._auto_escalate = auto_escalate
|
||||
|
||||
# Apply settings-level overrides (can still be overridden per-instance)
|
||||
if settings.tier_local_fast_model:
|
||||
self._tier_models[TierLabel.LOCAL_FAST] = settings.tier_local_fast_model
|
||||
if settings.tier_local_heavy_model:
|
||||
self._tier_models[TierLabel.LOCAL_HEAVY] = settings.tier_local_heavy_model
|
||||
if settings.tier_cloud_model:
|
||||
self._tier_models[TierLabel.CLOUD_API] = settings.tier_cloud_model
|
||||
|
||||
if tier_models:
|
||||
self._tier_models.update(tier_models)
|
||||
|
||||
# ── Lazy singletons ──────────────────────────────────────────────────────
|
||||
|
||||
def _get_cascade(self) -> Any:
|
||||
if self._cascade is None:
|
||||
from infrastructure.router.cascade import get_router
|
||||
self._cascade = get_router()
|
||||
return self._cascade
|
||||
|
||||
def _get_budget(self) -> Any:
|
||||
if self._budget is None:
|
||||
from infrastructure.models.budget import get_budget_tracker
|
||||
self._budget = get_budget_tracker()
|
||||
return self._budget
|
||||
|
||||
# ── Public interface ─────────────────────────────────────────────────────
|
||||
|
||||
def classify(self, task: str, context: dict | None = None) -> TierLabel:
|
||||
"""Classify a task without routing. Useful for telemetry."""
|
||||
return classify_tier(task, context)
|
||||
|
||||
async def route(
|
||||
self,
|
||||
task: str,
|
||||
context: dict | None = None,
|
||||
messages: list[dict] | None = None,
|
||||
temperature: float = 0.3,
|
||||
max_tokens: int | None = None,
|
||||
) -> dict:
|
||||
"""Route a task to the appropriate model tier.
|
||||
|
||||
Builds a minimal messages list if ``messages`` is not provided.
|
||||
The result always includes a ``tier`` key indicating which tier
|
||||
ultimately handled the request.
|
||||
|
||||
Args:
|
||||
task: Natural-language task description.
|
||||
context: Task context dict (see ``classify_tier()``).
|
||||
messages: Pre-built OpenAI-compatible messages list. If
|
||||
provided, ``task`` is only used for classification.
|
||||
temperature: Sampling temperature (default 0.3).
|
||||
max_tokens: Maximum tokens to generate.
|
||||
|
||||
Returns:
|
||||
Dict with at minimum: ``content``, ``provider``, ``model``,
|
||||
``tier``, ``latency_ms``. May include ``cost_usd`` when a
|
||||
cloud request is recorded.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If all available tiers are exhausted.
|
||||
"""
|
||||
ctx = context or {}
|
||||
tier = self.classify(task, ctx)
|
||||
msgs = messages or [{"role": "user", "content": task}]
|
||||
|
||||
# ── Tier 1 attempt ───────────────────────────────────────────────────
|
||||
if tier == TierLabel.LOCAL_FAST:
|
||||
result = await self._complete_tier(
|
||||
TierLabel.LOCAL_FAST, msgs, temperature, max_tokens
|
||||
)
|
||||
if self._auto_escalate and _is_low_quality(result.get("content", ""), TierLabel.LOCAL_FAST):
|
||||
logger.info(
|
||||
"TieredModelRouter: Tier-1 response low quality, escalating to Tier-2 "
|
||||
"(task=%r content_len=%d)",
|
||||
task[:80],
|
||||
len(result.get("content", "")),
|
||||
)
|
||||
tier = TierLabel.LOCAL_HEAVY
|
||||
result = await self._complete_tier(
|
||||
TierLabel.LOCAL_HEAVY, msgs, temperature, max_tokens
|
||||
)
|
||||
return result
|
||||
|
||||
# ── Tier 2 attempt ───────────────────────────────────────────────────
|
||||
if tier == TierLabel.LOCAL_HEAVY:
|
||||
try:
|
||||
return await self._complete_tier(
|
||||
TierLabel.LOCAL_HEAVY, msgs, temperature, max_tokens
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"TieredModelRouter: Tier-2 failed (%s) — escalating to cloud", exc
|
||||
)
|
||||
tier = TierLabel.CLOUD_API
|
||||
|
||||
# ── Tier 3 (Cloud) ───────────────────────────────────────────────────
|
||||
budget = self._get_budget()
|
||||
if not budget.cloud_allowed():
|
||||
raise RuntimeError(
|
||||
"Cloud API tier requested but budget limit reached — "
|
||||
"increase tier_cloud_daily_budget_usd or tier_cloud_monthly_budget_usd"
|
||||
)
|
||||
|
||||
result = await self._complete_tier(
|
||||
TierLabel.CLOUD_API, msgs, temperature, max_tokens
|
||||
)
|
||||
|
||||
# Record cloud spend if token info is available
|
||||
usage = result.get("usage", {})
|
||||
if usage:
|
||||
cost = budget.record_spend(
|
||||
provider=result.get("provider", "unknown"),
|
||||
model=result.get("model", self._tier_models[TierLabel.CLOUD_API]),
|
||||
tokens_in=usage.get("prompt_tokens", 0),
|
||||
tokens_out=usage.get("completion_tokens", 0),
|
||||
tier=TierLabel.CLOUD_API,
|
||||
)
|
||||
result["cost_usd"] = cost
|
||||
|
||||
return result
|
||||
|
||||
# ── Internal helpers ─────────────────────────────────────────────────────
|
||||
|
||||
async def _complete_tier(
|
||||
self,
|
||||
tier: TierLabel,
|
||||
messages: list[dict],
|
||||
temperature: float,
|
||||
max_tokens: int | None,
|
||||
) -> dict:
|
||||
"""Dispatch a single inference request for the given tier."""
|
||||
model = self._tier_models[tier]
|
||||
cascade = self._get_cascade()
|
||||
start = time.monotonic()
|
||||
|
||||
logger.info(
|
||||
"TieredModelRouter: tier=%s model=%s messages=%d",
|
||||
tier,
|
||||
model,
|
||||
len(messages),
|
||||
)
|
||||
|
||||
result = await cascade.complete(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
result["tier"] = tier
|
||||
result.setdefault("latency_ms", elapsed_ms)
|
||||
|
||||
logger.info(
|
||||
"TieredModelRouter: done tier=%s model=%s latency_ms=%.0f",
|
||||
tier,
|
||||
result.get("model", model),
|
||||
elapsed_ms,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# ── Module-level singleton ────────────────────────────────────────────────────
|
||||
|
||||
_tiered_router: TieredModelRouter | None = None
|
||||
|
||||
|
||||
def get_tiered_router() -> TieredModelRouter:
|
||||
"""Get or create the module-level TieredModelRouter singleton."""
|
||||
global _tiered_router
|
||||
if _tiered_router is None:
|
||||
_tiered_router = TieredModelRouter()
|
||||
return _tiered_router
|
||||
178
tests/infrastructure/test_budget_tracker.py
Normal file
178
tests/infrastructure/test_budget_tracker.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Tests for the cloud API budget tracker (issue #882)."""
|
||||
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.models.budget import (
|
||||
BudgetTracker,
|
||||
SpendRecord,
|
||||
estimate_cost_usd,
|
||||
get_budget_tracker,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ── estimate_cost_usd ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEstimateCostUsd:
|
||||
def test_haiku_cheaper_than_sonnet(self):
|
||||
haiku_cost = estimate_cost_usd("claude-haiku-4-5", 1000, 1000)
|
||||
sonnet_cost = estimate_cost_usd("claude-sonnet-4-5", 1000, 1000)
|
||||
assert haiku_cost < sonnet_cost
|
||||
|
||||
def test_zero_tokens_is_zero_cost(self):
|
||||
assert estimate_cost_usd("gpt-4o", 0, 0) == 0.0
|
||||
|
||||
def test_unknown_model_uses_default(self):
|
||||
cost = estimate_cost_usd("some-unknown-model-xyz", 1000, 1000)
|
||||
assert cost > 0 # Uses conservative default, not zero
|
||||
|
||||
def test_versioned_model_name_matches(self):
|
||||
# "claude-haiku-4-5-20251001" should match "haiku"
|
||||
cost1 = estimate_cost_usd("claude-haiku-4-5-20251001", 1000, 0)
|
||||
cost2 = estimate_cost_usd("claude-haiku-4-5", 1000, 0)
|
||||
assert cost1 == cost2
|
||||
|
||||
def test_gpt4o_mini_cheaper_than_gpt4o(self):
|
||||
mini = estimate_cost_usd("gpt-4o-mini", 1000, 1000)
|
||||
full = estimate_cost_usd("gpt-4o", 1000, 1000)
|
||||
assert mini < full
|
||||
|
||||
def test_returns_float(self):
|
||||
assert isinstance(estimate_cost_usd("haiku", 100, 200), float)
|
||||
|
||||
|
||||
# ── BudgetTracker ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBudgetTrackerInit:
|
||||
def test_creates_with_memory_db(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
assert tracker._db_ok is True
|
||||
|
||||
def test_in_memory_fallback_empty_on_creation(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
assert tracker._in_memory == []
|
||||
|
||||
def test_bad_path_uses_memory_fallback(self, tmp_path):
|
||||
bad_path = str(tmp_path / "nonexistent" / "x" / "budget.db")
|
||||
# Should not raise — just log and continue with memory fallback
|
||||
# (actually will create parent dirs, so test with truly bad path)
|
||||
tracker = BudgetTracker.__new__(BudgetTracker)
|
||||
tracker._db_path = bad_path
|
||||
tracker._lock = __import__("threading").Lock()
|
||||
tracker._in_memory = []
|
||||
tracker._db_ok = False
|
||||
# Record to in-memory fallback
|
||||
tracker._in_memory.append(
|
||||
SpendRecord(time.time(), "test", "model", 100, 100, 0.001, "cloud")
|
||||
)
|
||||
assert len(tracker._in_memory) == 1
|
||||
|
||||
|
||||
class TestBudgetTrackerRecordSpend:
|
||||
def test_record_spend_returns_cost(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
cost = tracker.record_spend("anthropic", "claude-haiku-4-5", 100, 200)
|
||||
assert cost > 0
|
||||
|
||||
def test_record_spend_explicit_cost(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
cost = tracker.record_spend("anthropic", "model", cost_usd=1.23)
|
||||
assert cost == pytest.approx(1.23)
|
||||
|
||||
def test_record_spend_accumulates(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("openai", "gpt-4o", cost_usd=0.01)
|
||||
tracker.record_spend("openai", "gpt-4o", cost_usd=0.02)
|
||||
assert tracker.get_daily_spend() == pytest.approx(0.03, abs=1e-9)
|
||||
|
||||
def test_record_spend_with_tier_label(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
cost = tracker.record_spend("anthropic", "haiku", tier="cloud_api")
|
||||
assert cost >= 0
|
||||
|
||||
def test_monthly_spend_includes_daily(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("anthropic", "haiku", cost_usd=5.00)
|
||||
assert tracker.get_monthly_spend() >= tracker.get_daily_spend()
|
||||
|
||||
|
||||
class TestBudgetTrackerCloudAllowed:
|
||||
def test_allowed_when_no_spend(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
with (
|
||||
patch.object(type(tracker._get_budget() if hasattr(tracker, "_get_budget") else tracker), "tier_cloud_daily_budget_usd", 5.0, create=True),
|
||||
):
|
||||
# Settings-based check — use real settings (5.0 default, 0 spent)
|
||||
assert tracker.cloud_allowed() is True
|
||||
|
||||
def test_blocked_when_daily_limit_exceeded(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("anthropic", "haiku", cost_usd=999.0)
|
||||
# With default daily limit of 5.0, 999 should block
|
||||
assert tracker.cloud_allowed() is False
|
||||
|
||||
def test_allowed_when_daily_limit_zero(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("anthropic", "haiku", cost_usd=999.0)
|
||||
with (
|
||||
patch("infrastructure.models.budget.settings") as mock_settings,
|
||||
):
|
||||
mock_settings.tier_cloud_daily_budget_usd = 0 # disabled
|
||||
mock_settings.tier_cloud_monthly_budget_usd = 0 # disabled
|
||||
assert tracker.cloud_allowed() is True
|
||||
|
||||
def test_blocked_when_monthly_limit_exceeded(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("anthropic", "haiku", cost_usd=999.0)
|
||||
with patch("infrastructure.models.budget.settings") as mock_settings:
|
||||
mock_settings.tier_cloud_daily_budget_usd = 0 # daily disabled
|
||||
mock_settings.tier_cloud_monthly_budget_usd = 10.0
|
||||
assert tracker.cloud_allowed() is False
|
||||
|
||||
|
||||
class TestBudgetTrackerSummary:
|
||||
def test_summary_keys_present(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
summary = tracker.get_summary()
|
||||
assert "daily_usd" in summary
|
||||
assert "monthly_usd" in summary
|
||||
assert "daily_limit_usd" in summary
|
||||
assert "monthly_limit_usd" in summary
|
||||
assert "daily_ok" in summary
|
||||
assert "monthly_ok" in summary
|
||||
|
||||
def test_summary_daily_ok_true_on_empty(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
summary = tracker.get_summary()
|
||||
assert summary["daily_ok"] is True
|
||||
assert summary["monthly_ok"] is True
|
||||
|
||||
def test_summary_daily_ok_false_when_exceeded(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("openai", "gpt-4o", cost_usd=999.0)
|
||||
summary = tracker.get_summary()
|
||||
assert summary["daily_ok"] is False
|
||||
|
||||
|
||||
# ── Singleton ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetBudgetTrackerSingleton:
|
||||
def test_returns_budget_tracker(self):
|
||||
import infrastructure.models.budget as bmod
|
||||
bmod._budget_tracker = None
|
||||
tracker = get_budget_tracker()
|
||||
assert isinstance(tracker, BudgetTracker)
|
||||
|
||||
def test_returns_same_instance(self):
|
||||
import infrastructure.models.budget as bmod
|
||||
bmod._budget_tracker = None
|
||||
t1 = get_budget_tracker()
|
||||
t2 = get_budget_tracker()
|
||||
assert t1 is t2
|
||||
380
tests/infrastructure/test_tiered_model_router.py
Normal file
380
tests/infrastructure/test_tiered_model_router.py
Normal file
@@ -0,0 +1,380 @@
|
||||
"""Tests for the tiered model router (issue #882).
|
||||
|
||||
Covers:
|
||||
- classify_tier() for Tier-1/2/3 routing
|
||||
- TieredModelRouter.route() with mocked CascadeRouter + BudgetTracker
|
||||
- Auto-escalation from Tier-1 on low-quality responses
|
||||
- Cloud-tier budget guard
|
||||
- Acceptance criteria from the issue:
|
||||
- "Walk to the next room" → LOCAL_FAST
|
||||
- "Plan the optimal path to become Hortator" → LOCAL_HEAVY
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.models.router import (
|
||||
TierLabel,
|
||||
TieredModelRouter,
|
||||
_is_low_quality,
|
||||
classify_tier,
|
||||
get_tiered_router,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ── classify_tier ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestClassifyTier:
|
||||
# ── Tier-1 (LOCAL_FAST) ────────────────────────────────────────────────
|
||||
|
||||
def test_simple_navigation_is_local_fast(self):
|
||||
assert classify_tier("walk to the next room") == TierLabel.LOCAL_FAST
|
||||
|
||||
def test_go_north_is_local_fast(self):
|
||||
assert classify_tier("go north") == TierLabel.LOCAL_FAST
|
||||
|
||||
def test_single_binary_choice_is_local_fast(self):
|
||||
assert classify_tier("yes") == TierLabel.LOCAL_FAST
|
||||
|
||||
def test_open_door_is_local_fast(self):
|
||||
assert classify_tier("open door") == TierLabel.LOCAL_FAST
|
||||
|
||||
def test_attack_is_local_fast(self):
|
||||
assert classify_tier("attack", {}) == TierLabel.LOCAL_FAST
|
||||
|
||||
# ── Tier-2 (LOCAL_HEAVY) ───────────────────────────────────────────────
|
||||
|
||||
def test_quest_planning_is_local_heavy(self):
|
||||
assert classify_tier("plan the optimal path to become Hortator") == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_strategy_keyword_is_local_heavy(self):
|
||||
assert classify_tier("what is the best strategy") == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_stuck_state_escalates_to_local_heavy(self):
|
||||
assert classify_tier("help me", {"stuck": True}) == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_require_t2_flag_is_local_heavy(self):
|
||||
assert classify_tier("go north", {"require_t2": True}) == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_long_input_is_local_heavy(self):
|
||||
long_task = "tell me about " + ("the dungeon " * 30)
|
||||
assert classify_tier(long_task) == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_active_quests_upgrades_to_local_heavy(self):
|
||||
ctx = {"active_quests": ["Q1", "Q2", "Q3"]}
|
||||
assert classify_tier("go north", ctx) == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_dialogue_active_upgrades_to_local_heavy(self):
|
||||
ctx = {"dialogue_active": True}
|
||||
assert classify_tier("yes", ctx) == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_analyze_is_local_heavy(self):
|
||||
assert classify_tier("analyze the situation") == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_optimize_is_local_heavy(self):
|
||||
assert classify_tier("optimize my build") == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_negotiate_is_local_heavy(self):
|
||||
assert classify_tier("negotiate with the Camonna Tong") == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_explain_is_local_heavy(self):
|
||||
assert classify_tier("explain the faction system") == TierLabel.LOCAL_HEAVY
|
||||
|
||||
# ── Tier-3 (CLOUD_API) ─────────────────────────────────────────────────
|
||||
|
||||
def test_require_cloud_flag_is_cloud_api(self):
|
||||
assert classify_tier("go north", {"require_cloud": True}) == TierLabel.CLOUD_API
|
||||
|
||||
def test_require_cloud_overrides_everything(self):
|
||||
assert classify_tier("yes", {"require_cloud": True}) == TierLabel.CLOUD_API
|
||||
|
||||
# ── Edge cases ────────────────────────────────────────────────────────
|
||||
|
||||
def test_empty_task_defaults_to_local_heavy(self):
|
||||
# Empty string → nothing classifies it as T1 or T3
|
||||
assert classify_tier("") == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert classify_tier("PLAN my route") == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_combat_active_upgrades_t1_to_heavy(self):
|
||||
ctx = {"combat_active": True}
|
||||
# "attack" is T1 word, but combat context → should NOT be LOCAL_FAST
|
||||
result = classify_tier("attack", ctx)
|
||||
assert result != TierLabel.LOCAL_FAST
|
||||
|
||||
|
||||
# ── _is_low_quality ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestIsLowQuality:
|
||||
def test_empty_is_low_quality(self):
|
||||
assert _is_low_quality("", TierLabel.LOCAL_FAST) is True
|
||||
|
||||
def test_whitespace_only_is_low_quality(self):
|
||||
assert _is_low_quality(" ", TierLabel.LOCAL_FAST) is True
|
||||
|
||||
def test_very_short_is_low_quality(self):
|
||||
assert _is_low_quality("ok", TierLabel.LOCAL_FAST) is True
|
||||
|
||||
def test_idontknow_is_low_quality(self):
|
||||
assert _is_low_quality("I don't know how to help with that.", TierLabel.LOCAL_FAST) is True
|
||||
|
||||
def test_not_sure_is_low_quality(self):
|
||||
assert _is_low_quality("I'm not sure about this.", TierLabel.LOCAL_FAST) is True
|
||||
|
||||
def test_as_an_ai_is_low_quality(self):
|
||||
assert _is_low_quality("As an AI, I cannot...", TierLabel.LOCAL_FAST) is True
|
||||
|
||||
def test_good_response_is_not_low_quality(self):
|
||||
response = "You move north into the Vivec Canton. The Ordinators watch your approach."
|
||||
assert _is_low_quality(response, TierLabel.LOCAL_FAST) is False
|
||||
|
||||
def test_t1_short_response_triggers_escalation(self):
|
||||
# Less than _ESCALATION_MIN_CHARS for T1
|
||||
assert _is_low_quality("OK, done.", TierLabel.LOCAL_FAST) is True
|
||||
|
||||
def test_borderline_ok_for_t2_not_t1(self):
|
||||
# Between _LOW_QUALITY_MIN_CHARS (20) and _ESCALATION_MIN_CHARS (60)
|
||||
# → low quality for T1 (escalation threshold), but acceptable for T2/T3
|
||||
response = "Done. The item is retrieved." # 28 chars: ≥20, <60
|
||||
assert _is_low_quality(response, TierLabel.LOCAL_FAST) is True
|
||||
assert _is_low_quality(response, TierLabel.LOCAL_HEAVY) is False
|
||||
|
||||
|
||||
# ── TieredModelRouter ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
_GOOD_CONTENT = (
|
||||
"You move north through the doorway into the next room. "
|
||||
"The stone walls glisten with moisture."
|
||||
) # 90 chars — well above the escalation threshold
|
||||
|
||||
|
||||
def _make_cascade_mock(content=_GOOD_CONTENT, model="llama3.1:8b"):
|
||||
mock = MagicMock()
|
||||
mock.complete = AsyncMock(
|
||||
return_value={
|
||||
"content": content,
|
||||
"provider": "ollama-local",
|
||||
"model": model,
|
||||
"latency_ms": 150.0,
|
||||
}
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
def _make_budget_mock(allowed=True):
|
||||
mock = MagicMock()
|
||||
mock.cloud_allowed = MagicMock(return_value=allowed)
|
||||
mock.record_spend = MagicMock(return_value=0.001)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestTieredModelRouterRoute:
|
||||
async def test_route_returns_tier_in_result(self):
|
||||
router = TieredModelRouter(cascade=_make_cascade_mock())
|
||||
result = await router.route("go north")
|
||||
assert "tier" in result
|
||||
assert result["tier"] == TierLabel.LOCAL_FAST
|
||||
|
||||
async def test_acceptance_walk_to_room_is_local_fast(self):
|
||||
"""Acceptance: 'Walk to the next room' → LOCAL_FAST."""
|
||||
router = TieredModelRouter(cascade=_make_cascade_mock())
|
||||
result = await router.route("Walk to the next room")
|
||||
assert result["tier"] == TierLabel.LOCAL_FAST
|
||||
|
||||
async def test_acceptance_plan_hortator_is_local_heavy(self):
|
||||
"""Acceptance: 'Plan the optimal path to become Hortator' → LOCAL_HEAVY."""
|
||||
router = TieredModelRouter(
|
||||
cascade=_make_cascade_mock(model="hermes3:70b"),
|
||||
)
|
||||
result = await router.route("Plan the optimal path to become Hortator")
|
||||
assert result["tier"] == TierLabel.LOCAL_HEAVY
|
||||
|
||||
async def test_t1_low_quality_escalates_to_t2(self):
|
||||
"""Failed Tier-1 response auto-escalates to Tier-2."""
|
||||
call_models = []
|
||||
cascade = MagicMock()
|
||||
|
||||
async def complete_side_effect(messages, model, temperature, max_tokens):
|
||||
call_models.append(model)
|
||||
# First call (T1) returns a low-quality response
|
||||
if len(call_models) == 1:
|
||||
return {
|
||||
"content": "I don't know.",
|
||||
"provider": "ollama",
|
||||
"model": model,
|
||||
"latency_ms": 50,
|
||||
}
|
||||
# Second call (T2) returns a good response
|
||||
return {
|
||||
"content": "You move to the northern passage, passing through the Dunmer stronghold.",
|
||||
"provider": "ollama",
|
||||
"model": model,
|
||||
"latency_ms": 800,
|
||||
}
|
||||
|
||||
cascade.complete = complete_side_effect
|
||||
|
||||
router = TieredModelRouter(cascade=cascade, auto_escalate=True)
|
||||
result = await router.route("go north")
|
||||
|
||||
assert len(call_models) == 2, "Should have called twice (T1 escalated to T2)"
|
||||
assert result["tier"] == TierLabel.LOCAL_HEAVY
|
||||
|
||||
async def test_auto_escalate_false_no_escalation(self):
|
||||
"""With auto_escalate=False, low-quality T1 response is returned as-is."""
|
||||
call_count = {"n": 0}
|
||||
cascade = MagicMock()
|
||||
|
||||
async def complete_side_effect(**kwargs):
|
||||
call_count["n"] += 1
|
||||
return {
|
||||
"content": "I don't know.",
|
||||
"provider": "ollama",
|
||||
"model": "llama3.1:8b",
|
||||
"latency_ms": 50,
|
||||
}
|
||||
|
||||
cascade.complete = AsyncMock(side_effect=complete_side_effect)
|
||||
router = TieredModelRouter(cascade=cascade, auto_escalate=False)
|
||||
result = await router.route("go north")
|
||||
assert call_count["n"] == 1
|
||||
assert result["tier"] == TierLabel.LOCAL_FAST
|
||||
|
||||
async def test_t2_failure_escalates_to_cloud(self):
|
||||
"""Tier-2 failure escalates to Cloud API (when budget allows)."""
|
||||
cascade = MagicMock()
|
||||
call_models = []
|
||||
|
||||
async def complete_side_effect(messages, model, temperature, max_tokens):
|
||||
call_models.append(model)
|
||||
if "hermes3" in model or "70b" in model.lower():
|
||||
raise RuntimeError("Tier-2 model unavailable")
|
||||
return {
|
||||
"content": "Cloud response here.",
|
||||
"provider": "anthropic",
|
||||
"model": model,
|
||||
"latency_ms": 1200,
|
||||
}
|
||||
|
||||
cascade.complete = complete_side_effect
|
||||
|
||||
budget = _make_budget_mock(allowed=True)
|
||||
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
|
||||
result = await router.route("plan my route", context={"require_t2": True})
|
||||
assert result["tier"] == TierLabel.CLOUD_API
|
||||
|
||||
async def test_cloud_blocked_by_budget_raises(self):
|
||||
"""Cloud tier blocked when budget is exhausted."""
|
||||
cascade = MagicMock()
|
||||
cascade.complete = AsyncMock(side_effect=RuntimeError("T2 fail"))
|
||||
|
||||
budget = _make_budget_mock(allowed=False)
|
||||
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
|
||||
|
||||
with pytest.raises(RuntimeError, match="budget limit"):
|
||||
await router.route("plan my route", context={"require_t2": True})
|
||||
|
||||
async def test_explicit_cloud_tier_uses_cloud_model(self):
|
||||
cascade = _make_cascade_mock(model="claude-haiku-4-5")
|
||||
budget = _make_budget_mock(allowed=True)
|
||||
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
|
||||
result = await router.route("go north", context={"require_cloud": True})
|
||||
assert result["tier"] == TierLabel.CLOUD_API
|
||||
|
||||
async def test_cloud_spend_recorded_with_usage(self):
|
||||
"""Cloud spend is recorded when the response includes usage info."""
|
||||
cascade = MagicMock()
|
||||
cascade.complete = AsyncMock(
|
||||
return_value={
|
||||
"content": "Cloud answer.",
|
||||
"provider": "anthropic",
|
||||
"model": "claude-haiku-4-5",
|
||||
"latency_ms": 900,
|
||||
"usage": {"prompt_tokens": 50, "completion_tokens": 100},
|
||||
}
|
||||
)
|
||||
budget = _make_budget_mock(allowed=True)
|
||||
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
|
||||
result = await router.route("go north", context={"require_cloud": True})
|
||||
budget.record_spend.assert_called_once()
|
||||
assert "cost_usd" in result
|
||||
|
||||
async def test_cloud_spend_not_recorded_without_usage(self):
|
||||
"""Cloud spend is not recorded when usage info is absent."""
|
||||
cascade = MagicMock()
|
||||
cascade.complete = AsyncMock(
|
||||
return_value={
|
||||
"content": "Cloud answer.",
|
||||
"provider": "anthropic",
|
||||
"model": "claude-haiku-4-5",
|
||||
"latency_ms": 900,
|
||||
# no "usage" key
|
||||
}
|
||||
)
|
||||
budget = _make_budget_mock(allowed=True)
|
||||
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
|
||||
result = await router.route("go north", context={"require_cloud": True})
|
||||
budget.record_spend.assert_not_called()
|
||||
assert "cost_usd" not in result
|
||||
|
||||
async def test_custom_tier_models_respected(self):
|
||||
cascade = _make_cascade_mock()
|
||||
router = TieredModelRouter(
|
||||
cascade=cascade,
|
||||
tier_models={TierLabel.LOCAL_FAST: "llama3.2:3b"},
|
||||
)
|
||||
await router.route("go north")
|
||||
call_kwargs = cascade.complete.call_args
|
||||
assert call_kwargs.kwargs["model"] == "llama3.2:3b"
|
||||
|
||||
async def test_messages_override_used_when_provided(self):
|
||||
cascade = _make_cascade_mock()
|
||||
router = TieredModelRouter(cascade=cascade)
|
||||
custom_msgs = [{"role": "user", "content": "custom message"}]
|
||||
await router.route("go north", messages=custom_msgs)
|
||||
call_kwargs = cascade.complete.call_args
|
||||
assert call_kwargs.kwargs["messages"] == custom_msgs
|
||||
|
||||
async def test_temperature_forwarded(self):
|
||||
cascade = _make_cascade_mock()
|
||||
router = TieredModelRouter(cascade=cascade)
|
||||
await router.route("go north", temperature=0.7)
|
||||
call_kwargs = cascade.complete.call_args
|
||||
assert call_kwargs.kwargs["temperature"] == 0.7
|
||||
|
||||
async def test_max_tokens_forwarded(self):
|
||||
cascade = _make_cascade_mock()
|
||||
router = TieredModelRouter(cascade=cascade)
|
||||
await router.route("go north", max_tokens=128)
|
||||
call_kwargs = cascade.complete.call_args
|
||||
assert call_kwargs.kwargs["max_tokens"] == 128
|
||||
|
||||
|
||||
class TestTieredModelRouterClassify:
|
||||
def test_classify_delegates_to_classify_tier(self):
|
||||
router = TieredModelRouter(cascade=MagicMock())
|
||||
assert router.classify("go north") == classify_tier("go north")
|
||||
assert router.classify("plan the quest") == classify_tier("plan the quest")
|
||||
|
||||
|
||||
class TestGetTieredRouterSingleton:
|
||||
def test_returns_tiered_router_instance(self):
|
||||
import infrastructure.models.router as rmod
|
||||
rmod._tiered_router = None
|
||||
router = get_tiered_router()
|
||||
assert isinstance(router, TieredModelRouter)
|
||||
|
||||
def test_singleton_returns_same_instance(self):
|
||||
import infrastructure.models.router as rmod
|
||||
rmod._tiered_router = None
|
||||
r1 = get_tiered_router()
|
||||
r2 = get_tiered_router()
|
||||
assert r1 is r2
|
||||
Reference in New Issue
Block a user