forked from Rockachopa/Timmy-time-dashboard
Compare commits
5 Commits
claude/iss
...
kimi/issue
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
51b1338453 | ||
| bde7232ece | |||
| fc4426954e | |||
| 5be4ecb9ef | |||
| 4f80cfcd58 |
@@ -0,0 +1 @@
|
||||
"""Timmy Time Dashboard — source root package."""
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
"""Central pydantic-settings configuration for Timmy Time Dashboard.
|
||||
|
||||
All environment variable access goes through the ``settings`` singleton
|
||||
exported from this module — never use ``os.environ.get()`` in app code.
|
||||
"""
|
||||
import logging as _logging
|
||||
import os
|
||||
import sys
|
||||
@@ -128,6 +133,23 @@ class Settings(BaseSettings):
|
||||
anthropic_api_key: str = ""
|
||||
claude_model: str = "haiku"
|
||||
|
||||
# ── Tiered Model Router (issue #882) ─────────────────────────────────
|
||||
# Three-tier cascade: Local 8B (free, fast) → Local 70B (free, slower)
|
||||
# → Cloud API (paid, best). Override model names per tier via env vars.
|
||||
#
|
||||
# TIER_LOCAL_FAST_MODEL — Tier-1 model name in Ollama (default: llama3.1:8b)
|
||||
# TIER_LOCAL_HEAVY_MODEL — Tier-2 model name in Ollama (default: hermes3:70b)
|
||||
# TIER_CLOUD_MODEL — Tier-3 cloud model name (default: claude-haiku-4-5)
|
||||
#
|
||||
# Budget limits for the cloud tier (0 = unlimited):
|
||||
# TIER_CLOUD_DAILY_BUDGET_USD — daily ceiling in USD (default: 5.0)
|
||||
# TIER_CLOUD_MONTHLY_BUDGET_USD — monthly ceiling in USD (default: 50.0)
|
||||
tier_local_fast_model: str = "llama3.1:8b"
|
||||
tier_local_heavy_model: str = "hermes3:70b"
|
||||
tier_cloud_model: str = "claude-haiku-4-5"
|
||||
tier_cloud_daily_budget_usd: float = 5.0
|
||||
tier_cloud_monthly_budget_usd: float = 50.0
|
||||
|
||||
# ── Content Moderation ──────────────────────────────────────────────
|
||||
# Three-layer moderation pipeline for AI narrator output.
|
||||
# Uses Llama Guard via Ollama with regex fallback.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
"""SQLAlchemy ORM models for the CALM task-management and journaling system."""
|
||||
from datetime import UTC, date, datetime
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
"""SQLAlchemy engine, session factory, and declarative Base for the CALM module."""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
"""Dashboard routes for agent chat interactions and tool-call display."""
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
"""Dashboard routes for the CALM task management and daily journaling interface."""
|
||||
import logging
|
||||
from datetime import UTC, date, datetime
|
||||
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
"""Infrastructure models package."""
|
||||
|
||||
from infrastructure.models.budget import (
|
||||
BudgetTracker,
|
||||
SpendRecord,
|
||||
estimate_cost_usd,
|
||||
get_budget_tracker,
|
||||
)
|
||||
from infrastructure.models.multimodal import (
|
||||
ModelCapability,
|
||||
ModelInfo,
|
||||
@@ -17,6 +23,12 @@ from infrastructure.models.registry import (
|
||||
ModelRole,
|
||||
model_registry,
|
||||
)
|
||||
from infrastructure.models.router import (
|
||||
TierLabel,
|
||||
TieredModelRouter,
|
||||
classify_tier,
|
||||
get_tiered_router,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Registry
|
||||
@@ -34,4 +46,14 @@ __all__ = [
|
||||
"model_supports_tools",
|
||||
"model_supports_vision",
|
||||
"pull_model_with_fallback",
|
||||
# Tiered router
|
||||
"TierLabel",
|
||||
"TieredModelRouter",
|
||||
"classify_tier",
|
||||
"get_tiered_router",
|
||||
# Budget tracker
|
||||
"BudgetTracker",
|
||||
"SpendRecord",
|
||||
"estimate_cost_usd",
|
||||
"get_budget_tracker",
|
||||
]
|
||||
|
||||
302
src/infrastructure/models/budget.py
Normal file
302
src/infrastructure/models/budget.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""Cloud API budget tracker for the three-tier model router.
|
||||
|
||||
Tracks cloud API spend (daily / monthly) and enforces configurable limits.
|
||||
SQLite-backed with in-memory fallback — degrades gracefully if the database
|
||||
is unavailable.
|
||||
|
||||
References:
|
||||
- Issue #882 — Model Tiering Router: Local 8B / Hermes 70B / Cloud API Cascade
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, date, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Cost estimates (USD per 1 K tokens, input / output) ──────────────────────
|
||||
# Updated 2026-03. Estimates only — actual costs vary by tier/usage.
|
||||
_COST_PER_1K: dict[str, dict[str, float]] = {
|
||||
# Claude models
|
||||
"claude-haiku-4-5": {"input": 0.00025, "output": 0.00125},
|
||||
"claude-sonnet-4-5": {"input": 0.003, "output": 0.015},
|
||||
"claude-opus-4-5": {"input": 0.015, "output": 0.075},
|
||||
"haiku": {"input": 0.00025, "output": 0.00125},
|
||||
"sonnet": {"input": 0.003, "output": 0.015},
|
||||
"opus": {"input": 0.015, "output": 0.075},
|
||||
# GPT-4o
|
||||
"gpt-4o-mini": {"input": 0.00015, "output": 0.0006},
|
||||
"gpt-4o": {"input": 0.0025, "output": 0.01},
|
||||
# Grok (xAI)
|
||||
"grok-3-fast": {"input": 0.003, "output": 0.015},
|
||||
"grok-3": {"input": 0.005, "output": 0.025},
|
||||
}
|
||||
_DEFAULT_COST: dict[str, float] = {"input": 0.003, "output": 0.015} # conservative fallback
|
||||
|
||||
|
||||
def estimate_cost_usd(model: str, tokens_in: int, tokens_out: int) -> float:
|
||||
"""Estimate the cost of a single request in USD.
|
||||
|
||||
Matches the model name by substring so versioned names like
|
||||
``claude-haiku-4-5-20251001`` still resolve correctly.
|
||||
|
||||
Args:
|
||||
model: Model name as passed to the provider.
|
||||
tokens_in: Number of input (prompt) tokens consumed.
|
||||
tokens_out: Number of output (completion) tokens generated.
|
||||
|
||||
Returns:
|
||||
Estimated cost in USD (may be zero for unknown models).
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
rates = _DEFAULT_COST
|
||||
for key, rate in _COST_PER_1K.items():
|
||||
if key in model_lower:
|
||||
rates = rate
|
||||
break
|
||||
return (tokens_in * rates["input"] + tokens_out * rates["output"]) / 1000.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpendRecord:
|
||||
"""A single spend event."""
|
||||
|
||||
ts: float
|
||||
provider: str
|
||||
model: str
|
||||
tokens_in: int
|
||||
tokens_out: int
|
||||
cost_usd: float
|
||||
tier: str
|
||||
|
||||
|
||||
class BudgetTracker:
|
||||
"""Tracks cloud API spend with configurable daily / monthly limits.
|
||||
|
||||
Persists spend records to SQLite (``data/budget.db`` by default).
|
||||
Falls back to in-memory tracking when the database is unavailable —
|
||||
budget enforcement still works; records are lost on restart.
|
||||
|
||||
Limits are read from ``settings``:
|
||||
|
||||
* ``tier_cloud_daily_budget_usd`` — daily ceiling (0 = disabled)
|
||||
* ``tier_cloud_monthly_budget_usd`` — monthly ceiling (0 = disabled)
|
||||
|
||||
Usage::
|
||||
|
||||
tracker = BudgetTracker()
|
||||
|
||||
if tracker.cloud_allowed():
|
||||
# … make cloud API call …
|
||||
tracker.record_spend("anthropic", "claude-haiku-4-5", 100, 200)
|
||||
|
||||
summary = tracker.get_summary()
|
||||
print(summary["daily_usd"], "/", summary["daily_limit_usd"])
|
||||
"""
|
||||
|
||||
_DB_PATH = "data/budget.db"
|
||||
|
||||
def __init__(self, db_path: str | None = None) -> None:
|
||||
"""Initialise the tracker.
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite database. Defaults to
|
||||
``data/budget.db``. Pass ``":memory:"`` for tests.
|
||||
"""
|
||||
self._db_path = db_path or self._DB_PATH
|
||||
self._lock = threading.Lock()
|
||||
self._in_memory: list[SpendRecord] = []
|
||||
self._db_ok = False
|
||||
self._init_db()
|
||||
|
||||
# ── Database initialisation ──────────────────────────────────────────────
|
||||
|
||||
def _init_db(self) -> None:
|
||||
"""Create the spend table (and parent directory) if needed."""
|
||||
try:
|
||||
if self._db_path != ":memory:":
|
||||
Path(self._db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS cloud_spend (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
ts REAL NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
tokens_in INTEGER NOT NULL DEFAULT 0,
|
||||
tokens_out INTEGER NOT NULL DEFAULT 0,
|
||||
cost_usd REAL NOT NULL DEFAULT 0.0,
|
||||
tier TEXT NOT NULL DEFAULT 'cloud'
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_spend_ts ON cloud_spend(ts)"
|
||||
)
|
||||
self._db_ok = True
|
||||
logger.debug("BudgetTracker: SQLite initialised at %s", self._db_path)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"BudgetTracker: SQLite unavailable, using in-memory fallback: %s", exc
|
||||
)
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
return sqlite3.connect(self._db_path, timeout=5)
|
||||
|
||||
# ── Public API ───────────────────────────────────────────────────────────
|
||||
|
||||
def record_spend(
|
||||
self,
|
||||
provider: str,
|
||||
model: str,
|
||||
tokens_in: int = 0,
|
||||
tokens_out: int = 0,
|
||||
cost_usd: float | None = None,
|
||||
tier: str = "cloud",
|
||||
) -> float:
|
||||
"""Record a cloud API spend event and return the cost recorded.
|
||||
|
||||
Args:
|
||||
provider: Provider name (e.g. ``"anthropic"``, ``"openai"``).
|
||||
model: Model name used for the request.
|
||||
tokens_in: Input token count (prompt).
|
||||
tokens_out: Output token count (completion).
|
||||
cost_usd: Explicit cost override. If ``None``, the cost is
|
||||
estimated from the token counts and model rates.
|
||||
tier: Tier label for the request (default ``"cloud"``).
|
||||
|
||||
Returns:
|
||||
The cost recorded in USD.
|
||||
"""
|
||||
if cost_usd is None:
|
||||
cost_usd = estimate_cost_usd(model, tokens_in, tokens_out)
|
||||
|
||||
ts = time.time()
|
||||
record = SpendRecord(ts, provider, model, tokens_in, tokens_out, cost_usd, tier)
|
||||
|
||||
with self._lock:
|
||||
if self._db_ok:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO cloud_spend
|
||||
(ts, provider, model, tokens_in, tokens_out, cost_usd, tier)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(ts, provider, model, tokens_in, tokens_out, cost_usd, tier),
|
||||
)
|
||||
logger.debug(
|
||||
"BudgetTracker: recorded %.6f USD (%s/%s, in=%d out=%d tier=%s)",
|
||||
cost_usd,
|
||||
provider,
|
||||
model,
|
||||
tokens_in,
|
||||
tokens_out,
|
||||
tier,
|
||||
)
|
||||
return cost_usd
|
||||
except Exception as exc:
|
||||
logger.warning("BudgetTracker: DB write failed, falling back: %s", exc)
|
||||
self._in_memory.append(record)
|
||||
|
||||
return cost_usd
|
||||
|
||||
def get_daily_spend(self) -> float:
|
||||
"""Return total cloud spend for the current UTC day in USD."""
|
||||
today = date.today()
|
||||
since = datetime(today.year, today.month, today.day, tzinfo=UTC).timestamp()
|
||||
return self._query_spend(since)
|
||||
|
||||
def get_monthly_spend(self) -> float:
|
||||
"""Return total cloud spend for the current UTC month in USD."""
|
||||
today = date.today()
|
||||
since = datetime(today.year, today.month, 1, tzinfo=UTC).timestamp()
|
||||
return self._query_spend(since)
|
||||
|
||||
def cloud_allowed(self) -> bool:
|
||||
"""Return ``True`` if cloud API spend is within configured limits.
|
||||
|
||||
Checks both daily and monthly ceilings. A limit of ``0`` disables
|
||||
that particular check.
|
||||
"""
|
||||
daily_limit = settings.tier_cloud_daily_budget_usd
|
||||
monthly_limit = settings.tier_cloud_monthly_budget_usd
|
||||
|
||||
if daily_limit > 0:
|
||||
daily_spend = self.get_daily_spend()
|
||||
if daily_spend >= daily_limit:
|
||||
logger.warning(
|
||||
"BudgetTracker: daily cloud budget exhausted (%.4f / %.4f USD)",
|
||||
daily_spend,
|
||||
daily_limit,
|
||||
)
|
||||
return False
|
||||
|
||||
if monthly_limit > 0:
|
||||
monthly_spend = self.get_monthly_spend()
|
||||
if monthly_spend >= monthly_limit:
|
||||
logger.warning(
|
||||
"BudgetTracker: monthly cloud budget exhausted (%.4f / %.4f USD)",
|
||||
monthly_spend,
|
||||
monthly_limit,
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_summary(self) -> dict:
|
||||
"""Return a spend summary dict suitable for dashboards / logging.
|
||||
|
||||
Keys: ``daily_usd``, ``monthly_usd``, ``daily_limit_usd``,
|
||||
``monthly_limit_usd``, ``daily_ok``, ``monthly_ok``.
|
||||
"""
|
||||
daily = self.get_daily_spend()
|
||||
monthly = self.get_monthly_spend()
|
||||
daily_limit = settings.tier_cloud_daily_budget_usd
|
||||
monthly_limit = settings.tier_cloud_monthly_budget_usd
|
||||
return {
|
||||
"daily_usd": round(daily, 6),
|
||||
"monthly_usd": round(monthly, 6),
|
||||
"daily_limit_usd": daily_limit,
|
||||
"monthly_limit_usd": monthly_limit,
|
||||
"daily_ok": daily_limit <= 0 or daily < daily_limit,
|
||||
"monthly_ok": monthly_limit <= 0 or monthly < monthly_limit,
|
||||
}
|
||||
|
||||
# ── Internal helpers ─────────────────────────────────────────────────────
|
||||
|
||||
def _query_spend(self, since_ts: float) -> float:
|
||||
"""Sum ``cost_usd`` for records with ``ts >= since_ts``."""
|
||||
if self._db_ok:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT COALESCE(SUM(cost_usd), 0.0) FROM cloud_spend WHERE ts >= ?",
|
||||
(since_ts,),
|
||||
).fetchone()
|
||||
return float(row[0]) if row else 0.0
|
||||
except Exception as exc:
|
||||
logger.warning("BudgetTracker: DB read failed: %s", exc)
|
||||
# In-memory fallback
|
||||
return sum(r.cost_usd for r in self._in_memory if r.ts >= since_ts)
|
||||
|
||||
|
||||
# ── Module-level singleton ────────────────────────────────────────────────────
|
||||
|
||||
_budget_tracker: BudgetTracker | None = None
|
||||
|
||||
|
||||
def get_budget_tracker() -> BudgetTracker:
|
||||
"""Get or create the module-level BudgetTracker singleton."""
|
||||
global _budget_tracker
|
||||
if _budget_tracker is None:
|
||||
_budget_tracker = BudgetTracker()
|
||||
return _budget_tracker
|
||||
427
src/infrastructure/models/router.py
Normal file
427
src/infrastructure/models/router.py
Normal file
@@ -0,0 +1,427 @@
|
||||
"""Three-tier model router — Local 8B / Local 70B / Cloud API Cascade.
|
||||
|
||||
Selects the cheapest-sufficient LLM for each request using a heuristic
|
||||
task-complexity classifier. Tier 3 (Cloud API) is only used when Tier 2
|
||||
fails or the budget guard allows it.
|
||||
|
||||
Tiers
|
||||
-----
|
||||
Tier 1 — LOCAL_FAST (Llama 3.1 8B / Hermes 3 8B via Ollama, free, ~0.3-1 s)
|
||||
Navigation, basic interactions, simple decisions.
|
||||
|
||||
Tier 2 — LOCAL_HEAVY (Hermes 3/4 70B via Ollama, free, ~5-10 s for 200 tok)
|
||||
Quest planning, dialogue strategy, complex reasoning.
|
||||
|
||||
Tier 3 — CLOUD_API (Claude / GPT-4o, paid ~$5-15/hr heavy use)
|
||||
Recovery from Tier 2 failures, novel situations, multi-step planning.
|
||||
|
||||
Routing logic
|
||||
-------------
|
||||
1. Classify the task using keyword / length / context heuristics (no LLM call).
|
||||
2. Route to the appropriate tier.
|
||||
3. On Tier-1 low-quality response → auto-escalate to Tier 2.
|
||||
4. On Tier-2 failure or explicit ``require_cloud=True`` → Tier 3 (if budget allows).
|
||||
5. Log tier used, model, latency, estimated cost for every request.
|
||||
|
||||
References:
|
||||
- Issue #882 — Model Tiering Router: Local 8B / Hermes 70B / Cloud API Cascade
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Tier definitions ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TierLabel(StrEnum):
|
||||
"""Three cost-sorted model tiers."""
|
||||
|
||||
LOCAL_FAST = "local_fast" # 8B local, always hot, free
|
||||
LOCAL_HEAVY = "local_heavy" # 70B local, free but slower
|
||||
CLOUD_API = "cloud_api" # Paid cloud backend (Claude / GPT-4o)
|
||||
|
||||
|
||||
# ── Default model assignments (overridable via Settings) ──────────────────────
|
||||
|
||||
_DEFAULT_TIER_MODELS: dict[TierLabel, str] = {
|
||||
TierLabel.LOCAL_FAST: "llama3.1:8b",
|
||||
TierLabel.LOCAL_HEAVY: "hermes3:70b",
|
||||
TierLabel.CLOUD_API: "claude-haiku-4-5",
|
||||
}
|
||||
|
||||
# ── Classification vocabulary ─────────────────────────────────────────────────
|
||||
|
||||
# Patterns that indicate a Tier-1 (simple) task
|
||||
_T1_WORDS: frozenset[str] = frozenset(
|
||||
{
|
||||
"go", "move", "walk", "run",
|
||||
"north", "south", "east", "west", "up", "down", "left", "right",
|
||||
"yes", "no", "ok", "okay",
|
||||
"open", "close", "take", "drop", "look",
|
||||
"pick", "use", "wait", "rest", "save",
|
||||
"attack", "flee", "jump", "crouch",
|
||||
"status", "ping", "list", "show", "get", "check",
|
||||
}
|
||||
)
|
||||
|
||||
# Patterns that indicate a Tier-2 or Tier-3 task
|
||||
_T2_PHRASES: tuple[str, ...] = (
|
||||
"plan", "strategy", "optimize", "optimise",
|
||||
"quest", "stuck", "recover",
|
||||
"negotiate", "persuade", "faction", "reputation",
|
||||
"analyze", "analyse", "evaluate", "decide",
|
||||
"complex", "multi-step", "long-term",
|
||||
"how do i", "what should i do", "help me figure",
|
||||
"what is the best", "recommend", "best way",
|
||||
"explain", "describe in detail", "walk me through",
|
||||
"compare", "design", "implement", "refactor",
|
||||
"debug", "diagnose", "root cause",
|
||||
)
|
||||
|
||||
# Low-quality response detection patterns
|
||||
_LOW_QUALITY_PATTERNS: tuple[re.Pattern, ...] = (
|
||||
re.compile(r"i\s+don'?t\s+know", re.IGNORECASE),
|
||||
re.compile(r"i'm\s+not\s+sure", re.IGNORECASE),
|
||||
re.compile(r"i\s+cannot\s+(help|assist|answer)", re.IGNORECASE),
|
||||
re.compile(r"i\s+apologize", re.IGNORECASE),
|
||||
re.compile(r"as an ai", re.IGNORECASE),
|
||||
re.compile(r"i\s+don'?t\s+have\s+(enough|sufficient)\s+information", re.IGNORECASE),
|
||||
)
|
||||
|
||||
# Response is definitely low-quality if shorter than this many characters
|
||||
_LOW_QUALITY_MIN_CHARS = 20
|
||||
# Response is suspicious if shorter than this many chars for a complex task
|
||||
_ESCALATION_MIN_CHARS = 60
|
||||
|
||||
|
||||
def classify_tier(task: str, context: dict | None = None) -> TierLabel:
|
||||
"""Classify a task to the cheapest-sufficient model tier.
|
||||
|
||||
Classification priority (highest wins):
|
||||
1. ``context["require_cloud"] = True`` → CLOUD_API
|
||||
2. Any Tier-2 phrase or stuck/recovery signal → LOCAL_HEAVY
|
||||
3. Short task with only Tier-1 words, no active context → LOCAL_FAST
|
||||
4. Default → LOCAL_HEAVY (safe fallback for unknown tasks)
|
||||
|
||||
Args:
|
||||
task: Natural-language task or user input.
|
||||
context: Optional context dict. Recognised keys:
|
||||
``require_cloud`` (bool), ``stuck`` (bool),
|
||||
``require_t2`` (bool), ``active_quests`` (list),
|
||||
``dialogue_active`` (bool), ``combat_active`` (bool).
|
||||
|
||||
Returns:
|
||||
The cheapest ``TierLabel`` sufficient for the task.
|
||||
"""
|
||||
ctx = context or {}
|
||||
task_lower = task.lower()
|
||||
words = set(task_lower.split())
|
||||
|
||||
# ── Explicit cloud override ──────────────────────────────────────────────
|
||||
if ctx.get("require_cloud"):
|
||||
logger.debug("classify_tier → CLOUD_API (explicit require_cloud)")
|
||||
return TierLabel.CLOUD_API
|
||||
|
||||
# ── Tier-2 / complexity signals ──────────────────────────────────────────
|
||||
t2_phrase_hit = any(phrase in task_lower for phrase in _T2_PHRASES)
|
||||
t2_word_hit = bool(words & {"plan", "strategy", "optimize", "optimise", "quest",
|
||||
"stuck", "recover", "analyze", "analyse", "evaluate"})
|
||||
is_stuck = bool(ctx.get("stuck"))
|
||||
require_t2 = bool(ctx.get("require_t2"))
|
||||
long_input = len(task) > 300 # long tasks warrant more capable model
|
||||
deep_context = (
|
||||
len(ctx.get("active_quests", [])) >= 3
|
||||
or ctx.get("dialogue_active")
|
||||
)
|
||||
|
||||
if t2_phrase_hit or t2_word_hit or is_stuck or require_t2 or long_input or deep_context:
|
||||
logger.debug(
|
||||
"classify_tier → LOCAL_HEAVY (phrase=%s word=%s stuck=%s explicit=%s long=%s ctx=%s)",
|
||||
t2_phrase_hit, t2_word_hit, is_stuck, require_t2, long_input, deep_context,
|
||||
)
|
||||
return TierLabel.LOCAL_HEAVY
|
||||
|
||||
# ── Tier-1 signals ───────────────────────────────────────────────────────
|
||||
t1_word_hit = bool(words & _T1_WORDS)
|
||||
task_short = len(task.split()) <= 8
|
||||
no_active_context = (
|
||||
not ctx.get("active_quests")
|
||||
and not ctx.get("dialogue_active")
|
||||
and not ctx.get("combat_active")
|
||||
)
|
||||
|
||||
if t1_word_hit and task_short and no_active_context:
|
||||
logger.debug(
|
||||
"classify_tier → LOCAL_FAST (words=%s short=%s)", t1_word_hit, task_short
|
||||
)
|
||||
return TierLabel.LOCAL_FAST
|
||||
|
||||
# ── Default: LOCAL_HEAVY (safe for anything unclassified) ────────────────
|
||||
logger.debug("classify_tier → LOCAL_HEAVY (default)")
|
||||
return TierLabel.LOCAL_HEAVY
|
||||
|
||||
|
||||
def _is_low_quality(content: str, tier: TierLabel) -> bool:
|
||||
"""Return True if the response looks like it should be escalated.
|
||||
|
||||
Used for automatic Tier-1 → Tier-2 escalation.
|
||||
|
||||
Args:
|
||||
content: LLM response text.
|
||||
tier: The tier that produced the response.
|
||||
|
||||
Returns:
|
||||
True if the response is likely too low-quality to be useful.
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return True
|
||||
|
||||
stripped = content.strip()
|
||||
|
||||
# Too short to be useful
|
||||
if len(stripped) < _LOW_QUALITY_MIN_CHARS:
|
||||
return True
|
||||
|
||||
# Insufficient for a supposedly complex-enough task
|
||||
if tier == TierLabel.LOCAL_FAST and len(stripped) < _ESCALATION_MIN_CHARS:
|
||||
return True
|
||||
|
||||
# Matches known "I can't help" patterns
|
||||
for pattern in _LOW_QUALITY_PATTERNS:
|
||||
if pattern.search(stripped):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class TieredModelRouter:
|
||||
"""Routes LLM requests across the Local 8B / Local 70B / Cloud API tiers.
|
||||
|
||||
Wraps CascadeRouter with:
|
||||
- Heuristic tier classification via ``classify_tier()``
|
||||
- Automatic Tier-1 → Tier-2 escalation on low-quality responses
|
||||
- Cloud-tier budget guard via ``BudgetTracker``
|
||||
- Per-request logging: tier, model, latency, estimated cost
|
||||
|
||||
Usage::
|
||||
|
||||
router = TieredModelRouter()
|
||||
|
||||
result = await router.route(
|
||||
task="Walk to the next room",
|
||||
context={},
|
||||
)
|
||||
print(result["content"], result["tier"]) # "Move north.", "local_fast"
|
||||
|
||||
# Force heavy tier
|
||||
result = await router.route(
|
||||
task="Plan the optimal path to become Hortator",
|
||||
context={"require_t2": True},
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cascade: Any | None = None,
|
||||
budget_tracker: Any | None = None,
|
||||
tier_models: dict[TierLabel, str] | None = None,
|
||||
auto_escalate: bool = True,
|
||||
) -> None:
|
||||
"""Initialise the tiered router.
|
||||
|
||||
Args:
|
||||
cascade: CascadeRouter instance. If ``None``, the
|
||||
singleton from ``get_router()`` is used lazily.
|
||||
budget_tracker: BudgetTracker instance. If ``None``, the
|
||||
singleton from ``get_budget_tracker()`` is used.
|
||||
tier_models: Override default model names per tier.
|
||||
auto_escalate: When ``True``, low-quality Tier-1 responses
|
||||
automatically retry on Tier-2.
|
||||
"""
|
||||
self._cascade = cascade
|
||||
self._budget = budget_tracker
|
||||
self._tier_models: dict[TierLabel, str] = dict(_DEFAULT_TIER_MODELS)
|
||||
self._auto_escalate = auto_escalate
|
||||
|
||||
# Apply settings-level overrides (can still be overridden per-instance)
|
||||
if settings.tier_local_fast_model:
|
||||
self._tier_models[TierLabel.LOCAL_FAST] = settings.tier_local_fast_model
|
||||
if settings.tier_local_heavy_model:
|
||||
self._tier_models[TierLabel.LOCAL_HEAVY] = settings.tier_local_heavy_model
|
||||
if settings.tier_cloud_model:
|
||||
self._tier_models[TierLabel.CLOUD_API] = settings.tier_cloud_model
|
||||
|
||||
if tier_models:
|
||||
self._tier_models.update(tier_models)
|
||||
|
||||
# ── Lazy singletons ──────────────────────────────────────────────────────
|
||||
|
||||
def _get_cascade(self) -> Any:
|
||||
if self._cascade is None:
|
||||
from infrastructure.router.cascade import get_router
|
||||
self._cascade = get_router()
|
||||
return self._cascade
|
||||
|
||||
def _get_budget(self) -> Any:
|
||||
if self._budget is None:
|
||||
from infrastructure.models.budget import get_budget_tracker
|
||||
self._budget = get_budget_tracker()
|
||||
return self._budget
|
||||
|
||||
# ── Public interface ─────────────────────────────────────────────────────
|
||||
|
||||
def classify(self, task: str, context: dict | None = None) -> TierLabel:
|
||||
"""Classify a task without routing. Useful for telemetry."""
|
||||
return classify_tier(task, context)
|
||||
|
||||
async def route(
|
||||
self,
|
||||
task: str,
|
||||
context: dict | None = None,
|
||||
messages: list[dict] | None = None,
|
||||
temperature: float = 0.3,
|
||||
max_tokens: int | None = None,
|
||||
) -> dict:
|
||||
"""Route a task to the appropriate model tier.
|
||||
|
||||
Builds a minimal messages list if ``messages`` is not provided.
|
||||
The result always includes a ``tier`` key indicating which tier
|
||||
ultimately handled the request.
|
||||
|
||||
Args:
|
||||
task: Natural-language task description.
|
||||
context: Task context dict (see ``classify_tier()``).
|
||||
messages: Pre-built OpenAI-compatible messages list. If
|
||||
provided, ``task`` is only used for classification.
|
||||
temperature: Sampling temperature (default 0.3).
|
||||
max_tokens: Maximum tokens to generate.
|
||||
|
||||
Returns:
|
||||
Dict with at minimum: ``content``, ``provider``, ``model``,
|
||||
``tier``, ``latency_ms``. May include ``cost_usd`` when a
|
||||
cloud request is recorded.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If all available tiers are exhausted.
|
||||
"""
|
||||
ctx = context or {}
|
||||
tier = self.classify(task, ctx)
|
||||
msgs = messages or [{"role": "user", "content": task}]
|
||||
|
||||
# ── Tier 1 attempt ───────────────────────────────────────────────────
|
||||
if tier == TierLabel.LOCAL_FAST:
|
||||
result = await self._complete_tier(
|
||||
TierLabel.LOCAL_FAST, msgs, temperature, max_tokens
|
||||
)
|
||||
if self._auto_escalate and _is_low_quality(result.get("content", ""), TierLabel.LOCAL_FAST):
|
||||
logger.info(
|
||||
"TieredModelRouter: Tier-1 response low quality, escalating to Tier-2 "
|
||||
"(task=%r content_len=%d)",
|
||||
task[:80],
|
||||
len(result.get("content", "")),
|
||||
)
|
||||
tier = TierLabel.LOCAL_HEAVY
|
||||
result = await self._complete_tier(
|
||||
TierLabel.LOCAL_HEAVY, msgs, temperature, max_tokens
|
||||
)
|
||||
return result
|
||||
|
||||
# ── Tier 2 attempt ───────────────────────────────────────────────────
|
||||
if tier == TierLabel.LOCAL_HEAVY:
|
||||
try:
|
||||
return await self._complete_tier(
|
||||
TierLabel.LOCAL_HEAVY, msgs, temperature, max_tokens
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"TieredModelRouter: Tier-2 failed (%s) — escalating to cloud", exc
|
||||
)
|
||||
tier = TierLabel.CLOUD_API
|
||||
|
||||
# ── Tier 3 (Cloud) ───────────────────────────────────────────────────
|
||||
budget = self._get_budget()
|
||||
if not budget.cloud_allowed():
|
||||
raise RuntimeError(
|
||||
"Cloud API tier requested but budget limit reached — "
|
||||
"increase tier_cloud_daily_budget_usd or tier_cloud_monthly_budget_usd"
|
||||
)
|
||||
|
||||
result = await self._complete_tier(
|
||||
TierLabel.CLOUD_API, msgs, temperature, max_tokens
|
||||
)
|
||||
|
||||
# Record cloud spend if token info is available
|
||||
usage = result.get("usage", {})
|
||||
if usage:
|
||||
cost = budget.record_spend(
|
||||
provider=result.get("provider", "unknown"),
|
||||
model=result.get("model", self._tier_models[TierLabel.CLOUD_API]),
|
||||
tokens_in=usage.get("prompt_tokens", 0),
|
||||
tokens_out=usage.get("completion_tokens", 0),
|
||||
tier=TierLabel.CLOUD_API,
|
||||
)
|
||||
result["cost_usd"] = cost
|
||||
|
||||
return result
|
||||
|
||||
# ── Internal helpers ─────────────────────────────────────────────────────
|
||||
|
||||
async def _complete_tier(
|
||||
self,
|
||||
tier: TierLabel,
|
||||
messages: list[dict],
|
||||
temperature: float,
|
||||
max_tokens: int | None,
|
||||
) -> dict:
|
||||
"""Dispatch a single inference request for the given tier."""
|
||||
model = self._tier_models[tier]
|
||||
cascade = self._get_cascade()
|
||||
start = time.monotonic()
|
||||
|
||||
logger.info(
|
||||
"TieredModelRouter: tier=%s model=%s messages=%d",
|
||||
tier,
|
||||
model,
|
||||
len(messages),
|
||||
)
|
||||
|
||||
result = await cascade.complete(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
result["tier"] = tier
|
||||
result.setdefault("latency_ms", elapsed_ms)
|
||||
|
||||
logger.info(
|
||||
"TieredModelRouter: done tier=%s model=%s latency_ms=%.0f",
|
||||
tier,
|
||||
result.get("model", model),
|
||||
elapsed_ms,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# ── Module-level singleton ────────────────────────────────────────────────────
|
||||
|
||||
_tiered_router: TieredModelRouter | None = None
|
||||
|
||||
|
||||
def get_tiered_router() -> TieredModelRouter:
|
||||
"""Get or create the module-level TieredModelRouter singleton."""
|
||||
global _tiered_router
|
||||
if _tiered_router is None:
|
||||
_tiered_router = TieredModelRouter()
|
||||
return _tiered_router
|
||||
@@ -0,0 +1 @@
|
||||
"""Vendor-specific chat platform adapters (e.g. Discord) for the chat bridge."""
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
"""Typer CLI entry point for the ``timmy`` command (chat, think, status)."""
|
||||
import asyncio
|
||||
import logging
|
||||
import subprocess
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
"""OpenCV template-matching cache for sovereignty perception (screen-state recognition)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
178
tests/infrastructure/test_budget_tracker.py
Normal file
178
tests/infrastructure/test_budget_tracker.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Tests for the cloud API budget tracker (issue #882)."""
|
||||
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.models.budget import (
|
||||
BudgetTracker,
|
||||
SpendRecord,
|
||||
estimate_cost_usd,
|
||||
get_budget_tracker,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ── estimate_cost_usd ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEstimateCostUsd:
|
||||
def test_haiku_cheaper_than_sonnet(self):
|
||||
haiku_cost = estimate_cost_usd("claude-haiku-4-5", 1000, 1000)
|
||||
sonnet_cost = estimate_cost_usd("claude-sonnet-4-5", 1000, 1000)
|
||||
assert haiku_cost < sonnet_cost
|
||||
|
||||
def test_zero_tokens_is_zero_cost(self):
|
||||
assert estimate_cost_usd("gpt-4o", 0, 0) == 0.0
|
||||
|
||||
def test_unknown_model_uses_default(self):
|
||||
cost = estimate_cost_usd("some-unknown-model-xyz", 1000, 1000)
|
||||
assert cost > 0 # Uses conservative default, not zero
|
||||
|
||||
def test_versioned_model_name_matches(self):
|
||||
# "claude-haiku-4-5-20251001" should match "haiku"
|
||||
cost1 = estimate_cost_usd("claude-haiku-4-5-20251001", 1000, 0)
|
||||
cost2 = estimate_cost_usd("claude-haiku-4-5", 1000, 0)
|
||||
assert cost1 == cost2
|
||||
|
||||
def test_gpt4o_mini_cheaper_than_gpt4o(self):
|
||||
mini = estimate_cost_usd("gpt-4o-mini", 1000, 1000)
|
||||
full = estimate_cost_usd("gpt-4o", 1000, 1000)
|
||||
assert mini < full
|
||||
|
||||
def test_returns_float(self):
|
||||
assert isinstance(estimate_cost_usd("haiku", 100, 200), float)
|
||||
|
||||
|
||||
# ── BudgetTracker ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBudgetTrackerInit:
|
||||
def test_creates_with_memory_db(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
assert tracker._db_ok is True
|
||||
|
||||
def test_in_memory_fallback_empty_on_creation(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
assert tracker._in_memory == []
|
||||
|
||||
def test_bad_path_uses_memory_fallback(self, tmp_path):
|
||||
bad_path = str(tmp_path / "nonexistent" / "x" / "budget.db")
|
||||
# Should not raise — just log and continue with memory fallback
|
||||
# (actually will create parent dirs, so test with truly bad path)
|
||||
tracker = BudgetTracker.__new__(BudgetTracker)
|
||||
tracker._db_path = bad_path
|
||||
tracker._lock = __import__("threading").Lock()
|
||||
tracker._in_memory = []
|
||||
tracker._db_ok = False
|
||||
# Record to in-memory fallback
|
||||
tracker._in_memory.append(
|
||||
SpendRecord(time.time(), "test", "model", 100, 100, 0.001, "cloud")
|
||||
)
|
||||
assert len(tracker._in_memory) == 1
|
||||
|
||||
|
||||
class TestBudgetTrackerRecordSpend:
|
||||
def test_record_spend_returns_cost(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
cost = tracker.record_spend("anthropic", "claude-haiku-4-5", 100, 200)
|
||||
assert cost > 0
|
||||
|
||||
def test_record_spend_explicit_cost(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
cost = tracker.record_spend("anthropic", "model", cost_usd=1.23)
|
||||
assert cost == pytest.approx(1.23)
|
||||
|
||||
def test_record_spend_accumulates(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("openai", "gpt-4o", cost_usd=0.01)
|
||||
tracker.record_spend("openai", "gpt-4o", cost_usd=0.02)
|
||||
assert tracker.get_daily_spend() == pytest.approx(0.03, abs=1e-9)
|
||||
|
||||
def test_record_spend_with_tier_label(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
cost = tracker.record_spend("anthropic", "haiku", tier="cloud_api")
|
||||
assert cost >= 0
|
||||
|
||||
def test_monthly_spend_includes_daily(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("anthropic", "haiku", cost_usd=5.00)
|
||||
assert tracker.get_monthly_spend() >= tracker.get_daily_spend()
|
||||
|
||||
|
||||
class TestBudgetTrackerCloudAllowed:
|
||||
def test_allowed_when_no_spend(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
with (
|
||||
patch.object(type(tracker._get_budget() if hasattr(tracker, "_get_budget") else tracker), "tier_cloud_daily_budget_usd", 5.0, create=True),
|
||||
):
|
||||
# Settings-based check — use real settings (5.0 default, 0 spent)
|
||||
assert tracker.cloud_allowed() is True
|
||||
|
||||
def test_blocked_when_daily_limit_exceeded(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("anthropic", "haiku", cost_usd=999.0)
|
||||
# With default daily limit of 5.0, 999 should block
|
||||
assert tracker.cloud_allowed() is False
|
||||
|
||||
def test_allowed_when_daily_limit_zero(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("anthropic", "haiku", cost_usd=999.0)
|
||||
with (
|
||||
patch("infrastructure.models.budget.settings") as mock_settings,
|
||||
):
|
||||
mock_settings.tier_cloud_daily_budget_usd = 0 # disabled
|
||||
mock_settings.tier_cloud_monthly_budget_usd = 0 # disabled
|
||||
assert tracker.cloud_allowed() is True
|
||||
|
||||
def test_blocked_when_monthly_limit_exceeded(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("anthropic", "haiku", cost_usd=999.0)
|
||||
with patch("infrastructure.models.budget.settings") as mock_settings:
|
||||
mock_settings.tier_cloud_daily_budget_usd = 0 # daily disabled
|
||||
mock_settings.tier_cloud_monthly_budget_usd = 10.0
|
||||
assert tracker.cloud_allowed() is False
|
||||
|
||||
|
||||
class TestBudgetTrackerSummary:
|
||||
def test_summary_keys_present(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
summary = tracker.get_summary()
|
||||
assert "daily_usd" in summary
|
||||
assert "monthly_usd" in summary
|
||||
assert "daily_limit_usd" in summary
|
||||
assert "monthly_limit_usd" in summary
|
||||
assert "daily_ok" in summary
|
||||
assert "monthly_ok" in summary
|
||||
|
||||
def test_summary_daily_ok_true_on_empty(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
summary = tracker.get_summary()
|
||||
assert summary["daily_ok"] is True
|
||||
assert summary["monthly_ok"] is True
|
||||
|
||||
def test_summary_daily_ok_false_when_exceeded(self):
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("openai", "gpt-4o", cost_usd=999.0)
|
||||
summary = tracker.get_summary()
|
||||
assert summary["daily_ok"] is False
|
||||
|
||||
|
||||
# ── Singleton ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetBudgetTrackerSingleton:
|
||||
def test_returns_budget_tracker(self):
|
||||
import infrastructure.models.budget as bmod
|
||||
bmod._budget_tracker = None
|
||||
tracker = get_budget_tracker()
|
||||
assert isinstance(tracker, BudgetTracker)
|
||||
|
||||
def test_returns_same_instance(self):
|
||||
import infrastructure.models.budget as bmod
|
||||
bmod._budget_tracker = None
|
||||
t1 = get_budget_tracker()
|
||||
t2 = get_budget_tracker()
|
||||
assert t1 is t2
|
||||
380
tests/infrastructure/test_tiered_model_router.py
Normal file
380
tests/infrastructure/test_tiered_model_router.py
Normal file
@@ -0,0 +1,380 @@
|
||||
"""Tests for the tiered model router (issue #882).
|
||||
|
||||
Covers:
|
||||
- classify_tier() for Tier-1/2/3 routing
|
||||
- TieredModelRouter.route() with mocked CascadeRouter + BudgetTracker
|
||||
- Auto-escalation from Tier-1 on low-quality responses
|
||||
- Cloud-tier budget guard
|
||||
- Acceptance criteria from the issue:
|
||||
- "Walk to the next room" → LOCAL_FAST
|
||||
- "Plan the optimal path to become Hortator" → LOCAL_HEAVY
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.models.router import (
|
||||
TierLabel,
|
||||
TieredModelRouter,
|
||||
_is_low_quality,
|
||||
classify_tier,
|
||||
get_tiered_router,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ── classify_tier ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestClassifyTier:
|
||||
# ── Tier-1 (LOCAL_FAST) ────────────────────────────────────────────────
|
||||
|
||||
def test_simple_navigation_is_local_fast(self):
|
||||
assert classify_tier("walk to the next room") == TierLabel.LOCAL_FAST
|
||||
|
||||
def test_go_north_is_local_fast(self):
|
||||
assert classify_tier("go north") == TierLabel.LOCAL_FAST
|
||||
|
||||
def test_single_binary_choice_is_local_fast(self):
|
||||
assert classify_tier("yes") == TierLabel.LOCAL_FAST
|
||||
|
||||
def test_open_door_is_local_fast(self):
|
||||
assert classify_tier("open door") == TierLabel.LOCAL_FAST
|
||||
|
||||
def test_attack_is_local_fast(self):
|
||||
assert classify_tier("attack", {}) == TierLabel.LOCAL_FAST
|
||||
|
||||
# ── Tier-2 (LOCAL_HEAVY) ───────────────────────────────────────────────
|
||||
|
||||
def test_quest_planning_is_local_heavy(self):
|
||||
assert classify_tier("plan the optimal path to become Hortator") == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_strategy_keyword_is_local_heavy(self):
|
||||
assert classify_tier("what is the best strategy") == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_stuck_state_escalates_to_local_heavy(self):
|
||||
assert classify_tier("help me", {"stuck": True}) == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_require_t2_flag_is_local_heavy(self):
|
||||
assert classify_tier("go north", {"require_t2": True}) == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_long_input_is_local_heavy(self):
|
||||
long_task = "tell me about " + ("the dungeon " * 30)
|
||||
assert classify_tier(long_task) == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_active_quests_upgrades_to_local_heavy(self):
|
||||
ctx = {"active_quests": ["Q1", "Q2", "Q3"]}
|
||||
assert classify_tier("go north", ctx) == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_dialogue_active_upgrades_to_local_heavy(self):
|
||||
ctx = {"dialogue_active": True}
|
||||
assert classify_tier("yes", ctx) == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_analyze_is_local_heavy(self):
|
||||
assert classify_tier("analyze the situation") == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_optimize_is_local_heavy(self):
|
||||
assert classify_tier("optimize my build") == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_negotiate_is_local_heavy(self):
|
||||
assert classify_tier("negotiate with the Camonna Tong") == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_explain_is_local_heavy(self):
|
||||
assert classify_tier("explain the faction system") == TierLabel.LOCAL_HEAVY
|
||||
|
||||
# ── Tier-3 (CLOUD_API) ─────────────────────────────────────────────────
|
||||
|
||||
def test_require_cloud_flag_is_cloud_api(self):
|
||||
assert classify_tier("go north", {"require_cloud": True}) == TierLabel.CLOUD_API
|
||||
|
||||
def test_require_cloud_overrides_everything(self):
|
||||
assert classify_tier("yes", {"require_cloud": True}) == TierLabel.CLOUD_API
|
||||
|
||||
# ── Edge cases ────────────────────────────────────────────────────────
|
||||
|
||||
def test_empty_task_defaults_to_local_heavy(self):
|
||||
# Empty string → nothing classifies it as T1 or T3
|
||||
assert classify_tier("") == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert classify_tier("PLAN my route") == TierLabel.LOCAL_HEAVY
|
||||
|
||||
def test_combat_active_upgrades_t1_to_heavy(self):
|
||||
ctx = {"combat_active": True}
|
||||
# "attack" is T1 word, but combat context → should NOT be LOCAL_FAST
|
||||
result = classify_tier("attack", ctx)
|
||||
assert result != TierLabel.LOCAL_FAST
|
||||
|
||||
|
||||
# ── _is_low_quality ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestIsLowQuality:
|
||||
def test_empty_is_low_quality(self):
|
||||
assert _is_low_quality("", TierLabel.LOCAL_FAST) is True
|
||||
|
||||
def test_whitespace_only_is_low_quality(self):
|
||||
assert _is_low_quality(" ", TierLabel.LOCAL_FAST) is True
|
||||
|
||||
def test_very_short_is_low_quality(self):
|
||||
assert _is_low_quality("ok", TierLabel.LOCAL_FAST) is True
|
||||
|
||||
def test_idontknow_is_low_quality(self):
|
||||
assert _is_low_quality("I don't know how to help with that.", TierLabel.LOCAL_FAST) is True
|
||||
|
||||
def test_not_sure_is_low_quality(self):
|
||||
assert _is_low_quality("I'm not sure about this.", TierLabel.LOCAL_FAST) is True
|
||||
|
||||
def test_as_an_ai_is_low_quality(self):
|
||||
assert _is_low_quality("As an AI, I cannot...", TierLabel.LOCAL_FAST) is True
|
||||
|
||||
def test_good_response_is_not_low_quality(self):
|
||||
response = "You move north into the Vivec Canton. The Ordinators watch your approach."
|
||||
assert _is_low_quality(response, TierLabel.LOCAL_FAST) is False
|
||||
|
||||
def test_t1_short_response_triggers_escalation(self):
|
||||
# Less than _ESCALATION_MIN_CHARS for T1
|
||||
assert _is_low_quality("OK, done.", TierLabel.LOCAL_FAST) is True
|
||||
|
||||
def test_borderline_ok_for_t2_not_t1(self):
|
||||
# Between _LOW_QUALITY_MIN_CHARS (20) and _ESCALATION_MIN_CHARS (60)
|
||||
# → low quality for T1 (escalation threshold), but acceptable for T2/T3
|
||||
response = "Done. The item is retrieved." # 28 chars: ≥20, <60
|
||||
assert _is_low_quality(response, TierLabel.LOCAL_FAST) is True
|
||||
assert _is_low_quality(response, TierLabel.LOCAL_HEAVY) is False
|
||||
|
||||
|
||||
# ── TieredModelRouter ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
_GOOD_CONTENT = (
|
||||
"You move north through the doorway into the next room. "
|
||||
"The stone walls glisten with moisture."
|
||||
) # 90 chars — well above the escalation threshold
|
||||
|
||||
|
||||
def _make_cascade_mock(content=_GOOD_CONTENT, model="llama3.1:8b"):
|
||||
mock = MagicMock()
|
||||
mock.complete = AsyncMock(
|
||||
return_value={
|
||||
"content": content,
|
||||
"provider": "ollama-local",
|
||||
"model": model,
|
||||
"latency_ms": 150.0,
|
||||
}
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
def _make_budget_mock(allowed=True):
|
||||
mock = MagicMock()
|
||||
mock.cloud_allowed = MagicMock(return_value=allowed)
|
||||
mock.record_spend = MagicMock(return_value=0.001)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestTieredModelRouterRoute:
|
||||
async def test_route_returns_tier_in_result(self):
|
||||
router = TieredModelRouter(cascade=_make_cascade_mock())
|
||||
result = await router.route("go north")
|
||||
assert "tier" in result
|
||||
assert result["tier"] == TierLabel.LOCAL_FAST
|
||||
|
||||
async def test_acceptance_walk_to_room_is_local_fast(self):
|
||||
"""Acceptance: 'Walk to the next room' → LOCAL_FAST."""
|
||||
router = TieredModelRouter(cascade=_make_cascade_mock())
|
||||
result = await router.route("Walk to the next room")
|
||||
assert result["tier"] == TierLabel.LOCAL_FAST
|
||||
|
||||
async def test_acceptance_plan_hortator_is_local_heavy(self):
|
||||
"""Acceptance: 'Plan the optimal path to become Hortator' → LOCAL_HEAVY."""
|
||||
router = TieredModelRouter(
|
||||
cascade=_make_cascade_mock(model="hermes3:70b"),
|
||||
)
|
||||
result = await router.route("Plan the optimal path to become Hortator")
|
||||
assert result["tier"] == TierLabel.LOCAL_HEAVY
|
||||
|
||||
async def test_t1_low_quality_escalates_to_t2(self):
|
||||
"""Failed Tier-1 response auto-escalates to Tier-2."""
|
||||
call_models = []
|
||||
cascade = MagicMock()
|
||||
|
||||
async def complete_side_effect(messages, model, temperature, max_tokens):
|
||||
call_models.append(model)
|
||||
# First call (T1) returns a low-quality response
|
||||
if len(call_models) == 1:
|
||||
return {
|
||||
"content": "I don't know.",
|
||||
"provider": "ollama",
|
||||
"model": model,
|
||||
"latency_ms": 50,
|
||||
}
|
||||
# Second call (T2) returns a good response
|
||||
return {
|
||||
"content": "You move to the northern passage, passing through the Dunmer stronghold.",
|
||||
"provider": "ollama",
|
||||
"model": model,
|
||||
"latency_ms": 800,
|
||||
}
|
||||
|
||||
cascade.complete = complete_side_effect
|
||||
|
||||
router = TieredModelRouter(cascade=cascade, auto_escalate=True)
|
||||
result = await router.route("go north")
|
||||
|
||||
assert len(call_models) == 2, "Should have called twice (T1 escalated to T2)"
|
||||
assert result["tier"] == TierLabel.LOCAL_HEAVY
|
||||
|
||||
async def test_auto_escalate_false_no_escalation(self):
|
||||
"""With auto_escalate=False, low-quality T1 response is returned as-is."""
|
||||
call_count = {"n": 0}
|
||||
cascade = MagicMock()
|
||||
|
||||
async def complete_side_effect(**kwargs):
|
||||
call_count["n"] += 1
|
||||
return {
|
||||
"content": "I don't know.",
|
||||
"provider": "ollama",
|
||||
"model": "llama3.1:8b",
|
||||
"latency_ms": 50,
|
||||
}
|
||||
|
||||
cascade.complete = AsyncMock(side_effect=complete_side_effect)
|
||||
router = TieredModelRouter(cascade=cascade, auto_escalate=False)
|
||||
result = await router.route("go north")
|
||||
assert call_count["n"] == 1
|
||||
assert result["tier"] == TierLabel.LOCAL_FAST
|
||||
|
||||
async def test_t2_failure_escalates_to_cloud(self):
|
||||
"""Tier-2 failure escalates to Cloud API (when budget allows)."""
|
||||
cascade = MagicMock()
|
||||
call_models = []
|
||||
|
||||
async def complete_side_effect(messages, model, temperature, max_tokens):
|
||||
call_models.append(model)
|
||||
if "hermes3" in model or "70b" in model.lower():
|
||||
raise RuntimeError("Tier-2 model unavailable")
|
||||
return {
|
||||
"content": "Cloud response here.",
|
||||
"provider": "anthropic",
|
||||
"model": model,
|
||||
"latency_ms": 1200,
|
||||
}
|
||||
|
||||
cascade.complete = complete_side_effect
|
||||
|
||||
budget = _make_budget_mock(allowed=True)
|
||||
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
|
||||
result = await router.route("plan my route", context={"require_t2": True})
|
||||
assert result["tier"] == TierLabel.CLOUD_API
|
||||
|
||||
async def test_cloud_blocked_by_budget_raises(self):
|
||||
"""Cloud tier blocked when budget is exhausted."""
|
||||
cascade = MagicMock()
|
||||
cascade.complete = AsyncMock(side_effect=RuntimeError("T2 fail"))
|
||||
|
||||
budget = _make_budget_mock(allowed=False)
|
||||
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
|
||||
|
||||
with pytest.raises(RuntimeError, match="budget limit"):
|
||||
await router.route("plan my route", context={"require_t2": True})
|
||||
|
||||
async def test_explicit_cloud_tier_uses_cloud_model(self):
|
||||
cascade = _make_cascade_mock(model="claude-haiku-4-5")
|
||||
budget = _make_budget_mock(allowed=True)
|
||||
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
|
||||
result = await router.route("go north", context={"require_cloud": True})
|
||||
assert result["tier"] == TierLabel.CLOUD_API
|
||||
|
||||
async def test_cloud_spend_recorded_with_usage(self):
|
||||
"""Cloud spend is recorded when the response includes usage info."""
|
||||
cascade = MagicMock()
|
||||
cascade.complete = AsyncMock(
|
||||
return_value={
|
||||
"content": "Cloud answer.",
|
||||
"provider": "anthropic",
|
||||
"model": "claude-haiku-4-5",
|
||||
"latency_ms": 900,
|
||||
"usage": {"prompt_tokens": 50, "completion_tokens": 100},
|
||||
}
|
||||
)
|
||||
budget = _make_budget_mock(allowed=True)
|
||||
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
|
||||
result = await router.route("go north", context={"require_cloud": True})
|
||||
budget.record_spend.assert_called_once()
|
||||
assert "cost_usd" in result
|
||||
|
||||
async def test_cloud_spend_not_recorded_without_usage(self):
|
||||
"""Cloud spend is not recorded when usage info is absent."""
|
||||
cascade = MagicMock()
|
||||
cascade.complete = AsyncMock(
|
||||
return_value={
|
||||
"content": "Cloud answer.",
|
||||
"provider": "anthropic",
|
||||
"model": "claude-haiku-4-5",
|
||||
"latency_ms": 900,
|
||||
# no "usage" key
|
||||
}
|
||||
)
|
||||
budget = _make_budget_mock(allowed=True)
|
||||
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
|
||||
result = await router.route("go north", context={"require_cloud": True})
|
||||
budget.record_spend.assert_not_called()
|
||||
assert "cost_usd" not in result
|
||||
|
||||
async def test_custom_tier_models_respected(self):
|
||||
cascade = _make_cascade_mock()
|
||||
router = TieredModelRouter(
|
||||
cascade=cascade,
|
||||
tier_models={TierLabel.LOCAL_FAST: "llama3.2:3b"},
|
||||
)
|
||||
await router.route("go north")
|
||||
call_kwargs = cascade.complete.call_args
|
||||
assert call_kwargs.kwargs["model"] == "llama3.2:3b"
|
||||
|
||||
async def test_messages_override_used_when_provided(self):
|
||||
cascade = _make_cascade_mock()
|
||||
router = TieredModelRouter(cascade=cascade)
|
||||
custom_msgs = [{"role": "user", "content": "custom message"}]
|
||||
await router.route("go north", messages=custom_msgs)
|
||||
call_kwargs = cascade.complete.call_args
|
||||
assert call_kwargs.kwargs["messages"] == custom_msgs
|
||||
|
||||
async def test_temperature_forwarded(self):
|
||||
cascade = _make_cascade_mock()
|
||||
router = TieredModelRouter(cascade=cascade)
|
||||
await router.route("go north", temperature=0.7)
|
||||
call_kwargs = cascade.complete.call_args
|
||||
assert call_kwargs.kwargs["temperature"] == 0.7
|
||||
|
||||
async def test_max_tokens_forwarded(self):
|
||||
cascade = _make_cascade_mock()
|
||||
router = TieredModelRouter(cascade=cascade)
|
||||
await router.route("go north", max_tokens=128)
|
||||
call_kwargs = cascade.complete.call_args
|
||||
assert call_kwargs.kwargs["max_tokens"] == 128
|
||||
|
||||
|
||||
class TestTieredModelRouterClassify:
|
||||
def test_classify_delegates_to_classify_tier(self):
|
||||
router = TieredModelRouter(cascade=MagicMock())
|
||||
assert router.classify("go north") == classify_tier("go north")
|
||||
assert router.classify("plan the quest") == classify_tier("plan the quest")
|
||||
|
||||
|
||||
class TestGetTieredRouterSingleton:
|
||||
def test_returns_tiered_router_instance(self):
|
||||
import infrastructure.models.router as rmod
|
||||
rmod._tiered_router = None
|
||||
router = get_tiered_router()
|
||||
assert isinstance(router, TieredModelRouter)
|
||||
|
||||
def test_singleton_returns_same_instance(self):
|
||||
import infrastructure.models.router as rmod
|
||||
rmod._tiered_router = None
|
||||
r1 = get_tiered_router()
|
||||
r2 = get_tiered_router()
|
||||
assert r1 is r2
|
||||
0
tests/sovereignty/__init__.py
Normal file
0
tests/sovereignty/__init__.py
Normal file
379
tests/sovereignty/test_perception_cache.py
Normal file
379
tests/sovereignty/test_perception_cache.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""Tests for the sovereignty perception cache (template matching).
|
||||
|
||||
Refs: #1261
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestTemplate:
|
||||
"""Tests for the Template dataclass."""
|
||||
|
||||
def test_template_default_values(self):
|
||||
"""Template dataclass has correct defaults."""
|
||||
from timmy.sovereignty.perception_cache import Template
|
||||
|
||||
image = np.array([[1, 2], [3, 4]])
|
||||
template = Template(name="test_template", image=image)
|
||||
|
||||
assert template.name == "test_template"
|
||||
assert np.array_equal(template.image, image)
|
||||
assert template.threshold == 0.85
|
||||
|
||||
def test_template_custom_threshold(self):
|
||||
"""Template can have custom threshold."""
|
||||
from timmy.sovereignty.perception_cache import Template
|
||||
|
||||
image = np.array([[1, 2], [3, 4]])
|
||||
template = Template(name="test_template", image=image, threshold=0.95)
|
||||
|
||||
assert template.threshold == 0.95
|
||||
|
||||
|
||||
class TestCacheResult:
|
||||
"""Tests for the CacheResult dataclass."""
|
||||
|
||||
def test_cache_result_with_state(self):
|
||||
"""CacheResult stores confidence and state."""
|
||||
from timmy.sovereignty.perception_cache import CacheResult
|
||||
|
||||
result = CacheResult(confidence=0.92, state={"template_name": "test"})
|
||||
assert result.confidence == 0.92
|
||||
assert result.state == {"template_name": "test"}
|
||||
|
||||
def test_cache_result_no_state(self):
|
||||
"""CacheResult can have None state."""
|
||||
from timmy.sovereignty.perception_cache import CacheResult
|
||||
|
||||
result = CacheResult(confidence=0.5, state=None)
|
||||
assert result.confidence == 0.5
|
||||
assert result.state is None
|
||||
|
||||
|
||||
class TestPerceptionCacheInit:
|
||||
"""Tests for PerceptionCache initialization."""
|
||||
|
||||
def test_init_creates_empty_cache_when_no_file(self, tmp_path):
|
||||
"""Cache initializes empty when templates file doesn't exist."""
|
||||
from timmy.sovereignty.perception_cache import PerceptionCache
|
||||
|
||||
templates_path = tmp_path / "nonexistent_templates.json"
|
||||
cache = PerceptionCache(templates_path=templates_path)
|
||||
|
||||
assert cache.templates_path == templates_path
|
||||
assert cache.templates == []
|
||||
|
||||
def test_init_loads_existing_templates(self, tmp_path):
|
||||
"""Cache loads templates from existing JSON file."""
|
||||
from timmy.sovereignty.perception_cache import PerceptionCache
|
||||
|
||||
templates_path = tmp_path / "templates.json"
|
||||
templates_data = [
|
||||
{"name": "template1", "threshold": 0.85},
|
||||
{"name": "template2", "threshold": 0.90},
|
||||
]
|
||||
with open(templates_path, "w") as f:
|
||||
json.dump(templates_data, f)
|
||||
|
||||
cache = PerceptionCache(templates_path=templates_path)
|
||||
|
||||
assert len(cache.templates) == 2
|
||||
assert cache.templates[0].name == "template1"
|
||||
assert cache.templates[0].threshold == 0.85
|
||||
assert cache.templates[1].name == "template2"
|
||||
assert cache.templates[1].threshold == 0.90
|
||||
|
||||
def test_init_with_string_path(self, tmp_path):
|
||||
"""Cache accepts string path for templates."""
|
||||
from timmy.sovereignty.perception_cache import PerceptionCache
|
||||
|
||||
templates_path = str(tmp_path / "templates.json")
|
||||
cache = PerceptionCache(templates_path=templates_path)
|
||||
|
||||
assert str(cache.templates_path) == templates_path
|
||||
|
||||
|
||||
class TestPerceptionCacheMatch:
|
||||
"""Tests for PerceptionCache.match() template matching."""
|
||||
|
||||
def test_match_no_templates_returns_low_confidence(self, tmp_path):
|
||||
"""Matching with no templates returns low confidence and None state."""
|
||||
from timmy.sovereignty.perception_cache import PerceptionCache
|
||||
|
||||
cache = PerceptionCache(templates_path=tmp_path / "templates.json")
|
||||
screenshot = np.array([[1, 2], [3, 4]])
|
||||
|
||||
result = cache.match(screenshot)
|
||||
|
||||
assert result.confidence == 0.0
|
||||
assert result.state is None
|
||||
|
||||
@patch("timmy.sovereignty.perception_cache.cv2")
|
||||
def test_match_finds_best_template(self, mock_cv2, tmp_path):
|
||||
"""Match returns the best matching template above threshold."""
|
||||
from timmy.sovereignty.perception_cache import PerceptionCache, Template
|
||||
|
||||
# Setup mock cv2 behavior
|
||||
mock_cv2.matchTemplate.return_value = np.array([[0.5, 0.6], [0.7, 0.8]])
|
||||
mock_cv2.TM_CCOEFF_NORMED = "TM_CCOEFF_NORMED"
|
||||
mock_cv2.minMaxLoc.return_value = (None, 0.92, None, None)
|
||||
|
||||
cache = PerceptionCache(templates_path=tmp_path / "templates.json")
|
||||
template = Template(name="best_match", image=np.array([[1, 2], [3, 4]]))
|
||||
cache.add([template])
|
||||
|
||||
screenshot = np.array([[5, 6], [7, 8]])
|
||||
result = cache.match(screenshot)
|
||||
|
||||
assert result.confidence == 0.92
|
||||
assert result.state == {"template_name": "best_match"}
|
||||
|
||||
@patch("timmy.sovereignty.perception_cache.cv2")
|
||||
def test_match_respects_global_threshold(self, mock_cv2, tmp_path):
|
||||
"""Match returns None state when confidence is below threshold."""
|
||||
from timmy.sovereignty.perception_cache import PerceptionCache, Template
|
||||
|
||||
# Setup mock cv2 to return confidence below 0.85 threshold
|
||||
mock_cv2.matchTemplate.return_value = np.array([[0.1, 0.2], [0.3, 0.4]])
|
||||
mock_cv2.TM_CCOEFF_NORMED = "TM_CCOEFF_NORMED"
|
||||
mock_cv2.minMaxLoc.return_value = (None, 0.75, None, None)
|
||||
|
||||
cache = PerceptionCache(templates_path=tmp_path / "templates.json")
|
||||
template = Template(name="low_match", image=np.array([[1, 2], [3, 4]]))
|
||||
cache.add([template])
|
||||
|
||||
screenshot = np.array([[5, 6], [7, 8]])
|
||||
result = cache.match(screenshot)
|
||||
|
||||
# Confidence is recorded but state is None (below threshold)
|
||||
assert result.confidence == 0.75
|
||||
assert result.state is None
|
||||
|
||||
@patch("timmy.sovereignty.perception_cache.cv2")
|
||||
def test_match_selects_highest_confidence(self, mock_cv2, tmp_path):
|
||||
"""Match selects template with highest confidence across all templates."""
|
||||
from timmy.sovereignty.perception_cache import PerceptionCache, Template
|
||||
|
||||
mock_cv2.TM_CCOEFF_NORMED = "TM_CCOEFF_NORMED"
|
||||
|
||||
# Each template will return a different confidence
|
||||
mock_cv2.minMaxLoc.side_effect = [
|
||||
(None, 0.70, None, None), # template1
|
||||
(None, 0.95, None, None), # template2 (best)
|
||||
(None, 0.80, None, None), # template3
|
||||
]
|
||||
|
||||
cache = PerceptionCache(templates_path=tmp_path / "templates.json")
|
||||
templates = [
|
||||
Template(name="template1", image=np.array([[1, 2], [3, 4]])),
|
||||
Template(name="template2", image=np.array([[5, 6], [7, 8]])),
|
||||
Template(name="template3", image=np.array([[9, 10], [11, 12]])),
|
||||
]
|
||||
cache.add(templates)
|
||||
|
||||
screenshot = np.array([[13, 14], [15, 16]])
|
||||
result = cache.match(screenshot)
|
||||
|
||||
assert result.confidence == 0.95
|
||||
assert result.state == {"template_name": "template2"}
|
||||
|
||||
@patch("timmy.sovereignty.perception_cache.cv2")
|
||||
def test_match_exactly_at_threshold(self, mock_cv2, tmp_path):
|
||||
"""Match returns state when confidence is exactly at threshold boundary."""
|
||||
from timmy.sovereignty.perception_cache import PerceptionCache, Template
|
||||
|
||||
mock_cv2.matchTemplate.return_value = np.array([[0.1]])
|
||||
mock_cv2.TM_CCOEFF_NORMED = "TM_CCOEFF_NORMED"
|
||||
mock_cv2.minMaxLoc.return_value = (None, 0.85, None, None) # Exactly at threshold
|
||||
|
||||
cache = PerceptionCache(templates_path=tmp_path / "templates.json")
|
||||
template = Template(name="threshold_match", image=np.array([[1, 2], [3, 4]]))
|
||||
cache.add([template])
|
||||
|
||||
screenshot = np.array([[5, 6], [7, 8]])
|
||||
result = cache.match(screenshot)
|
||||
|
||||
# Note: current implementation uses > 0.85, so exactly 0.85 returns None state
|
||||
assert result.confidence == 0.85
|
||||
assert result.state is None
|
||||
|
||||
@patch("timmy.sovereignty.perception_cache.cv2")
|
||||
def test_match_just_above_threshold(self, mock_cv2, tmp_path):
|
||||
"""Match returns state when confidence is just above threshold."""
|
||||
from timmy.sovereignty.perception_cache import PerceptionCache, Template
|
||||
|
||||
mock_cv2.matchTemplate.return_value = np.array([[0.1]])
|
||||
mock_cv2.TM_CCOEFF_NORMED = "TM_CCOEFF_NORMED"
|
||||
mock_cv2.minMaxLoc.return_value = (None, 0.851, None, None) # Just above threshold
|
||||
|
||||
cache = PerceptionCache(templates_path=tmp_path / "templates.json")
|
||||
template = Template(name="above_threshold", image=np.array([[1, 2], [3, 4]]))
|
||||
cache.add([template])
|
||||
|
||||
screenshot = np.array([[5, 6], [7, 8]])
|
||||
result = cache.match(screenshot)
|
||||
|
||||
assert result.confidence == 0.851
|
||||
assert result.state == {"template_name": "above_threshold"}
|
||||
|
||||
|
||||
class TestPerceptionCacheAdd:
|
||||
"""Tests for PerceptionCache.add() method."""
|
||||
|
||||
def test_add_single_template(self, tmp_path):
|
||||
"""Can add a single template to the cache."""
|
||||
from timmy.sovereignty.perception_cache import PerceptionCache, Template
|
||||
|
||||
cache = PerceptionCache(templates_path=tmp_path / "templates.json")
|
||||
template = Template(name="new_template", image=np.array([[1, 2], [3, 4]]))
|
||||
|
||||
cache.add([template])
|
||||
|
||||
assert len(cache.templates) == 1
|
||||
assert cache.templates[0].name == "new_template"
|
||||
|
||||
def test_add_multiple_templates(self, tmp_path):
|
||||
"""Can add multiple templates at once."""
|
||||
from timmy.sovereignty.perception_cache import PerceptionCache, Template
|
||||
|
||||
cache = PerceptionCache(templates_path=tmp_path / "templates.json")
|
||||
templates = [
|
||||
Template(name="template1", image=np.array([[1, 2], [3, 4]])),
|
||||
Template(name="template2", image=np.array([[5, 6], [7, 8]])),
|
||||
]
|
||||
|
||||
cache.add(templates)
|
||||
|
||||
assert len(cache.templates) == 2
|
||||
assert cache.templates[0].name == "template1"
|
||||
assert cache.templates[1].name == "template2"
|
||||
|
||||
def test_add_templates_accumulate(self, tmp_path):
|
||||
"""Adding templates multiple times accumulates them."""
|
||||
from timmy.sovereignty.perception_cache import PerceptionCache, Template
|
||||
|
||||
cache = PerceptionCache(templates_path=tmp_path / "templates.json")
|
||||
cache.add([Template(name="first", image=np.array([[1]]))])
|
||||
cache.add([Template(name="second", image=np.array([[2]]))])
|
||||
|
||||
assert len(cache.templates) == 2
|
||||
|
||||
|
||||
class TestPerceptionCachePersist:
|
||||
"""Tests for PerceptionCache.persist() method."""
|
||||
|
||||
def test_persist_creates_file(self, tmp_path):
|
||||
"""Persist creates templates JSON file."""
|
||||
from timmy.sovereignty.perception_cache import PerceptionCache, Template
|
||||
|
||||
templates_path = tmp_path / "subdir" / "templates.json"
|
||||
cache = PerceptionCache(templates_path=templates_path)
|
||||
cache.add([Template(name="persisted", image=np.array([[1, 2], [3, 4]]))])
|
||||
|
||||
cache.persist()
|
||||
|
||||
assert templates_path.exists()
|
||||
|
||||
def test_persist_stores_template_names(self, tmp_path):
|
||||
"""Persist stores template names and thresholds."""
|
||||
from timmy.sovereignty.perception_cache import PerceptionCache, Template
|
||||
|
||||
templates_path = tmp_path / "templates.json"
|
||||
cache = PerceptionCache(templates_path=templates_path)
|
||||
cache.add([
|
||||
Template(name="template1", image=np.array([[1]]), threshold=0.85),
|
||||
Template(name="template2", image=np.array([[2]]), threshold=0.90),
|
||||
])
|
||||
|
||||
cache.persist()
|
||||
|
||||
with open(templates_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
assert len(data) == 2
|
||||
assert data[0]["name"] == "template1"
|
||||
assert data[0]["threshold"] == 0.85
|
||||
assert data[1]["name"] == "template2"
|
||||
assert data[1]["threshold"] == 0.90
|
||||
|
||||
def test_persist_does_not_store_image_data(self, tmp_path):
|
||||
"""Persist only stores metadata, not actual image arrays."""
|
||||
from timmy.sovereignty.perception_cache import PerceptionCache, Template
|
||||
|
||||
templates_path = tmp_path / "templates.json"
|
||||
cache = PerceptionCache(templates_path=templates_path)
|
||||
cache.add([Template(name="no_image", image=np.array([[1, 2, 3], [4, 5, 6]]))])
|
||||
|
||||
cache.persist()
|
||||
|
||||
with open(templates_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
assert "image" not in data[0]
|
||||
assert set(data[0].keys()) == {"name", "threshold"}
|
||||
|
||||
|
||||
class TestPerceptionCacheLoad:
|
||||
"""Tests for PerceptionCache.load() method."""
|
||||
|
||||
def test_load_from_existing_file(self, tmp_path):
|
||||
"""Load restores templates from persisted file."""
|
||||
from timmy.sovereignty.perception_cache import PerceptionCache
|
||||
|
||||
templates_path = tmp_path / "templates.json"
|
||||
|
||||
# Create initial cache with templates and persist
|
||||
cache1 = PerceptionCache(templates_path=templates_path)
|
||||
from timmy.sovereignty.perception_cache import Template
|
||||
|
||||
cache1.add([Template(name="loaded", image=np.array([[1]]), threshold=0.88)])
|
||||
cache1.persist()
|
||||
|
||||
# Create new cache instance that loads from same file
|
||||
cache2 = PerceptionCache(templates_path=templates_path)
|
||||
|
||||
assert len(cache2.templates) == 1
|
||||
assert cache2.templates[0].name == "loaded"
|
||||
assert cache2.templates[0].threshold == 0.88
|
||||
# Note: images are loaded as empty arrays per current implementation
|
||||
assert cache2.templates[0].image.size == 0
|
||||
|
||||
def test_load_empty_file(self, tmp_path):
|
||||
"""Load handles empty template list in file."""
|
||||
from timmy.sovereignty.perception_cache import PerceptionCache
|
||||
|
||||
templates_path = tmp_path / "templates.json"
|
||||
with open(templates_path, "w") as f:
|
||||
json.dump([], f)
|
||||
|
||||
cache = PerceptionCache(templates_path=templates_path)
|
||||
|
||||
assert cache.templates == []
|
||||
|
||||
|
||||
class TestCrystallizePerception:
|
||||
"""Tests for crystallize_perception function."""
|
||||
|
||||
def test_crystallize_returns_empty_list(self, tmp_path):
|
||||
"""crystallize_perception currently returns empty list (placeholder)."""
|
||||
from timmy.sovereignty.perception_cache import crystallize_perception
|
||||
|
||||
screenshot = np.array([[1, 2], [3, 4]])
|
||||
result = crystallize_perception(screenshot, {"some": "response"})
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_crystallize_accepts_any_vlm_response(self, tmp_path):
|
||||
"""crystallize_perception accepts any vlm_response format."""
|
||||
from timmy.sovereignty.perception_cache import crystallize_perception
|
||||
|
||||
screenshot = np.array([[1, 2], [3, 4]])
|
||||
|
||||
# Test with various response types
|
||||
assert crystallize_perception(screenshot, None) == []
|
||||
assert crystallize_perception(screenshot, {}) == []
|
||||
assert crystallize_perception(screenshot, {"items": []}) == []
|
||||
assert crystallize_perception(screenshot, "string response") == []
|
||||
643
tests/timmy/test_kimi_delegation.py
Normal file
643
tests/timmy/test_kimi_delegation.py
Normal file
@@ -0,0 +1,643 @@
|
||||
"""Unit tests for timmy.kimi_delegation — Kimi research delegation pipeline."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# exceeds_local_capacity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExceedsLocalCapacity:
|
||||
def test_heavy_keyword_triggers_delegation(self):
|
||||
from timmy.kimi_delegation import exceeds_local_capacity
|
||||
|
||||
assert exceeds_local_capacity("Do a comprehensive review of the codebase") is True
|
||||
|
||||
def test_all_heavy_keywords_detected(self):
|
||||
from timmy.kimi_delegation import _HEAVY_RESEARCH_KEYWORDS, exceeds_local_capacity
|
||||
|
||||
for kw in _HEAVY_RESEARCH_KEYWORDS:
|
||||
assert exceeds_local_capacity(f"Please {kw} the topic") is True, f"Missed keyword: {kw}"
|
||||
|
||||
def test_long_task_triggers_delegation(self):
|
||||
from timmy.kimi_delegation import _HEAVY_WORD_THRESHOLD, exceeds_local_capacity
|
||||
|
||||
long_task = " ".join(["word"] * (_HEAVY_WORD_THRESHOLD + 1))
|
||||
assert exceeds_local_capacity(long_task) is True
|
||||
|
||||
def test_short_simple_task_returns_false(self):
|
||||
from timmy.kimi_delegation import exceeds_local_capacity
|
||||
|
||||
assert exceeds_local_capacity("Fix the typo in README") is False
|
||||
|
||||
def test_exactly_at_word_threshold_triggers(self):
|
||||
from timmy.kimi_delegation import _HEAVY_WORD_THRESHOLD, exceeds_local_capacity
|
||||
|
||||
task = " ".join(["word"] * _HEAVY_WORD_THRESHOLD)
|
||||
assert exceeds_local_capacity(task) is True
|
||||
|
||||
def test_keyword_case_insensitive(self):
|
||||
from timmy.kimi_delegation import exceeds_local_capacity
|
||||
|
||||
assert exceeds_local_capacity("Run a COMPREHENSIVE analysis") is True
|
||||
|
||||
def test_empty_string_returns_false(self):
|
||||
from timmy.kimi_delegation import exceeds_local_capacity
|
||||
|
||||
assert exceeds_local_capacity("") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _slugify
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSlugify:
|
||||
def test_basic_text(self):
|
||||
from timmy.kimi_delegation import _slugify
|
||||
|
||||
assert _slugify("Hello World") == "hello-world"
|
||||
|
||||
def test_special_characters_removed(self):
|
||||
from timmy.kimi_delegation import _slugify
|
||||
|
||||
assert _slugify("Research: AI & ML!") == "research-ai--ml"
|
||||
|
||||
def test_underscores_become_dashes(self):
|
||||
from timmy.kimi_delegation import _slugify
|
||||
|
||||
assert _slugify("some_snake_case") == "some-snake-case"
|
||||
|
||||
def test_long_text_truncated_to_60(self):
|
||||
from timmy.kimi_delegation import _slugify
|
||||
|
||||
long_text = "a" * 100
|
||||
result = _slugify(long_text)
|
||||
assert len(result) <= 60
|
||||
|
||||
def test_leading_trailing_dashes_stripped(self):
|
||||
from timmy.kimi_delegation import _slugify
|
||||
|
||||
result = _slugify(" hello ")
|
||||
assert not result.startswith("-")
|
||||
assert not result.endswith("-")
|
||||
|
||||
def test_multiple_spaces_become_single_dash(self):
|
||||
from timmy.kimi_delegation import _slugify
|
||||
|
||||
assert _slugify("one two") == "one-two"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_research_template
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildResearchTemplate:
|
||||
def test_contains_task_title(self):
|
||||
from timmy.kimi_delegation import _build_research_template
|
||||
|
||||
body = _build_research_template("My Task", "background", "the question?")
|
||||
assert "My Task" in body
|
||||
|
||||
def test_contains_question(self):
|
||||
from timmy.kimi_delegation import _build_research_template
|
||||
|
||||
body = _build_research_template("task", "context", "What is X?")
|
||||
assert "What is X?" in body
|
||||
|
||||
def test_contains_context(self):
|
||||
from timmy.kimi_delegation import _build_research_template
|
||||
|
||||
body = _build_research_template("task", "some context here", "q?")
|
||||
assert "some context here" in body
|
||||
|
||||
def test_default_priority_normal(self):
|
||||
from timmy.kimi_delegation import _build_research_template
|
||||
|
||||
body = _build_research_template("task", "ctx", "q?")
|
||||
assert "normal" in body
|
||||
|
||||
def test_custom_priority_included(self):
|
||||
from timmy.kimi_delegation import _build_research_template
|
||||
|
||||
body = _build_research_template("task", "ctx", "q?", priority="high")
|
||||
assert "high" in body
|
||||
|
||||
def test_kimi_label_mentioned(self):
|
||||
from timmy.kimi_delegation import KIMI_READY_LABEL, _build_research_template
|
||||
|
||||
body = _build_research_template("task", "ctx", "q?")
|
||||
assert KIMI_READY_LABEL in body
|
||||
|
||||
def test_slugified_task_in_artifact_path(self):
|
||||
from timmy.kimi_delegation import _build_research_template
|
||||
|
||||
body = _build_research_template("My Research Task", "ctx", "q?")
|
||||
assert "my-research-task" in body
|
||||
|
||||
def test_sections_present(self):
|
||||
from timmy.kimi_delegation import _build_research_template
|
||||
|
||||
body = _build_research_template("task", "ctx", "q?")
|
||||
assert "## Research Request" in body
|
||||
assert "### Research Question" in body
|
||||
assert "### Background / Context" in body
|
||||
assert "### Deliverables" in body
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _extract_action_items
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractActionItems:
|
||||
def test_checkbox_items_extracted(self):
|
||||
from timmy.kimi_delegation import _extract_action_items
|
||||
|
||||
text = "- [ ] Fix the bug\n- [ ] Write tests\n"
|
||||
items = _extract_action_items(text)
|
||||
assert "Fix the bug" in items
|
||||
assert "Write tests" in items
|
||||
|
||||
def test_numbered_list_extracted(self):
|
||||
from timmy.kimi_delegation import _extract_action_items
|
||||
|
||||
text = "1. Deploy to staging\n2. Run smoke tests\n"
|
||||
items = _extract_action_items(text)
|
||||
assert "Deploy to staging" in items
|
||||
assert "Run smoke tests" in items
|
||||
|
||||
def test_action_prefix_extracted(self):
|
||||
from timmy.kimi_delegation import _extract_action_items
|
||||
|
||||
text = "Action: Update the config file\n"
|
||||
items = _extract_action_items(text)
|
||||
assert "Update the config file" in items
|
||||
|
||||
def test_todo_prefix_extracted(self):
|
||||
from timmy.kimi_delegation import _extract_action_items
|
||||
|
||||
text = "TODO: Add error handling\n"
|
||||
items = _extract_action_items(text)
|
||||
assert "Add error handling" in items
|
||||
|
||||
def test_next_step_prefix_extracted(self):
|
||||
from timmy.kimi_delegation import _extract_action_items
|
||||
|
||||
text = "Next step: Validate results\n"
|
||||
items = _extract_action_items(text)
|
||||
assert "Validate results" in items
|
||||
|
||||
def test_case_insensitive_prefixes(self):
|
||||
from timmy.kimi_delegation import _extract_action_items
|
||||
|
||||
text = "todo: lowercase todo\nACTION: uppercase action\n"
|
||||
items = _extract_action_items(text)
|
||||
assert "lowercase todo" in items
|
||||
assert "uppercase action" in items
|
||||
|
||||
def test_deduplication(self):
|
||||
from timmy.kimi_delegation import _extract_action_items
|
||||
|
||||
text = "1. Do the thing\n2. Do the thing\n"
|
||||
items = _extract_action_items(text)
|
||||
assert items.count("Do the thing") == 1
|
||||
|
||||
def test_empty_text_returns_empty_list(self):
|
||||
from timmy.kimi_delegation import _extract_action_items
|
||||
|
||||
assert _extract_action_items("") == []
|
||||
|
||||
def test_no_action_items_returns_empty_list(self):
|
||||
from timmy.kimi_delegation import _extract_action_items
|
||||
|
||||
text = "This is just plain prose with no action items here."
|
||||
assert _extract_action_items(text) == []
|
||||
|
||||
def test_mixed_sources_combined(self):
|
||||
from timmy.kimi_delegation import _extract_action_items
|
||||
|
||||
text = "- [ ] checkbox item\n1. numbered item\nAction: action item\n"
|
||||
items = _extract_action_items(text)
|
||||
assert len(items) == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_or_create_label (async)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetOrCreateLabel:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_existing_label_id(self):
|
||||
from timmy.kimi_delegation import KIMI_READY_LABEL, _get_or_create_label
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = [{"name": KIMI_READY_LABEL, "id": 42}]
|
||||
|
||||
client = MagicMock()
|
||||
client.get = AsyncMock(return_value=mock_resp)
|
||||
|
||||
result = await _get_or_create_label(client, "http://git", {"Authorization": "token x"}, "owner/repo")
|
||||
assert result == 42
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_label_when_missing(self):
|
||||
from timmy.kimi_delegation import _get_or_create_label
|
||||
|
||||
list_resp = MagicMock()
|
||||
list_resp.status_code = 200
|
||||
list_resp.json.return_value = [] # no existing labels
|
||||
|
||||
create_resp = MagicMock()
|
||||
create_resp.status_code = 201
|
||||
create_resp.json.return_value = {"id": 99}
|
||||
|
||||
client = MagicMock()
|
||||
client.get = AsyncMock(return_value=list_resp)
|
||||
client.post = AsyncMock(return_value=create_resp)
|
||||
|
||||
result = await _get_or_create_label(client, "http://git", {"Authorization": "token x"}, "owner/repo")
|
||||
assert result == 99
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_list_exception(self):
|
||||
from timmy.kimi_delegation import _get_or_create_label
|
||||
|
||||
client = MagicMock()
|
||||
client.get = AsyncMock(side_effect=Exception("network error"))
|
||||
|
||||
result = await _get_or_create_label(client, "http://git", {}, "owner/repo")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_create_exception(self):
|
||||
from timmy.kimi_delegation import _get_or_create_label
|
||||
|
||||
list_resp = MagicMock()
|
||||
list_resp.status_code = 200
|
||||
list_resp.json.return_value = []
|
||||
|
||||
client = MagicMock()
|
||||
client.get = AsyncMock(return_value=list_resp)
|
||||
client.post = AsyncMock(side_effect=Exception("create failed"))
|
||||
|
||||
result = await _get_or_create_label(client, "http://git", {}, "owner/repo")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# create_kimi_research_issue (async)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCreateKimiResearchIssue:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_error_when_gitea_disabled(self):
|
||||
from timmy.kimi_delegation import create_kimi_research_issue
|
||||
|
||||
with patch("timmy.kimi_delegation.settings") as mock_settings:
|
||||
mock_settings.gitea_enabled = False
|
||||
mock_settings.gitea_token = ""
|
||||
result = await create_kimi_research_issue("task", "ctx", "q?")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "not configured" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_error_when_no_token(self):
|
||||
from timmy.kimi_delegation import create_kimi_research_issue
|
||||
|
||||
with patch("timmy.kimi_delegation.settings") as mock_settings:
|
||||
mock_settings.gitea_enabled = True
|
||||
mock_settings.gitea_token = ""
|
||||
result = await create_kimi_research_issue("task", "ctx", "q?")
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_issue_creation(self):
|
||||
from timmy.kimi_delegation import create_kimi_research_issue
|
||||
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.gitea_enabled = True
|
||||
mock_settings.gitea_token = "tok"
|
||||
mock_settings.gitea_url = "http://git"
|
||||
mock_settings.gitea_repo = "owner/repo"
|
||||
|
||||
label_resp = MagicMock()
|
||||
label_resp.status_code = 200
|
||||
label_resp.json.return_value = [{"name": "kimi-ready", "id": 5}]
|
||||
|
||||
issue_resp = MagicMock()
|
||||
issue_resp.status_code = 201
|
||||
issue_resp.json.return_value = {"number": 42, "html_url": "http://git/issues/42"}
|
||||
|
||||
async_client = AsyncMock()
|
||||
async_client.get = AsyncMock(return_value=label_resp)
|
||||
async_client.post = AsyncMock(return_value=issue_resp)
|
||||
async_client.__aenter__ = AsyncMock(return_value=async_client)
|
||||
async_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("timmy.kimi_delegation.settings", mock_settings),
|
||||
patch("timmy.kimi_delegation.httpx") as mock_httpx,
|
||||
):
|
||||
mock_httpx.AsyncClient.return_value = async_client
|
||||
result = await create_kimi_research_issue("task", "ctx", "q?")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["issue_number"] == 42
|
||||
assert "http://git/issues/42" in result["issue_url"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error_returns_failure(self):
|
||||
from timmy.kimi_delegation import create_kimi_research_issue
|
||||
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.gitea_enabled = True
|
||||
mock_settings.gitea_token = "tok"
|
||||
mock_settings.gitea_url = "http://git"
|
||||
mock_settings.gitea_repo = "owner/repo"
|
||||
|
||||
label_resp = MagicMock()
|
||||
label_resp.status_code = 200
|
||||
label_resp.json.return_value = []
|
||||
|
||||
create_label_resp = MagicMock()
|
||||
create_label_resp.status_code = 201
|
||||
create_label_resp.json.return_value = {"id": 1}
|
||||
|
||||
issue_resp = MagicMock()
|
||||
issue_resp.status_code = 500
|
||||
issue_resp.text = "Internal Server Error"
|
||||
|
||||
async_client = AsyncMock()
|
||||
async_client.get = AsyncMock(return_value=label_resp)
|
||||
async_client.post = AsyncMock(side_effect=[create_label_resp, issue_resp])
|
||||
async_client.__aenter__ = AsyncMock(return_value=async_client)
|
||||
async_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("timmy.kimi_delegation.settings", mock_settings),
|
||||
patch("timmy.kimi_delegation.httpx") as mock_httpx,
|
||||
):
|
||||
mock_httpx.AsyncClient.return_value = async_client
|
||||
result = await create_kimi_research_issue("task", "ctx", "q?")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "500" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_returns_failure(self):
|
||||
from timmy.kimi_delegation import create_kimi_research_issue
|
||||
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.gitea_enabled = True
|
||||
mock_settings.gitea_token = "tok"
|
||||
mock_settings.gitea_url = "http://git"
|
||||
mock_settings.gitea_repo = "owner/repo"
|
||||
|
||||
async_client = AsyncMock()
|
||||
async_client.__aenter__ = AsyncMock(side_effect=Exception("connection refused"))
|
||||
async_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("timmy.kimi_delegation.settings", mock_settings),
|
||||
patch("timmy.kimi_delegation.httpx") as mock_httpx,
|
||||
):
|
||||
mock_httpx.AsyncClient.return_value = async_client
|
||||
result = await create_kimi_research_issue("task", "ctx", "q?")
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["error"] != ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# poll_kimi_issue (async)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPollKimiIssue:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_error_when_gitea_not_configured(self):
|
||||
from timmy.kimi_delegation import poll_kimi_issue
|
||||
|
||||
with patch("timmy.kimi_delegation.settings") as mock_settings:
|
||||
mock_settings.gitea_enabled = False
|
||||
mock_settings.gitea_token = ""
|
||||
result = await poll_kimi_issue(123)
|
||||
|
||||
assert result["completed"] is False
|
||||
assert "not configured" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_completed_when_issue_closed(self):
|
||||
from timmy.kimi_delegation import poll_kimi_issue
|
||||
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.gitea_enabled = True
|
||||
mock_settings.gitea_token = "tok"
|
||||
mock_settings.gitea_url = "http://git"
|
||||
mock_settings.gitea_repo = "owner/repo"
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {"state": "closed", "body": "Done!"}
|
||||
|
||||
async_client = AsyncMock()
|
||||
async_client.get = AsyncMock(return_value=resp)
|
||||
async_client.__aenter__ = AsyncMock(return_value=async_client)
|
||||
async_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("timmy.kimi_delegation.settings", mock_settings),
|
||||
patch("timmy.kimi_delegation.httpx") as mock_httpx,
|
||||
):
|
||||
mock_httpx.AsyncClient.return_value = async_client
|
||||
result = await poll_kimi_issue(42, poll_interval=0, max_wait=1)
|
||||
|
||||
assert result["completed"] is True
|
||||
assert result["state"] == "closed"
|
||||
assert result["body"] == "Done!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_times_out_when_issue_stays_open(self):
|
||||
from timmy.kimi_delegation import poll_kimi_issue
|
||||
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.gitea_enabled = True
|
||||
mock_settings.gitea_token = "tok"
|
||||
mock_settings.gitea_url = "http://git"
|
||||
mock_settings.gitea_repo = "owner/repo"
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {"state": "open", "body": ""}
|
||||
|
||||
async_client = AsyncMock()
|
||||
async_client.get = AsyncMock(return_value=resp)
|
||||
async_client.__aenter__ = AsyncMock(return_value=async_client)
|
||||
async_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("timmy.kimi_delegation.settings", mock_settings),
|
||||
patch("timmy.kimi_delegation.httpx") as mock_httpx,
|
||||
patch("timmy.kimi_delegation.asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
mock_httpx.AsyncClient.return_value = async_client
|
||||
# poll_interval > max_wait so it exits immediately after first sleep
|
||||
result = await poll_kimi_issue(42, poll_interval=10, max_wait=5)
|
||||
|
||||
assert result["completed"] is False
|
||||
assert result["state"] == "timeout"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# index_kimi_artifact (async)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIndexKimiArtifact:
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_artifact_returns_error(self):
|
||||
from timmy.kimi_delegation import index_kimi_artifact
|
||||
|
||||
result = await index_kimi_artifact(1, "title", " ")
|
||||
assert result["success"] is False
|
||||
assert "Empty artifact" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_indexing(self):
|
||||
from timmy.kimi_delegation import index_kimi_artifact
|
||||
|
||||
mock_entry = MagicMock()
|
||||
mock_entry.id = "mem-123"
|
||||
|
||||
with patch("timmy.kimi_delegation.asyncio.to_thread", new_callable=AsyncMock) as mock_thread:
|
||||
mock_thread.return_value = mock_entry
|
||||
result = await index_kimi_artifact(42, "My Research", "Some research content here")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["memory_id"] == "mem-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_returns_failure(self):
|
||||
from timmy.kimi_delegation import index_kimi_artifact
|
||||
|
||||
with patch("timmy.kimi_delegation.asyncio.to_thread", new_callable=AsyncMock) as mock_thread:
|
||||
mock_thread.side_effect = Exception("DB error")
|
||||
result = await index_kimi_artifact(42, "title", "some content")
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["error"] != ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_and_create_followups (async)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractAndCreateFollowups:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_action_items_returns_empty_created(self):
|
||||
from timmy.kimi_delegation import extract_and_create_followups
|
||||
|
||||
result = await extract_and_create_followups("Plain prose, nothing to do.", 1)
|
||||
assert result["success"] is True
|
||||
assert result["created"] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gitea_not_configured_returns_error(self):
|
||||
from timmy.kimi_delegation import extract_and_create_followups
|
||||
|
||||
text = "1. Do something important\n"
|
||||
|
||||
with patch("timmy.kimi_delegation.settings") as mock_settings:
|
||||
mock_settings.gitea_enabled = False
|
||||
mock_settings.gitea_token = ""
|
||||
result = await extract_and_create_followups(text, 5)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_followup_issues(self):
|
||||
from timmy.kimi_delegation import extract_and_create_followups
|
||||
|
||||
text = "1. Deploy the service\n2. Run integration tests\n"
|
||||
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.gitea_enabled = True
|
||||
mock_settings.gitea_token = "tok"
|
||||
mock_settings.gitea_url = "http://git"
|
||||
mock_settings.gitea_repo = "owner/repo"
|
||||
|
||||
issue_resp = MagicMock()
|
||||
issue_resp.status_code = 201
|
||||
issue_resp.json.return_value = {"number": 10}
|
||||
|
||||
async_client = AsyncMock()
|
||||
async_client.post = AsyncMock(return_value=issue_resp)
|
||||
async_client.__aenter__ = AsyncMock(return_value=async_client)
|
||||
async_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("timmy.kimi_delegation.settings", mock_settings),
|
||||
patch("timmy.kimi_delegation.httpx") as mock_httpx,
|
||||
):
|
||||
mock_httpx.AsyncClient.return_value = async_client
|
||||
result = await extract_and_create_followups(text, 5)
|
||||
|
||||
assert result["success"] is True
|
||||
assert len(result["created"]) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# delegate_research_to_kimi (async)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDelegateResearchToKimi:
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_task_returns_error(self):
|
||||
from timmy.kimi_delegation import delegate_research_to_kimi
|
||||
|
||||
result = await delegate_research_to_kimi("", "ctx", "q?")
|
||||
assert result["success"] is False
|
||||
assert "required" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_task_returns_error(self):
|
||||
from timmy.kimi_delegation import delegate_research_to_kimi
|
||||
|
||||
result = await delegate_research_to_kimi(" ", "ctx", "q?")
|
||||
assert result["success"] is False
|
||||
assert "required" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_question_returns_error(self):
|
||||
from timmy.kimi_delegation import delegate_research_to_kimi
|
||||
|
||||
result = await delegate_research_to_kimi("valid task", "ctx", "")
|
||||
assert result["success"] is False
|
||||
assert "required" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delegates_to_create_issue(self):
|
||||
from timmy.kimi_delegation import delegate_research_to_kimi
|
||||
|
||||
with patch(
|
||||
"timmy.kimi_delegation.create_kimi_research_issue",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create:
|
||||
mock_create.return_value = {"success": True, "issue_number": 7, "issue_url": "http://x", "error": None}
|
||||
result = await delegate_research_to_kimi("Research X", "ctx", "What is X?", priority="high")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["issue_number"] == 7
|
||||
mock_create.assert_awaited_once_with("Research X", "ctx", "What is X?", "high")
|
||||
667
tests/timmy/test_orchestration_loop.py
Normal file
667
tests/timmy/test_orchestration_loop.py
Normal file
@@ -0,0 +1,667 @@
|
||||
"""Tests for timmy.vassal.orchestration_loop — VassalOrchestrator core module.
|
||||
|
||||
Refs #1278
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from timmy.vassal.orchestration_loop import VassalCycleRecord, VassalOrchestrator
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# VassalCycleRecord tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVassalCycleRecord:
|
||||
"""Unit tests for the VassalCycleRecord dataclass."""
|
||||
|
||||
def test_creation_defaults(self):
|
||||
"""Test creating a cycle record with minimal fields."""
|
||||
record = VassalCycleRecord(
|
||||
cycle_id=1,
|
||||
started_at="2026-03-23T12:00:00+00:00",
|
||||
)
|
||||
assert record.cycle_id == 1
|
||||
assert record.started_at == "2026-03-23T12:00:00+00:00"
|
||||
assert record.finished_at == ""
|
||||
assert record.duration_ms == 0
|
||||
assert record.issues_fetched == 0
|
||||
assert record.issues_dispatched == 0
|
||||
assert record.stuck_agents == []
|
||||
assert record.house_warnings == []
|
||||
assert record.errors == []
|
||||
|
||||
def test_healthy_property_no_issues(self):
|
||||
"""Record is healthy when no errors or warnings."""
|
||||
record = VassalCycleRecord(
|
||||
cycle_id=1,
|
||||
started_at="2026-03-23T12:00:00+00:00",
|
||||
)
|
||||
assert record.healthy is True
|
||||
|
||||
def test_healthy_property_with_errors(self):
|
||||
"""Record is unhealthy when errors exist."""
|
||||
record = VassalCycleRecord(
|
||||
cycle_id=1,
|
||||
started_at="2026-03-23T12:00:00+00:00",
|
||||
errors=["backlog: Connection failed"],
|
||||
)
|
||||
assert record.healthy is False
|
||||
|
||||
def test_healthy_property_with_warnings(self):
|
||||
"""Record is unhealthy when house warnings exist."""
|
||||
record = VassalCycleRecord(
|
||||
cycle_id=1,
|
||||
started_at="2026-03-23T12:00:00+00:00",
|
||||
house_warnings=["Disk: 90% used"],
|
||||
)
|
||||
assert record.healthy is False
|
||||
|
||||
def test_full_populated_record(self):
|
||||
"""Test a fully populated cycle record."""
|
||||
record = VassalCycleRecord(
|
||||
cycle_id=5,
|
||||
started_at="2026-03-23T12:00:00+00:00",
|
||||
finished_at="2026-03-23T12:00:01+00:00",
|
||||
duration_ms=1000,
|
||||
issues_fetched=10,
|
||||
issues_dispatched=3,
|
||||
dispatched_to_claude=1,
|
||||
dispatched_to_kimi=1,
|
||||
dispatched_to_timmy=1,
|
||||
stuck_agents=["claude"],
|
||||
nudges_sent=1,
|
||||
house_warnings=[],
|
||||
cleanup_deleted=0,
|
||||
errors=[],
|
||||
)
|
||||
assert record.cycle_id == 5
|
||||
assert record.duration_ms == 1000
|
||||
assert record.healthy is True
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# VassalOrchestrator initialization tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVassalOrchestratorInit:
|
||||
"""Tests for VassalOrchestrator initialization."""
|
||||
|
||||
def test_default_initialization(self):
|
||||
"""Test default initialization with no parameters."""
|
||||
orchestrator = VassalOrchestrator()
|
||||
assert orchestrator.cycle_count == 0
|
||||
assert orchestrator.is_running is False
|
||||
assert orchestrator.history == []
|
||||
assert orchestrator._max_dispatch == 10
|
||||
|
||||
def test_custom_interval(self):
|
||||
"""Test initialization with custom cycle interval."""
|
||||
orchestrator = VassalOrchestrator(cycle_interval=60.0)
|
||||
assert orchestrator._cycle_interval == 60.0
|
||||
|
||||
def test_custom_max_dispatch(self):
|
||||
"""Test initialization with custom max dispatch."""
|
||||
orchestrator = VassalOrchestrator(max_dispatch_per_cycle=5)
|
||||
assert orchestrator._max_dispatch == 5
|
||||
|
||||
def test_get_status_empty_history(self):
|
||||
"""Test get_status when no cycles have run."""
|
||||
orchestrator = VassalOrchestrator()
|
||||
status = orchestrator.get_status()
|
||||
assert status["running"] is False
|
||||
assert status["cycle_count"] == 0
|
||||
assert status["last_cycle"] is None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Run cycle tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunCycle:
|
||||
"""Tests for the run_cycle method."""
|
||||
|
||||
@pytest.fixture
|
||||
def orchestrator(self):
|
||||
"""Create a fresh orchestrator for each test."""
|
||||
return VassalOrchestrator()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_dispatch_registry(self):
|
||||
"""Clear dispatch registry before each test."""
|
||||
from timmy.vassal.dispatch import clear_dispatch_registry
|
||||
|
||||
clear_dispatch_registry()
|
||||
yield
|
||||
clear_dispatch_registry()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cycle_empty_backlog(self, orchestrator):
|
||||
"""Test a cycle with no issues to process."""
|
||||
with patch(
|
||||
"timmy.vassal.orchestration_loop.VassalOrchestrator._broadcast"
|
||||
) as mock_broadcast:
|
||||
mock_broadcast.return_value = None
|
||||
with patch(
|
||||
"timmy.vassal.backlog.fetch_open_issues", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = []
|
||||
record = await orchestrator.run_cycle()
|
||||
|
||||
assert record.cycle_id == 1
|
||||
assert record.issues_fetched == 0
|
||||
assert record.issues_dispatched == 0
|
||||
assert record.duration_ms >= 0
|
||||
assert record.finished_at != ""
|
||||
assert orchestrator.cycle_count == 1
|
||||
assert len(orchestrator.history) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cycle_dispatches_issues(self, orchestrator):
|
||||
"""Test dispatching issues to agents."""
|
||||
|
||||
mock_issue = {
|
||||
"number": 123,
|
||||
"title": "Test issue",
|
||||
"body": "Test body",
|
||||
"labels": [],
|
||||
"assignees": [],
|
||||
"html_url": "http://test/123",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"timmy.vassal.orchestration_loop.VassalOrchestrator._broadcast"
|
||||
) as mock_broadcast:
|
||||
mock_broadcast.return_value = None
|
||||
with patch(
|
||||
"timmy.vassal.backlog.fetch_open_issues", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = [mock_issue]
|
||||
with patch(
|
||||
"timmy.vassal.dispatch.dispatch_issue", new_callable=AsyncMock
|
||||
) as mock_dispatch:
|
||||
mock_dispatch.return_value = MagicMock()
|
||||
record = await orchestrator.run_cycle()
|
||||
|
||||
assert record.cycle_id == 1
|
||||
assert record.issues_fetched == 1
|
||||
assert record.issues_dispatched == 1
|
||||
mock_dispatch.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cycle_respects_max_dispatch(self, orchestrator):
|
||||
"""Test that max_dispatch_per_cycle limits dispatches."""
|
||||
mock_issues = [
|
||||
{
|
||||
"number": i,
|
||||
"title": f"Issue {i}",
|
||||
"body": "Test",
|
||||
"labels": [],
|
||||
"assignees": [],
|
||||
"html_url": f"http://test/{i}",
|
||||
}
|
||||
for i in range(1, 15)
|
||||
]
|
||||
|
||||
orchestrator._max_dispatch = 3
|
||||
|
||||
with patch(
|
||||
"timmy.vassal.orchestration_loop.VassalOrchestrator._broadcast"
|
||||
) as mock_broadcast:
|
||||
mock_broadcast.return_value = None
|
||||
with patch(
|
||||
"timmy.vassal.backlog.fetch_open_issues", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = mock_issues
|
||||
with patch(
|
||||
"timmy.vassal.dispatch.dispatch_issue", new_callable=AsyncMock
|
||||
) as mock_dispatch:
|
||||
mock_dispatch.return_value = MagicMock()
|
||||
record = await orchestrator.run_cycle()
|
||||
|
||||
assert record.issues_fetched == 14
|
||||
assert record.issues_dispatched == 3
|
||||
assert mock_dispatch.await_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cycle_skips_already_dispatched(self, orchestrator):
|
||||
"""Test that already dispatched issues are skipped."""
|
||||
mock_issues = [
|
||||
{
|
||||
"number": 1,
|
||||
"title": "Issue 1",
|
||||
"body": "Test",
|
||||
"labels": [],
|
||||
"assignees": [],
|
||||
"html_url": "http://test/1",
|
||||
},
|
||||
{
|
||||
"number": 2,
|
||||
"title": "Issue 2",
|
||||
"body": "Test",
|
||||
"labels": [],
|
||||
"assignees": [],
|
||||
"html_url": "http://test/2",
|
||||
},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"timmy.vassal.orchestration_loop.VassalOrchestrator._broadcast"
|
||||
) as mock_broadcast:
|
||||
mock_broadcast.return_value = None
|
||||
with patch(
|
||||
"timmy.vassal.backlog.fetch_open_issues", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = mock_issues
|
||||
with patch(
|
||||
"timmy.vassal.dispatch.get_dispatch_registry"
|
||||
) as mock_registry:
|
||||
# Issue 1 already dispatched
|
||||
mock_registry.return_value = {1: MagicMock()}
|
||||
with patch(
|
||||
"timmy.vassal.dispatch.dispatch_issue", new_callable=AsyncMock
|
||||
) as mock_dispatch:
|
||||
mock_dispatch.return_value = MagicMock()
|
||||
record = await orchestrator.run_cycle()
|
||||
|
||||
assert record.issues_fetched == 2
|
||||
assert record.issues_dispatched == 1
|
||||
mock_dispatch.assert_awaited_once()
|
||||
# Should be called with issue 2
|
||||
call_args = mock_dispatch.call_args[0][0]
|
||||
assert call_args.number == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cycle_tracks_agent_targets(self, orchestrator):
|
||||
"""Test that dispatch counts are tracked per agent."""
|
||||
|
||||
mock_issues = [
|
||||
{
|
||||
"number": 1,
|
||||
"title": "Architecture refactor", # Should route to Claude
|
||||
"body": "Test",
|
||||
"labels": [],
|
||||
"assignees": [],
|
||||
"html_url": "http://test/1",
|
||||
},
|
||||
{
|
||||
"number": 2,
|
||||
"title": "Research analysis", # Should route to Kimi
|
||||
"body": "Test",
|
||||
"labels": [],
|
||||
"assignees": [],
|
||||
"html_url": "http://test/2",
|
||||
},
|
||||
{
|
||||
"number": 3,
|
||||
"title": "Docs update", # Should route to Timmy
|
||||
"body": "Test",
|
||||
"labels": [],
|
||||
"assignees": [],
|
||||
"html_url": "http://test/3",
|
||||
},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"timmy.vassal.orchestration_loop.VassalOrchestrator._broadcast"
|
||||
) as mock_broadcast:
|
||||
mock_broadcast.return_value = None
|
||||
with patch(
|
||||
"timmy.vassal.backlog.fetch_open_issues", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = mock_issues
|
||||
with patch(
|
||||
"timmy.vassal.dispatch.dispatch_issue", new_callable=AsyncMock
|
||||
) as mock_dispatch:
|
||||
mock_dispatch.return_value = MagicMock()
|
||||
record = await orchestrator.run_cycle()
|
||||
|
||||
assert record.issues_dispatched == 3
|
||||
assert record.dispatched_to_claude == 1
|
||||
assert record.dispatched_to_kimi == 1
|
||||
assert record.dispatched_to_timmy == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cycle_handles_backlog_error(self, orchestrator):
|
||||
"""Test graceful handling of backlog step errors."""
|
||||
with patch(
|
||||
"timmy.vassal.orchestration_loop.VassalOrchestrator._broadcast"
|
||||
) as mock_broadcast:
|
||||
mock_broadcast.return_value = None
|
||||
with patch(
|
||||
"timmy.vassal.backlog.fetch_open_issues", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.side_effect = RuntimeError("Gitea down")
|
||||
record = await orchestrator.run_cycle()
|
||||
|
||||
assert record.cycle_id == 1
|
||||
assert record.issues_fetched == 0
|
||||
assert len(record.errors) == 1
|
||||
assert "backlog" in record.errors[0]
|
||||
assert record.healthy is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cycle_handles_agent_health_error(self, orchestrator):
|
||||
"""Test graceful handling of agent health step errors."""
|
||||
with patch(
|
||||
"timmy.vassal.orchestration_loop.VassalOrchestrator._broadcast"
|
||||
) as mock_broadcast:
|
||||
mock_broadcast.return_value = None
|
||||
with patch(
|
||||
"timmy.vassal.backlog.fetch_open_issues", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = []
|
||||
with patch(
|
||||
"timmy.vassal.agent_health.get_full_health_report",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_health:
|
||||
mock_health.side_effect = RuntimeError("Health check failed")
|
||||
record = await orchestrator.run_cycle()
|
||||
|
||||
assert len(record.errors) == 1
|
||||
assert "agent_health" in record.errors[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cycle_handles_house_health_error(self, orchestrator):
|
||||
"""Test graceful handling of house health step errors."""
|
||||
with patch(
|
||||
"timmy.vassal.orchestration_loop.VassalOrchestrator._broadcast"
|
||||
) as mock_broadcast:
|
||||
mock_broadcast.return_value = None
|
||||
with patch(
|
||||
"timmy.vassal.backlog.fetch_open_issues", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = []
|
||||
with patch(
|
||||
"timmy.vassal.house_health.get_system_snapshot",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_snapshot:
|
||||
mock_snapshot.side_effect = RuntimeError("Snapshot failed")
|
||||
record = await orchestrator.run_cycle()
|
||||
|
||||
assert len(record.errors) == 1
|
||||
assert "house_health" in record.errors[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cycle_detects_stuck_agents(self, orchestrator):
|
||||
"""Test detection and nudging of stuck agents."""
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@dataclass
|
||||
class MockAgentStatus:
|
||||
agent: str
|
||||
is_stuck: bool = False
|
||||
is_idle: bool = False
|
||||
stuck_issue_numbers: list = field(default_factory=list)
|
||||
|
||||
mock_report = MagicMock()
|
||||
mock_report.agents = [
|
||||
MockAgentStatus(agent="claude", is_stuck=True, stuck_issue_numbers=[100]),
|
||||
MockAgentStatus(agent="kimi", is_stuck=False),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"timmy.vassal.orchestration_loop.VassalOrchestrator._broadcast"
|
||||
) as mock_broadcast:
|
||||
mock_broadcast.return_value = None
|
||||
with patch(
|
||||
"timmy.vassal.backlog.fetch_open_issues", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = []
|
||||
with patch(
|
||||
"timmy.vassal.agent_health.get_full_health_report",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_health:
|
||||
mock_health.return_value = mock_report
|
||||
with patch(
|
||||
"timmy.vassal.agent_health.nudge_stuck_agent",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_nudge:
|
||||
mock_nudge.return_value = True
|
||||
record = await orchestrator.run_cycle()
|
||||
|
||||
assert "claude" in record.stuck_agents
|
||||
assert record.nudges_sent == 1
|
||||
mock_nudge.assert_awaited_once_with("claude", 100)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cycle_triggers_cleanup_on_high_disk(self, orchestrator):
|
||||
"""Test cleanup is triggered when disk usage is high."""
|
||||
mock_snapshot = MagicMock()
|
||||
mock_snapshot.disk.percent_used = 85.0 # Above 80% threshold
|
||||
mock_snapshot.warnings = ["Disk: 85% used"]
|
||||
|
||||
with patch(
|
||||
"timmy.vassal.orchestration_loop.VassalOrchestrator._broadcast"
|
||||
) as mock_broadcast:
|
||||
mock_broadcast.return_value = None
|
||||
with patch(
|
||||
"timmy.vassal.backlog.fetch_open_issues", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = []
|
||||
with patch(
|
||||
"timmy.vassal.house_health.get_system_snapshot",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_snapshot_fn:
|
||||
mock_snapshot_fn.return_value = mock_snapshot
|
||||
with patch(
|
||||
"timmy.vassal.house_health.cleanup_stale_files",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_cleanup:
|
||||
mock_cleanup.return_value = {"deleted_count": 5}
|
||||
record = await orchestrator.run_cycle()
|
||||
|
||||
assert record.cleanup_deleted == 5
|
||||
assert record.house_warnings == ["Disk: 85% used"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_status_after_cycle(self, orchestrator):
|
||||
"""Test get_status returns correct info after a cycle."""
|
||||
with patch(
|
||||
"timmy.vassal.orchestration_loop.VassalOrchestrator._broadcast"
|
||||
) as mock_broadcast:
|
||||
mock_broadcast.return_value = None
|
||||
with patch(
|
||||
"timmy.vassal.backlog.fetch_open_issues", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = []
|
||||
await orchestrator.run_cycle()
|
||||
|
||||
status = orchestrator.get_status()
|
||||
assert status["running"] is False
|
||||
assert status["cycle_count"] == 1
|
||||
assert status["last_cycle"] is not None
|
||||
assert status["last_cycle"]["cycle_id"] == 1
|
||||
assert status["last_cycle"]["issues_fetched"] == 0
|
||||
assert status["last_cycle"]["healthy"] is True
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Background loop tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBackgroundLoop:
|
||||
"""Tests for the start/stop background loop methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def orchestrator(self):
|
||||
"""Create a fresh orchestrator for each test."""
|
||||
return VassalOrchestrator(cycle_interval=0.1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_stop_cycle(self, orchestrator):
|
||||
"""Test starting and stopping the background loop."""
|
||||
with patch.object(orchestrator, "run_cycle", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = MagicMock()
|
||||
|
||||
# Start the loop
|
||||
await orchestrator.start()
|
||||
assert orchestrator.is_running is True
|
||||
assert orchestrator._task is not None
|
||||
|
||||
# Let it run for a bit
|
||||
await asyncio.sleep(0.25)
|
||||
|
||||
# Stop the loop
|
||||
orchestrator.stop()
|
||||
assert orchestrator.is_running is False
|
||||
|
||||
# Should have run at least once
|
||||
assert mock_run.await_count >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_already_running(self, orchestrator):
|
||||
"""Test starting when already running is a no-op."""
|
||||
with patch.object(orchestrator, "run_cycle", new_callable=AsyncMock):
|
||||
await orchestrator.start()
|
||||
first_task = orchestrator._task
|
||||
|
||||
# Start again should not create new task
|
||||
await orchestrator.start()
|
||||
assert orchestrator._task is first_task
|
||||
|
||||
orchestrator.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_not_running(self, orchestrator):
|
||||
"""Test stopping when not running is a no-op."""
|
||||
orchestrator.stop()
|
||||
assert orchestrator.is_running is False
|
||||
assert orchestrator._task is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_handles_cycle_exceptions(self, orchestrator):
|
||||
"""Test that exceptions in run_cycle don't crash the loop."""
|
||||
with patch.object(
|
||||
orchestrator, "run_cycle", new_callable=AsyncMock
|
||||
) as mock_run:
|
||||
mock_run.side_effect = [RuntimeError("Boom"), MagicMock()]
|
||||
|
||||
await orchestrator.start()
|
||||
await asyncio.sleep(0.25)
|
||||
orchestrator.stop()
|
||||
|
||||
# Should have been called multiple times despite error
|
||||
assert mock_run.await_count >= 2
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Interval resolution tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIntervalResolution:
|
||||
"""Tests for the _resolve_interval method."""
|
||||
|
||||
def test_resolve_interval_explicit(self):
|
||||
"""Test that explicit interval is used when provided."""
|
||||
orchestrator = VassalOrchestrator(cycle_interval=60.0)
|
||||
assert orchestrator._resolve_interval() == 60.0
|
||||
|
||||
def test_resolve_interval_from_settings(self):
|
||||
"""Test interval is read from settings when not explicitly set."""
|
||||
orchestrator = VassalOrchestrator()
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.vassal_cycle_interval = 120.0
|
||||
|
||||
with patch("config.settings", mock_settings):
|
||||
assert orchestrator._resolve_interval() == 120.0
|
||||
|
||||
def test_resolve_interval_default_fallback(self):
|
||||
"""Test default 300s is used when settings fails."""
|
||||
orchestrator = VassalOrchestrator()
|
||||
|
||||
with patch("config.settings", None):
|
||||
assert orchestrator._resolve_interval() == 300.0
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Broadcast tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBroadcast:
|
||||
"""Tests for the _broadcast helper."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_success(self):
|
||||
"""Test successful WebSocket broadcast."""
|
||||
orchestrator = VassalOrchestrator()
|
||||
record = VassalCycleRecord(
|
||||
cycle_id=1,
|
||||
started_at="2026-03-23T12:00:00+00:00",
|
||||
finished_at="2026-03-23T12:00:01+00:00",
|
||||
duration_ms=1000,
|
||||
issues_fetched=5,
|
||||
issues_dispatched=2,
|
||||
)
|
||||
|
||||
mock_ws_manager = MagicMock()
|
||||
mock_ws_manager.broadcast = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"infrastructure.ws_manager.handler.ws_manager", mock_ws_manager
|
||||
):
|
||||
await orchestrator._broadcast(record)
|
||||
|
||||
mock_ws_manager.broadcast.assert_awaited_once()
|
||||
call_args = mock_ws_manager.broadcast.call_args[0]
|
||||
assert call_args[0] == "vassal.cycle"
|
||||
assert call_args[1]["cycle_id"] == 1
|
||||
assert call_args[1]["healthy"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_graceful_degradation(self):
|
||||
"""Test broadcast gracefully handles errors."""
|
||||
orchestrator = VassalOrchestrator()
|
||||
record = VassalCycleRecord(cycle_id=1, started_at="2026-03-23T12:00:00+00:00")
|
||||
|
||||
with patch(
|
||||
"infrastructure.ws_manager.handler.ws_manager"
|
||||
) as mock_ws_manager:
|
||||
mock_ws_manager.broadcast = AsyncMock(
|
||||
side_effect=RuntimeError("WS disconnected")
|
||||
)
|
||||
# Should not raise
|
||||
await orchestrator._broadcast(record)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_import_error(self):
|
||||
"""Test broadcast handles missing ws_manager module."""
|
||||
orchestrator = VassalOrchestrator()
|
||||
record = VassalCycleRecord(cycle_id=1, started_at="2026-03-23T12:00:00+00:00")
|
||||
|
||||
with patch.dict("sys.modules", {"infrastructure.ws_manager.handler": None}):
|
||||
# Should not raise
|
||||
await orchestrator._broadcast(record)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Module singleton test
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestModuleSingleton:
|
||||
"""Tests for the module-level vassal_orchestrator singleton."""
|
||||
|
||||
def test_singleton_import(self):
|
||||
"""Test that the module-level singleton is available."""
|
||||
from timmy.vassal import vassal_orchestrator
|
||||
|
||||
assert isinstance(vassal_orchestrator, VassalOrchestrator)
|
||||
|
||||
def test_singleton_is_single_instance(self):
|
||||
"""Test that importing twice returns same instance."""
|
||||
from timmy.vassal import vassal_orchestrator as orch1
|
||||
from timmy.vassal import vassal_orchestrator as orch2
|
||||
|
||||
assert orch1 is orch2
|
||||
|
||||
|
||||
# Need to import asyncio for the background loop tests
|
||||
import asyncio # noqa: E402
|
||||
@@ -1,839 +0,0 @@
|
||||
"""Unit tests for timmy.quest_system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import timmy.quest_system as qs
|
||||
from timmy.quest_system import (
|
||||
QuestDefinition,
|
||||
QuestProgress,
|
||||
QuestStatus,
|
||||
QuestType,
|
||||
_get_progress_key,
|
||||
_get_target_value,
|
||||
_is_on_cooldown,
|
||||
check_daily_run_quest,
|
||||
check_issue_count_quest,
|
||||
check_issue_reduce_quest,
|
||||
claim_quest_reward,
|
||||
evaluate_quest_progress,
|
||||
get_active_quests,
|
||||
get_agent_quests_status,
|
||||
get_or_create_progress,
|
||||
get_quest_definition,
|
||||
get_quest_definitions,
|
||||
get_quest_leaderboard,
|
||||
get_quest_progress,
|
||||
load_quest_config,
|
||||
reset_quest_progress,
|
||||
update_quest_progress,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_quest(
|
||||
quest_id: str = "test_quest",
|
||||
quest_type: QuestType = QuestType.ISSUE_COUNT,
|
||||
reward_tokens: int = 10,
|
||||
enabled: bool = True,
|
||||
repeatable: bool = False,
|
||||
cooldown_hours: int = 0,
|
||||
criteria: dict[str, Any] | None = None,
|
||||
) -> QuestDefinition:
|
||||
return QuestDefinition(
|
||||
id=quest_id,
|
||||
name=f"Quest {quest_id}",
|
||||
description="Test quest",
|
||||
reward_tokens=reward_tokens,
|
||||
quest_type=quest_type,
|
||||
enabled=enabled,
|
||||
repeatable=repeatable,
|
||||
cooldown_hours=cooldown_hours,
|
||||
criteria=criteria or {"target_count": 3},
|
||||
notification_message="Quest Complete! You earned {tokens} tokens.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_state():
|
||||
"""Reset module-level state before and after each test."""
|
||||
reset_quest_progress()
|
||||
qs._quest_definitions.clear()
|
||||
qs._quest_settings.clear()
|
||||
yield
|
||||
reset_quest_progress()
|
||||
qs._quest_definitions.clear()
|
||||
qs._quest_settings.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QuestDefinition
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestQuestDefinition:
|
||||
def test_from_dict_minimal(self):
|
||||
data = {"id": "q1"}
|
||||
defn = QuestDefinition.from_dict(data)
|
||||
assert defn.id == "q1"
|
||||
assert defn.name == "Unnamed Quest"
|
||||
assert defn.reward_tokens == 0
|
||||
assert defn.quest_type == QuestType.CUSTOM
|
||||
assert defn.enabled is True
|
||||
assert defn.repeatable is False
|
||||
assert defn.cooldown_hours == 0
|
||||
|
||||
def test_from_dict_full(self):
|
||||
data = {
|
||||
"id": "q2",
|
||||
"name": "Full Quest",
|
||||
"description": "A full quest",
|
||||
"reward_tokens": 50,
|
||||
"type": "issue_count",
|
||||
"enabled": False,
|
||||
"repeatable": True,
|
||||
"cooldown_hours": 24,
|
||||
"criteria": {"target_count": 5},
|
||||
"notification_message": "You earned {tokens}!",
|
||||
}
|
||||
defn = QuestDefinition.from_dict(data)
|
||||
assert defn.id == "q2"
|
||||
assert defn.name == "Full Quest"
|
||||
assert defn.reward_tokens == 50
|
||||
assert defn.quest_type == QuestType.ISSUE_COUNT
|
||||
assert defn.enabled is False
|
||||
assert defn.repeatable is True
|
||||
assert defn.cooldown_hours == 24
|
||||
assert defn.criteria == {"target_count": 5}
|
||||
assert defn.notification_message == "You earned {tokens}!"
|
||||
|
||||
def test_from_dict_invalid_type_raises(self):
|
||||
data = {"id": "q3", "type": "not_a_real_type"}
|
||||
with pytest.raises(ValueError):
|
||||
QuestDefinition.from_dict(data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QuestProgress
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestQuestProgress:
|
||||
def test_to_dict_roundtrip(self):
|
||||
progress = QuestProgress(
|
||||
quest_id="q1",
|
||||
agent_id="agent_a",
|
||||
status=QuestStatus.IN_PROGRESS,
|
||||
current_value=2,
|
||||
target_value=5,
|
||||
started_at="2026-01-01T00:00:00",
|
||||
metadata={"key": "val"},
|
||||
)
|
||||
d = progress.to_dict()
|
||||
assert d["quest_id"] == "q1"
|
||||
assert d["agent_id"] == "agent_a"
|
||||
assert d["status"] == "in_progress"
|
||||
assert d["current_value"] == 2
|
||||
assert d["target_value"] == 5
|
||||
assert d["metadata"] == {"key": "val"}
|
||||
|
||||
def test_to_dict_defaults(self):
|
||||
progress = QuestProgress(
|
||||
quest_id="q1",
|
||||
agent_id="agent_a",
|
||||
status=QuestStatus.NOT_STARTED,
|
||||
)
|
||||
d = progress.to_dict()
|
||||
assert d["completion_count"] == 0
|
||||
assert d["started_at"] == ""
|
||||
assert d["completed_at"] == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_progress_key
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_get_progress_key():
|
||||
assert _get_progress_key("q1", "agent_a") == "agent_a:q1"
|
||||
|
||||
|
||||
def test_get_progress_key_different_agents():
|
||||
key_a = _get_progress_key("q1", "agent_a")
|
||||
key_b = _get_progress_key("q1", "agent_b")
|
||||
assert key_a != key_b
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_quest_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLoadQuestConfig:
|
||||
def test_missing_file_returns_empty(self, tmp_path):
|
||||
missing = tmp_path / "nonexistent.yaml"
|
||||
with patch.object(qs, "QUEST_CONFIG_PATH", missing):
|
||||
defs, settings = load_quest_config()
|
||||
assert defs == {}
|
||||
assert settings == {}
|
||||
|
||||
def test_valid_yaml_loads_quests(self, tmp_path):
|
||||
config_path = tmp_path / "quests.yaml"
|
||||
config_path.write_text(
|
||||
"""
|
||||
quests:
|
||||
first_quest:
|
||||
name: First Quest
|
||||
description: Do stuff
|
||||
reward_tokens: 25
|
||||
type: issue_count
|
||||
enabled: true
|
||||
repeatable: false
|
||||
cooldown_hours: 0
|
||||
criteria:
|
||||
target_count: 3
|
||||
notification_message: "Done! {tokens} tokens"
|
||||
settings:
|
||||
some_setting: true
|
||||
"""
|
||||
)
|
||||
with patch.object(qs, "QUEST_CONFIG_PATH", config_path):
|
||||
defs, settings = load_quest_config()
|
||||
|
||||
assert "first_quest" in defs
|
||||
assert defs["first_quest"].name == "First Quest"
|
||||
assert defs["first_quest"].reward_tokens == 25
|
||||
assert settings == {"some_setting": True}
|
||||
|
||||
def test_invalid_yaml_returns_empty(self, tmp_path):
|
||||
config_path = tmp_path / "quests.yaml"
|
||||
config_path.write_text(":: not valid yaml ::")
|
||||
with patch.object(qs, "QUEST_CONFIG_PATH", config_path):
|
||||
defs, settings = load_quest_config()
|
||||
assert defs == {}
|
||||
assert settings == {}
|
||||
|
||||
def test_non_dict_yaml_returns_empty(self, tmp_path):
|
||||
config_path = tmp_path / "quests.yaml"
|
||||
config_path.write_text("- item1\n- item2\n")
|
||||
with patch.object(qs, "QUEST_CONFIG_PATH", config_path):
|
||||
defs, settings = load_quest_config()
|
||||
assert defs == {}
|
||||
assert settings == {}
|
||||
|
||||
def test_bad_quest_entry_is_skipped(self, tmp_path):
|
||||
config_path = tmp_path / "quests.yaml"
|
||||
config_path.write_text(
|
||||
"""
|
||||
quests:
|
||||
good_quest:
|
||||
name: Good
|
||||
type: issue_count
|
||||
reward_tokens: 10
|
||||
enabled: true
|
||||
repeatable: false
|
||||
cooldown_hours: 0
|
||||
criteria: {}
|
||||
notification_message: "{tokens}"
|
||||
bad_quest:
|
||||
type: invalid_type_that_does_not_exist
|
||||
"""
|
||||
)
|
||||
with patch.object(qs, "QUEST_CONFIG_PATH", config_path):
|
||||
defs, _ = load_quest_config()
|
||||
assert "good_quest" in defs
|
||||
assert "bad_quest" not in defs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_quest_definitions / get_quest_definition / get_active_quests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestQuestLookup:
|
||||
def setup_method(self):
|
||||
q1 = _make_quest("q1", enabled=True)
|
||||
q2 = _make_quest("q2", enabled=False)
|
||||
qs._quest_definitions.update({"q1": q1, "q2": q2})
|
||||
|
||||
def test_get_quest_definitions_returns_all(self):
|
||||
defs = get_quest_definitions()
|
||||
assert "q1" in defs
|
||||
assert "q2" in defs
|
||||
|
||||
def test_get_quest_definition_found(self):
|
||||
defn = get_quest_definition("q1")
|
||||
assert defn is not None
|
||||
assert defn.id == "q1"
|
||||
|
||||
def test_get_quest_definition_not_found(self):
|
||||
assert get_quest_definition("missing") is None
|
||||
|
||||
def test_get_active_quests_only_enabled(self):
|
||||
active = get_active_quests()
|
||||
ids = [q.id for q in active]
|
||||
assert "q1" in ids
|
||||
assert "q2" not in ids
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_target_value
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetTargetValue:
|
||||
def test_issue_count(self):
|
||||
q = _make_quest(quest_type=QuestType.ISSUE_COUNT, criteria={"target_count": 7})
|
||||
assert _get_target_value(q) == 7
|
||||
|
||||
def test_issue_reduce(self):
|
||||
q = _make_quest(quest_type=QuestType.ISSUE_REDUCE, criteria={"target_reduction": 5})
|
||||
assert _get_target_value(q) == 5
|
||||
|
||||
def test_daily_run(self):
|
||||
q = _make_quest(quest_type=QuestType.DAILY_RUN, criteria={"min_sessions": 3})
|
||||
assert _get_target_value(q) == 3
|
||||
|
||||
def test_docs_update(self):
|
||||
q = _make_quest(quest_type=QuestType.DOCS_UPDATE, criteria={"min_files_changed": 2})
|
||||
assert _get_target_value(q) == 2
|
||||
|
||||
def test_test_improve(self):
|
||||
q = _make_quest(quest_type=QuestType.TEST_IMPROVE, criteria={"min_new_tests": 4})
|
||||
assert _get_target_value(q) == 4
|
||||
|
||||
def test_custom_defaults_to_one(self):
|
||||
q = _make_quest(quest_type=QuestType.CUSTOM, criteria={})
|
||||
assert _get_target_value(q) == 1
|
||||
|
||||
def test_missing_criteria_key_defaults_to_one(self):
|
||||
q = _make_quest(quest_type=QuestType.ISSUE_COUNT, criteria={})
|
||||
assert _get_target_value(q) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_or_create_progress / get_quest_progress
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestProgressCreation:
|
||||
def setup_method(self):
|
||||
qs._quest_definitions["q1"] = _make_quest("q1", criteria={"target_count": 5})
|
||||
|
||||
def test_creates_new_progress(self):
|
||||
progress = get_or_create_progress("q1", "agent_a")
|
||||
assert progress.quest_id == "q1"
|
||||
assert progress.agent_id == "agent_a"
|
||||
assert progress.status == QuestStatus.NOT_STARTED
|
||||
assert progress.target_value == 5
|
||||
assert progress.current_value == 0
|
||||
|
||||
def test_returns_existing_progress(self):
|
||||
p1 = get_or_create_progress("q1", "agent_a")
|
||||
p1.current_value = 3
|
||||
p2 = get_or_create_progress("q1", "agent_a")
|
||||
assert p2.current_value == 3
|
||||
assert p1 is p2
|
||||
|
||||
def test_raises_for_unknown_quest(self):
|
||||
with pytest.raises(ValueError, match="Quest unknown not found"):
|
||||
get_or_create_progress("unknown", "agent_a")
|
||||
|
||||
def test_get_quest_progress_none_before_creation(self):
|
||||
assert get_quest_progress("q1", "agent_a") is None
|
||||
|
||||
def test_get_quest_progress_after_creation(self):
|
||||
get_or_create_progress("q1", "agent_a")
|
||||
progress = get_quest_progress("q1", "agent_a")
|
||||
assert progress is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# update_quest_progress
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUpdateQuestProgress:
|
||||
def setup_method(self):
|
||||
qs._quest_definitions["q1"] = _make_quest("q1", criteria={"target_count": 3})
|
||||
|
||||
def test_updates_current_value(self):
|
||||
progress = update_quest_progress("q1", "agent_a", 2)
|
||||
assert progress.current_value == 2
|
||||
assert progress.status == QuestStatus.NOT_STARTED
|
||||
|
||||
def test_marks_completed_when_target_reached(self):
|
||||
progress = update_quest_progress("q1", "agent_a", 3)
|
||||
assert progress.status == QuestStatus.COMPLETED
|
||||
assert progress.completed_at != ""
|
||||
|
||||
def test_marks_completed_when_value_exceeds_target(self):
|
||||
progress = update_quest_progress("q1", "agent_a", 10)
|
||||
assert progress.status == QuestStatus.COMPLETED
|
||||
|
||||
def test_does_not_re_complete_already_completed(self):
|
||||
p = update_quest_progress("q1", "agent_a", 3)
|
||||
first_completed_at = p.completed_at
|
||||
p2 = update_quest_progress("q1", "agent_a", 5)
|
||||
# should not change completed_at again
|
||||
assert p2.completed_at == first_completed_at
|
||||
|
||||
def test_does_not_re_complete_claimed_quest(self):
|
||||
p = update_quest_progress("q1", "agent_a", 3)
|
||||
p.status = QuestStatus.CLAIMED
|
||||
p2 = update_quest_progress("q1", "agent_a", 5)
|
||||
assert p2.status == QuestStatus.CLAIMED
|
||||
|
||||
def test_updates_metadata(self):
|
||||
progress = update_quest_progress("q1", "agent_a", 1, metadata={"info": "value"})
|
||||
assert progress.metadata["info"] == "value"
|
||||
|
||||
def test_merges_metadata(self):
|
||||
update_quest_progress("q1", "agent_a", 1, metadata={"a": 1})
|
||||
progress = update_quest_progress("q1", "agent_a", 2, metadata={"b": 2})
|
||||
assert progress.metadata["a"] == 1
|
||||
assert progress.metadata["b"] == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_on_cooldown
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsOnCooldown:
|
||||
def test_non_repeatable_never_on_cooldown(self):
|
||||
quest = _make_quest(repeatable=False, cooldown_hours=24)
|
||||
progress = QuestProgress(
|
||||
quest_id="q1",
|
||||
agent_id="agent_a",
|
||||
status=QuestStatus.CLAIMED,
|
||||
last_completed_at=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
assert _is_on_cooldown(progress, quest) is False
|
||||
|
||||
def test_no_last_completed_not_on_cooldown(self):
|
||||
quest = _make_quest(repeatable=True, cooldown_hours=24)
|
||||
progress = QuestProgress(
|
||||
quest_id="q1",
|
||||
agent_id="agent_a",
|
||||
status=QuestStatus.NOT_STARTED,
|
||||
last_completed_at="",
|
||||
)
|
||||
assert _is_on_cooldown(progress, quest) is False
|
||||
|
||||
def test_zero_cooldown_not_on_cooldown(self):
|
||||
quest = _make_quest(repeatable=True, cooldown_hours=0)
|
||||
progress = QuestProgress(
|
||||
quest_id="q1",
|
||||
agent_id="agent_a",
|
||||
status=QuestStatus.CLAIMED,
|
||||
last_completed_at=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
assert _is_on_cooldown(progress, quest) is False
|
||||
|
||||
def test_recent_completion_is_on_cooldown(self):
|
||||
quest = _make_quest(repeatable=True, cooldown_hours=24)
|
||||
recent = datetime.now(UTC) - timedelta(hours=1)
|
||||
progress = QuestProgress(
|
||||
quest_id="q1",
|
||||
agent_id="agent_a",
|
||||
status=QuestStatus.NOT_STARTED,
|
||||
last_completed_at=recent.isoformat(),
|
||||
)
|
||||
assert _is_on_cooldown(progress, quest) is True
|
||||
|
||||
def test_expired_cooldown_not_on_cooldown(self):
|
||||
quest = _make_quest(repeatable=True, cooldown_hours=24)
|
||||
old = datetime.now(UTC) - timedelta(hours=25)
|
||||
progress = QuestProgress(
|
||||
quest_id="q1",
|
||||
agent_id="agent_a",
|
||||
status=QuestStatus.NOT_STARTED,
|
||||
last_completed_at=old.isoformat(),
|
||||
)
|
||||
assert _is_on_cooldown(progress, quest) is False
|
||||
|
||||
def test_invalid_last_completed_returns_false(self):
|
||||
quest = _make_quest(repeatable=True, cooldown_hours=24)
|
||||
progress = QuestProgress(
|
||||
quest_id="q1",
|
||||
agent_id="agent_a",
|
||||
status=QuestStatus.NOT_STARTED,
|
||||
last_completed_at="not-a-date",
|
||||
)
|
||||
assert _is_on_cooldown(progress, quest) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# claim_quest_reward
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestClaimQuestReward:
|
||||
def setup_method(self):
|
||||
qs._quest_definitions["q1"] = _make_quest("q1", reward_tokens=25)
|
||||
|
||||
def test_returns_none_if_no_progress(self):
|
||||
assert claim_quest_reward("q1", "agent_a") is None
|
||||
|
||||
def test_returns_none_if_not_completed(self):
|
||||
get_or_create_progress("q1", "agent_a")
|
||||
assert claim_quest_reward("q1", "agent_a") is None
|
||||
|
||||
def test_returns_none_if_quest_not_found(self):
|
||||
assert claim_quest_reward("nonexistent", "agent_a") is None
|
||||
|
||||
def test_successful_claim(self):
|
||||
progress = get_or_create_progress("q1", "agent_a")
|
||||
progress.status = QuestStatus.COMPLETED
|
||||
progress.completed_at = datetime.now(UTC).isoformat()
|
||||
|
||||
mock_invoice = MagicMock()
|
||||
mock_invoice.payment_hash = "quest_q1_agent_a_123"
|
||||
|
||||
with (
|
||||
patch("timmy.quest_system.create_invoice_entry", return_value=mock_invoice),
|
||||
patch("timmy.quest_system.mark_settled"),
|
||||
):
|
||||
result = claim_quest_reward("q1", "agent_a")
|
||||
|
||||
assert result is not None
|
||||
assert result["tokens_awarded"] == 25
|
||||
assert result["quest_id"] == "q1"
|
||||
assert result["agent_id"] == "agent_a"
|
||||
assert result["completion_count"] == 1
|
||||
|
||||
def test_successful_claim_marks_claimed(self):
|
||||
progress = get_or_create_progress("q1", "agent_a")
|
||||
progress.status = QuestStatus.COMPLETED
|
||||
progress.completed_at = datetime.now(UTC).isoformat()
|
||||
|
||||
mock_invoice = MagicMock()
|
||||
mock_invoice.payment_hash = "phash"
|
||||
|
||||
with (
|
||||
patch("timmy.quest_system.create_invoice_entry", return_value=mock_invoice),
|
||||
patch("timmy.quest_system.mark_settled"),
|
||||
):
|
||||
claim_quest_reward("q1", "agent_a")
|
||||
|
||||
assert progress.status == QuestStatus.CLAIMED
|
||||
|
||||
def test_repeatable_quest_resets_after_claim(self):
|
||||
qs._quest_definitions["rep"] = _make_quest(
|
||||
"rep", repeatable=True, cooldown_hours=0, reward_tokens=10
|
||||
)
|
||||
progress = get_or_create_progress("rep", "agent_a")
|
||||
progress.status = QuestStatus.COMPLETED
|
||||
progress.completed_at = datetime.now(UTC).isoformat()
|
||||
progress.current_value = 5
|
||||
|
||||
mock_invoice = MagicMock()
|
||||
mock_invoice.payment_hash = "phash"
|
||||
|
||||
with (
|
||||
patch("timmy.quest_system.create_invoice_entry", return_value=mock_invoice),
|
||||
patch("timmy.quest_system.mark_settled"),
|
||||
):
|
||||
result = claim_quest_reward("rep", "agent_a")
|
||||
|
||||
assert result is not None
|
||||
assert progress.status == QuestStatus.NOT_STARTED
|
||||
assert progress.current_value == 0
|
||||
assert progress.completed_at == ""
|
||||
|
||||
def test_on_cooldown_returns_none(self):
|
||||
qs._quest_definitions["rep"] = _make_quest("rep", repeatable=True, cooldown_hours=24)
|
||||
progress = get_or_create_progress("rep", "agent_a")
|
||||
progress.status = QuestStatus.COMPLETED
|
||||
recent = datetime.now(UTC) - timedelta(hours=1)
|
||||
progress.last_completed_at = recent.isoformat()
|
||||
|
||||
assert claim_quest_reward("rep", "agent_a") is None
|
||||
|
||||
def test_ledger_error_returns_none(self):
|
||||
progress = get_or_create_progress("q1", "agent_a")
|
||||
progress.status = QuestStatus.COMPLETED
|
||||
progress.completed_at = datetime.now(UTC).isoformat()
|
||||
|
||||
with patch("timmy.quest_system.create_invoice_entry", side_effect=Exception("ledger error")):
|
||||
result = claim_quest_reward("q1", "agent_a")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_issue_count_quest
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckIssueCountQuest:
|
||||
def setup_method(self):
|
||||
qs._quest_definitions["iq"] = _make_quest(
|
||||
"iq", quest_type=QuestType.ISSUE_COUNT, criteria={"target_count": 2, "issue_labels": ["bug"]}
|
||||
)
|
||||
|
||||
def test_counts_matching_issues(self):
|
||||
issues = [
|
||||
{"labels": [{"name": "bug"}]},
|
||||
{"labels": [{"name": "bug"}, {"name": "priority"}]},
|
||||
{"labels": [{"name": "feature"}]}, # doesn't match
|
||||
]
|
||||
progress = check_issue_count_quest(
|
||||
qs._quest_definitions["iq"], "agent_a", issues
|
||||
)
|
||||
assert progress.current_value == 2
|
||||
assert progress.status == QuestStatus.COMPLETED
|
||||
|
||||
def test_empty_issues_returns_zero(self):
|
||||
progress = check_issue_count_quest(qs._quest_definitions["iq"], "agent_a", [])
|
||||
assert progress.current_value == 0
|
||||
|
||||
def test_no_labels_filter_counts_all_labeled(self):
|
||||
q = _make_quest(
|
||||
"nolabel",
|
||||
quest_type=QuestType.ISSUE_COUNT,
|
||||
criteria={"target_count": 1, "issue_labels": []},
|
||||
)
|
||||
qs._quest_definitions["nolabel"] = q
|
||||
issues = [
|
||||
{"labels": [{"name": "bug"}]},
|
||||
{"labels": [{"name": "feature"}]},
|
||||
]
|
||||
progress = check_issue_count_quest(q, "agent_a", issues)
|
||||
assert progress.current_value == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_issue_reduce_quest
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckIssueReduceQuest:
|
||||
def setup_method(self):
|
||||
qs._quest_definitions["ir"] = _make_quest(
|
||||
"ir", quest_type=QuestType.ISSUE_REDUCE, criteria={"target_reduction": 5}
|
||||
)
|
||||
|
||||
def test_computes_reduction(self):
|
||||
progress = check_issue_reduce_quest(qs._quest_definitions["ir"], "agent_a", 20, 15)
|
||||
assert progress.current_value == 5
|
||||
assert progress.status == QuestStatus.COMPLETED
|
||||
|
||||
def test_negative_reduction_treated_as_zero(self):
|
||||
progress = check_issue_reduce_quest(qs._quest_definitions["ir"], "agent_a", 10, 15)
|
||||
assert progress.current_value == 0
|
||||
|
||||
def test_no_change_yields_zero(self):
|
||||
progress = check_issue_reduce_quest(qs._quest_definitions["ir"], "agent_a", 10, 10)
|
||||
assert progress.current_value == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_daily_run_quest
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckDailyRunQuest:
|
||||
def setup_method(self):
|
||||
qs._quest_definitions["dr"] = _make_quest(
|
||||
"dr", quest_type=QuestType.DAILY_RUN, criteria={"min_sessions": 2}
|
||||
)
|
||||
|
||||
def test_tracks_sessions(self):
|
||||
progress = check_daily_run_quest(qs._quest_definitions["dr"], "agent_a", 2)
|
||||
assert progress.current_value == 2
|
||||
assert progress.status == QuestStatus.COMPLETED
|
||||
|
||||
def test_incomplete_sessions(self):
|
||||
progress = check_daily_run_quest(qs._quest_definitions["dr"], "agent_a", 1)
|
||||
assert progress.current_value == 1
|
||||
assert progress.status != QuestStatus.COMPLETED
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# evaluate_quest_progress
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEvaluateQuestProgress:
|
||||
def setup_method(self):
|
||||
qs._quest_definitions["iq"] = _make_quest(
|
||||
"iq", quest_type=QuestType.ISSUE_COUNT, criteria={"target_count": 1}
|
||||
)
|
||||
qs._quest_definitions["dis"] = _make_quest("dis", enabled=False)
|
||||
|
||||
def test_disabled_quest_returns_none(self):
|
||||
result = evaluate_quest_progress("dis", "agent_a", {})
|
||||
assert result is None
|
||||
|
||||
def test_missing_quest_returns_none(self):
|
||||
result = evaluate_quest_progress("nonexistent", "agent_a", {})
|
||||
assert result is None
|
||||
|
||||
def test_issue_count_quest_evaluated(self):
|
||||
context = {"closed_issues": [{"labels": [{"name": "bug"}]}]}
|
||||
result = evaluate_quest_progress("iq", "agent_a", context)
|
||||
assert result is not None
|
||||
assert result.current_value == 1
|
||||
|
||||
def test_issue_reduce_quest_evaluated(self):
|
||||
qs._quest_definitions["ir"] = _make_quest(
|
||||
"ir", quest_type=QuestType.ISSUE_REDUCE, criteria={"target_reduction": 3}
|
||||
)
|
||||
context = {"previous_issue_count": 10, "current_issue_count": 7}
|
||||
result = evaluate_quest_progress("ir", "agent_a", context)
|
||||
assert result is not None
|
||||
assert result.current_value == 3
|
||||
|
||||
def test_daily_run_quest_evaluated(self):
|
||||
qs._quest_definitions["dr"] = _make_quest(
|
||||
"dr", quest_type=QuestType.DAILY_RUN, criteria={"min_sessions": 1}
|
||||
)
|
||||
context = {"sessions_completed": 2}
|
||||
result = evaluate_quest_progress("dr", "agent_a", context)
|
||||
assert result is not None
|
||||
assert result.current_value == 2
|
||||
|
||||
def test_custom_quest_returns_existing_progress(self):
|
||||
qs._quest_definitions["cust"] = _make_quest("cust", quest_type=QuestType.CUSTOM)
|
||||
# No progress yet => None (custom quests don't auto-create progress here)
|
||||
result = evaluate_quest_progress("cust", "agent_a", {})
|
||||
assert result is None
|
||||
|
||||
def test_cooldown_prevents_evaluation(self):
|
||||
q = _make_quest("rep_iq", quest_type=QuestType.ISSUE_COUNT, repeatable=True, cooldown_hours=24, criteria={"target_count": 1})
|
||||
qs._quest_definitions["rep_iq"] = q
|
||||
progress = get_or_create_progress("rep_iq", "agent_a")
|
||||
recent = datetime.now(UTC) - timedelta(hours=1)
|
||||
progress.last_completed_at = recent.isoformat()
|
||||
|
||||
context = {"closed_issues": [{"labels": [{"name": "bug"}]}]}
|
||||
result = evaluate_quest_progress("rep_iq", "agent_a", context)
|
||||
# Should return existing progress without updating
|
||||
assert result is progress
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reset_quest_progress
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestResetQuestProgress:
|
||||
def setup_method(self):
|
||||
qs._quest_definitions["q1"] = _make_quest("q1")
|
||||
qs._quest_definitions["q2"] = _make_quest("q2")
|
||||
|
||||
def test_reset_all(self):
|
||||
get_or_create_progress("q1", "agent_a")
|
||||
get_or_create_progress("q2", "agent_a")
|
||||
count = reset_quest_progress()
|
||||
assert count == 2
|
||||
assert get_quest_progress("q1", "agent_a") is None
|
||||
assert get_quest_progress("q2", "agent_a") is None
|
||||
|
||||
def test_reset_specific_quest(self):
|
||||
get_or_create_progress("q1", "agent_a")
|
||||
get_or_create_progress("q2", "agent_a")
|
||||
count = reset_quest_progress(quest_id="q1")
|
||||
assert count == 1
|
||||
assert get_quest_progress("q1", "agent_a") is None
|
||||
assert get_quest_progress("q2", "agent_a") is not None
|
||||
|
||||
def test_reset_specific_agent(self):
|
||||
get_or_create_progress("q1", "agent_a")
|
||||
get_or_create_progress("q1", "agent_b")
|
||||
count = reset_quest_progress(agent_id="agent_a")
|
||||
assert count == 1
|
||||
assert get_quest_progress("q1", "agent_a") is None
|
||||
assert get_quest_progress("q1", "agent_b") is not None
|
||||
|
||||
def test_reset_specific_quest_and_agent(self):
|
||||
get_or_create_progress("q1", "agent_a")
|
||||
get_or_create_progress("q1", "agent_b")
|
||||
count = reset_quest_progress(quest_id="q1", agent_id="agent_a")
|
||||
assert count == 1
|
||||
|
||||
def test_reset_empty_returns_zero(self):
|
||||
count = reset_quest_progress()
|
||||
assert count == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_quest_leaderboard
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetQuestLeaderboard:
|
||||
def setup_method(self):
|
||||
qs._quest_definitions["q1"] = _make_quest("q1", reward_tokens=10)
|
||||
qs._quest_definitions["q2"] = _make_quest("q2", reward_tokens=20)
|
||||
|
||||
def test_empty_progress_returns_empty(self):
|
||||
assert get_quest_leaderboard() == []
|
||||
|
||||
def test_leaderboard_sorted_by_tokens(self):
|
||||
p_a = get_or_create_progress("q1", "agent_a")
|
||||
p_a.completion_count = 1
|
||||
p_b = get_or_create_progress("q2", "agent_b")
|
||||
p_b.completion_count = 2
|
||||
|
||||
board = get_quest_leaderboard()
|
||||
assert board[0]["agent_id"] == "agent_b" # 40 tokens
|
||||
assert board[1]["agent_id"] == "agent_a" # 10 tokens
|
||||
|
||||
def test_leaderboard_aggregates_multiple_quests(self):
|
||||
p1 = get_or_create_progress("q1", "agent_a")
|
||||
p1.completion_count = 2 # 20 tokens
|
||||
p2 = get_or_create_progress("q2", "agent_a")
|
||||
p2.completion_count = 1 # 20 tokens
|
||||
|
||||
board = get_quest_leaderboard()
|
||||
assert len(board) == 1
|
||||
assert board[0]["total_tokens"] == 40
|
||||
assert board[0]["total_completions"] == 3
|
||||
|
||||
def test_leaderboard_counts_unique_quests(self):
|
||||
p1 = get_or_create_progress("q1", "agent_a")
|
||||
p1.completion_count = 2
|
||||
p2 = get_or_create_progress("q2", "agent_a")
|
||||
p2.completion_count = 1
|
||||
|
||||
board = get_quest_leaderboard()
|
||||
assert board[0]["unique_quests_completed"] == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_agent_quests_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetAgentQuestsStatus:
|
||||
def setup_method(self):
|
||||
qs._quest_definitions["q1"] = _make_quest("q1", reward_tokens=10)
|
||||
|
||||
def test_returns_status_structure(self):
|
||||
result = get_agent_quests_status("agent_a")
|
||||
assert result["agent_id"] == "agent_a"
|
||||
assert isinstance(result["quests"], list)
|
||||
assert "total_tokens_earned" in result
|
||||
assert "total_quests_completed" in result
|
||||
assert "active_quests_count" in result
|
||||
|
||||
def test_includes_quest_info(self):
|
||||
result = get_agent_quests_status("agent_a")
|
||||
quest_info = result["quests"][0]
|
||||
assert quest_info["quest_id"] == "q1"
|
||||
assert quest_info["reward_tokens"] == 10
|
||||
assert quest_info["status"] == QuestStatus.NOT_STARTED.value
|
||||
|
||||
def test_accumulates_tokens_from_completions(self):
|
||||
p = get_or_create_progress("q1", "agent_a")
|
||||
p.completion_count = 3
|
||||
result = get_agent_quests_status("agent_a")
|
||||
assert result["total_tokens_earned"] == 30
|
||||
assert result["total_quests_completed"] == 3
|
||||
|
||||
def test_cooldown_hours_remaining_calculated(self):
|
||||
q = _make_quest("qcool", repeatable=True, cooldown_hours=24, reward_tokens=5)
|
||||
qs._quest_definitions["qcool"] = q
|
||||
p = get_or_create_progress("qcool", "agent_a")
|
||||
recent = datetime.now(UTC) - timedelta(hours=2)
|
||||
p.last_completed_at = recent.isoformat()
|
||||
p.completion_count = 1
|
||||
|
||||
result = get_agent_quests_status("agent_a")
|
||||
qcool_info = next(qi for qi in result["quests"] if qi["quest_id"] == "qcool")
|
||||
assert qcool_info["on_cooldown"] is True
|
||||
assert qcool_info["cooldown_hours_remaining"] > 0
|
||||
Reference in New Issue
Block a user