forked from Rockachopa/Timmy-time-dashboard
Co-authored-by: Claude (Opus 4.6) <claude@hermes.local> Co-committed-by: Claude (Opus 4.6) <claude@hermes.local>
This commit is contained in:
@@ -3,6 +3,14 @@
|
||||
from .api import router
|
||||
from .cascade import CascadeRouter, Provider, ProviderStatus, get_router
|
||||
from .history import HealthHistoryStore, get_history_store
|
||||
from .metabolic import (
|
||||
DEFAULT_TIER_MODELS,
|
||||
MetabolicRouter,
|
||||
ModelTier,
|
||||
build_prompt,
|
||||
classify_complexity,
|
||||
get_metabolic_router,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CascadeRouter",
|
||||
@@ -12,4 +20,11 @@ __all__ = [
|
||||
"router",
|
||||
"HealthHistoryStore",
|
||||
"get_history_store",
|
||||
# Metabolic router
|
||||
"MetabolicRouter",
|
||||
"ModelTier",
|
||||
"DEFAULT_TIER_MODELS",
|
||||
"classify_complexity",
|
||||
"build_prompt",
|
||||
"get_metabolic_router",
|
||||
]
|
||||
|
||||
381
src/infrastructure/router/metabolic.py
Normal file
381
src/infrastructure/router/metabolic.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""Three-tier metabolic LLM router.
|
||||
|
||||
Routes queries to the cheapest-sufficient model tier using MLX for all
|
||||
inference on Apple Silicon GPU:
|
||||
|
||||
T1 — Routine (Qwen3-8B Q6_K, ~45-55 tok/s): Simple navigation, basic choices.
|
||||
T2 — Medium (Qwen3-14B Q5_K_M, ~20-28 tok/s): Dialogue, inventory management.
|
||||
T3 — Complex (Qwen3-32B Q4_K_M, ~8-12 tok/s): Quest planning, stuck recovery.
|
||||
|
||||
Memory budget:
|
||||
- T1+T2 always loaded (~8.5 GB combined)
|
||||
- T3 loaded on demand (+20 GB) — game pauses during inference
|
||||
|
||||
Design notes:
|
||||
- 70% of game ticks never reach the LLM (handled upstream by behavior trees)
|
||||
- T3 pauses the game world before inference and unpauses after (graceful if no world)
|
||||
- All inference via vllm-mlx / Ollama — local-first, no cloud for game ticks
|
||||
|
||||
References:
|
||||
- Issue #966 — Three-Tier Metabolic LLM Router
|
||||
- Issue #1063 — Best Local Uncensored Agent Model for M3 Max 36GB
|
||||
- Issue #1075 — Claude Quota Monitor + Metabolic Protocol
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelTier(StrEnum):
|
||||
"""Three metabolic model tiers ordered by cost and capability.
|
||||
|
||||
Tier selection is driven by classify_complexity(). The cheapest
|
||||
sufficient tier is always chosen — T1 handles routine tasks, T2
|
||||
handles dialogue and management, T3 handles planning and recovery.
|
||||
"""
|
||||
|
||||
T1_ROUTINE = "t1_routine" # Fast, cheap — Qwen3-8B, always loaded
|
||||
T2_MEDIUM = "t2_medium" # Balanced — Qwen3-14B, always loaded
|
||||
T3_COMPLEX = "t3_complex" # Deep — Qwen3-32B, loaded on demand, pauses game
|
||||
|
||||
|
||||
# ── Classification vocabulary ────────────────────────────────────────────────
|
||||
|
||||
# T1: single-action navigation and binary-choice words
|
||||
_T1_KEYWORDS = frozenset(
|
||||
{
|
||||
"go", "move", "walk", "run", "north", "south", "east", "west",
|
||||
"up", "down", "left", "right", "yes", "no", "ok", "okay",
|
||||
"open", "close", "take", "drop", "look", "pick", "use",
|
||||
"wait", "rest", "save", "attack", "flee", "jump", "crouch",
|
||||
}
|
||||
)
|
||||
|
||||
# T3: planning, optimisation, or recovery signals
|
||||
_T3_KEYWORDS = frozenset(
|
||||
{
|
||||
"plan", "strategy", "optimize", "optimise", "quest", "stuck",
|
||||
"recover", "multi-step", "long-term", "negotiate", "persuade",
|
||||
"faction", "reputation", "best", "optimal", "recommend",
|
||||
"analyze", "analyse", "evaluate", "decide", "complex", "how do i",
|
||||
"what should i do", "help me figure", "what is the best",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def classify_complexity(task: str, state: dict) -> ModelTier:
|
||||
"""Classify a task to the cheapest-sufficient model tier.
|
||||
|
||||
Classification priority (highest wins):
|
||||
1. T3 — any T3 keyword, stuck indicator, or ``state["require_t3"] = True``
|
||||
2. T1 — short task with only T1 keywords and no active context
|
||||
3. T2 — everything else (safe default)
|
||||
|
||||
Args:
|
||||
task: Natural-language task description or player input.
|
||||
state: Current game state dict. Recognised keys:
|
||||
``stuck`` (bool), ``require_t3`` (bool),
|
||||
``active_quests`` (list), ``dialogue_active`` (bool).
|
||||
|
||||
Returns:
|
||||
ModelTier appropriate for the task.
|
||||
"""
|
||||
task_lower = task.lower()
|
||||
words = set(task_lower.split())
|
||||
|
||||
# ── T3 signals ──────────────────────────────────────────────────────────
|
||||
t3_keyword_hit = bool(words & _T3_KEYWORDS)
|
||||
# Check multi-word T3 phrases
|
||||
t3_phrase_hit = any(phrase in task_lower for phrase in _T3_KEYWORDS if " " in phrase)
|
||||
is_stuck = bool(state.get("stuck", False))
|
||||
explicit_t3 = bool(state.get("require_t3", False))
|
||||
|
||||
if t3_keyword_hit or t3_phrase_hit or is_stuck or explicit_t3:
|
||||
logger.debug(
|
||||
"classify_complexity → T3 (keywords=%s stuck=%s explicit=%s)",
|
||||
t3_keyword_hit or t3_phrase_hit,
|
||||
is_stuck,
|
||||
explicit_t3,
|
||||
)
|
||||
return ModelTier.T3_COMPLEX
|
||||
|
||||
# ── T1 signals ──────────────────────────────────────────────────────────
|
||||
t1_keyword_hit = bool(words & _T1_KEYWORDS)
|
||||
task_short = len(task.split()) <= 6
|
||||
no_active_context = (
|
||||
not state.get("active_quests")
|
||||
and not state.get("dialogue_active")
|
||||
and not state.get("combat_active")
|
||||
)
|
||||
|
||||
if t1_keyword_hit and task_short and no_active_context:
|
||||
logger.debug("classify_complexity → T1 (keywords=%s short=%s)", t1_keyword_hit, task_short)
|
||||
return ModelTier.T1_ROUTINE
|
||||
|
||||
# ── Default: T2 ─────────────────────────────────────────────────────────
|
||||
logger.debug("classify_complexity → T2 (default)")
|
||||
return ModelTier.T2_MEDIUM
|
||||
|
||||
|
||||
def build_prompt(
|
||||
state: dict,
|
||||
ui_state: dict,
|
||||
text: str,
|
||||
visual_context: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""Build an OpenAI-compatible messages list from game context.
|
||||
|
||||
Assembles a system message from structured game state and a user
|
||||
message from the player's text input. This format is accepted by
|
||||
CascadeRouter.complete() directly.
|
||||
|
||||
Args:
|
||||
state: Current game state dict. Common keys:
|
||||
``location`` (str), ``health`` (int/float),
|
||||
``inventory`` (list), ``active_quests`` (list),
|
||||
``stuck`` (bool).
|
||||
ui_state: Current UI state dict. Common keys:
|
||||
``dialogue_active`` (bool), ``dialogue_npc`` (str),
|
||||
``menu_open`` (str), ``combat_active`` (bool).
|
||||
text: Player text or task description (becomes user message).
|
||||
visual_context: Optional free-text description of the current screen
|
||||
or scene — from a vision model or rule-based extractor.
|
||||
|
||||
Returns:
|
||||
List of message dicts: [{"role": "system", ...}, {"role": "user", ...}]
|
||||
"""
|
||||
context_lines: list[str] = []
|
||||
|
||||
location = state.get("location", "unknown")
|
||||
context_lines.append(f"Location: {location}")
|
||||
|
||||
health = state.get("health")
|
||||
if health is not None:
|
||||
context_lines.append(f"Health: {health}")
|
||||
|
||||
inventory = state.get("inventory", [])
|
||||
if inventory:
|
||||
items = [i if isinstance(i, str) else i.get("name", str(i)) for i in inventory[:10]]
|
||||
context_lines.append(f"Inventory: {', '.join(items)}")
|
||||
|
||||
active_quests = state.get("active_quests", [])
|
||||
if active_quests:
|
||||
names = [
|
||||
q if isinstance(q, str) else q.get("name", str(q)) for q in active_quests[:5]
|
||||
]
|
||||
context_lines.append(f"Active quests: {', '.join(names)}")
|
||||
|
||||
if state.get("stuck"):
|
||||
context_lines.append("Status: STUCK — need recovery strategy")
|
||||
|
||||
if ui_state.get("dialogue_active"):
|
||||
npc = ui_state.get("dialogue_npc", "NPC")
|
||||
context_lines.append(f"In dialogue with: {npc}")
|
||||
|
||||
if ui_state.get("menu_open"):
|
||||
context_lines.append(f"Menu open: {ui_state['menu_open']}")
|
||||
|
||||
if ui_state.get("combat_active"):
|
||||
context_lines.append("Status: IN COMBAT")
|
||||
|
||||
if visual_context:
|
||||
context_lines.append(f"Scene: {visual_context}")
|
||||
|
||||
system_content = (
|
||||
"You are Timmy, an AI game agent. "
|
||||
"Respond with valid game commands only.\n\n"
|
||||
+ "\n".join(context_lines)
|
||||
)
|
||||
|
||||
return [
|
||||
{"role": "system", "content": system_content},
|
||||
{"role": "user", "content": text},
|
||||
]
|
||||
|
||||
|
||||
# ── Default model assignments ────────────────────────────────────────────────
|
||||
# Overridable per deployment via MetabolicRouter(tier_models={...}).
|
||||
# Model benchmarks (M3 Max 36 GB, issue #1063):
|
||||
# Qwen3-8B Q6_K — 0.933 F1 tool calling, ~45-55 tok/s (~6 GB)
|
||||
# Qwen3-14B Q5_K_M — 0.971 F1 tool calling, ~20-28 tok/s (~9.5 GB)
|
||||
# Qwen3-32B Q4_K_M — highest quality, ~8-12 tok/s (~20 GB, on demand)
|
||||
DEFAULT_TIER_MODELS: dict[ModelTier, str] = {
|
||||
ModelTier.T1_ROUTINE: "qwen3:8b",
|
||||
ModelTier.T2_MEDIUM: "qwen3:14b",
|
||||
ModelTier.T3_COMPLEX: "qwen3:30b", # Closest Ollama tag to 32B Q4
|
||||
}
|
||||
|
||||
|
||||
class MetabolicRouter:
|
||||
"""Routes LLM requests to the cheapest-sufficient model tier.
|
||||
|
||||
Wraps CascadeRouter with:
|
||||
- Complexity classification via classify_complexity()
|
||||
- Prompt assembly via build_prompt()
|
||||
- T3 world-pause / world-unpause (graceful if no world adapter)
|
||||
|
||||
Usage::
|
||||
|
||||
router = MetabolicRouter()
|
||||
|
||||
# Simple route call — classification + prompt + inference in one step
|
||||
result = await router.route(
|
||||
task="Go north",
|
||||
state={"location": "Balmora"},
|
||||
ui_state={},
|
||||
)
|
||||
print(result["content"], result["tier"])
|
||||
|
||||
# Pre-classify if you need the tier for telemetry
|
||||
tier = router.classify("Plan the best path to Vivec", game_state)
|
||||
|
||||
# Wire in world adapter for T3 pause/unpause
|
||||
router.set_world(world_adapter)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cascade: Any | None = None,
|
||||
tier_models: dict[ModelTier, str] | None = None,
|
||||
) -> None:
|
||||
"""Initialise the metabolic router.
|
||||
|
||||
Args:
|
||||
cascade: CascadeRouter instance to use. If None, the
|
||||
singleton returned by get_router() is used lazily.
|
||||
tier_models: Override default model names per tier.
|
||||
"""
|
||||
self._cascade = cascade
|
||||
self._tier_models: dict[ModelTier, str] = dict(DEFAULT_TIER_MODELS)
|
||||
if tier_models:
|
||||
self._tier_models.update(tier_models)
|
||||
self._world: Any | None = None
|
||||
|
||||
def set_world(self, world: Any) -> None:
|
||||
"""Wire in a world adapter for T3 pause / unpause support.
|
||||
|
||||
The adapter only needs to implement ``act(CommandInput)`` — the full
|
||||
WorldInterface contract is not required. A missing or broken world
|
||||
adapter degrades gracefully (logs a warning, inference continues).
|
||||
|
||||
Args:
|
||||
world: Any object with an ``act(CommandInput)`` method.
|
||||
"""
|
||||
self._world = world
|
||||
|
||||
def _get_cascade(self) -> Any:
|
||||
"""Return the CascadeRouter, creating the singleton if needed."""
|
||||
if self._cascade is None:
|
||||
from infrastructure.router.cascade import get_router
|
||||
|
||||
self._cascade = get_router()
|
||||
return self._cascade
|
||||
|
||||
def classify(self, task: str, state: dict) -> ModelTier:
|
||||
"""Classify task complexity. Delegates to classify_complexity()."""
|
||||
return classify_complexity(task, state)
|
||||
|
||||
async def _pause_world(self) -> None:
|
||||
"""Pause the game world before T3 inference (graceful degradation)."""
|
||||
if self._world is None:
|
||||
return
|
||||
try:
|
||||
from infrastructure.world.types import CommandInput
|
||||
|
||||
await asyncio.to_thread(self._world.act, CommandInput(action="pause"))
|
||||
logger.debug("MetabolicRouter: world paused for T3 inference")
|
||||
except Exception as exc:
|
||||
logger.warning("world.pause() failed — continuing without pause: %s", exc)
|
||||
|
||||
async def _unpause_world(self) -> None:
|
||||
"""Unpause the game world after T3 inference (always called, even on error)."""
|
||||
if self._world is None:
|
||||
return
|
||||
try:
|
||||
from infrastructure.world.types import CommandInput
|
||||
|
||||
await asyncio.to_thread(self._world.act, CommandInput(action="unpause"))
|
||||
logger.debug("MetabolicRouter: world unpaused after T3 inference")
|
||||
except Exception as exc:
|
||||
logger.warning("world.unpause() failed — game may remain paused: %s", exc)
|
||||
|
||||
async def route(
|
||||
self,
|
||||
task: str,
|
||||
state: dict,
|
||||
ui_state: dict | None = None,
|
||||
visual_context: str | None = None,
|
||||
temperature: float = 0.3,
|
||||
max_tokens: int | None = None,
|
||||
) -> dict:
|
||||
"""Route a task to the appropriate model tier and return the LLM response.
|
||||
|
||||
Selects the tier via classify_complexity(), assembles the prompt via
|
||||
build_prompt(), and dispatches to CascadeRouter. For T3, the game
|
||||
world is paused before inference and unpaused after (in a finally block).
|
||||
|
||||
Args:
|
||||
task: Natural-language task description or player input.
|
||||
state: Current game state dict.
|
||||
ui_state: Current UI state dict (optional, defaults to {}).
|
||||
visual_context: Optional screen/scene description from vision model.
|
||||
temperature: Sampling temperature (default 0.3 for game commands).
|
||||
max_tokens: Maximum tokens to generate.
|
||||
|
||||
Returns:
|
||||
Dict with keys: ``content``, ``provider``, ``model``, ``tier``,
|
||||
``latency_ms``, plus any extra keys from CascadeRouter.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If all providers fail (propagated from CascadeRouter).
|
||||
"""
|
||||
ui_state = ui_state or {}
|
||||
tier = self.classify(task, state)
|
||||
model = self._tier_models[tier]
|
||||
messages = build_prompt(state, ui_state, task, visual_context)
|
||||
cascade = self._get_cascade()
|
||||
|
||||
logger.info(
|
||||
"MetabolicRouter: tier=%s model=%s task=%r",
|
||||
tier,
|
||||
model,
|
||||
task[:80],
|
||||
)
|
||||
|
||||
if tier == ModelTier.T3_COMPLEX:
|
||||
await self._pause_world()
|
||||
try:
|
||||
result = await cascade.complete(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
finally:
|
||||
await self._unpause_world()
|
||||
else:
|
||||
result = await cascade.complete(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
result["tier"] = tier
|
||||
return result
|
||||
|
||||
|
||||
# ── Module-level singleton ────────────────────────────────────────────────────
|
||||
_metabolic_router: MetabolicRouter | None = None
|
||||
|
||||
|
||||
def get_metabolic_router() -> MetabolicRouter:
|
||||
"""Get or create the MetabolicRouter singleton."""
|
||||
global _metabolic_router
|
||||
if _metabolic_router is None:
|
||||
_metabolic_router = MetabolicRouter()
|
||||
return _metabolic_router
|
||||
386
tests/infrastructure/test_metabolic_router.py
Normal file
386
tests/infrastructure/test_metabolic_router.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""Tests for the three-tier metabolic LLM router (issue #966)."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
from infrastructure.router.metabolic import (
|
||||
DEFAULT_TIER_MODELS,
|
||||
MetabolicRouter,
|
||||
ModelTier,
|
||||
build_prompt,
|
||||
classify_complexity,
|
||||
get_metabolic_router,
|
||||
)
|
||||
|
||||
|
||||
# ── classify_complexity ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestClassifyComplexity:
|
||||
"""Verify tier classification for representative task / state pairs."""
|
||||
|
||||
# ── T1: Routine ─────────────────────────────────────────────────────────
|
||||
|
||||
def test_simple_navigation_is_t1(self):
|
||||
assert classify_complexity("go north", {}) == ModelTier.T1_ROUTINE
|
||||
|
||||
def test_single_action_is_t1(self):
|
||||
assert classify_complexity("open door", {}) == ModelTier.T1_ROUTINE
|
||||
|
||||
def test_t1_with_extra_words_stays_t1(self):
|
||||
# 6 words, all T1 territory, no active context
|
||||
assert classify_complexity("go south and take it", {}) == ModelTier.T1_ROUTINE
|
||||
|
||||
def test_t1_long_task_upgrades_to_t2(self):
|
||||
# More than 6 words → not T1 even with nav words
|
||||
assert (
|
||||
classify_complexity("go north and then move east and pick up the sword", {})
|
||||
!= ModelTier.T1_ROUTINE
|
||||
)
|
||||
|
||||
def test_active_quest_upgrades_t1_to_t2(self):
|
||||
state = {"active_quests": ["Rescue the Mage"]}
|
||||
assert classify_complexity("go north", state) == ModelTier.T2_MEDIUM
|
||||
|
||||
def test_dialogue_active_upgrades_t1_to_t2(self):
|
||||
state = {"dialogue_active": True}
|
||||
assert classify_complexity("yes", state) == ModelTier.T2_MEDIUM
|
||||
|
||||
def test_combat_active_upgrades_t1_to_t2(self):
|
||||
state = {"combat_active": True}
|
||||
assert classify_complexity("attack", state) == ModelTier.T2_MEDIUM
|
||||
|
||||
# ── T2: Medium ──────────────────────────────────────────────────────────
|
||||
|
||||
def test_default_is_t2(self):
|
||||
assert classify_complexity("what do I have in my inventory", {}) == ModelTier.T2_MEDIUM
|
||||
|
||||
def test_dialogue_response_is_t2(self):
|
||||
state = {"dialogue_active": True, "dialogue_npc": "Caius Cosades"}
|
||||
result = classify_complexity("I'm looking for Caius Cosades", state)
|
||||
assert result == ModelTier.T2_MEDIUM
|
||||
|
||||
# ── T3: Complex ─────────────────────────────────────────────────────────
|
||||
|
||||
def test_quest_planning_is_t3(self):
|
||||
assert classify_complexity("plan my quest route", {}) == ModelTier.T3_COMPLEX
|
||||
|
||||
def test_strategy_keyword_is_t3(self):
|
||||
assert classify_complexity("what is the best strategy", {}) == ModelTier.T3_COMPLEX
|
||||
|
||||
def test_stuck_keyword_is_t3(self):
|
||||
assert classify_complexity("I am stuck", {}) == ModelTier.T3_COMPLEX
|
||||
|
||||
def test_stuck_state_is_t3(self):
|
||||
assert classify_complexity("help me", {"stuck": True}) == ModelTier.T3_COMPLEX
|
||||
|
||||
def test_require_t3_flag_forces_t3(self):
|
||||
state = {"require_t3": True}
|
||||
assert classify_complexity("go north", state) == ModelTier.T3_COMPLEX
|
||||
|
||||
def test_optimize_keyword_is_t3(self):
|
||||
assert classify_complexity("optimize my skill build", {}) == ModelTier.T3_COMPLEX
|
||||
|
||||
def test_multi_word_t3_phrase(self):
|
||||
assert classify_complexity("how do i get past the guards", {}) == ModelTier.T3_COMPLEX
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert classify_complexity("PLAN my route", {}) == ModelTier.T3_COMPLEX
|
||||
|
||||
|
||||
# ── build_prompt ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBuildPrompt:
|
||||
"""Verify prompt structure and content assembly."""
|
||||
|
||||
def test_returns_two_messages(self):
|
||||
msgs = build_prompt({}, {}, "go north")
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0]["role"] == "system"
|
||||
assert msgs[1]["role"] == "user"
|
||||
|
||||
def test_user_message_contains_task(self):
|
||||
msgs = build_prompt({}, {}, "pick up the sword")
|
||||
assert msgs[1]["content"] == "pick up the sword"
|
||||
|
||||
def test_location_in_system(self):
|
||||
msgs = build_prompt({"location": "Balmora"}, {}, "look around")
|
||||
assert "Balmora" in msgs[0]["content"]
|
||||
|
||||
def test_health_in_system(self):
|
||||
msgs = build_prompt({"health": 42}, {}, "rest")
|
||||
assert "42" in msgs[0]["content"]
|
||||
|
||||
def test_inventory_in_system(self):
|
||||
msgs = build_prompt({"inventory": ["iron sword", "bread"]}, {}, "use item")
|
||||
assert "iron sword" in msgs[0]["content"]
|
||||
|
||||
def test_inventory_truncated_to_10(self):
|
||||
inventory = [f"item{i}" for i in range(20)]
|
||||
msgs = build_prompt({"inventory": inventory}, {}, "check")
|
||||
# Only first 10 should appear in the system message
|
||||
assert "item10" not in msgs[0]["content"]
|
||||
|
||||
def test_active_quests_in_system(self):
|
||||
msgs = build_prompt({"active_quests": ["Morrowind Main Quest"]}, {}, "help")
|
||||
assert "Morrowind Main Quest" in msgs[0]["content"]
|
||||
|
||||
def test_stuck_indicator_in_system(self):
|
||||
msgs = build_prompt({"stuck": True}, {}, "what now")
|
||||
assert "STUCK" in msgs[0]["content"]
|
||||
|
||||
def test_dialogue_npc_in_system(self):
|
||||
msgs = build_prompt({}, {"dialogue_active": True, "dialogue_npc": "Vivec"}, "hello")
|
||||
assert "Vivec" in msgs[0]["content"]
|
||||
|
||||
def test_menu_open_in_system(self):
|
||||
msgs = build_prompt({}, {"menu_open": "inventory"}, "check items")
|
||||
assert "inventory" in msgs[0]["content"]
|
||||
|
||||
def test_combat_active_in_system(self):
|
||||
msgs = build_prompt({}, {"combat_active": True}, "attack")
|
||||
assert "COMBAT" in msgs[0]["content"]
|
||||
|
||||
def test_visual_context_in_system(self):
|
||||
msgs = build_prompt({}, {}, "where am I", visual_context="A dark dungeon corridor")
|
||||
assert "dungeon corridor" in msgs[0]["content"]
|
||||
|
||||
def test_missing_optional_fields_omitted(self):
|
||||
msgs = build_prompt({}, {}, "move forward")
|
||||
system = msgs[0]["content"]
|
||||
assert "Health:" not in system
|
||||
assert "Inventory:" not in system
|
||||
assert "Active quests:" not in system
|
||||
|
||||
def test_inventory_dict_items(self):
|
||||
inventory = [{"name": "silver dagger"}, {"name": "potion"}]
|
||||
msgs = build_prompt({"inventory": inventory}, {}, "use")
|
||||
assert "silver dagger" in msgs[0]["content"]
|
||||
|
||||
def test_quest_dict_items(self):
|
||||
quests = [{"name": "The Warlord"}, {"name": "Lost in Translation"}]
|
||||
msgs = build_prompt({"active_quests": quests}, {}, "help")
|
||||
assert "The Warlord" in msgs[0]["content"]
|
||||
|
||||
|
||||
# ── MetabolicRouter ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestMetabolicRouter:
|
||||
"""Test MetabolicRouter routing, tier labelling, and T3 world-pause logic."""
|
||||
|
||||
def _make_router(self, mock_cascade=None):
|
||||
"""Create a MetabolicRouter with a mocked CascadeRouter."""
|
||||
if mock_cascade is None:
|
||||
mock_cascade = MagicMock()
|
||||
mock_cascade.complete = AsyncMock(
|
||||
return_value={
|
||||
"content": "Move north confirmed.",
|
||||
"provider": "ollama-local",
|
||||
"model": "qwen3:8b",
|
||||
"latency_ms": 120.0,
|
||||
}
|
||||
)
|
||||
return MetabolicRouter(cascade=mock_cascade)
|
||||
|
||||
async def test_route_returns_tier_in_result(self):
|
||||
router = self._make_router()
|
||||
result = await router.route("go north", state={})
|
||||
assert "tier" in result
|
||||
assert result["tier"] == ModelTier.T1_ROUTINE
|
||||
|
||||
async def test_t1_uses_t1_model(self):
|
||||
mock_cascade = MagicMock()
|
||||
mock_cascade.complete = AsyncMock(
|
||||
return_value={"content": "ok", "provider": "ollama-local", "model": "qwen3:8b", "latency_ms": 100}
|
||||
)
|
||||
router = MetabolicRouter(cascade=mock_cascade)
|
||||
await router.route("go north", state={})
|
||||
call_kwargs = mock_cascade.complete.call_args
|
||||
assert call_kwargs.kwargs["model"] == DEFAULT_TIER_MODELS[ModelTier.T1_ROUTINE]
|
||||
|
||||
async def test_t2_uses_t2_model(self):
|
||||
mock_cascade = MagicMock()
|
||||
mock_cascade.complete = AsyncMock(
|
||||
return_value={"content": "ok", "provider": "ollama-local", "model": "qwen3:14b", "latency_ms": 300}
|
||||
)
|
||||
router = MetabolicRouter(cascade=mock_cascade)
|
||||
await router.route("what should I say to the innkeeper", state={})
|
||||
call_kwargs = mock_cascade.complete.call_args
|
||||
assert call_kwargs.kwargs["model"] == DEFAULT_TIER_MODELS[ModelTier.T2_MEDIUM]
|
||||
|
||||
async def test_t3_uses_t3_model(self):
|
||||
mock_cascade = MagicMock()
|
||||
mock_cascade.complete = AsyncMock(
|
||||
return_value={"content": "ok", "provider": "ollama-local", "model": "qwen3:30b", "latency_ms": 2000}
|
||||
)
|
||||
router = MetabolicRouter(cascade=mock_cascade)
|
||||
await router.route("plan the optimal quest route", state={})
|
||||
call_kwargs = mock_cascade.complete.call_args
|
||||
assert call_kwargs.kwargs["model"] == DEFAULT_TIER_MODELS[ModelTier.T3_COMPLEX]
|
||||
|
||||
async def test_custom_tier_models_respected(self):
|
||||
mock_cascade = MagicMock()
|
||||
mock_cascade.complete = AsyncMock(
|
||||
return_value={"content": "ok", "provider": "test", "model": "custom-8b", "latency_ms": 100}
|
||||
)
|
||||
custom = {ModelTier.T1_ROUTINE: "custom-8b"}
|
||||
router = MetabolicRouter(cascade=mock_cascade, tier_models=custom)
|
||||
await router.route("go north", state={})
|
||||
call_kwargs = mock_cascade.complete.call_args
|
||||
assert call_kwargs.kwargs["model"] == "custom-8b"
|
||||
|
||||
async def test_t3_pauses_world_before_inference(self):
|
||||
mock_cascade = MagicMock()
|
||||
mock_cascade.complete = AsyncMock(
|
||||
return_value={"content": "ok", "provider": "ollama", "model": "qwen3:30b", "latency_ms": 1500}
|
||||
)
|
||||
router = MetabolicRouter(cascade=mock_cascade)
|
||||
|
||||
pause_calls = []
|
||||
unpause_calls = []
|
||||
|
||||
mock_world = MagicMock()
|
||||
|
||||
def track_act(cmd):
|
||||
if cmd.action == "pause":
|
||||
pause_calls.append(cmd)
|
||||
elif cmd.action == "unpause":
|
||||
unpause_calls.append(cmd)
|
||||
|
||||
mock_world.act = track_act
|
||||
router.set_world(mock_world)
|
||||
|
||||
await router.route("plan the quest", state={})
|
||||
|
||||
assert len(pause_calls) == 1, "world.pause() should be called once for T3"
|
||||
assert len(unpause_calls) == 1, "world.unpause() should be called once for T3"
|
||||
|
||||
async def test_t3_unpauses_world_even_on_llm_error(self):
|
||||
"""world.unpause() must be called even when the LLM raises."""
|
||||
mock_cascade = MagicMock()
|
||||
mock_cascade.complete = AsyncMock(side_effect=RuntimeError("LLM failed"))
|
||||
router = MetabolicRouter(cascade=mock_cascade)
|
||||
|
||||
unpause_calls = []
|
||||
mock_world = MagicMock()
|
||||
mock_world.act = lambda cmd: unpause_calls.append(cmd) if cmd.action == "unpause" else None
|
||||
router.set_world(mock_world)
|
||||
|
||||
with pytest.raises(RuntimeError, match="LLM failed"):
|
||||
await router.route("plan the quest", state={})
|
||||
|
||||
assert len(unpause_calls) == 1, "world.unpause() must run even when LLM errors"
|
||||
|
||||
async def test_t1_does_not_pause_world(self):
|
||||
mock_cascade = MagicMock()
|
||||
mock_cascade.complete = AsyncMock(
|
||||
return_value={"content": "ok", "provider": "ollama", "model": "qwen3:8b", "latency_ms": 120}
|
||||
)
|
||||
router = MetabolicRouter(cascade=mock_cascade)
|
||||
|
||||
pause_calls = []
|
||||
mock_world = MagicMock()
|
||||
mock_world.act = lambda cmd: pause_calls.append(cmd)
|
||||
router.set_world(mock_world)
|
||||
|
||||
await router.route("go north", state={})
|
||||
|
||||
assert len(pause_calls) == 0, "world.pause() must NOT be called for T1"
|
||||
|
||||
async def test_t2_does_not_pause_world(self):
|
||||
mock_cascade = MagicMock()
|
||||
mock_cascade.complete = AsyncMock(
|
||||
return_value={"content": "ok", "provider": "ollama", "model": "qwen3:14b", "latency_ms": 350}
|
||||
)
|
||||
router = MetabolicRouter(cascade=mock_cascade)
|
||||
|
||||
pause_calls = []
|
||||
mock_world = MagicMock()
|
||||
mock_world.act = lambda cmd: pause_calls.append(cmd)
|
||||
router.set_world(mock_world)
|
||||
|
||||
await router.route("talk to the merchant", state={})
|
||||
|
||||
assert len(pause_calls) == 0, "world.pause() must NOT be called for T2"
|
||||
|
||||
async def test_broken_world_adapter_degrades_gracefully(self):
|
||||
"""If world.act() raises, inference must still complete."""
|
||||
mock_cascade = MagicMock()
|
||||
mock_cascade.complete = AsyncMock(
|
||||
return_value={"content": "done", "provider": "ollama", "model": "qwen3:30b", "latency_ms": 2000}
|
||||
)
|
||||
router = MetabolicRouter(cascade=mock_cascade)
|
||||
|
||||
mock_world = MagicMock()
|
||||
mock_world.act = MagicMock(side_effect=RuntimeError("world broken"))
|
||||
router.set_world(mock_world)
|
||||
|
||||
# Should not raise — degradation only logs a warning
|
||||
result = await router.route("plan the quest", state={})
|
||||
assert result["content"] == "done"
|
||||
|
||||
async def test_no_world_adapter_t3_still_works(self):
|
||||
mock_cascade = MagicMock()
|
||||
mock_cascade.complete = AsyncMock(
|
||||
return_value={"content": "plan done", "provider": "ollama", "model": "qwen3:30b", "latency_ms": 2000}
|
||||
)
|
||||
router = MetabolicRouter(cascade=mock_cascade)
|
||||
# No set_world() called
|
||||
|
||||
result = await router.route("plan the quest route", state={})
|
||||
assert result["content"] == "plan done"
|
||||
assert result["tier"] == ModelTier.T3_COMPLEX
|
||||
|
||||
async def test_classify_delegates_to_module_function(self):
|
||||
router = MetabolicRouter(cascade=MagicMock())
|
||||
assert router.classify("go north", {}) == classify_complexity("go north", {})
|
||||
assert router.classify("plan the quest", {}) == classify_complexity("plan the quest", {})
|
||||
|
||||
async def test_ui_state_defaults_to_empty_dict(self):
|
||||
"""Calling route without ui_state should not raise."""
|
||||
mock_cascade = MagicMock()
|
||||
mock_cascade.complete = AsyncMock(
|
||||
return_value={"content": "ok", "provider": "ollama", "model": "qwen3:8b", "latency_ms": 100}
|
||||
)
|
||||
router = MetabolicRouter(cascade=mock_cascade)
|
||||
# No ui_state argument
|
||||
result = await router.route("go north", state={})
|
||||
assert result["content"] == "ok"
|
||||
|
||||
async def test_temperature_and_max_tokens_forwarded(self):
|
||||
mock_cascade = MagicMock()
|
||||
mock_cascade.complete = AsyncMock(
|
||||
return_value={"content": "ok", "provider": "ollama", "model": "qwen3:14b", "latency_ms": 200}
|
||||
)
|
||||
router = MetabolicRouter(cascade=mock_cascade)
|
||||
await router.route("describe the scene", state={}, temperature=0.1, max_tokens=50)
|
||||
call_kwargs = mock_cascade.complete.call_args.kwargs
|
||||
assert call_kwargs["temperature"] == 0.1
|
||||
assert call_kwargs["max_tokens"] == 50
|
||||
|
||||
|
||||
class TestGetMetabolicRouter:
|
||||
"""Test module-level singleton."""
|
||||
|
||||
def test_returns_metabolic_router_instance(self):
|
||||
import infrastructure.router.metabolic as m_module
|
||||
|
||||
# Reset singleton for clean test
|
||||
m_module._metabolic_router = None
|
||||
router = get_metabolic_router()
|
||||
assert isinstance(router, MetabolicRouter)
|
||||
|
||||
def test_singleton_returns_same_instance(self):
|
||||
import infrastructure.router.metabolic as m_module
|
||||
|
||||
m_module._metabolic_router = None
|
||||
r1 = get_metabolic_router()
|
||||
r2 = get_metabolic_router()
|
||||
assert r1 is r2
|
||||
Reference in New Issue
Block a user