1
0

Compare commits

..

5 Commits

Author SHA1 Message Date
Alexander Whitestone
aa3cad5707 WIP: Claude Code progress on #1277
Automated salvage commit — agent session ended (exit 124).
Work in progress, may need continuation.
2026-03-23 22:09:59 -04:00
bde7232ece [claude] Add unit tests for kimi_delegation.py (#1295) (#1303) 2026-03-24 01:54:44 +00:00
fc4426954e [claude] Add module docstrings to 9 undocumented files (#1296) (#1302)
Co-authored-by: Claude (Opus 4.6) <claude@hermes.local>
Co-committed-by: Claude (Opus 4.6) <claude@hermes.local>
2026-03-24 01:54:18 +00:00
5be4ecb9ef [kimi] Add unit tests for sovereignty/perception_cache.py (#1261) (#1301)
Co-authored-by: Kimi Agent <kimi@timmy.local>
Co-committed-by: Kimi Agent <kimi@timmy.local>
2026-03-24 01:53:44 +00:00
4f80cfcd58 [claude] Three-tier model router: Local 8B / Hermes 70B / Cloud API cascade (#882) (#1297)
Co-authored-by: Claude (Opus 4.6) <claude@hermes.local>
Co-committed-by: Claude (Opus 4.6) <claude@hermes.local>
2026-03-24 01:53:25 +00:00
23 changed files with 3374 additions and 1959 deletions

View File

@@ -0,0 +1 @@
"""Timmy Time Dashboard — source root package."""

View File

@@ -1,3 +1,8 @@
"""Central pydantic-settings configuration for Timmy Time Dashboard.
All environment variable access goes through the ``settings`` singleton
exported from this module — never use ``os.environ.get()`` in app code.
"""
import logging as _logging
import os
import sys
@@ -128,6 +133,23 @@ class Settings(BaseSettings):
anthropic_api_key: str = ""
claude_model: str = "haiku"
# ── Tiered Model Router (issue #882) ─────────────────────────────────
# Three-tier cascade: Local 8B (free, fast) → Local 70B (free, slower)
# → Cloud API (paid, best). Override model names per tier via env vars.
#
# TIER_LOCAL_FAST_MODEL — Tier-1 model name in Ollama (default: llama3.1:8b)
# TIER_LOCAL_HEAVY_MODEL — Tier-2 model name in Ollama (default: hermes3:70b)
# TIER_CLOUD_MODEL — Tier-3 cloud model name (default: claude-haiku-4-5)
#
# Budget limits for the cloud tier (0 = unlimited):
# TIER_CLOUD_DAILY_BUDGET_USD — daily ceiling in USD (default: 5.0)
# TIER_CLOUD_MONTHLY_BUDGET_USD — monthly ceiling in USD (default: 50.0)
tier_local_fast_model: str = "llama3.1:8b"
tier_local_heavy_model: str = "hermes3:70b"
tier_cloud_model: str = "claude-haiku-4-5"
tier_cloud_daily_budget_usd: float = 5.0
tier_cloud_monthly_budget_usd: float = 50.0
# ── Content Moderation ──────────────────────────────────────────────
# Three-layer moderation pipeline for AI narrator output.
# Uses Llama Guard via Ollama with regex fallback.

View File

@@ -1,3 +1,4 @@
"""SQLAlchemy ORM models for the CALM task-management and journaling system."""
from datetime import UTC, date, datetime
from enum import StrEnum

View File

@@ -1,3 +1,4 @@
"""SQLAlchemy engine, session factory, and declarative Base for the CALM module."""
import logging
from pathlib import Path

View File

@@ -1,3 +1,4 @@
"""Dashboard routes for agent chat interactions and tool-call display."""
import json
import logging
from datetime import datetime

View File

@@ -1,3 +1,4 @@
"""Dashboard routes for the CALM task management and daily journaling interface."""
import logging
from datetime import UTC, date, datetime

View File

@@ -1,5 +1,11 @@
"""Infrastructure models package."""
from infrastructure.models.budget import (
BudgetTracker,
SpendRecord,
estimate_cost_usd,
get_budget_tracker,
)
from infrastructure.models.multimodal import (
ModelCapability,
ModelInfo,
@@ -17,6 +23,12 @@ from infrastructure.models.registry import (
ModelRole,
model_registry,
)
from infrastructure.models.router import (
TierLabel,
TieredModelRouter,
classify_tier,
get_tiered_router,
)
__all__ = [
# Registry
@@ -34,4 +46,14 @@ __all__ = [
"model_supports_tools",
"model_supports_vision",
"pull_model_with_fallback",
# Tiered router
"TierLabel",
"TieredModelRouter",
"classify_tier",
"get_tiered_router",
# Budget tracker
"BudgetTracker",
"SpendRecord",
"estimate_cost_usd",
"get_budget_tracker",
]

View File

@@ -0,0 +1,302 @@
"""Cloud API budget tracker for the three-tier model router.
Tracks cloud API spend (daily / monthly) and enforces configurable limits.
SQLite-backed with in-memory fallback — degrades gracefully if the database
is unavailable.
References:
- Issue #882 — Model Tiering Router: Local 8B / Hermes 70B / Cloud API Cascade
"""
import logging
import sqlite3
import threading
import time
from dataclasses import dataclass
from datetime import UTC, date, datetime
from pathlib import Path
from config import settings
logger = logging.getLogger(__name__)
# ── Cost estimates (USD per 1 K tokens, input / output) ──────────────────────
# Updated 2026-03. Estimates only — actual costs vary by tier/usage.
_COST_PER_1K: dict[str, dict[str, float]] = {
# Claude models
"claude-haiku-4-5": {"input": 0.00025, "output": 0.00125},
"claude-sonnet-4-5": {"input": 0.003, "output": 0.015},
"claude-opus-4-5": {"input": 0.015, "output": 0.075},
"haiku": {"input": 0.00025, "output": 0.00125},
"sonnet": {"input": 0.003, "output": 0.015},
"opus": {"input": 0.015, "output": 0.075},
# GPT-4o
"gpt-4o-mini": {"input": 0.00015, "output": 0.0006},
"gpt-4o": {"input": 0.0025, "output": 0.01},
# Grok (xAI)
"grok-3-fast": {"input": 0.003, "output": 0.015},
"grok-3": {"input": 0.005, "output": 0.025},
}
_DEFAULT_COST: dict[str, float] = {"input": 0.003, "output": 0.015} # conservative fallback
def estimate_cost_usd(model: str, tokens_in: int, tokens_out: int) -> float:
"""Estimate the cost of a single request in USD.
Matches the model name by substring so versioned names like
``claude-haiku-4-5-20251001`` still resolve correctly.
Args:
model: Model name as passed to the provider.
tokens_in: Number of input (prompt) tokens consumed.
tokens_out: Number of output (completion) tokens generated.
Returns:
Estimated cost in USD (may be zero for unknown models).
"""
model_lower = model.lower()
rates = _DEFAULT_COST
for key, rate in _COST_PER_1K.items():
if key in model_lower:
rates = rate
break
return (tokens_in * rates["input"] + tokens_out * rates["output"]) / 1000.0
@dataclass
class SpendRecord:
"""A single spend event."""
ts: float
provider: str
model: str
tokens_in: int
tokens_out: int
cost_usd: float
tier: str
class BudgetTracker:
"""Tracks cloud API spend with configurable daily / monthly limits.
Persists spend records to SQLite (``data/budget.db`` by default).
Falls back to in-memory tracking when the database is unavailable —
budget enforcement still works; records are lost on restart.
Limits are read from ``settings``:
* ``tier_cloud_daily_budget_usd`` — daily ceiling (0 = disabled)
* ``tier_cloud_monthly_budget_usd`` — monthly ceiling (0 = disabled)
Usage::
tracker = BudgetTracker()
if tracker.cloud_allowed():
# … make cloud API call …
tracker.record_spend("anthropic", "claude-haiku-4-5", 100, 200)
summary = tracker.get_summary()
print(summary["daily_usd"], "/", summary["daily_limit_usd"])
"""
_DB_PATH = "data/budget.db"
def __init__(self, db_path: str | None = None) -> None:
"""Initialise the tracker.
Args:
db_path: Path to the SQLite database. Defaults to
``data/budget.db``. Pass ``":memory:"`` for tests.
"""
self._db_path = db_path or self._DB_PATH
self._lock = threading.Lock()
self._in_memory: list[SpendRecord] = []
self._db_ok = False
self._init_db()
# ── Database initialisation ──────────────────────────────────────────────
def _init_db(self) -> None:
"""Create the spend table (and parent directory) if needed."""
try:
if self._db_path != ":memory:":
Path(self._db_path).parent.mkdir(parents=True, exist_ok=True)
with self._connect() as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS cloud_spend (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ts REAL NOT NULL,
provider TEXT NOT NULL,
model TEXT NOT NULL,
tokens_in INTEGER NOT NULL DEFAULT 0,
tokens_out INTEGER NOT NULL DEFAULT 0,
cost_usd REAL NOT NULL DEFAULT 0.0,
tier TEXT NOT NULL DEFAULT 'cloud'
)
"""
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_spend_ts ON cloud_spend(ts)"
)
self._db_ok = True
logger.debug("BudgetTracker: SQLite initialised at %s", self._db_path)
except Exception as exc:
logger.warning(
"BudgetTracker: SQLite unavailable, using in-memory fallback: %s", exc
)
def _connect(self) -> sqlite3.Connection:
return sqlite3.connect(self._db_path, timeout=5)
# ── Public API ───────────────────────────────────────────────────────────
def record_spend(
self,
provider: str,
model: str,
tokens_in: int = 0,
tokens_out: int = 0,
cost_usd: float | None = None,
tier: str = "cloud",
) -> float:
"""Record a cloud API spend event and return the cost recorded.
Args:
provider: Provider name (e.g. ``"anthropic"``, ``"openai"``).
model: Model name used for the request.
tokens_in: Input token count (prompt).
tokens_out: Output token count (completion).
cost_usd: Explicit cost override. If ``None``, the cost is
estimated from the token counts and model rates.
tier: Tier label for the request (default ``"cloud"``).
Returns:
The cost recorded in USD.
"""
if cost_usd is None:
cost_usd = estimate_cost_usd(model, tokens_in, tokens_out)
ts = time.time()
record = SpendRecord(ts, provider, model, tokens_in, tokens_out, cost_usd, tier)
with self._lock:
if self._db_ok:
try:
with self._connect() as conn:
conn.execute(
"""
INSERT INTO cloud_spend
(ts, provider, model, tokens_in, tokens_out, cost_usd, tier)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(ts, provider, model, tokens_in, tokens_out, cost_usd, tier),
)
logger.debug(
"BudgetTracker: recorded %.6f USD (%s/%s, in=%d out=%d tier=%s)",
cost_usd,
provider,
model,
tokens_in,
tokens_out,
tier,
)
return cost_usd
except Exception as exc:
logger.warning("BudgetTracker: DB write failed, falling back: %s", exc)
self._in_memory.append(record)
return cost_usd
def get_daily_spend(self) -> float:
"""Return total cloud spend for the current UTC day in USD."""
today = date.today()
since = datetime(today.year, today.month, today.day, tzinfo=UTC).timestamp()
return self._query_spend(since)
def get_monthly_spend(self) -> float:
"""Return total cloud spend for the current UTC month in USD."""
today = date.today()
since = datetime(today.year, today.month, 1, tzinfo=UTC).timestamp()
return self._query_spend(since)
def cloud_allowed(self) -> bool:
"""Return ``True`` if cloud API spend is within configured limits.
Checks both daily and monthly ceilings. A limit of ``0`` disables
that particular check.
"""
daily_limit = settings.tier_cloud_daily_budget_usd
monthly_limit = settings.tier_cloud_monthly_budget_usd
if daily_limit > 0:
daily_spend = self.get_daily_spend()
if daily_spend >= daily_limit:
logger.warning(
"BudgetTracker: daily cloud budget exhausted (%.4f / %.4f USD)",
daily_spend,
daily_limit,
)
return False
if monthly_limit > 0:
monthly_spend = self.get_monthly_spend()
if monthly_spend >= monthly_limit:
logger.warning(
"BudgetTracker: monthly cloud budget exhausted (%.4f / %.4f USD)",
monthly_spend,
monthly_limit,
)
return False
return True
def get_summary(self) -> dict:
"""Return a spend summary dict suitable for dashboards / logging.
Keys: ``daily_usd``, ``monthly_usd``, ``daily_limit_usd``,
``monthly_limit_usd``, ``daily_ok``, ``monthly_ok``.
"""
daily = self.get_daily_spend()
monthly = self.get_monthly_spend()
daily_limit = settings.tier_cloud_daily_budget_usd
monthly_limit = settings.tier_cloud_monthly_budget_usd
return {
"daily_usd": round(daily, 6),
"monthly_usd": round(monthly, 6),
"daily_limit_usd": daily_limit,
"monthly_limit_usd": monthly_limit,
"daily_ok": daily_limit <= 0 or daily < daily_limit,
"monthly_ok": monthly_limit <= 0 or monthly < monthly_limit,
}
# ── Internal helpers ─────────────────────────────────────────────────────
def _query_spend(self, since_ts: float) -> float:
"""Sum ``cost_usd`` for records with ``ts >= since_ts``."""
if self._db_ok:
try:
with self._connect() as conn:
row = conn.execute(
"SELECT COALESCE(SUM(cost_usd), 0.0) FROM cloud_spend WHERE ts >= ?",
(since_ts,),
).fetchone()
return float(row[0]) if row else 0.0
except Exception as exc:
logger.warning("BudgetTracker: DB read failed: %s", exc)
# In-memory fallback
return sum(r.cost_usd for r in self._in_memory if r.ts >= since_ts)
# ── Module-level singleton ────────────────────────────────────────────────────
_budget_tracker: BudgetTracker | None = None
def get_budget_tracker() -> BudgetTracker:
"""Get or create the module-level BudgetTracker singleton."""
global _budget_tracker
if _budget_tracker is None:
_budget_tracker = BudgetTracker()
return _budget_tracker

View File

@@ -0,0 +1,427 @@
"""Three-tier model router — Local 8B / Local 70B / Cloud API Cascade.
Selects the cheapest-sufficient LLM for each request using a heuristic
task-complexity classifier. Tier 3 (Cloud API) is only used when Tier 2
fails or the budget guard allows it.
Tiers
-----
Tier 1 — LOCAL_FAST (Llama 3.1 8B / Hermes 3 8B via Ollama, free, ~0.3-1 s)
Navigation, basic interactions, simple decisions.
Tier 2 — LOCAL_HEAVY (Hermes 3/4 70B via Ollama, free, ~5-10 s for 200 tok)
Quest planning, dialogue strategy, complex reasoning.
Tier 3 — CLOUD_API (Claude / GPT-4o, paid ~$5-15/hr heavy use)
Recovery from Tier 2 failures, novel situations, multi-step planning.
Routing logic
-------------
1. Classify the task using keyword / length / context heuristics (no LLM call).
2. Route to the appropriate tier.
3. On Tier-1 low-quality response → auto-escalate to Tier 2.
4. On Tier-2 failure or explicit ``require_cloud=True`` → Tier 3 (if budget allows).
5. Log tier used, model, latency, estimated cost for every request.
References:
- Issue #882 — Model Tiering Router: Local 8B / Hermes 70B / Cloud API Cascade
"""
import asyncio
import logging
import re
import time
from enum import StrEnum
from typing import Any
from config import settings
logger = logging.getLogger(__name__)
# ── Tier definitions ──────────────────────────────────────────────────────────
class TierLabel(StrEnum):
"""Three cost-sorted model tiers."""
LOCAL_FAST = "local_fast" # 8B local, always hot, free
LOCAL_HEAVY = "local_heavy" # 70B local, free but slower
CLOUD_API = "cloud_api" # Paid cloud backend (Claude / GPT-4o)
# ── Default model assignments (overridable via Settings) ──────────────────────
_DEFAULT_TIER_MODELS: dict[TierLabel, str] = {
TierLabel.LOCAL_FAST: "llama3.1:8b",
TierLabel.LOCAL_HEAVY: "hermes3:70b",
TierLabel.CLOUD_API: "claude-haiku-4-5",
}
# ── Classification vocabulary ─────────────────────────────────────────────────
# Patterns that indicate a Tier-1 (simple) task
_T1_WORDS: frozenset[str] = frozenset(
{
"go", "move", "walk", "run",
"north", "south", "east", "west", "up", "down", "left", "right",
"yes", "no", "ok", "okay",
"open", "close", "take", "drop", "look",
"pick", "use", "wait", "rest", "save",
"attack", "flee", "jump", "crouch",
"status", "ping", "list", "show", "get", "check",
}
)
# Patterns that indicate a Tier-2 or Tier-3 task
_T2_PHRASES: tuple[str, ...] = (
"plan", "strategy", "optimize", "optimise",
"quest", "stuck", "recover",
"negotiate", "persuade", "faction", "reputation",
"analyze", "analyse", "evaluate", "decide",
"complex", "multi-step", "long-term",
"how do i", "what should i do", "help me figure",
"what is the best", "recommend", "best way",
"explain", "describe in detail", "walk me through",
"compare", "design", "implement", "refactor",
"debug", "diagnose", "root cause",
)
# Low-quality response detection patterns
_LOW_QUALITY_PATTERNS: tuple[re.Pattern, ...] = (
re.compile(r"i\s+don'?t\s+know", re.IGNORECASE),
re.compile(r"i'm\s+not\s+sure", re.IGNORECASE),
re.compile(r"i\s+cannot\s+(help|assist|answer)", re.IGNORECASE),
re.compile(r"i\s+apologize", re.IGNORECASE),
re.compile(r"as an ai", re.IGNORECASE),
re.compile(r"i\s+don'?t\s+have\s+(enough|sufficient)\s+information", re.IGNORECASE),
)
# Response is definitely low-quality if shorter than this many characters
_LOW_QUALITY_MIN_CHARS = 20
# Response is suspicious if shorter than this many chars for a complex task
_ESCALATION_MIN_CHARS = 60
def classify_tier(task: str, context: dict | None = None) -> TierLabel:
"""Classify a task to the cheapest-sufficient model tier.
Classification priority (highest wins):
1. ``context["require_cloud"] = True`` → CLOUD_API
2. Any Tier-2 phrase or stuck/recovery signal → LOCAL_HEAVY
3. Short task with only Tier-1 words, no active context → LOCAL_FAST
4. Default → LOCAL_HEAVY (safe fallback for unknown tasks)
Args:
task: Natural-language task or user input.
context: Optional context dict. Recognised keys:
``require_cloud`` (bool), ``stuck`` (bool),
``require_t2`` (bool), ``active_quests`` (list),
``dialogue_active`` (bool), ``combat_active`` (bool).
Returns:
The cheapest ``TierLabel`` sufficient for the task.
"""
ctx = context or {}
task_lower = task.lower()
words = set(task_lower.split())
# ── Explicit cloud override ──────────────────────────────────────────────
if ctx.get("require_cloud"):
logger.debug("classify_tier → CLOUD_API (explicit require_cloud)")
return TierLabel.CLOUD_API
# ── Tier-2 / complexity signals ──────────────────────────────────────────
t2_phrase_hit = any(phrase in task_lower for phrase in _T2_PHRASES)
t2_word_hit = bool(words & {"plan", "strategy", "optimize", "optimise", "quest",
"stuck", "recover", "analyze", "analyse", "evaluate"})
is_stuck = bool(ctx.get("stuck"))
require_t2 = bool(ctx.get("require_t2"))
long_input = len(task) > 300 # long tasks warrant more capable model
deep_context = (
len(ctx.get("active_quests", [])) >= 3
or ctx.get("dialogue_active")
)
if t2_phrase_hit or t2_word_hit or is_stuck or require_t2 or long_input or deep_context:
logger.debug(
"classify_tier → LOCAL_HEAVY (phrase=%s word=%s stuck=%s explicit=%s long=%s ctx=%s)",
t2_phrase_hit, t2_word_hit, is_stuck, require_t2, long_input, deep_context,
)
return TierLabel.LOCAL_HEAVY
# ── Tier-1 signals ───────────────────────────────────────────────────────
t1_word_hit = bool(words & _T1_WORDS)
task_short = len(task.split()) <= 8
no_active_context = (
not ctx.get("active_quests")
and not ctx.get("dialogue_active")
and not ctx.get("combat_active")
)
if t1_word_hit and task_short and no_active_context:
logger.debug(
"classify_tier → LOCAL_FAST (words=%s short=%s)", t1_word_hit, task_short
)
return TierLabel.LOCAL_FAST
# ── Default: LOCAL_HEAVY (safe for anything unclassified) ────────────────
logger.debug("classify_tier → LOCAL_HEAVY (default)")
return TierLabel.LOCAL_HEAVY
def _is_low_quality(content: str, tier: TierLabel) -> bool:
"""Return True if the response looks like it should be escalated.
Used for automatic Tier-1 → Tier-2 escalation.
Args:
content: LLM response text.
tier: The tier that produced the response.
Returns:
True if the response is likely too low-quality to be useful.
"""
if not content or not content.strip():
return True
stripped = content.strip()
# Too short to be useful
if len(stripped) < _LOW_QUALITY_MIN_CHARS:
return True
# Insufficient for a supposedly complex-enough task
if tier == TierLabel.LOCAL_FAST and len(stripped) < _ESCALATION_MIN_CHARS:
return True
# Matches known "I can't help" patterns
for pattern in _LOW_QUALITY_PATTERNS:
if pattern.search(stripped):
return True
return False
class TieredModelRouter:
"""Routes LLM requests across the Local 8B / Local 70B / Cloud API tiers.
Wraps CascadeRouter with:
- Heuristic tier classification via ``classify_tier()``
- Automatic Tier-1 → Tier-2 escalation on low-quality responses
- Cloud-tier budget guard via ``BudgetTracker``
- Per-request logging: tier, model, latency, estimated cost
Usage::
router = TieredModelRouter()
result = await router.route(
task="Walk to the next room",
context={},
)
print(result["content"], result["tier"]) # "Move north.", "local_fast"
# Force heavy tier
result = await router.route(
task="Plan the optimal path to become Hortator",
context={"require_t2": True},
)
"""
def __init__(
self,
cascade: Any | None = None,
budget_tracker: Any | None = None,
tier_models: dict[TierLabel, str] | None = None,
auto_escalate: bool = True,
) -> None:
"""Initialise the tiered router.
Args:
cascade: CascadeRouter instance. If ``None``, the
singleton from ``get_router()`` is used lazily.
budget_tracker: BudgetTracker instance. If ``None``, the
singleton from ``get_budget_tracker()`` is used.
tier_models: Override default model names per tier.
auto_escalate: When ``True``, low-quality Tier-1 responses
automatically retry on Tier-2.
"""
self._cascade = cascade
self._budget = budget_tracker
self._tier_models: dict[TierLabel, str] = dict(_DEFAULT_TIER_MODELS)
self._auto_escalate = auto_escalate
# Apply settings-level overrides (can still be overridden per-instance)
if settings.tier_local_fast_model:
self._tier_models[TierLabel.LOCAL_FAST] = settings.tier_local_fast_model
if settings.tier_local_heavy_model:
self._tier_models[TierLabel.LOCAL_HEAVY] = settings.tier_local_heavy_model
if settings.tier_cloud_model:
self._tier_models[TierLabel.CLOUD_API] = settings.tier_cloud_model
if tier_models:
self._tier_models.update(tier_models)
# ── Lazy singletons ──────────────────────────────────────────────────────
def _get_cascade(self) -> Any:
if self._cascade is None:
from infrastructure.router.cascade import get_router
self._cascade = get_router()
return self._cascade
def _get_budget(self) -> Any:
if self._budget is None:
from infrastructure.models.budget import get_budget_tracker
self._budget = get_budget_tracker()
return self._budget
# ── Public interface ─────────────────────────────────────────────────────
def classify(self, task: str, context: dict | None = None) -> TierLabel:
"""Classify a task without routing. Useful for telemetry."""
return classify_tier(task, context)
async def route(
self,
task: str,
context: dict | None = None,
messages: list[dict] | None = None,
temperature: float = 0.3,
max_tokens: int | None = None,
) -> dict:
"""Route a task to the appropriate model tier.
Builds a minimal messages list if ``messages`` is not provided.
The result always includes a ``tier`` key indicating which tier
ultimately handled the request.
Args:
task: Natural-language task description.
context: Task context dict (see ``classify_tier()``).
messages: Pre-built OpenAI-compatible messages list. If
provided, ``task`` is only used for classification.
temperature: Sampling temperature (default 0.3).
max_tokens: Maximum tokens to generate.
Returns:
Dict with at minimum: ``content``, ``provider``, ``model``,
``tier``, ``latency_ms``. May include ``cost_usd`` when a
cloud request is recorded.
Raises:
RuntimeError: If all available tiers are exhausted.
"""
ctx = context or {}
tier = self.classify(task, ctx)
msgs = messages or [{"role": "user", "content": task}]
# ── Tier 1 attempt ───────────────────────────────────────────────────
if tier == TierLabel.LOCAL_FAST:
result = await self._complete_tier(
TierLabel.LOCAL_FAST, msgs, temperature, max_tokens
)
if self._auto_escalate and _is_low_quality(result.get("content", ""), TierLabel.LOCAL_FAST):
logger.info(
"TieredModelRouter: Tier-1 response low quality, escalating to Tier-2 "
"(task=%r content_len=%d)",
task[:80],
len(result.get("content", "")),
)
tier = TierLabel.LOCAL_HEAVY
result = await self._complete_tier(
TierLabel.LOCAL_HEAVY, msgs, temperature, max_tokens
)
return result
# ── Tier 2 attempt ───────────────────────────────────────────────────
if tier == TierLabel.LOCAL_HEAVY:
try:
return await self._complete_tier(
TierLabel.LOCAL_HEAVY, msgs, temperature, max_tokens
)
except Exception as exc:
logger.warning(
"TieredModelRouter: Tier-2 failed (%s) — escalating to cloud", exc
)
tier = TierLabel.CLOUD_API
# ── Tier 3 (Cloud) ───────────────────────────────────────────────────
budget = self._get_budget()
if not budget.cloud_allowed():
raise RuntimeError(
"Cloud API tier requested but budget limit reached — "
"increase tier_cloud_daily_budget_usd or tier_cloud_monthly_budget_usd"
)
result = await self._complete_tier(
TierLabel.CLOUD_API, msgs, temperature, max_tokens
)
# Record cloud spend if token info is available
usage = result.get("usage", {})
if usage:
cost = budget.record_spend(
provider=result.get("provider", "unknown"),
model=result.get("model", self._tier_models[TierLabel.CLOUD_API]),
tokens_in=usage.get("prompt_tokens", 0),
tokens_out=usage.get("completion_tokens", 0),
tier=TierLabel.CLOUD_API,
)
result["cost_usd"] = cost
return result
# ── Internal helpers ─────────────────────────────────────────────────────
async def _complete_tier(
self,
tier: TierLabel,
messages: list[dict],
temperature: float,
max_tokens: int | None,
) -> dict:
"""Dispatch a single inference request for the given tier."""
model = self._tier_models[tier]
cascade = self._get_cascade()
start = time.monotonic()
logger.info(
"TieredModelRouter: tier=%s model=%s messages=%d",
tier,
model,
len(messages),
)
result = await cascade.complete(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
)
elapsed_ms = (time.monotonic() - start) * 1000
result["tier"] = tier
result.setdefault("latency_ms", elapsed_ms)
logger.info(
"TieredModelRouter: done tier=%s model=%s latency_ms=%.0f",
tier,
result.get("model", model),
elapsed_ms,
)
return result
# ── Module-level singleton ────────────────────────────────────────────────────
_tiered_router: TieredModelRouter | None = None
def get_tiered_router() -> TieredModelRouter:
"""Get or create the module-level TieredModelRouter singleton."""
global _tiered_router
if _tiered_router is None:
_tiered_router = TieredModelRouter()
return _tiered_router

View File

@@ -0,0 +1 @@
"""Vendor-specific chat platform adapters (e.g. Discord) for the chat bridge."""

View File

@@ -1,3 +1,4 @@
"""Typer CLI entry point for the ``timmy`` command (chat, think, status)."""
import asyncio
import logging
import subprocess

View File

@@ -1,7 +1,10 @@
"""Memory — Persistent conversation and knowledge memory.
Sub-modules:
embeddings — text-to-vector embedding + similarity functions
unified — unified memory schema and connection management
vector_store — backward compatibility re-exports from memory_system
embeddings — text-to-vector embedding + similarity functions
unified — unified memory schema and connection management
chain — CRUD operations (store, search, delete, stats)
semantic — SemanticMemory and MemorySearcher classes
consolidation — HotMemory and VaultMemory classes
vector_store — backward compatibility re-exports from memory_system
"""

387
src/timmy/memory/chain.py Normal file
View File

@@ -0,0 +1,387 @@
"""CRUD operations for Timmy's unified memory database.
Provides store, search, delete, and management functions for the
`memories` table defined in timmy.memory.unified.
"""
import json
import logging
import sqlite3
import uuid
from contextlib import contextmanager
from datetime import UTC, datetime, timedelta
from pathlib import Path
from config import settings
from timmy.memory.embeddings import (
_keyword_overlap,
cosine_similarity,
embed_text,
)
from timmy.memory.unified import (
DB_PATH,
MemoryEntry,
_ensure_schema,
get_connection,
)
logger = logging.getLogger(__name__)
def store_memory(
content: str,
source: str,
context_type: str = "conversation",
agent_id: str | None = None,
task_id: str | None = None,
session_id: str | None = None,
metadata: dict | None = None,
compute_embedding: bool = True,
) -> MemoryEntry:
"""Store a memory entry with optional embedding.
Args:
content: The text content to store
source: Source of the memory (agent name, user, system)
context_type: Type of context (conversation, document, fact, vault_chunk)
agent_id: Associated agent ID
task_id: Associated task ID
session_id: Session identifier
metadata: Additional structured data
compute_embedding: Whether to compute vector embedding
Returns:
The stored MemoryEntry
"""
embedding = None
if compute_embedding:
embedding = embed_text(content)
entry = MemoryEntry(
content=content,
source=source,
context_type=context_type,
agent_id=agent_id,
task_id=task_id,
session_id=session_id,
metadata=metadata,
embedding=embedding,
)
with get_connection() as conn:
conn.execute(
"""
INSERT INTO memories
(id, content, memory_type, source, agent_id, task_id, session_id,
metadata, embedding, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
entry.id,
entry.content,
entry.context_type, # DB column is memory_type
entry.source,
entry.agent_id,
entry.task_id,
entry.session_id,
json.dumps(metadata) if metadata else None,
json.dumps(embedding) if embedding else None,
entry.timestamp,
),
)
conn.commit()
return entry
def _build_search_filters(
context_type: str | None,
agent_id: str | None,
session_id: str | None,
) -> tuple[str, list]:
"""Build SQL WHERE clause and params from search filters."""
conditions: list[str] = []
params: list = []
if context_type:
conditions.append("memory_type = ?")
params.append(context_type)
if agent_id:
conditions.append("agent_id = ?")
params.append(agent_id)
if session_id:
conditions.append("session_id = ?")
params.append(session_id)
where_clause = "WHERE " + " AND ".join(conditions) if conditions else ""
return where_clause, params
def _fetch_memory_candidates(
where_clause: str, params: list, candidate_limit: int
) -> list[sqlite3.Row]:
"""Fetch candidate memory rows from the database."""
query_sql = f"""
SELECT * FROM memories
{where_clause}
ORDER BY created_at DESC
LIMIT ?
"""
params.append(candidate_limit)
with get_connection() as conn:
return conn.execute(query_sql, params).fetchall()
def _row_to_entry(row: sqlite3.Row) -> MemoryEntry:
"""Convert a database row to a MemoryEntry."""
return MemoryEntry(
id=row["id"],
content=row["content"],
source=row["source"],
context_type=row["memory_type"], # DB column -> API field
agent_id=row["agent_id"],
task_id=row["task_id"],
session_id=row["session_id"],
metadata=json.loads(row["metadata"]) if row["metadata"] else None,
embedding=json.loads(row["embedding"]) if row["embedding"] else None,
timestamp=row["created_at"],
)
def _score_and_filter(
rows: list[sqlite3.Row],
query: str,
query_embedding: list[float],
min_relevance: float,
) -> list[MemoryEntry]:
"""Score candidate rows by similarity and filter by min_relevance."""
results = []
for row in rows:
entry = _row_to_entry(row)
if entry.embedding:
score = cosine_similarity(query_embedding, entry.embedding)
else:
score = _keyword_overlap(query, entry.content)
entry.relevance_score = score
if score >= min_relevance:
results.append(entry)
results.sort(key=lambda x: x.relevance_score or 0, reverse=True)
return results
def search_memories(
query: str,
limit: int = 10,
context_type: str | None = None,
agent_id: str | None = None,
session_id: str | None = None,
min_relevance: float = 0.0,
) -> list[MemoryEntry]:
"""Search for memories by semantic similarity.
Args:
query: Search query text
limit: Maximum results
context_type: Filter by memory type (maps to DB memory_type column)
agent_id: Filter by agent
session_id: Filter by session
min_relevance: Minimum similarity score (0-1)
Returns:
List of MemoryEntry objects sorted by relevance
"""
query_embedding = embed_text(query)
where_clause, params = _build_search_filters(context_type, agent_id, session_id)
rows = _fetch_memory_candidates(where_clause, params, limit * 3)
results = _score_and_filter(rows, query, query_embedding, min_relevance)
return results[:limit]
def delete_memory(memory_id: str) -> bool:
"""Delete a memory entry by ID.
Returns:
True if deleted, False if not found
"""
with get_connection() as conn:
cursor = conn.execute(
"DELETE FROM memories WHERE id = ?",
(memory_id,),
)
conn.commit()
return cursor.rowcount > 0
def get_memory_stats() -> dict:
"""Get statistics about the memory store.
Returns:
Dict with counts by type, total entries, etc.
"""
from timmy.memory.embeddings import _get_embedding_model
with get_connection() as conn:
total = conn.execute("SELECT COUNT(*) as count FROM memories").fetchone()["count"]
by_type = {}
rows = conn.execute(
"SELECT memory_type, COUNT(*) as count FROM memories GROUP BY memory_type"
).fetchall()
for row in rows:
by_type[row["memory_type"]] = row["count"]
with_embeddings = conn.execute(
"SELECT COUNT(*) as count FROM memories WHERE embedding IS NOT NULL"
).fetchone()["count"]
return {
"total_entries": total,
"by_type": by_type,
"with_embeddings": with_embeddings,
"has_embedding_model": _get_embedding_model() is not False,
}
def prune_memories(older_than_days: int = 90, keep_facts: bool = True) -> int:
"""Delete old memories to manage storage.
Args:
older_than_days: Delete memories older than this
keep_facts: Whether to preserve fact-type memories
Returns:
Number of entries deleted
"""
cutoff = (datetime.now(UTC) - timedelta(days=older_than_days)).isoformat()
with get_connection() as conn:
if keep_facts:
cursor = conn.execute(
"""
DELETE FROM memories
WHERE created_at < ? AND memory_type != 'fact'
""",
(cutoff,),
)
else:
cursor = conn.execute(
"DELETE FROM memories WHERE created_at < ?",
(cutoff,),
)
deleted = cursor.rowcount
conn.commit()
return deleted
def get_memory_context(query: str, max_tokens: int = 2000, **filters) -> str:
"""Get relevant memory context as formatted text for LLM prompts.
Args:
query: Search query
max_tokens: Approximate maximum tokens to return
**filters: Additional filters (agent_id, session_id, etc.)
Returns:
Formatted context string for inclusion in prompts
"""
memories = search_memories(query, limit=20, **filters)
context_parts = []
total_chars = 0
max_chars = max_tokens * 4 # Rough approximation
for mem in memories:
formatted = f"[{mem.source}]: {mem.content}"
if total_chars + len(formatted) > max_chars:
break
context_parts.append(formatted)
total_chars += len(formatted)
if not context_parts:
return ""
return "Relevant context from memory:\n" + "\n\n".join(context_parts)
def recall_personal_facts(agent_id: str | None = None) -> list[str]:
"""Recall personal facts about the user or system.
Args:
agent_id: Optional agent filter
Returns:
List of fact strings
"""
with get_connection() as conn:
if agent_id:
rows = conn.execute(
"""
SELECT content FROM memories
WHERE memory_type = 'fact' AND agent_id = ?
ORDER BY created_at DESC
LIMIT 100
""",
(agent_id,),
).fetchall()
else:
rows = conn.execute(
"""
SELECT content FROM memories
WHERE memory_type = 'fact'
ORDER BY created_at DESC
LIMIT 100
""",
).fetchall()
return [r["content"] for r in rows]
def recall_personal_facts_with_ids(agent_id: str | None = None) -> list[dict]:
"""Recall personal facts with their IDs for edit/delete operations."""
with get_connection() as conn:
if agent_id:
rows = conn.execute(
"SELECT id, content FROM memories WHERE memory_type = 'fact' AND agent_id = ? ORDER BY created_at DESC LIMIT 100",
(agent_id,),
).fetchall()
else:
rows = conn.execute(
"SELECT id, content FROM memories WHERE memory_type = 'fact' ORDER BY created_at DESC LIMIT 100",
).fetchall()
return [{"id": r["id"], "content": r["content"]} for r in rows]
def update_personal_fact(memory_id: str, new_content: str) -> bool:
"""Update a personal fact's content."""
with get_connection() as conn:
cursor = conn.execute(
"UPDATE memories SET content = ? WHERE id = ? AND memory_type = 'fact'",
(new_content, memory_id),
)
conn.commit()
return cursor.rowcount > 0
def store_personal_fact(fact: str, agent_id: str | None = None) -> MemoryEntry:
"""Store a personal fact about the user or system.
Args:
fact: The fact to store
agent_id: Associated agent
Returns:
The stored MemoryEntry
"""
return store_memory(
content=fact,
source="system",
context_type="fact",
agent_id=agent_id,
metadata={"auto_extracted": False},
)

View File

@@ -0,0 +1,310 @@
"""Hot and Vault memory classes for Timmy's memory consolidation tier.
HotMemory: Tier 1 — computed view of top facts from the database.
VaultMemory: Tier 2 — structured vault (memory/ directory), append-only markdown.
"""
import logging
import re
from datetime import UTC, datetime
from pathlib import Path
from timmy.memory.unified import PROJECT_ROOT
logger = logging.getLogger(__name__)
VAULT_PATH = PROJECT_ROOT / "memory"
_DEFAULT_HOT_MEMORY_TEMPLATE = """\
# Timmy Hot Memory
> Working RAM — always loaded, ~300 lines max, pruned monthly
> Last updated: {date}
---
## Current Status
**Agent State:** Operational
**Mode:** Development
**Active Tasks:** 0
**Pending Decisions:** None
---
## Standing Rules
1. **Sovereignty First** — No cloud dependencies
2. **Local-Only Inference** — Ollama on localhost
3. **Privacy by Design** — Telemetry disabled
4. **Tool Minimalism** — Use tools only when necessary
5. **Memory Discipline** — Write handoffs at session end
---
## Agent Roster
| Agent | Role | Status |
|-------|------|--------|
| Timmy | Core | Active |
---
## User Profile
**Name:** (not set)
**Interests:** (to be learned)
---
## Key Decisions
(none yet)
---
## Pending Actions
- [ ] Learn user's name
---
*Prune date: {prune_date}*
"""
class HotMemory:
"""Tier 1: Hot memory — computed view of top facts from DB."""
def __init__(self, path=None) -> None:
if path is None:
path = PROJECT_ROOT / "MEMORY.md"
self.path = path
self._content: str | None = None
self._last_modified: float | None = None
def read(self, force_refresh: bool = False) -> str:
"""Read hot memory — computed view of top facts + last reflection from DB."""
from timmy.memory.chain import recall_personal_facts
# Import recall_last_reflection lazily to support patching in memory_system
try:
# Use the version from memory_system so patches work correctly
import timmy.memory_system as _ms
recall_last_reflection = _ms.recall_last_reflection
except Exception:
from timmy.memory.chain import recall_personal_facts as _rpf # noqa: F811
recall_last_reflection = None
try:
facts = recall_personal_facts()
lines = ["# Timmy Hot Memory\n"]
if facts:
lines.append("## Known Facts\n")
for f in facts[:15]:
lines.append(f"- {f}")
# Include the last reflection if available
if recall_last_reflection is not None:
try:
reflection = recall_last_reflection()
if reflection:
lines.append("\n## Last Reflection\n")
lines.append(reflection)
except Exception:
pass
if len(lines) > 1:
return "\n".join(lines)
except Exception:
logger.debug("DB context read failed, falling back to file")
# Fallback to file if DB unavailable
if self.path.exists():
return self.path.read_text()
return "# Timmy Hot Memory\n\nNo memories stored yet.\n"
def update_section(self, section: str, content: str) -> None:
"""Update a specific section in MEMORY.md.
DEPRECATED: Hot memory is now computed from the database.
This method is kept for backward compatibility during transition.
Use memory_write() to store facts in the database.
"""
logger.warning(
"HotMemory.update_section() is deprecated. "
"Use memory_write() to store facts in the database."
)
# Keep file-writing for backward compatibility during transition
# Guard against empty or excessively large writes
if not content or not content.strip():
logger.warning("HotMemory: Refusing empty write to section '%s'", section)
return
if len(content) > 2000:
logger.warning("HotMemory: Truncating oversized write to section '%s'", section)
content = content[:2000] + "\n... [truncated]"
if not self.path.exists():
self._create_default()
full_content = self.read()
# Find section
pattern = rf"(## {re.escape(section)}.*?)(?=\n## |\Z)"
match = re.search(pattern, full_content, re.DOTALL)
if match:
# Replace section
new_section = f"## {section}\n\n{content}\n\n"
full_content = full_content[: match.start()] + new_section + full_content[match.end() :]
else:
# Append section — guard against missing prune marker
insert_point = full_content.rfind("*Prune date:")
new_section = f"## {section}\n\n{content}\n\n"
if insert_point < 0:
# No prune marker — just append at end
full_content = full_content.rstrip() + "\n\n" + new_section
else:
full_content = (
full_content[:insert_point] + new_section + "\n" + full_content[insert_point:]
)
self.path.write_text(full_content)
self._content = full_content
self._last_modified = self.path.stat().st_mtime
logger.info("HotMemory: Updated section '%s'", section)
def _create_default(self) -> None:
"""Create default MEMORY.md if missing.
DEPRECATED: Hot memory is now computed from the database.
This method is kept for backward compatibility during transition.
"""
logger.debug(
"HotMemory._create_default() - creating default MEMORY.md for backward compatibility"
)
now = datetime.now(UTC)
content = _DEFAULT_HOT_MEMORY_TEMPLATE.format(
date=now.strftime("%Y-%m-%d"),
prune_date=now.replace(day=25).strftime("%Y-%m-%d"),
)
self.path.write_text(content)
logger.info("HotMemory: Created default MEMORY.md")
class VaultMemory:
"""Tier 2: Structured vault (memory/) — append-only markdown."""
def __init__(self) -> None:
self.path = VAULT_PATH
self._ensure_structure()
def _ensure_structure(self) -> None:
"""Ensure vault directory structure exists."""
(self.path / "self").mkdir(parents=True, exist_ok=True)
(self.path / "notes").mkdir(parents=True, exist_ok=True)
(self.path / "aar").mkdir(parents=True, exist_ok=True)
def write_note(self, name: str, content: str, namespace: str = "notes") -> Path:
"""Write a note to the vault."""
# Add timestamp to filename
timestamp = datetime.now(UTC).strftime("%Y%m%d")
filename = f"{timestamp}_{name}.md"
filepath = self.path / namespace / filename
# Add header
full_content = f"""# {name.replace("_", " ").title()}
> Created: {datetime.now(UTC).isoformat()}
> Namespace: {namespace}
---
{content}
---
*Auto-generated by Timmy Memory System*
"""
filepath.write_text(full_content)
logger.info("VaultMemory: Wrote %s", filepath)
return filepath
def read_file(self, filepath: Path) -> str:
"""Read a file from the vault."""
if not filepath.exists():
return ""
return filepath.read_text()
def update_user_profile(self, key: str, value: str) -> None:
"""Update a field in user_profile.md.
DEPRECATED: User profile updates should now use memory_write() to store
facts in the database. This method is kept for backward compatibility.
"""
logger.warning(
"VaultMemory.update_user_profile() is deprecated. "
"Use memory_write() to store user facts in the database."
)
# Still update the file for backward compatibility during transition
profile_path = self.path / "self" / "user_profile.md"
if not profile_path.exists():
self._create_default_profile()
content = profile_path.read_text()
pattern = rf"(\*\*{re.escape(key)}:\*\*).*"
if re.search(pattern, content):
safe_value = value.strip()
content = re.sub(pattern, lambda m: f"{m.group(1)} {safe_value}", content)
else:
facts_section = "## Important Facts"
if facts_section in content:
insert_point = content.find(facts_section) + len(facts_section)
content = content[:insert_point] + f"\n- {key}: {value}" + content[insert_point:]
content = re.sub(
r"\*Last updated:.*\*",
f"*Last updated: {datetime.now(UTC).strftime('%Y-%m-%d')}*",
content,
)
profile_path.write_text(content)
logger.info("VaultMemory: Updated user profile: %s = %s", key, value)
def _create_default_profile(self) -> None:
"""Create default user profile."""
profile_path = self.path / "self" / "user_profile.md"
default = """# User Profile
> Learned information about the user.
## Basic Information
**Name:** (unknown)
**Location:** (unknown)
**Occupation:** (unknown)
## Interests & Expertise
- (to be learned)
## Preferences
- Response style: concise, technical
- Tool usage: minimal
## Important Facts
- (to be extracted)
---
*Last updated: {date}*
""".format(date=datetime.now(UTC).strftime("%Y-%m-%d"))
profile_path.write_text(default)

View File

@@ -0,0 +1,278 @@
"""Semantic memory and search classes for Timmy.
Provides SemanticMemory (vector search over vault content) and
MemorySearcher (high-level multi-tier search interface).
"""
import hashlib
import json
import logging
import sqlite3
from contextlib import closing, contextmanager
from collections.abc import Generator
from datetime import UTC, datetime
from pathlib import Path
from config import settings
from timmy.memory.embeddings import (
EMBEDDING_DIM,
_get_embedding_model,
cosine_similarity,
embed_text,
)
from timmy.memory.unified import (
DB_PATH,
PROJECT_ROOT,
_ensure_schema,
get_connection,
)
logger = logging.getLogger(__name__)
VAULT_PATH = PROJECT_ROOT / "memory"
class SemanticMemory:
"""Vector-based semantic search over vault content."""
def __init__(self) -> None:
self.db_path = DB_PATH
self.vault_path = VAULT_PATH
@contextmanager
def _get_conn(self) -> Generator[sqlite3.Connection, None, None]:
"""Get connection to the instance's db_path (backward compatibility).
Uses self.db_path if set differently from global DB_PATH,
otherwise uses the global get_connection().
"""
if self.db_path == DB_PATH:
# Use global connection (normal production path)
with get_connection() as conn:
yield conn
else:
# Use instance-specific db_path (test path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
with closing(sqlite3.connect(str(self.db_path))) as conn:
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute(f"PRAGMA busy_timeout={settings.db_busy_timeout_ms}")
# Ensure schema exists
_ensure_schema(conn)
yield conn
def _init_db(self) -> None:
"""Initialize database at self.db_path (backward compatibility).
This method is kept for backward compatibility with existing code and tests.
Schema creation is handled by _get_conn.
"""
# Trigger schema creation via _get_conn
with self._get_conn():
pass
def index_file(self, filepath: Path) -> int:
"""Index a single file into semantic memory."""
if not filepath.exists():
return 0
content = filepath.read_text()
file_hash = hashlib.md5(content.encode()).hexdigest()
with self._get_conn() as conn:
# Check if already indexed with same hash
cursor = conn.execute(
"SELECT metadata FROM memories WHERE source = ? AND memory_type = 'vault_chunk' LIMIT 1",
(str(filepath),),
)
existing = cursor.fetchone()
if existing and existing[0]:
try:
meta = json.loads(existing[0])
if meta.get("source_hash") == file_hash:
return 0 # Already indexed
except json.JSONDecodeError:
pass
# Delete old chunks for this file
conn.execute(
"DELETE FROM memories WHERE source = ? AND memory_type = 'vault_chunk'",
(str(filepath),),
)
# Split into chunks (paragraphs)
chunks = self._split_into_chunks(content)
# Index each chunk
now = datetime.now(UTC).isoformat()
for i, chunk_text in enumerate(chunks):
if len(chunk_text.strip()) < 20: # Skip tiny chunks
continue
chunk_id = f"{filepath.stem}_{i}"
chunk_embedding = embed_text(chunk_text)
conn.execute(
"""INSERT INTO memories
(id, content, memory_type, source, metadata, embedding, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
(
chunk_id,
chunk_text,
"vault_chunk",
str(filepath),
json.dumps({"source_hash": file_hash, "chunk_index": i}),
json.dumps(chunk_embedding),
now,
),
)
conn.commit()
logger.info("SemanticMemory: Indexed %s (%d chunks)", filepath.name, len(chunks))
return len(chunks)
def _split_into_chunks(self, text: str, max_chunk_size: int = 500) -> list[str]:
"""Split text into semantic chunks."""
# Split by paragraphs first
paragraphs = text.split("\n\n")
chunks = []
for para in paragraphs:
para = para.strip()
if not para:
continue
# If paragraph is small enough, keep as one chunk
if len(para) <= max_chunk_size:
chunks.append(para)
else:
# Split long paragraphs by sentences
sentences = para.replace(". ", ".\n").split("\n")
current_chunk = ""
for sent in sentences:
if len(current_chunk) + len(sent) < max_chunk_size:
current_chunk += " " + sent if current_chunk else sent
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sent
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def index_vault(self) -> int:
"""Index entire vault directory."""
total_chunks = 0
for md_file in self.vault_path.rglob("*.md"):
# Skip handoff file (handled separately)
if "last-session-handoff" in md_file.name:
continue
total_chunks += self.index_file(md_file)
logger.info("SemanticMemory: Indexed vault (%d total chunks)", total_chunks)
return total_chunks
def search(self, query: str, top_k: int = 5) -> list[tuple[str, float]]:
"""Search for relevant memory chunks."""
query_embedding = embed_text(query)
with self._get_conn() as conn:
conn.row_factory = sqlite3.Row
# Get all vault chunks
rows = conn.execute(
"SELECT source, content, embedding FROM memories WHERE memory_type = 'vault_chunk'"
).fetchall()
# Calculate similarities
scored = []
for row in rows:
embedding = json.loads(row["embedding"])
score = cosine_similarity(query_embedding, embedding)
scored.append((row["source"], row["content"], score))
# Sort by score descending
scored.sort(key=lambda x: x[2], reverse=True)
# Return top_k
return [(content, score) for _, content, score in scored[:top_k]]
def get_relevant_context(self, query: str, max_chars: int = 2000) -> str:
"""Get formatted context string for a query."""
results = self.search(query, top_k=3)
if not results:
return ""
parts = []
total_chars = 0
for content, score in results:
if score < 0.3: # Similarity threshold
continue
chunk = f"[Relevant memory - score {score:.2f}]: {content[:400]}..."
if total_chars + len(chunk) > max_chars:
break
parts.append(chunk)
total_chars += len(chunk)
return "\n\n".join(parts) if parts else ""
def stats(self) -> dict:
"""Get indexing statistics."""
with self._get_conn() as conn:
cursor = conn.execute(
"SELECT COUNT(*), COUNT(DISTINCT source) FROM memories WHERE memory_type = 'vault_chunk'"
)
total_chunks, total_files = cursor.fetchone()
return {
"total_chunks": total_chunks,
"total_files": total_files,
"embedding_dim": EMBEDDING_DIM if _get_embedding_model() else 128,
}
class MemorySearcher:
"""High-level interface for memory search."""
def __init__(self) -> None:
self.semantic = SemanticMemory()
def search(self, query: str, tiers: list[str] = None) -> dict:
"""Search across memory tiers.
Args:
query: Search query
tiers: List of tiers to search ["hot", "vault", "semantic"]
Returns:
Dict with results from each tier
"""
tiers = tiers or ["semantic"] # Default to semantic only
results = {}
if "semantic" in tiers:
semantic_results = self.semantic.search(query, top_k=5)
results["semantic"] = [
{"content": content, "score": score} for content, score in semantic_results
]
return results
def get_context_for_query(self, query: str) -> str:
"""Get comprehensive context for a user query."""
# Get semantic context
semantic_context = self.semantic.get_relevant_context(query)
if semantic_context:
return f"## Relevant Past Context\n\n{semantic_context}"
return ""

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,4 @@
"""OpenCV template-matching cache for sovereignty perception (screen-state recognition)."""
from __future__ import annotations
import json

View File

@@ -0,0 +1,178 @@
"""Tests for the cloud API budget tracker (issue #882)."""
import time
from unittest.mock import patch
import pytest
from infrastructure.models.budget import (
BudgetTracker,
SpendRecord,
estimate_cost_usd,
get_budget_tracker,
)
pytestmark = pytest.mark.unit
# ── estimate_cost_usd ─────────────────────────────────────────────────────────
class TestEstimateCostUsd:
def test_haiku_cheaper_than_sonnet(self):
haiku_cost = estimate_cost_usd("claude-haiku-4-5", 1000, 1000)
sonnet_cost = estimate_cost_usd("claude-sonnet-4-5", 1000, 1000)
assert haiku_cost < sonnet_cost
def test_zero_tokens_is_zero_cost(self):
assert estimate_cost_usd("gpt-4o", 0, 0) == 0.0
def test_unknown_model_uses_default(self):
cost = estimate_cost_usd("some-unknown-model-xyz", 1000, 1000)
assert cost > 0 # Uses conservative default, not zero
def test_versioned_model_name_matches(self):
# "claude-haiku-4-5-20251001" should match "haiku"
cost1 = estimate_cost_usd("claude-haiku-4-5-20251001", 1000, 0)
cost2 = estimate_cost_usd("claude-haiku-4-5", 1000, 0)
assert cost1 == cost2
def test_gpt4o_mini_cheaper_than_gpt4o(self):
mini = estimate_cost_usd("gpt-4o-mini", 1000, 1000)
full = estimate_cost_usd("gpt-4o", 1000, 1000)
assert mini < full
def test_returns_float(self):
assert isinstance(estimate_cost_usd("haiku", 100, 200), float)
# ── BudgetTracker ─────────────────────────────────────────────────────────────
class TestBudgetTrackerInit:
def test_creates_with_memory_db(self):
tracker = BudgetTracker(db_path=":memory:")
assert tracker._db_ok is True
def test_in_memory_fallback_empty_on_creation(self):
tracker = BudgetTracker(db_path=":memory:")
assert tracker._in_memory == []
def test_bad_path_uses_memory_fallback(self, tmp_path):
bad_path = str(tmp_path / "nonexistent" / "x" / "budget.db")
# Should not raise — just log and continue with memory fallback
# (actually will create parent dirs, so test with truly bad path)
tracker = BudgetTracker.__new__(BudgetTracker)
tracker._db_path = bad_path
tracker._lock = __import__("threading").Lock()
tracker._in_memory = []
tracker._db_ok = False
# Record to in-memory fallback
tracker._in_memory.append(
SpendRecord(time.time(), "test", "model", 100, 100, 0.001, "cloud")
)
assert len(tracker._in_memory) == 1
class TestBudgetTrackerRecordSpend:
def test_record_spend_returns_cost(self):
tracker = BudgetTracker(db_path=":memory:")
cost = tracker.record_spend("anthropic", "claude-haiku-4-5", 100, 200)
assert cost > 0
def test_record_spend_explicit_cost(self):
tracker = BudgetTracker(db_path=":memory:")
cost = tracker.record_spend("anthropic", "model", cost_usd=1.23)
assert cost == pytest.approx(1.23)
def test_record_spend_accumulates(self):
tracker = BudgetTracker(db_path=":memory:")
tracker.record_spend("openai", "gpt-4o", cost_usd=0.01)
tracker.record_spend("openai", "gpt-4o", cost_usd=0.02)
assert tracker.get_daily_spend() == pytest.approx(0.03, abs=1e-9)
def test_record_spend_with_tier_label(self):
tracker = BudgetTracker(db_path=":memory:")
cost = tracker.record_spend("anthropic", "haiku", tier="cloud_api")
assert cost >= 0
def test_monthly_spend_includes_daily(self):
tracker = BudgetTracker(db_path=":memory:")
tracker.record_spend("anthropic", "haiku", cost_usd=5.00)
assert tracker.get_monthly_spend() >= tracker.get_daily_spend()
class TestBudgetTrackerCloudAllowed:
def test_allowed_when_no_spend(self):
tracker = BudgetTracker(db_path=":memory:")
with (
patch.object(type(tracker._get_budget() if hasattr(tracker, "_get_budget") else tracker), "tier_cloud_daily_budget_usd", 5.0, create=True),
):
# Settings-based check — use real settings (5.0 default, 0 spent)
assert tracker.cloud_allowed() is True
def test_blocked_when_daily_limit_exceeded(self):
tracker = BudgetTracker(db_path=":memory:")
tracker.record_spend("anthropic", "haiku", cost_usd=999.0)
# With default daily limit of 5.0, 999 should block
assert tracker.cloud_allowed() is False
def test_allowed_when_daily_limit_zero(self):
tracker = BudgetTracker(db_path=":memory:")
tracker.record_spend("anthropic", "haiku", cost_usd=999.0)
with (
patch("infrastructure.models.budget.settings") as mock_settings,
):
mock_settings.tier_cloud_daily_budget_usd = 0 # disabled
mock_settings.tier_cloud_monthly_budget_usd = 0 # disabled
assert tracker.cloud_allowed() is True
def test_blocked_when_monthly_limit_exceeded(self):
tracker = BudgetTracker(db_path=":memory:")
tracker.record_spend("anthropic", "haiku", cost_usd=999.0)
with patch("infrastructure.models.budget.settings") as mock_settings:
mock_settings.tier_cloud_daily_budget_usd = 0 # daily disabled
mock_settings.tier_cloud_monthly_budget_usd = 10.0
assert tracker.cloud_allowed() is False
class TestBudgetTrackerSummary:
def test_summary_keys_present(self):
tracker = BudgetTracker(db_path=":memory:")
summary = tracker.get_summary()
assert "daily_usd" in summary
assert "monthly_usd" in summary
assert "daily_limit_usd" in summary
assert "monthly_limit_usd" in summary
assert "daily_ok" in summary
assert "monthly_ok" in summary
def test_summary_daily_ok_true_on_empty(self):
tracker = BudgetTracker(db_path=":memory:")
summary = tracker.get_summary()
assert summary["daily_ok"] is True
assert summary["monthly_ok"] is True
def test_summary_daily_ok_false_when_exceeded(self):
tracker = BudgetTracker(db_path=":memory:")
tracker.record_spend("openai", "gpt-4o", cost_usd=999.0)
summary = tracker.get_summary()
assert summary["daily_ok"] is False
# ── Singleton ─────────────────────────────────────────────────────────────────
class TestGetBudgetTrackerSingleton:
def test_returns_budget_tracker(self):
import infrastructure.models.budget as bmod
bmod._budget_tracker = None
tracker = get_budget_tracker()
assert isinstance(tracker, BudgetTracker)
def test_returns_same_instance(self):
import infrastructure.models.budget as bmod
bmod._budget_tracker = None
t1 = get_budget_tracker()
t2 = get_budget_tracker()
assert t1 is t2

View File

@@ -0,0 +1,380 @@
"""Tests for the tiered model router (issue #882).
Covers:
- classify_tier() for Tier-1/2/3 routing
- TieredModelRouter.route() with mocked CascadeRouter + BudgetTracker
- Auto-escalation from Tier-1 on low-quality responses
- Cloud-tier budget guard
- Acceptance criteria from the issue:
- "Walk to the next room" → LOCAL_FAST
- "Plan the optimal path to become Hortator" → LOCAL_HEAVY
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from infrastructure.models.router import (
TierLabel,
TieredModelRouter,
_is_low_quality,
classify_tier,
get_tiered_router,
)
pytestmark = pytest.mark.unit
# ── classify_tier ─────────────────────────────────────────────────────────────
class TestClassifyTier:
# ── Tier-1 (LOCAL_FAST) ────────────────────────────────────────────────
def test_simple_navigation_is_local_fast(self):
assert classify_tier("walk to the next room") == TierLabel.LOCAL_FAST
def test_go_north_is_local_fast(self):
assert classify_tier("go north") == TierLabel.LOCAL_FAST
def test_single_binary_choice_is_local_fast(self):
assert classify_tier("yes") == TierLabel.LOCAL_FAST
def test_open_door_is_local_fast(self):
assert classify_tier("open door") == TierLabel.LOCAL_FAST
def test_attack_is_local_fast(self):
assert classify_tier("attack", {}) == TierLabel.LOCAL_FAST
# ── Tier-2 (LOCAL_HEAVY) ───────────────────────────────────────────────
def test_quest_planning_is_local_heavy(self):
assert classify_tier("plan the optimal path to become Hortator") == TierLabel.LOCAL_HEAVY
def test_strategy_keyword_is_local_heavy(self):
assert classify_tier("what is the best strategy") == TierLabel.LOCAL_HEAVY
def test_stuck_state_escalates_to_local_heavy(self):
assert classify_tier("help me", {"stuck": True}) == TierLabel.LOCAL_HEAVY
def test_require_t2_flag_is_local_heavy(self):
assert classify_tier("go north", {"require_t2": True}) == TierLabel.LOCAL_HEAVY
def test_long_input_is_local_heavy(self):
long_task = "tell me about " + ("the dungeon " * 30)
assert classify_tier(long_task) == TierLabel.LOCAL_HEAVY
def test_active_quests_upgrades_to_local_heavy(self):
ctx = {"active_quests": ["Q1", "Q2", "Q3"]}
assert classify_tier("go north", ctx) == TierLabel.LOCAL_HEAVY
def test_dialogue_active_upgrades_to_local_heavy(self):
ctx = {"dialogue_active": True}
assert classify_tier("yes", ctx) == TierLabel.LOCAL_HEAVY
def test_analyze_is_local_heavy(self):
assert classify_tier("analyze the situation") == TierLabel.LOCAL_HEAVY
def test_optimize_is_local_heavy(self):
assert classify_tier("optimize my build") == TierLabel.LOCAL_HEAVY
def test_negotiate_is_local_heavy(self):
assert classify_tier("negotiate with the Camonna Tong") == TierLabel.LOCAL_HEAVY
def test_explain_is_local_heavy(self):
assert classify_tier("explain the faction system") == TierLabel.LOCAL_HEAVY
# ── Tier-3 (CLOUD_API) ─────────────────────────────────────────────────
def test_require_cloud_flag_is_cloud_api(self):
assert classify_tier("go north", {"require_cloud": True}) == TierLabel.CLOUD_API
def test_require_cloud_overrides_everything(self):
assert classify_tier("yes", {"require_cloud": True}) == TierLabel.CLOUD_API
# ── Edge cases ────────────────────────────────────────────────────────
def test_empty_task_defaults_to_local_heavy(self):
# Empty string → nothing classifies it as T1 or T3
assert classify_tier("") == TierLabel.LOCAL_HEAVY
def test_case_insensitive(self):
assert classify_tier("PLAN my route") == TierLabel.LOCAL_HEAVY
def test_combat_active_upgrades_t1_to_heavy(self):
ctx = {"combat_active": True}
# "attack" is T1 word, but combat context → should NOT be LOCAL_FAST
result = classify_tier("attack", ctx)
assert result != TierLabel.LOCAL_FAST
# ── _is_low_quality ───────────────────────────────────────────────────────────
class TestIsLowQuality:
def test_empty_is_low_quality(self):
assert _is_low_quality("", TierLabel.LOCAL_FAST) is True
def test_whitespace_only_is_low_quality(self):
assert _is_low_quality(" ", TierLabel.LOCAL_FAST) is True
def test_very_short_is_low_quality(self):
assert _is_low_quality("ok", TierLabel.LOCAL_FAST) is True
def test_idontknow_is_low_quality(self):
assert _is_low_quality("I don't know how to help with that.", TierLabel.LOCAL_FAST) is True
def test_not_sure_is_low_quality(self):
assert _is_low_quality("I'm not sure about this.", TierLabel.LOCAL_FAST) is True
def test_as_an_ai_is_low_quality(self):
assert _is_low_quality("As an AI, I cannot...", TierLabel.LOCAL_FAST) is True
def test_good_response_is_not_low_quality(self):
response = "You move north into the Vivec Canton. The Ordinators watch your approach."
assert _is_low_quality(response, TierLabel.LOCAL_FAST) is False
def test_t1_short_response_triggers_escalation(self):
# Less than _ESCALATION_MIN_CHARS for T1
assert _is_low_quality("OK, done.", TierLabel.LOCAL_FAST) is True
def test_borderline_ok_for_t2_not_t1(self):
# Between _LOW_QUALITY_MIN_CHARS (20) and _ESCALATION_MIN_CHARS (60)
# → low quality for T1 (escalation threshold), but acceptable for T2/T3
response = "Done. The item is retrieved." # 28 chars: ≥20, <60
assert _is_low_quality(response, TierLabel.LOCAL_FAST) is True
assert _is_low_quality(response, TierLabel.LOCAL_HEAVY) is False
# ── TieredModelRouter ─────────────────────────────────────────────────────────
_GOOD_CONTENT = (
"You move north through the doorway into the next room. "
"The stone walls glisten with moisture."
) # 90 chars — well above the escalation threshold
def _make_cascade_mock(content=_GOOD_CONTENT, model="llama3.1:8b"):
mock = MagicMock()
mock.complete = AsyncMock(
return_value={
"content": content,
"provider": "ollama-local",
"model": model,
"latency_ms": 150.0,
}
)
return mock
def _make_budget_mock(allowed=True):
mock = MagicMock()
mock.cloud_allowed = MagicMock(return_value=allowed)
mock.record_spend = MagicMock(return_value=0.001)
return mock
@pytest.mark.asyncio
class TestTieredModelRouterRoute:
async def test_route_returns_tier_in_result(self):
router = TieredModelRouter(cascade=_make_cascade_mock())
result = await router.route("go north")
assert "tier" in result
assert result["tier"] == TierLabel.LOCAL_FAST
async def test_acceptance_walk_to_room_is_local_fast(self):
"""Acceptance: 'Walk to the next room' → LOCAL_FAST."""
router = TieredModelRouter(cascade=_make_cascade_mock())
result = await router.route("Walk to the next room")
assert result["tier"] == TierLabel.LOCAL_FAST
async def test_acceptance_plan_hortator_is_local_heavy(self):
"""Acceptance: 'Plan the optimal path to become Hortator' → LOCAL_HEAVY."""
router = TieredModelRouter(
cascade=_make_cascade_mock(model="hermes3:70b"),
)
result = await router.route("Plan the optimal path to become Hortator")
assert result["tier"] == TierLabel.LOCAL_HEAVY
async def test_t1_low_quality_escalates_to_t2(self):
"""Failed Tier-1 response auto-escalates to Tier-2."""
call_models = []
cascade = MagicMock()
async def complete_side_effect(messages, model, temperature, max_tokens):
call_models.append(model)
# First call (T1) returns a low-quality response
if len(call_models) == 1:
return {
"content": "I don't know.",
"provider": "ollama",
"model": model,
"latency_ms": 50,
}
# Second call (T2) returns a good response
return {
"content": "You move to the northern passage, passing through the Dunmer stronghold.",
"provider": "ollama",
"model": model,
"latency_ms": 800,
}
cascade.complete = complete_side_effect
router = TieredModelRouter(cascade=cascade, auto_escalate=True)
result = await router.route("go north")
assert len(call_models) == 2, "Should have called twice (T1 escalated to T2)"
assert result["tier"] == TierLabel.LOCAL_HEAVY
async def test_auto_escalate_false_no_escalation(self):
"""With auto_escalate=False, low-quality T1 response is returned as-is."""
call_count = {"n": 0}
cascade = MagicMock()
async def complete_side_effect(**kwargs):
call_count["n"] += 1
return {
"content": "I don't know.",
"provider": "ollama",
"model": "llama3.1:8b",
"latency_ms": 50,
}
cascade.complete = AsyncMock(side_effect=complete_side_effect)
router = TieredModelRouter(cascade=cascade, auto_escalate=False)
result = await router.route("go north")
assert call_count["n"] == 1
assert result["tier"] == TierLabel.LOCAL_FAST
async def test_t2_failure_escalates_to_cloud(self):
"""Tier-2 failure escalates to Cloud API (when budget allows)."""
cascade = MagicMock()
call_models = []
async def complete_side_effect(messages, model, temperature, max_tokens):
call_models.append(model)
if "hermes3" in model or "70b" in model.lower():
raise RuntimeError("Tier-2 model unavailable")
return {
"content": "Cloud response here.",
"provider": "anthropic",
"model": model,
"latency_ms": 1200,
}
cascade.complete = complete_side_effect
budget = _make_budget_mock(allowed=True)
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
result = await router.route("plan my route", context={"require_t2": True})
assert result["tier"] == TierLabel.CLOUD_API
async def test_cloud_blocked_by_budget_raises(self):
"""Cloud tier blocked when budget is exhausted."""
cascade = MagicMock()
cascade.complete = AsyncMock(side_effect=RuntimeError("T2 fail"))
budget = _make_budget_mock(allowed=False)
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
with pytest.raises(RuntimeError, match="budget limit"):
await router.route("plan my route", context={"require_t2": True})
async def test_explicit_cloud_tier_uses_cloud_model(self):
cascade = _make_cascade_mock(model="claude-haiku-4-5")
budget = _make_budget_mock(allowed=True)
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
result = await router.route("go north", context={"require_cloud": True})
assert result["tier"] == TierLabel.CLOUD_API
async def test_cloud_spend_recorded_with_usage(self):
"""Cloud spend is recorded when the response includes usage info."""
cascade = MagicMock()
cascade.complete = AsyncMock(
return_value={
"content": "Cloud answer.",
"provider": "anthropic",
"model": "claude-haiku-4-5",
"latency_ms": 900,
"usage": {"prompt_tokens": 50, "completion_tokens": 100},
}
)
budget = _make_budget_mock(allowed=True)
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
result = await router.route("go north", context={"require_cloud": True})
budget.record_spend.assert_called_once()
assert "cost_usd" in result
async def test_cloud_spend_not_recorded_without_usage(self):
"""Cloud spend is not recorded when usage info is absent."""
cascade = MagicMock()
cascade.complete = AsyncMock(
return_value={
"content": "Cloud answer.",
"provider": "anthropic",
"model": "claude-haiku-4-5",
"latency_ms": 900,
# no "usage" key
}
)
budget = _make_budget_mock(allowed=True)
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
result = await router.route("go north", context={"require_cloud": True})
budget.record_spend.assert_not_called()
assert "cost_usd" not in result
async def test_custom_tier_models_respected(self):
cascade = _make_cascade_mock()
router = TieredModelRouter(
cascade=cascade,
tier_models={TierLabel.LOCAL_FAST: "llama3.2:3b"},
)
await router.route("go north")
call_kwargs = cascade.complete.call_args
assert call_kwargs.kwargs["model"] == "llama3.2:3b"
async def test_messages_override_used_when_provided(self):
cascade = _make_cascade_mock()
router = TieredModelRouter(cascade=cascade)
custom_msgs = [{"role": "user", "content": "custom message"}]
await router.route("go north", messages=custom_msgs)
call_kwargs = cascade.complete.call_args
assert call_kwargs.kwargs["messages"] == custom_msgs
async def test_temperature_forwarded(self):
cascade = _make_cascade_mock()
router = TieredModelRouter(cascade=cascade)
await router.route("go north", temperature=0.7)
call_kwargs = cascade.complete.call_args
assert call_kwargs.kwargs["temperature"] == 0.7
async def test_max_tokens_forwarded(self):
cascade = _make_cascade_mock()
router = TieredModelRouter(cascade=cascade)
await router.route("go north", max_tokens=128)
call_kwargs = cascade.complete.call_args
assert call_kwargs.kwargs["max_tokens"] == 128
class TestTieredModelRouterClassify:
def test_classify_delegates_to_classify_tier(self):
router = TieredModelRouter(cascade=MagicMock())
assert router.classify("go north") == classify_tier("go north")
assert router.classify("plan the quest") == classify_tier("plan the quest")
class TestGetTieredRouterSingleton:
def test_returns_tiered_router_instance(self):
import infrastructure.models.router as rmod
rmod._tiered_router = None
router = get_tiered_router()
assert isinstance(router, TieredModelRouter)
def test_singleton_returns_same_instance(self):
import infrastructure.models.router as rmod
rmod._tiered_router = None
r1 = get_tiered_router()
r2 = get_tiered_router()
assert r1 is r2

View File

View File

@@ -0,0 +1,379 @@
"""Tests for the sovereignty perception cache (template matching).
Refs: #1261
"""
import json
from unittest.mock import patch
import numpy as np
class TestTemplate:
"""Tests for the Template dataclass."""
def test_template_default_values(self):
"""Template dataclass has correct defaults."""
from timmy.sovereignty.perception_cache import Template
image = np.array([[1, 2], [3, 4]])
template = Template(name="test_template", image=image)
assert template.name == "test_template"
assert np.array_equal(template.image, image)
assert template.threshold == 0.85
def test_template_custom_threshold(self):
"""Template can have custom threshold."""
from timmy.sovereignty.perception_cache import Template
image = np.array([[1, 2], [3, 4]])
template = Template(name="test_template", image=image, threshold=0.95)
assert template.threshold == 0.95
class TestCacheResult:
"""Tests for the CacheResult dataclass."""
def test_cache_result_with_state(self):
"""CacheResult stores confidence and state."""
from timmy.sovereignty.perception_cache import CacheResult
result = CacheResult(confidence=0.92, state={"template_name": "test"})
assert result.confidence == 0.92
assert result.state == {"template_name": "test"}
def test_cache_result_no_state(self):
"""CacheResult can have None state."""
from timmy.sovereignty.perception_cache import CacheResult
result = CacheResult(confidence=0.5, state=None)
assert result.confidence == 0.5
assert result.state is None
class TestPerceptionCacheInit:
"""Tests for PerceptionCache initialization."""
def test_init_creates_empty_cache_when_no_file(self, tmp_path):
"""Cache initializes empty when templates file doesn't exist."""
from timmy.sovereignty.perception_cache import PerceptionCache
templates_path = tmp_path / "nonexistent_templates.json"
cache = PerceptionCache(templates_path=templates_path)
assert cache.templates_path == templates_path
assert cache.templates == []
def test_init_loads_existing_templates(self, tmp_path):
"""Cache loads templates from existing JSON file."""
from timmy.sovereignty.perception_cache import PerceptionCache
templates_path = tmp_path / "templates.json"
templates_data = [
{"name": "template1", "threshold": 0.85},
{"name": "template2", "threshold": 0.90},
]
with open(templates_path, "w") as f:
json.dump(templates_data, f)
cache = PerceptionCache(templates_path=templates_path)
assert len(cache.templates) == 2
assert cache.templates[0].name == "template1"
assert cache.templates[0].threshold == 0.85
assert cache.templates[1].name == "template2"
assert cache.templates[1].threshold == 0.90
def test_init_with_string_path(self, tmp_path):
"""Cache accepts string path for templates."""
from timmy.sovereignty.perception_cache import PerceptionCache
templates_path = str(tmp_path / "templates.json")
cache = PerceptionCache(templates_path=templates_path)
assert str(cache.templates_path) == templates_path
class TestPerceptionCacheMatch:
"""Tests for PerceptionCache.match() template matching."""
def test_match_no_templates_returns_low_confidence(self, tmp_path):
"""Matching with no templates returns low confidence and None state."""
from timmy.sovereignty.perception_cache import PerceptionCache
cache = PerceptionCache(templates_path=tmp_path / "templates.json")
screenshot = np.array([[1, 2], [3, 4]])
result = cache.match(screenshot)
assert result.confidence == 0.0
assert result.state is None
@patch("timmy.sovereignty.perception_cache.cv2")
def test_match_finds_best_template(self, mock_cv2, tmp_path):
"""Match returns the best matching template above threshold."""
from timmy.sovereignty.perception_cache import PerceptionCache, Template
# Setup mock cv2 behavior
mock_cv2.matchTemplate.return_value = np.array([[0.5, 0.6], [0.7, 0.8]])
mock_cv2.TM_CCOEFF_NORMED = "TM_CCOEFF_NORMED"
mock_cv2.minMaxLoc.return_value = (None, 0.92, None, None)
cache = PerceptionCache(templates_path=tmp_path / "templates.json")
template = Template(name="best_match", image=np.array([[1, 2], [3, 4]]))
cache.add([template])
screenshot = np.array([[5, 6], [7, 8]])
result = cache.match(screenshot)
assert result.confidence == 0.92
assert result.state == {"template_name": "best_match"}
@patch("timmy.sovereignty.perception_cache.cv2")
def test_match_respects_global_threshold(self, mock_cv2, tmp_path):
"""Match returns None state when confidence is below threshold."""
from timmy.sovereignty.perception_cache import PerceptionCache, Template
# Setup mock cv2 to return confidence below 0.85 threshold
mock_cv2.matchTemplate.return_value = np.array([[0.1, 0.2], [0.3, 0.4]])
mock_cv2.TM_CCOEFF_NORMED = "TM_CCOEFF_NORMED"
mock_cv2.minMaxLoc.return_value = (None, 0.75, None, None)
cache = PerceptionCache(templates_path=tmp_path / "templates.json")
template = Template(name="low_match", image=np.array([[1, 2], [3, 4]]))
cache.add([template])
screenshot = np.array([[5, 6], [7, 8]])
result = cache.match(screenshot)
# Confidence is recorded but state is None (below threshold)
assert result.confidence == 0.75
assert result.state is None
@patch("timmy.sovereignty.perception_cache.cv2")
def test_match_selects_highest_confidence(self, mock_cv2, tmp_path):
"""Match selects template with highest confidence across all templates."""
from timmy.sovereignty.perception_cache import PerceptionCache, Template
mock_cv2.TM_CCOEFF_NORMED = "TM_CCOEFF_NORMED"
# Each template will return a different confidence
mock_cv2.minMaxLoc.side_effect = [
(None, 0.70, None, None), # template1
(None, 0.95, None, None), # template2 (best)
(None, 0.80, None, None), # template3
]
cache = PerceptionCache(templates_path=tmp_path / "templates.json")
templates = [
Template(name="template1", image=np.array([[1, 2], [3, 4]])),
Template(name="template2", image=np.array([[5, 6], [7, 8]])),
Template(name="template3", image=np.array([[9, 10], [11, 12]])),
]
cache.add(templates)
screenshot = np.array([[13, 14], [15, 16]])
result = cache.match(screenshot)
assert result.confidence == 0.95
assert result.state == {"template_name": "template2"}
@patch("timmy.sovereignty.perception_cache.cv2")
def test_match_exactly_at_threshold(self, mock_cv2, tmp_path):
"""Match returns state when confidence is exactly at threshold boundary."""
from timmy.sovereignty.perception_cache import PerceptionCache, Template
mock_cv2.matchTemplate.return_value = np.array([[0.1]])
mock_cv2.TM_CCOEFF_NORMED = "TM_CCOEFF_NORMED"
mock_cv2.minMaxLoc.return_value = (None, 0.85, None, None) # Exactly at threshold
cache = PerceptionCache(templates_path=tmp_path / "templates.json")
template = Template(name="threshold_match", image=np.array([[1, 2], [3, 4]]))
cache.add([template])
screenshot = np.array([[5, 6], [7, 8]])
result = cache.match(screenshot)
# Note: current implementation uses > 0.85, so exactly 0.85 returns None state
assert result.confidence == 0.85
assert result.state is None
@patch("timmy.sovereignty.perception_cache.cv2")
def test_match_just_above_threshold(self, mock_cv2, tmp_path):
"""Match returns state when confidence is just above threshold."""
from timmy.sovereignty.perception_cache import PerceptionCache, Template
mock_cv2.matchTemplate.return_value = np.array([[0.1]])
mock_cv2.TM_CCOEFF_NORMED = "TM_CCOEFF_NORMED"
mock_cv2.minMaxLoc.return_value = (None, 0.851, None, None) # Just above threshold
cache = PerceptionCache(templates_path=tmp_path / "templates.json")
template = Template(name="above_threshold", image=np.array([[1, 2], [3, 4]]))
cache.add([template])
screenshot = np.array([[5, 6], [7, 8]])
result = cache.match(screenshot)
assert result.confidence == 0.851
assert result.state == {"template_name": "above_threshold"}
class TestPerceptionCacheAdd:
"""Tests for PerceptionCache.add() method."""
def test_add_single_template(self, tmp_path):
"""Can add a single template to the cache."""
from timmy.sovereignty.perception_cache import PerceptionCache, Template
cache = PerceptionCache(templates_path=tmp_path / "templates.json")
template = Template(name="new_template", image=np.array([[1, 2], [3, 4]]))
cache.add([template])
assert len(cache.templates) == 1
assert cache.templates[0].name == "new_template"
def test_add_multiple_templates(self, tmp_path):
"""Can add multiple templates at once."""
from timmy.sovereignty.perception_cache import PerceptionCache, Template
cache = PerceptionCache(templates_path=tmp_path / "templates.json")
templates = [
Template(name="template1", image=np.array([[1, 2], [3, 4]])),
Template(name="template2", image=np.array([[5, 6], [7, 8]])),
]
cache.add(templates)
assert len(cache.templates) == 2
assert cache.templates[0].name == "template1"
assert cache.templates[1].name == "template2"
def test_add_templates_accumulate(self, tmp_path):
"""Adding templates multiple times accumulates them."""
from timmy.sovereignty.perception_cache import PerceptionCache, Template
cache = PerceptionCache(templates_path=tmp_path / "templates.json")
cache.add([Template(name="first", image=np.array([[1]]))])
cache.add([Template(name="second", image=np.array([[2]]))])
assert len(cache.templates) == 2
class TestPerceptionCachePersist:
"""Tests for PerceptionCache.persist() method."""
def test_persist_creates_file(self, tmp_path):
"""Persist creates templates JSON file."""
from timmy.sovereignty.perception_cache import PerceptionCache, Template
templates_path = tmp_path / "subdir" / "templates.json"
cache = PerceptionCache(templates_path=templates_path)
cache.add([Template(name="persisted", image=np.array([[1, 2], [3, 4]]))])
cache.persist()
assert templates_path.exists()
def test_persist_stores_template_names(self, tmp_path):
"""Persist stores template names and thresholds."""
from timmy.sovereignty.perception_cache import PerceptionCache, Template
templates_path = tmp_path / "templates.json"
cache = PerceptionCache(templates_path=templates_path)
cache.add([
Template(name="template1", image=np.array([[1]]), threshold=0.85),
Template(name="template2", image=np.array([[2]]), threshold=0.90),
])
cache.persist()
with open(templates_path) as f:
data = json.load(f)
assert len(data) == 2
assert data[0]["name"] == "template1"
assert data[0]["threshold"] == 0.85
assert data[1]["name"] == "template2"
assert data[1]["threshold"] == 0.90
def test_persist_does_not_store_image_data(self, tmp_path):
"""Persist only stores metadata, not actual image arrays."""
from timmy.sovereignty.perception_cache import PerceptionCache, Template
templates_path = tmp_path / "templates.json"
cache = PerceptionCache(templates_path=templates_path)
cache.add([Template(name="no_image", image=np.array([[1, 2, 3], [4, 5, 6]]))])
cache.persist()
with open(templates_path) as f:
data = json.load(f)
assert "image" not in data[0]
assert set(data[0].keys()) == {"name", "threshold"}
class TestPerceptionCacheLoad:
"""Tests for PerceptionCache.load() method."""
def test_load_from_existing_file(self, tmp_path):
"""Load restores templates from persisted file."""
from timmy.sovereignty.perception_cache import PerceptionCache
templates_path = tmp_path / "templates.json"
# Create initial cache with templates and persist
cache1 = PerceptionCache(templates_path=templates_path)
from timmy.sovereignty.perception_cache import Template
cache1.add([Template(name="loaded", image=np.array([[1]]), threshold=0.88)])
cache1.persist()
# Create new cache instance that loads from same file
cache2 = PerceptionCache(templates_path=templates_path)
assert len(cache2.templates) == 1
assert cache2.templates[0].name == "loaded"
assert cache2.templates[0].threshold == 0.88
# Note: images are loaded as empty arrays per current implementation
assert cache2.templates[0].image.size == 0
def test_load_empty_file(self, tmp_path):
"""Load handles empty template list in file."""
from timmy.sovereignty.perception_cache import PerceptionCache
templates_path = tmp_path / "templates.json"
with open(templates_path, "w") as f:
json.dump([], f)
cache = PerceptionCache(templates_path=templates_path)
assert cache.templates == []
class TestCrystallizePerception:
"""Tests for crystallize_perception function."""
def test_crystallize_returns_empty_list(self, tmp_path):
"""crystallize_perception currently returns empty list (placeholder)."""
from timmy.sovereignty.perception_cache import crystallize_perception
screenshot = np.array([[1, 2], [3, 4]])
result = crystallize_perception(screenshot, {"some": "response"})
assert result == []
def test_crystallize_accepts_any_vlm_response(self, tmp_path):
"""crystallize_perception accepts any vlm_response format."""
from timmy.sovereignty.perception_cache import crystallize_perception
screenshot = np.array([[1, 2], [3, 4]])
# Test with various response types
assert crystallize_perception(screenshot, None) == []
assert crystallize_perception(screenshot, {}) == []
assert crystallize_perception(screenshot, {"items": []}) == []
assert crystallize_perception(screenshot, "string response") == []

View File

@@ -0,0 +1,643 @@
"""Unit tests for timmy.kimi_delegation — Kimi research delegation pipeline."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# exceeds_local_capacity
# ---------------------------------------------------------------------------
class TestExceedsLocalCapacity:
def test_heavy_keyword_triggers_delegation(self):
from timmy.kimi_delegation import exceeds_local_capacity
assert exceeds_local_capacity("Do a comprehensive review of the codebase") is True
def test_all_heavy_keywords_detected(self):
from timmy.kimi_delegation import _HEAVY_RESEARCH_KEYWORDS, exceeds_local_capacity
for kw in _HEAVY_RESEARCH_KEYWORDS:
assert exceeds_local_capacity(f"Please {kw} the topic") is True, f"Missed keyword: {kw}"
def test_long_task_triggers_delegation(self):
from timmy.kimi_delegation import _HEAVY_WORD_THRESHOLD, exceeds_local_capacity
long_task = " ".join(["word"] * (_HEAVY_WORD_THRESHOLD + 1))
assert exceeds_local_capacity(long_task) is True
def test_short_simple_task_returns_false(self):
from timmy.kimi_delegation import exceeds_local_capacity
assert exceeds_local_capacity("Fix the typo in README") is False
def test_exactly_at_word_threshold_triggers(self):
from timmy.kimi_delegation import _HEAVY_WORD_THRESHOLD, exceeds_local_capacity
task = " ".join(["word"] * _HEAVY_WORD_THRESHOLD)
assert exceeds_local_capacity(task) is True
def test_keyword_case_insensitive(self):
from timmy.kimi_delegation import exceeds_local_capacity
assert exceeds_local_capacity("Run a COMPREHENSIVE analysis") is True
def test_empty_string_returns_false(self):
from timmy.kimi_delegation import exceeds_local_capacity
assert exceeds_local_capacity("") is False
# ---------------------------------------------------------------------------
# _slugify
# ---------------------------------------------------------------------------
class TestSlugify:
def test_basic_text(self):
from timmy.kimi_delegation import _slugify
assert _slugify("Hello World") == "hello-world"
def test_special_characters_removed(self):
from timmy.kimi_delegation import _slugify
assert _slugify("Research: AI & ML!") == "research-ai--ml"
def test_underscores_become_dashes(self):
from timmy.kimi_delegation import _slugify
assert _slugify("some_snake_case") == "some-snake-case"
def test_long_text_truncated_to_60(self):
from timmy.kimi_delegation import _slugify
long_text = "a" * 100
result = _slugify(long_text)
assert len(result) <= 60
def test_leading_trailing_dashes_stripped(self):
from timmy.kimi_delegation import _slugify
result = _slugify(" hello ")
assert not result.startswith("-")
assert not result.endswith("-")
def test_multiple_spaces_become_single_dash(self):
from timmy.kimi_delegation import _slugify
assert _slugify("one two") == "one-two"
# ---------------------------------------------------------------------------
# _build_research_template
# ---------------------------------------------------------------------------
class TestBuildResearchTemplate:
def test_contains_task_title(self):
from timmy.kimi_delegation import _build_research_template
body = _build_research_template("My Task", "background", "the question?")
assert "My Task" in body
def test_contains_question(self):
from timmy.kimi_delegation import _build_research_template
body = _build_research_template("task", "context", "What is X?")
assert "What is X?" in body
def test_contains_context(self):
from timmy.kimi_delegation import _build_research_template
body = _build_research_template("task", "some context here", "q?")
assert "some context here" in body
def test_default_priority_normal(self):
from timmy.kimi_delegation import _build_research_template
body = _build_research_template("task", "ctx", "q?")
assert "normal" in body
def test_custom_priority_included(self):
from timmy.kimi_delegation import _build_research_template
body = _build_research_template("task", "ctx", "q?", priority="high")
assert "high" in body
def test_kimi_label_mentioned(self):
from timmy.kimi_delegation import KIMI_READY_LABEL, _build_research_template
body = _build_research_template("task", "ctx", "q?")
assert KIMI_READY_LABEL in body
def test_slugified_task_in_artifact_path(self):
from timmy.kimi_delegation import _build_research_template
body = _build_research_template("My Research Task", "ctx", "q?")
assert "my-research-task" in body
def test_sections_present(self):
from timmy.kimi_delegation import _build_research_template
body = _build_research_template("task", "ctx", "q?")
assert "## Research Request" in body
assert "### Research Question" in body
assert "### Background / Context" in body
assert "### Deliverables" in body
# ---------------------------------------------------------------------------
# _extract_action_items
# ---------------------------------------------------------------------------
class TestExtractActionItems:
def test_checkbox_items_extracted(self):
from timmy.kimi_delegation import _extract_action_items
text = "- [ ] Fix the bug\n- [ ] Write tests\n"
items = _extract_action_items(text)
assert "Fix the bug" in items
assert "Write tests" in items
def test_numbered_list_extracted(self):
from timmy.kimi_delegation import _extract_action_items
text = "1. Deploy to staging\n2. Run smoke tests\n"
items = _extract_action_items(text)
assert "Deploy to staging" in items
assert "Run smoke tests" in items
def test_action_prefix_extracted(self):
from timmy.kimi_delegation import _extract_action_items
text = "Action: Update the config file\n"
items = _extract_action_items(text)
assert "Update the config file" in items
def test_todo_prefix_extracted(self):
from timmy.kimi_delegation import _extract_action_items
text = "TODO: Add error handling\n"
items = _extract_action_items(text)
assert "Add error handling" in items
def test_next_step_prefix_extracted(self):
from timmy.kimi_delegation import _extract_action_items
text = "Next step: Validate results\n"
items = _extract_action_items(text)
assert "Validate results" in items
def test_case_insensitive_prefixes(self):
from timmy.kimi_delegation import _extract_action_items
text = "todo: lowercase todo\nACTION: uppercase action\n"
items = _extract_action_items(text)
assert "lowercase todo" in items
assert "uppercase action" in items
def test_deduplication(self):
from timmy.kimi_delegation import _extract_action_items
text = "1. Do the thing\n2. Do the thing\n"
items = _extract_action_items(text)
assert items.count("Do the thing") == 1
def test_empty_text_returns_empty_list(self):
from timmy.kimi_delegation import _extract_action_items
assert _extract_action_items("") == []
def test_no_action_items_returns_empty_list(self):
from timmy.kimi_delegation import _extract_action_items
text = "This is just plain prose with no action items here."
assert _extract_action_items(text) == []
def test_mixed_sources_combined(self):
from timmy.kimi_delegation import _extract_action_items
text = "- [ ] checkbox item\n1. numbered item\nAction: action item\n"
items = _extract_action_items(text)
assert len(items) == 3
# ---------------------------------------------------------------------------
# _get_or_create_label (async)
# ---------------------------------------------------------------------------
class TestGetOrCreateLabel:
@pytest.mark.asyncio
async def test_returns_existing_label_id(self):
from timmy.kimi_delegation import KIMI_READY_LABEL, _get_or_create_label
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = [{"name": KIMI_READY_LABEL, "id": 42}]
client = MagicMock()
client.get = AsyncMock(return_value=mock_resp)
result = await _get_or_create_label(client, "http://git", {"Authorization": "token x"}, "owner/repo")
assert result == 42
@pytest.mark.asyncio
async def test_creates_label_when_missing(self):
from timmy.kimi_delegation import _get_or_create_label
list_resp = MagicMock()
list_resp.status_code = 200
list_resp.json.return_value = [] # no existing labels
create_resp = MagicMock()
create_resp.status_code = 201
create_resp.json.return_value = {"id": 99}
client = MagicMock()
client.get = AsyncMock(return_value=list_resp)
client.post = AsyncMock(return_value=create_resp)
result = await _get_or_create_label(client, "http://git", {"Authorization": "token x"}, "owner/repo")
assert result == 99
@pytest.mark.asyncio
async def test_returns_none_on_list_exception(self):
from timmy.kimi_delegation import _get_or_create_label
client = MagicMock()
client.get = AsyncMock(side_effect=Exception("network error"))
result = await _get_or_create_label(client, "http://git", {}, "owner/repo")
assert result is None
@pytest.mark.asyncio
async def test_returns_none_on_create_exception(self):
from timmy.kimi_delegation import _get_or_create_label
list_resp = MagicMock()
list_resp.status_code = 200
list_resp.json.return_value = []
client = MagicMock()
client.get = AsyncMock(return_value=list_resp)
client.post = AsyncMock(side_effect=Exception("create failed"))
result = await _get_or_create_label(client, "http://git", {}, "owner/repo")
assert result is None
# ---------------------------------------------------------------------------
# create_kimi_research_issue (async)
# ---------------------------------------------------------------------------
class TestCreateKimiResearchIssue:
@pytest.mark.asyncio
async def test_returns_error_when_gitea_disabled(self):
from timmy.kimi_delegation import create_kimi_research_issue
with patch("timmy.kimi_delegation.settings") as mock_settings:
mock_settings.gitea_enabled = False
mock_settings.gitea_token = ""
result = await create_kimi_research_issue("task", "ctx", "q?")
assert result["success"] is False
assert "not configured" in result["error"]
@pytest.mark.asyncio
async def test_returns_error_when_no_token(self):
from timmy.kimi_delegation import create_kimi_research_issue
with patch("timmy.kimi_delegation.settings") as mock_settings:
mock_settings.gitea_enabled = True
mock_settings.gitea_token = ""
result = await create_kimi_research_issue("task", "ctx", "q?")
assert result["success"] is False
@pytest.mark.asyncio
async def test_successful_issue_creation(self):
from timmy.kimi_delegation import create_kimi_research_issue
mock_settings = MagicMock()
mock_settings.gitea_enabled = True
mock_settings.gitea_token = "tok"
mock_settings.gitea_url = "http://git"
mock_settings.gitea_repo = "owner/repo"
label_resp = MagicMock()
label_resp.status_code = 200
label_resp.json.return_value = [{"name": "kimi-ready", "id": 5}]
issue_resp = MagicMock()
issue_resp.status_code = 201
issue_resp.json.return_value = {"number": 42, "html_url": "http://git/issues/42"}
async_client = AsyncMock()
async_client.get = AsyncMock(return_value=label_resp)
async_client.post = AsyncMock(return_value=issue_resp)
async_client.__aenter__ = AsyncMock(return_value=async_client)
async_client.__aexit__ = AsyncMock(return_value=False)
with (
patch("timmy.kimi_delegation.settings", mock_settings),
patch("timmy.kimi_delegation.httpx") as mock_httpx,
):
mock_httpx.AsyncClient.return_value = async_client
result = await create_kimi_research_issue("task", "ctx", "q?")
assert result["success"] is True
assert result["issue_number"] == 42
assert "http://git/issues/42" in result["issue_url"]
@pytest.mark.asyncio
async def test_api_error_returns_failure(self):
from timmy.kimi_delegation import create_kimi_research_issue
mock_settings = MagicMock()
mock_settings.gitea_enabled = True
mock_settings.gitea_token = "tok"
mock_settings.gitea_url = "http://git"
mock_settings.gitea_repo = "owner/repo"
label_resp = MagicMock()
label_resp.status_code = 200
label_resp.json.return_value = []
create_label_resp = MagicMock()
create_label_resp.status_code = 201
create_label_resp.json.return_value = {"id": 1}
issue_resp = MagicMock()
issue_resp.status_code = 500
issue_resp.text = "Internal Server Error"
async_client = AsyncMock()
async_client.get = AsyncMock(return_value=label_resp)
async_client.post = AsyncMock(side_effect=[create_label_resp, issue_resp])
async_client.__aenter__ = AsyncMock(return_value=async_client)
async_client.__aexit__ = AsyncMock(return_value=False)
with (
patch("timmy.kimi_delegation.settings", mock_settings),
patch("timmy.kimi_delegation.httpx") as mock_httpx,
):
mock_httpx.AsyncClient.return_value = async_client
result = await create_kimi_research_issue("task", "ctx", "q?")
assert result["success"] is False
assert "500" in result["error"]
@pytest.mark.asyncio
async def test_exception_returns_failure(self):
from timmy.kimi_delegation import create_kimi_research_issue
mock_settings = MagicMock()
mock_settings.gitea_enabled = True
mock_settings.gitea_token = "tok"
mock_settings.gitea_url = "http://git"
mock_settings.gitea_repo = "owner/repo"
async_client = AsyncMock()
async_client.__aenter__ = AsyncMock(side_effect=Exception("connection refused"))
async_client.__aexit__ = AsyncMock(return_value=False)
with (
patch("timmy.kimi_delegation.settings", mock_settings),
patch("timmy.kimi_delegation.httpx") as mock_httpx,
):
mock_httpx.AsyncClient.return_value = async_client
result = await create_kimi_research_issue("task", "ctx", "q?")
assert result["success"] is False
assert result["error"] != ""
# ---------------------------------------------------------------------------
# poll_kimi_issue (async)
# ---------------------------------------------------------------------------
class TestPollKimiIssue:
@pytest.mark.asyncio
async def test_returns_error_when_gitea_not_configured(self):
from timmy.kimi_delegation import poll_kimi_issue
with patch("timmy.kimi_delegation.settings") as mock_settings:
mock_settings.gitea_enabled = False
mock_settings.gitea_token = ""
result = await poll_kimi_issue(123)
assert result["completed"] is False
assert "not configured" in result["error"]
@pytest.mark.asyncio
async def test_returns_completed_when_issue_closed(self):
from timmy.kimi_delegation import poll_kimi_issue
mock_settings = MagicMock()
mock_settings.gitea_enabled = True
mock_settings.gitea_token = "tok"
mock_settings.gitea_url = "http://git"
mock_settings.gitea_repo = "owner/repo"
resp = MagicMock()
resp.status_code = 200
resp.json.return_value = {"state": "closed", "body": "Done!"}
async_client = AsyncMock()
async_client.get = AsyncMock(return_value=resp)
async_client.__aenter__ = AsyncMock(return_value=async_client)
async_client.__aexit__ = AsyncMock(return_value=False)
with (
patch("timmy.kimi_delegation.settings", mock_settings),
patch("timmy.kimi_delegation.httpx") as mock_httpx,
):
mock_httpx.AsyncClient.return_value = async_client
result = await poll_kimi_issue(42, poll_interval=0, max_wait=1)
assert result["completed"] is True
assert result["state"] == "closed"
assert result["body"] == "Done!"
@pytest.mark.asyncio
async def test_times_out_when_issue_stays_open(self):
from timmy.kimi_delegation import poll_kimi_issue
mock_settings = MagicMock()
mock_settings.gitea_enabled = True
mock_settings.gitea_token = "tok"
mock_settings.gitea_url = "http://git"
mock_settings.gitea_repo = "owner/repo"
resp = MagicMock()
resp.status_code = 200
resp.json.return_value = {"state": "open", "body": ""}
async_client = AsyncMock()
async_client.get = AsyncMock(return_value=resp)
async_client.__aenter__ = AsyncMock(return_value=async_client)
async_client.__aexit__ = AsyncMock(return_value=False)
with (
patch("timmy.kimi_delegation.settings", mock_settings),
patch("timmy.kimi_delegation.httpx") as mock_httpx,
patch("timmy.kimi_delegation.asyncio.sleep", new_callable=AsyncMock),
):
mock_httpx.AsyncClient.return_value = async_client
# poll_interval > max_wait so it exits immediately after first sleep
result = await poll_kimi_issue(42, poll_interval=10, max_wait=5)
assert result["completed"] is False
assert result["state"] == "timeout"
# ---------------------------------------------------------------------------
# index_kimi_artifact (async)
# ---------------------------------------------------------------------------
class TestIndexKimiArtifact:
@pytest.mark.asyncio
async def test_empty_artifact_returns_error(self):
from timmy.kimi_delegation import index_kimi_artifact
result = await index_kimi_artifact(1, "title", " ")
assert result["success"] is False
assert "Empty artifact" in result["error"]
@pytest.mark.asyncio
async def test_successful_indexing(self):
from timmy.kimi_delegation import index_kimi_artifact
mock_entry = MagicMock()
mock_entry.id = "mem-123"
with patch("timmy.kimi_delegation.asyncio.to_thread", new_callable=AsyncMock) as mock_thread:
mock_thread.return_value = mock_entry
result = await index_kimi_artifact(42, "My Research", "Some research content here")
assert result["success"] is True
assert result["memory_id"] == "mem-123"
@pytest.mark.asyncio
async def test_exception_returns_failure(self):
from timmy.kimi_delegation import index_kimi_artifact
with patch("timmy.kimi_delegation.asyncio.to_thread", new_callable=AsyncMock) as mock_thread:
mock_thread.side_effect = Exception("DB error")
result = await index_kimi_artifact(42, "title", "some content")
assert result["success"] is False
assert result["error"] != ""
# ---------------------------------------------------------------------------
# extract_and_create_followups (async)
# ---------------------------------------------------------------------------
class TestExtractAndCreateFollowups:
@pytest.mark.asyncio
async def test_no_action_items_returns_empty_created(self):
from timmy.kimi_delegation import extract_and_create_followups
result = await extract_and_create_followups("Plain prose, nothing to do.", 1)
assert result["success"] is True
assert result["created"] == []
@pytest.mark.asyncio
async def test_gitea_not_configured_returns_error(self):
from timmy.kimi_delegation import extract_and_create_followups
text = "1. Do something important\n"
with patch("timmy.kimi_delegation.settings") as mock_settings:
mock_settings.gitea_enabled = False
mock_settings.gitea_token = ""
result = await extract_and_create_followups(text, 5)
assert result["success"] is False
@pytest.mark.asyncio
async def test_creates_followup_issues(self):
from timmy.kimi_delegation import extract_and_create_followups
text = "1. Deploy the service\n2. Run integration tests\n"
mock_settings = MagicMock()
mock_settings.gitea_enabled = True
mock_settings.gitea_token = "tok"
mock_settings.gitea_url = "http://git"
mock_settings.gitea_repo = "owner/repo"
issue_resp = MagicMock()
issue_resp.status_code = 201
issue_resp.json.return_value = {"number": 10}
async_client = AsyncMock()
async_client.post = AsyncMock(return_value=issue_resp)
async_client.__aenter__ = AsyncMock(return_value=async_client)
async_client.__aexit__ = AsyncMock(return_value=False)
with (
patch("timmy.kimi_delegation.settings", mock_settings),
patch("timmy.kimi_delegation.httpx") as mock_httpx,
):
mock_httpx.AsyncClient.return_value = async_client
result = await extract_and_create_followups(text, 5)
assert result["success"] is True
assert len(result["created"]) == 2
# ---------------------------------------------------------------------------
# delegate_research_to_kimi (async)
# ---------------------------------------------------------------------------
class TestDelegateResearchToKimi:
@pytest.mark.asyncio
async def test_empty_task_returns_error(self):
from timmy.kimi_delegation import delegate_research_to_kimi
result = await delegate_research_to_kimi("", "ctx", "q?")
assert result["success"] is False
assert "required" in result["error"]
@pytest.mark.asyncio
async def test_whitespace_task_returns_error(self):
from timmy.kimi_delegation import delegate_research_to_kimi
result = await delegate_research_to_kimi(" ", "ctx", "q?")
assert result["success"] is False
assert "required" in result["error"]
@pytest.mark.asyncio
async def test_empty_question_returns_error(self):
from timmy.kimi_delegation import delegate_research_to_kimi
result = await delegate_research_to_kimi("valid task", "ctx", "")
assert result["success"] is False
assert "required" in result["error"]
@pytest.mark.asyncio
async def test_delegates_to_create_issue(self):
from timmy.kimi_delegation import delegate_research_to_kimi
with patch(
"timmy.kimi_delegation.create_kimi_research_issue",
new_callable=AsyncMock,
) as mock_create:
mock_create.return_value = {"success": True, "issue_number": 7, "issue_url": "http://x", "error": None}
result = await delegate_research_to_kimi("Research X", "ctx", "What is X?", priority="high")
assert result["success"] is True
assert result["issue_number"] == 7
mock_create.assert_awaited_once_with("Research X", "ctx", "What is X?", "high")

View File

@@ -1,839 +0,0 @@
"""Unit tests for timmy.quest_system."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
import timmy.quest_system as qs
from timmy.quest_system import (
QuestDefinition,
QuestProgress,
QuestStatus,
QuestType,
_get_progress_key,
_get_target_value,
_is_on_cooldown,
check_daily_run_quest,
check_issue_count_quest,
check_issue_reduce_quest,
claim_quest_reward,
evaluate_quest_progress,
get_active_quests,
get_agent_quests_status,
get_or_create_progress,
get_quest_definition,
get_quest_definitions,
get_quest_leaderboard,
get_quest_progress,
load_quest_config,
reset_quest_progress,
update_quest_progress,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_quest(
quest_id: str = "test_quest",
quest_type: QuestType = QuestType.ISSUE_COUNT,
reward_tokens: int = 10,
enabled: bool = True,
repeatable: bool = False,
cooldown_hours: int = 0,
criteria: dict[str, Any] | None = None,
) -> QuestDefinition:
return QuestDefinition(
id=quest_id,
name=f"Quest {quest_id}",
description="Test quest",
reward_tokens=reward_tokens,
quest_type=quest_type,
enabled=enabled,
repeatable=repeatable,
cooldown_hours=cooldown_hours,
criteria=criteria or {"target_count": 3},
notification_message="Quest Complete! You earned {tokens} tokens.",
)
@pytest.fixture(autouse=True)
def clean_state():
"""Reset module-level state before and after each test."""
reset_quest_progress()
qs._quest_definitions.clear()
qs._quest_settings.clear()
yield
reset_quest_progress()
qs._quest_definitions.clear()
qs._quest_settings.clear()
# ---------------------------------------------------------------------------
# QuestDefinition
# ---------------------------------------------------------------------------
class TestQuestDefinition:
def test_from_dict_minimal(self):
data = {"id": "q1"}
defn = QuestDefinition.from_dict(data)
assert defn.id == "q1"
assert defn.name == "Unnamed Quest"
assert defn.reward_tokens == 0
assert defn.quest_type == QuestType.CUSTOM
assert defn.enabled is True
assert defn.repeatable is False
assert defn.cooldown_hours == 0
def test_from_dict_full(self):
data = {
"id": "q2",
"name": "Full Quest",
"description": "A full quest",
"reward_tokens": 50,
"type": "issue_count",
"enabled": False,
"repeatable": True,
"cooldown_hours": 24,
"criteria": {"target_count": 5},
"notification_message": "You earned {tokens}!",
}
defn = QuestDefinition.from_dict(data)
assert defn.id == "q2"
assert defn.name == "Full Quest"
assert defn.reward_tokens == 50
assert defn.quest_type == QuestType.ISSUE_COUNT
assert defn.enabled is False
assert defn.repeatable is True
assert defn.cooldown_hours == 24
assert defn.criteria == {"target_count": 5}
assert defn.notification_message == "You earned {tokens}!"
def test_from_dict_invalid_type_raises(self):
data = {"id": "q3", "type": "not_a_real_type"}
with pytest.raises(ValueError):
QuestDefinition.from_dict(data)
# ---------------------------------------------------------------------------
# QuestProgress
# ---------------------------------------------------------------------------
class TestQuestProgress:
def test_to_dict_roundtrip(self):
progress = QuestProgress(
quest_id="q1",
agent_id="agent_a",
status=QuestStatus.IN_PROGRESS,
current_value=2,
target_value=5,
started_at="2026-01-01T00:00:00",
metadata={"key": "val"},
)
d = progress.to_dict()
assert d["quest_id"] == "q1"
assert d["agent_id"] == "agent_a"
assert d["status"] == "in_progress"
assert d["current_value"] == 2
assert d["target_value"] == 5
assert d["metadata"] == {"key": "val"}
def test_to_dict_defaults(self):
progress = QuestProgress(
quest_id="q1",
agent_id="agent_a",
status=QuestStatus.NOT_STARTED,
)
d = progress.to_dict()
assert d["completion_count"] == 0
assert d["started_at"] == ""
assert d["completed_at"] == ""
# ---------------------------------------------------------------------------
# _get_progress_key
# ---------------------------------------------------------------------------
def test_get_progress_key():
assert _get_progress_key("q1", "agent_a") == "agent_a:q1"
def test_get_progress_key_different_agents():
key_a = _get_progress_key("q1", "agent_a")
key_b = _get_progress_key("q1", "agent_b")
assert key_a != key_b
# ---------------------------------------------------------------------------
# load_quest_config
# ---------------------------------------------------------------------------
class TestLoadQuestConfig:
def test_missing_file_returns_empty(self, tmp_path):
missing = tmp_path / "nonexistent.yaml"
with patch.object(qs, "QUEST_CONFIG_PATH", missing):
defs, settings = load_quest_config()
assert defs == {}
assert settings == {}
def test_valid_yaml_loads_quests(self, tmp_path):
config_path = tmp_path / "quests.yaml"
config_path.write_text(
"""
quests:
first_quest:
name: First Quest
description: Do stuff
reward_tokens: 25
type: issue_count
enabled: true
repeatable: false
cooldown_hours: 0
criteria:
target_count: 3
notification_message: "Done! {tokens} tokens"
settings:
some_setting: true
"""
)
with patch.object(qs, "QUEST_CONFIG_PATH", config_path):
defs, settings = load_quest_config()
assert "first_quest" in defs
assert defs["first_quest"].name == "First Quest"
assert defs["first_quest"].reward_tokens == 25
assert settings == {"some_setting": True}
def test_invalid_yaml_returns_empty(self, tmp_path):
config_path = tmp_path / "quests.yaml"
config_path.write_text(":: not valid yaml ::")
with patch.object(qs, "QUEST_CONFIG_PATH", config_path):
defs, settings = load_quest_config()
assert defs == {}
assert settings == {}
def test_non_dict_yaml_returns_empty(self, tmp_path):
config_path = tmp_path / "quests.yaml"
config_path.write_text("- item1\n- item2\n")
with patch.object(qs, "QUEST_CONFIG_PATH", config_path):
defs, settings = load_quest_config()
assert defs == {}
assert settings == {}
def test_bad_quest_entry_is_skipped(self, tmp_path):
config_path = tmp_path / "quests.yaml"
config_path.write_text(
"""
quests:
good_quest:
name: Good
type: issue_count
reward_tokens: 10
enabled: true
repeatable: false
cooldown_hours: 0
criteria: {}
notification_message: "{tokens}"
bad_quest:
type: invalid_type_that_does_not_exist
"""
)
with patch.object(qs, "QUEST_CONFIG_PATH", config_path):
defs, _ = load_quest_config()
assert "good_quest" in defs
assert "bad_quest" not in defs
# ---------------------------------------------------------------------------
# get_quest_definitions / get_quest_definition / get_active_quests
# ---------------------------------------------------------------------------
class TestQuestLookup:
def setup_method(self):
q1 = _make_quest("q1", enabled=True)
q2 = _make_quest("q2", enabled=False)
qs._quest_definitions.update({"q1": q1, "q2": q2})
def test_get_quest_definitions_returns_all(self):
defs = get_quest_definitions()
assert "q1" in defs
assert "q2" in defs
def test_get_quest_definition_found(self):
defn = get_quest_definition("q1")
assert defn is not None
assert defn.id == "q1"
def test_get_quest_definition_not_found(self):
assert get_quest_definition("missing") is None
def test_get_active_quests_only_enabled(self):
active = get_active_quests()
ids = [q.id for q in active]
assert "q1" in ids
assert "q2" not in ids
# ---------------------------------------------------------------------------
# _get_target_value
# ---------------------------------------------------------------------------
class TestGetTargetValue:
def test_issue_count(self):
q = _make_quest(quest_type=QuestType.ISSUE_COUNT, criteria={"target_count": 7})
assert _get_target_value(q) == 7
def test_issue_reduce(self):
q = _make_quest(quest_type=QuestType.ISSUE_REDUCE, criteria={"target_reduction": 5})
assert _get_target_value(q) == 5
def test_daily_run(self):
q = _make_quest(quest_type=QuestType.DAILY_RUN, criteria={"min_sessions": 3})
assert _get_target_value(q) == 3
def test_docs_update(self):
q = _make_quest(quest_type=QuestType.DOCS_UPDATE, criteria={"min_files_changed": 2})
assert _get_target_value(q) == 2
def test_test_improve(self):
q = _make_quest(quest_type=QuestType.TEST_IMPROVE, criteria={"min_new_tests": 4})
assert _get_target_value(q) == 4
def test_custom_defaults_to_one(self):
q = _make_quest(quest_type=QuestType.CUSTOM, criteria={})
assert _get_target_value(q) == 1
def test_missing_criteria_key_defaults_to_one(self):
q = _make_quest(quest_type=QuestType.ISSUE_COUNT, criteria={})
assert _get_target_value(q) == 1
# ---------------------------------------------------------------------------
# get_or_create_progress / get_quest_progress
# ---------------------------------------------------------------------------
class TestProgressCreation:
def setup_method(self):
qs._quest_definitions["q1"] = _make_quest("q1", criteria={"target_count": 5})
def test_creates_new_progress(self):
progress = get_or_create_progress("q1", "agent_a")
assert progress.quest_id == "q1"
assert progress.agent_id == "agent_a"
assert progress.status == QuestStatus.NOT_STARTED
assert progress.target_value == 5
assert progress.current_value == 0
def test_returns_existing_progress(self):
p1 = get_or_create_progress("q1", "agent_a")
p1.current_value = 3
p2 = get_or_create_progress("q1", "agent_a")
assert p2.current_value == 3
assert p1 is p2
def test_raises_for_unknown_quest(self):
with pytest.raises(ValueError, match="Quest unknown not found"):
get_or_create_progress("unknown", "agent_a")
def test_get_quest_progress_none_before_creation(self):
assert get_quest_progress("q1", "agent_a") is None
def test_get_quest_progress_after_creation(self):
get_or_create_progress("q1", "agent_a")
progress = get_quest_progress("q1", "agent_a")
assert progress is not None
# ---------------------------------------------------------------------------
# update_quest_progress
# ---------------------------------------------------------------------------
class TestUpdateQuestProgress:
def setup_method(self):
qs._quest_definitions["q1"] = _make_quest("q1", criteria={"target_count": 3})
def test_updates_current_value(self):
progress = update_quest_progress("q1", "agent_a", 2)
assert progress.current_value == 2
assert progress.status == QuestStatus.NOT_STARTED
def test_marks_completed_when_target_reached(self):
progress = update_quest_progress("q1", "agent_a", 3)
assert progress.status == QuestStatus.COMPLETED
assert progress.completed_at != ""
def test_marks_completed_when_value_exceeds_target(self):
progress = update_quest_progress("q1", "agent_a", 10)
assert progress.status == QuestStatus.COMPLETED
def test_does_not_re_complete_already_completed(self):
p = update_quest_progress("q1", "agent_a", 3)
first_completed_at = p.completed_at
p2 = update_quest_progress("q1", "agent_a", 5)
# should not change completed_at again
assert p2.completed_at == first_completed_at
def test_does_not_re_complete_claimed_quest(self):
p = update_quest_progress("q1", "agent_a", 3)
p.status = QuestStatus.CLAIMED
p2 = update_quest_progress("q1", "agent_a", 5)
assert p2.status == QuestStatus.CLAIMED
def test_updates_metadata(self):
progress = update_quest_progress("q1", "agent_a", 1, metadata={"info": "value"})
assert progress.metadata["info"] == "value"
def test_merges_metadata(self):
update_quest_progress("q1", "agent_a", 1, metadata={"a": 1})
progress = update_quest_progress("q1", "agent_a", 2, metadata={"b": 2})
assert progress.metadata["a"] == 1
assert progress.metadata["b"] == 2
# ---------------------------------------------------------------------------
# _is_on_cooldown
# ---------------------------------------------------------------------------
class TestIsOnCooldown:
def test_non_repeatable_never_on_cooldown(self):
quest = _make_quest(repeatable=False, cooldown_hours=24)
progress = QuestProgress(
quest_id="q1",
agent_id="agent_a",
status=QuestStatus.CLAIMED,
last_completed_at=datetime.now(UTC).isoformat(),
)
assert _is_on_cooldown(progress, quest) is False
def test_no_last_completed_not_on_cooldown(self):
quest = _make_quest(repeatable=True, cooldown_hours=24)
progress = QuestProgress(
quest_id="q1",
agent_id="agent_a",
status=QuestStatus.NOT_STARTED,
last_completed_at="",
)
assert _is_on_cooldown(progress, quest) is False
def test_zero_cooldown_not_on_cooldown(self):
quest = _make_quest(repeatable=True, cooldown_hours=0)
progress = QuestProgress(
quest_id="q1",
agent_id="agent_a",
status=QuestStatus.CLAIMED,
last_completed_at=datetime.now(UTC).isoformat(),
)
assert _is_on_cooldown(progress, quest) is False
def test_recent_completion_is_on_cooldown(self):
quest = _make_quest(repeatable=True, cooldown_hours=24)
recent = datetime.now(UTC) - timedelta(hours=1)
progress = QuestProgress(
quest_id="q1",
agent_id="agent_a",
status=QuestStatus.NOT_STARTED,
last_completed_at=recent.isoformat(),
)
assert _is_on_cooldown(progress, quest) is True
def test_expired_cooldown_not_on_cooldown(self):
quest = _make_quest(repeatable=True, cooldown_hours=24)
old = datetime.now(UTC) - timedelta(hours=25)
progress = QuestProgress(
quest_id="q1",
agent_id="agent_a",
status=QuestStatus.NOT_STARTED,
last_completed_at=old.isoformat(),
)
assert _is_on_cooldown(progress, quest) is False
def test_invalid_last_completed_returns_false(self):
quest = _make_quest(repeatable=True, cooldown_hours=24)
progress = QuestProgress(
quest_id="q1",
agent_id="agent_a",
status=QuestStatus.NOT_STARTED,
last_completed_at="not-a-date",
)
assert _is_on_cooldown(progress, quest) is False
# ---------------------------------------------------------------------------
# claim_quest_reward
# ---------------------------------------------------------------------------
class TestClaimQuestReward:
def setup_method(self):
qs._quest_definitions["q1"] = _make_quest("q1", reward_tokens=25)
def test_returns_none_if_no_progress(self):
assert claim_quest_reward("q1", "agent_a") is None
def test_returns_none_if_not_completed(self):
get_or_create_progress("q1", "agent_a")
assert claim_quest_reward("q1", "agent_a") is None
def test_returns_none_if_quest_not_found(self):
assert claim_quest_reward("nonexistent", "agent_a") is None
def test_successful_claim(self):
progress = get_or_create_progress("q1", "agent_a")
progress.status = QuestStatus.COMPLETED
progress.completed_at = datetime.now(UTC).isoformat()
mock_invoice = MagicMock()
mock_invoice.payment_hash = "quest_q1_agent_a_123"
with (
patch("timmy.quest_system.create_invoice_entry", return_value=mock_invoice),
patch("timmy.quest_system.mark_settled"),
):
result = claim_quest_reward("q1", "agent_a")
assert result is not None
assert result["tokens_awarded"] == 25
assert result["quest_id"] == "q1"
assert result["agent_id"] == "agent_a"
assert result["completion_count"] == 1
def test_successful_claim_marks_claimed(self):
progress = get_or_create_progress("q1", "agent_a")
progress.status = QuestStatus.COMPLETED
progress.completed_at = datetime.now(UTC).isoformat()
mock_invoice = MagicMock()
mock_invoice.payment_hash = "phash"
with (
patch("timmy.quest_system.create_invoice_entry", return_value=mock_invoice),
patch("timmy.quest_system.mark_settled"),
):
claim_quest_reward("q1", "agent_a")
assert progress.status == QuestStatus.CLAIMED
def test_repeatable_quest_resets_after_claim(self):
qs._quest_definitions["rep"] = _make_quest(
"rep", repeatable=True, cooldown_hours=0, reward_tokens=10
)
progress = get_or_create_progress("rep", "agent_a")
progress.status = QuestStatus.COMPLETED
progress.completed_at = datetime.now(UTC).isoformat()
progress.current_value = 5
mock_invoice = MagicMock()
mock_invoice.payment_hash = "phash"
with (
patch("timmy.quest_system.create_invoice_entry", return_value=mock_invoice),
patch("timmy.quest_system.mark_settled"),
):
result = claim_quest_reward("rep", "agent_a")
assert result is not None
assert progress.status == QuestStatus.NOT_STARTED
assert progress.current_value == 0
assert progress.completed_at == ""
def test_on_cooldown_returns_none(self):
qs._quest_definitions["rep"] = _make_quest("rep", repeatable=True, cooldown_hours=24)
progress = get_or_create_progress("rep", "agent_a")
progress.status = QuestStatus.COMPLETED
recent = datetime.now(UTC) - timedelta(hours=1)
progress.last_completed_at = recent.isoformat()
assert claim_quest_reward("rep", "agent_a") is None
def test_ledger_error_returns_none(self):
progress = get_or_create_progress("q1", "agent_a")
progress.status = QuestStatus.COMPLETED
progress.completed_at = datetime.now(UTC).isoformat()
with patch("timmy.quest_system.create_invoice_entry", side_effect=Exception("ledger error")):
result = claim_quest_reward("q1", "agent_a")
assert result is None
# ---------------------------------------------------------------------------
# check_issue_count_quest
# ---------------------------------------------------------------------------
class TestCheckIssueCountQuest:
def setup_method(self):
qs._quest_definitions["iq"] = _make_quest(
"iq", quest_type=QuestType.ISSUE_COUNT, criteria={"target_count": 2, "issue_labels": ["bug"]}
)
def test_counts_matching_issues(self):
issues = [
{"labels": [{"name": "bug"}]},
{"labels": [{"name": "bug"}, {"name": "priority"}]},
{"labels": [{"name": "feature"}]}, # doesn't match
]
progress = check_issue_count_quest(
qs._quest_definitions["iq"], "agent_a", issues
)
assert progress.current_value == 2
assert progress.status == QuestStatus.COMPLETED
def test_empty_issues_returns_zero(self):
progress = check_issue_count_quest(qs._quest_definitions["iq"], "agent_a", [])
assert progress.current_value == 0
def test_no_labels_filter_counts_all_labeled(self):
q = _make_quest(
"nolabel",
quest_type=QuestType.ISSUE_COUNT,
criteria={"target_count": 1, "issue_labels": []},
)
qs._quest_definitions["nolabel"] = q
issues = [
{"labels": [{"name": "bug"}]},
{"labels": [{"name": "feature"}]},
]
progress = check_issue_count_quest(q, "agent_a", issues)
assert progress.current_value == 2
# ---------------------------------------------------------------------------
# check_issue_reduce_quest
# ---------------------------------------------------------------------------
class TestCheckIssueReduceQuest:
def setup_method(self):
qs._quest_definitions["ir"] = _make_quest(
"ir", quest_type=QuestType.ISSUE_REDUCE, criteria={"target_reduction": 5}
)
def test_computes_reduction(self):
progress = check_issue_reduce_quest(qs._quest_definitions["ir"], "agent_a", 20, 15)
assert progress.current_value == 5
assert progress.status == QuestStatus.COMPLETED
def test_negative_reduction_treated_as_zero(self):
progress = check_issue_reduce_quest(qs._quest_definitions["ir"], "agent_a", 10, 15)
assert progress.current_value == 0
def test_no_change_yields_zero(self):
progress = check_issue_reduce_quest(qs._quest_definitions["ir"], "agent_a", 10, 10)
assert progress.current_value == 0
# ---------------------------------------------------------------------------
# check_daily_run_quest
# ---------------------------------------------------------------------------
class TestCheckDailyRunQuest:
def setup_method(self):
qs._quest_definitions["dr"] = _make_quest(
"dr", quest_type=QuestType.DAILY_RUN, criteria={"min_sessions": 2}
)
def test_tracks_sessions(self):
progress = check_daily_run_quest(qs._quest_definitions["dr"], "agent_a", 2)
assert progress.current_value == 2
assert progress.status == QuestStatus.COMPLETED
def test_incomplete_sessions(self):
progress = check_daily_run_quest(qs._quest_definitions["dr"], "agent_a", 1)
assert progress.current_value == 1
assert progress.status != QuestStatus.COMPLETED
# ---------------------------------------------------------------------------
# evaluate_quest_progress
# ---------------------------------------------------------------------------
class TestEvaluateQuestProgress:
def setup_method(self):
qs._quest_definitions["iq"] = _make_quest(
"iq", quest_type=QuestType.ISSUE_COUNT, criteria={"target_count": 1}
)
qs._quest_definitions["dis"] = _make_quest("dis", enabled=False)
def test_disabled_quest_returns_none(self):
result = evaluate_quest_progress("dis", "agent_a", {})
assert result is None
def test_missing_quest_returns_none(self):
result = evaluate_quest_progress("nonexistent", "agent_a", {})
assert result is None
def test_issue_count_quest_evaluated(self):
context = {"closed_issues": [{"labels": [{"name": "bug"}]}]}
result = evaluate_quest_progress("iq", "agent_a", context)
assert result is not None
assert result.current_value == 1
def test_issue_reduce_quest_evaluated(self):
qs._quest_definitions["ir"] = _make_quest(
"ir", quest_type=QuestType.ISSUE_REDUCE, criteria={"target_reduction": 3}
)
context = {"previous_issue_count": 10, "current_issue_count": 7}
result = evaluate_quest_progress("ir", "agent_a", context)
assert result is not None
assert result.current_value == 3
def test_daily_run_quest_evaluated(self):
qs._quest_definitions["dr"] = _make_quest(
"dr", quest_type=QuestType.DAILY_RUN, criteria={"min_sessions": 1}
)
context = {"sessions_completed": 2}
result = evaluate_quest_progress("dr", "agent_a", context)
assert result is not None
assert result.current_value == 2
def test_custom_quest_returns_existing_progress(self):
qs._quest_definitions["cust"] = _make_quest("cust", quest_type=QuestType.CUSTOM)
# No progress yet => None (custom quests don't auto-create progress here)
result = evaluate_quest_progress("cust", "agent_a", {})
assert result is None
def test_cooldown_prevents_evaluation(self):
q = _make_quest("rep_iq", quest_type=QuestType.ISSUE_COUNT, repeatable=True, cooldown_hours=24, criteria={"target_count": 1})
qs._quest_definitions["rep_iq"] = q
progress = get_or_create_progress("rep_iq", "agent_a")
recent = datetime.now(UTC) - timedelta(hours=1)
progress.last_completed_at = recent.isoformat()
context = {"closed_issues": [{"labels": [{"name": "bug"}]}]}
result = evaluate_quest_progress("rep_iq", "agent_a", context)
# Should return existing progress without updating
assert result is progress
# ---------------------------------------------------------------------------
# reset_quest_progress
# ---------------------------------------------------------------------------
class TestResetQuestProgress:
def setup_method(self):
qs._quest_definitions["q1"] = _make_quest("q1")
qs._quest_definitions["q2"] = _make_quest("q2")
def test_reset_all(self):
get_or_create_progress("q1", "agent_a")
get_or_create_progress("q2", "agent_a")
count = reset_quest_progress()
assert count == 2
assert get_quest_progress("q1", "agent_a") is None
assert get_quest_progress("q2", "agent_a") is None
def test_reset_specific_quest(self):
get_or_create_progress("q1", "agent_a")
get_or_create_progress("q2", "agent_a")
count = reset_quest_progress(quest_id="q1")
assert count == 1
assert get_quest_progress("q1", "agent_a") is None
assert get_quest_progress("q2", "agent_a") is not None
def test_reset_specific_agent(self):
get_or_create_progress("q1", "agent_a")
get_or_create_progress("q1", "agent_b")
count = reset_quest_progress(agent_id="agent_a")
assert count == 1
assert get_quest_progress("q1", "agent_a") is None
assert get_quest_progress("q1", "agent_b") is not None
def test_reset_specific_quest_and_agent(self):
get_or_create_progress("q1", "agent_a")
get_or_create_progress("q1", "agent_b")
count = reset_quest_progress(quest_id="q1", agent_id="agent_a")
assert count == 1
def test_reset_empty_returns_zero(self):
count = reset_quest_progress()
assert count == 0
# ---------------------------------------------------------------------------
# get_quest_leaderboard
# ---------------------------------------------------------------------------
class TestGetQuestLeaderboard:
def setup_method(self):
qs._quest_definitions["q1"] = _make_quest("q1", reward_tokens=10)
qs._quest_definitions["q2"] = _make_quest("q2", reward_tokens=20)
def test_empty_progress_returns_empty(self):
assert get_quest_leaderboard() == []
def test_leaderboard_sorted_by_tokens(self):
p_a = get_or_create_progress("q1", "agent_a")
p_a.completion_count = 1
p_b = get_or_create_progress("q2", "agent_b")
p_b.completion_count = 2
board = get_quest_leaderboard()
assert board[0]["agent_id"] == "agent_b" # 40 tokens
assert board[1]["agent_id"] == "agent_a" # 10 tokens
def test_leaderboard_aggregates_multiple_quests(self):
p1 = get_or_create_progress("q1", "agent_a")
p1.completion_count = 2 # 20 tokens
p2 = get_or_create_progress("q2", "agent_a")
p2.completion_count = 1 # 20 tokens
board = get_quest_leaderboard()
assert len(board) == 1
assert board[0]["total_tokens"] == 40
assert board[0]["total_completions"] == 3
def test_leaderboard_counts_unique_quests(self):
p1 = get_or_create_progress("q1", "agent_a")
p1.completion_count = 2
p2 = get_or_create_progress("q2", "agent_a")
p2.completion_count = 1
board = get_quest_leaderboard()
assert board[0]["unique_quests_completed"] == 2
# ---------------------------------------------------------------------------
# get_agent_quests_status
# ---------------------------------------------------------------------------
class TestGetAgentQuestsStatus:
def setup_method(self):
qs._quest_definitions["q1"] = _make_quest("q1", reward_tokens=10)
def test_returns_status_structure(self):
result = get_agent_quests_status("agent_a")
assert result["agent_id"] == "agent_a"
assert isinstance(result["quests"], list)
assert "total_tokens_earned" in result
assert "total_quests_completed" in result
assert "active_quests_count" in result
def test_includes_quest_info(self):
result = get_agent_quests_status("agent_a")
quest_info = result["quests"][0]
assert quest_info["quest_id"] == "q1"
assert quest_info["reward_tokens"] == 10
assert quest_info["status"] == QuestStatus.NOT_STARTED.value
def test_accumulates_tokens_from_completions(self):
p = get_or_create_progress("q1", "agent_a")
p.completion_count = 3
result = get_agent_quests_status("agent_a")
assert result["total_tokens_earned"] == 30
assert result["total_quests_completed"] == 3
def test_cooldown_hours_remaining_calculated(self):
q = _make_quest("qcool", repeatable=True, cooldown_hours=24, reward_tokens=5)
qs._quest_definitions["qcool"] = q
p = get_or_create_progress("qcool", "agent_a")
recent = datetime.now(UTC) - timedelta(hours=2)
p.last_completed_at = recent.isoformat()
p.completion_count = 1
result = get_agent_quests_status("agent_a")
qcool_info = next(qi for qi in result["quests"] if qi["quest_id"] == "qcool")
assert qcool_info["on_cooldown"] is True
assert qcool_info["cooldown_hours_remaining"] > 0