forked from Rockachopa/Timmy-time-dashboard
Co-authored-by: Claude (Opus 4.6) <claude@hermes.local> Co-committed-by: Claude (Opus 4.6) <claude@hermes.local>
440 lines
17 KiB
Python
440 lines
17 KiB
Python
"""Tests for the three-tier metabolic LLM router (issue #966)."""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from infrastructure.router.metabolic import (
|
|
DEFAULT_TIER_MODELS,
|
|
MetabolicRouter,
|
|
ModelTier,
|
|
build_prompt,
|
|
classify_complexity,
|
|
get_metabolic_router,
|
|
)
|
|
|
|
pytestmark = pytest.mark.unit
|
|
|
|
# ── classify_complexity ──────────────────────────────────────────────────────
|
|
|
|
|
|
class TestClassifyComplexity:
|
|
"""Verify tier classification for representative task / state pairs."""
|
|
|
|
# ── T1: Routine ─────────────────────────────────────────────────────────
|
|
|
|
def test_simple_navigation_is_t1(self):
|
|
assert classify_complexity("go north", {}) == ModelTier.T1_ROUTINE
|
|
|
|
def test_single_action_is_t1(self):
|
|
assert classify_complexity("open door", {}) == ModelTier.T1_ROUTINE
|
|
|
|
def test_t1_with_extra_words_stays_t1(self):
|
|
# 6 words, all T1 territory, no active context
|
|
assert classify_complexity("go south and take it", {}) == ModelTier.T1_ROUTINE
|
|
|
|
def test_t1_long_task_upgrades_to_t2(self):
|
|
# More than 6 words → not T1 even with nav words
|
|
assert (
|
|
classify_complexity("go north and then move east and pick up the sword", {})
|
|
!= ModelTier.T1_ROUTINE
|
|
)
|
|
|
|
def test_active_quest_upgrades_t1_to_t2(self):
|
|
state = {"active_quests": ["Rescue the Mage"]}
|
|
assert classify_complexity("go north", state) == ModelTier.T2_MEDIUM
|
|
|
|
def test_dialogue_active_upgrades_t1_to_t2(self):
|
|
state = {"dialogue_active": True}
|
|
assert classify_complexity("yes", state) == ModelTier.T2_MEDIUM
|
|
|
|
def test_combat_active_upgrades_t1_to_t2(self):
|
|
state = {"combat_active": True}
|
|
assert classify_complexity("attack", state) == ModelTier.T2_MEDIUM
|
|
|
|
# ── T2: Medium ──────────────────────────────────────────────────────────
|
|
|
|
def test_default_is_t2(self):
|
|
assert classify_complexity("what do I have in my inventory", {}) == ModelTier.T2_MEDIUM
|
|
|
|
def test_dialogue_response_is_t2(self):
|
|
state = {"dialogue_active": True, "dialogue_npc": "Caius Cosades"}
|
|
result = classify_complexity("I'm looking for Caius Cosades", state)
|
|
assert result == ModelTier.T2_MEDIUM
|
|
|
|
# ── T3: Complex ─────────────────────────────────────────────────────────
|
|
|
|
def test_quest_planning_is_t3(self):
|
|
assert classify_complexity("plan my quest route", {}) == ModelTier.T3_COMPLEX
|
|
|
|
def test_strategy_keyword_is_t3(self):
|
|
assert classify_complexity("what is the best strategy", {}) == ModelTier.T3_COMPLEX
|
|
|
|
def test_stuck_keyword_is_t3(self):
|
|
assert classify_complexity("I am stuck", {}) == ModelTier.T3_COMPLEX
|
|
|
|
def test_stuck_state_is_t3(self):
|
|
assert classify_complexity("help me", {"stuck": True}) == ModelTier.T3_COMPLEX
|
|
|
|
def test_require_t3_flag_forces_t3(self):
|
|
state = {"require_t3": True}
|
|
assert classify_complexity("go north", state) == ModelTier.T3_COMPLEX
|
|
|
|
def test_optimize_keyword_is_t3(self):
|
|
assert classify_complexity("optimize my skill build", {}) == ModelTier.T3_COMPLEX
|
|
|
|
def test_multi_word_t3_phrase(self):
|
|
assert classify_complexity("how do i get past the guards", {}) == ModelTier.T3_COMPLEX
|
|
|
|
def test_case_insensitive(self):
|
|
assert classify_complexity("PLAN my route", {}) == ModelTier.T3_COMPLEX
|
|
|
|
|
|
# ── build_prompt ─────────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestBuildPrompt:
|
|
"""Verify prompt structure and content assembly."""
|
|
|
|
def test_returns_two_messages(self):
|
|
msgs = build_prompt({}, {}, "go north")
|
|
assert len(msgs) == 2
|
|
assert msgs[0]["role"] == "system"
|
|
assert msgs[1]["role"] == "user"
|
|
|
|
def test_user_message_contains_task(self):
|
|
msgs = build_prompt({}, {}, "pick up the sword")
|
|
assert msgs[1]["content"] == "pick up the sword"
|
|
|
|
def test_location_in_system(self):
|
|
msgs = build_prompt({"location": "Balmora"}, {}, "look around")
|
|
assert "Balmora" in msgs[0]["content"]
|
|
|
|
def test_health_in_system(self):
|
|
msgs = build_prompt({"health": 42}, {}, "rest")
|
|
assert "42" in msgs[0]["content"]
|
|
|
|
def test_inventory_in_system(self):
|
|
msgs = build_prompt({"inventory": ["iron sword", "bread"]}, {}, "use item")
|
|
assert "iron sword" in msgs[0]["content"]
|
|
|
|
def test_inventory_truncated_to_10(self):
|
|
inventory = [f"item{i}" for i in range(20)]
|
|
msgs = build_prompt({"inventory": inventory}, {}, "check")
|
|
# Only first 10 should appear in the system message
|
|
assert "item10" not in msgs[0]["content"]
|
|
|
|
def test_active_quests_in_system(self):
|
|
msgs = build_prompt({"active_quests": ["Morrowind Main Quest"]}, {}, "help")
|
|
assert "Morrowind Main Quest" in msgs[0]["content"]
|
|
|
|
def test_stuck_indicator_in_system(self):
|
|
msgs = build_prompt({"stuck": True}, {}, "what now")
|
|
assert "STUCK" in msgs[0]["content"]
|
|
|
|
def test_dialogue_npc_in_system(self):
|
|
msgs = build_prompt({}, {"dialogue_active": True, "dialogue_npc": "Vivec"}, "hello")
|
|
assert "Vivec" in msgs[0]["content"]
|
|
|
|
def test_menu_open_in_system(self):
|
|
msgs = build_prompt({}, {"menu_open": "inventory"}, "check items")
|
|
assert "inventory" in msgs[0]["content"]
|
|
|
|
def test_combat_active_in_system(self):
|
|
msgs = build_prompt({}, {"combat_active": True}, "attack")
|
|
assert "COMBAT" in msgs[0]["content"]
|
|
|
|
def test_visual_context_in_system(self):
|
|
msgs = build_prompt({}, {}, "where am I", visual_context="A dark dungeon corridor")
|
|
assert "dungeon corridor" in msgs[0]["content"]
|
|
|
|
def test_missing_optional_fields_omitted(self):
|
|
msgs = build_prompt({}, {}, "move forward")
|
|
system = msgs[0]["content"]
|
|
assert "Health:" not in system
|
|
assert "Inventory:" not in system
|
|
assert "Active quests:" not in system
|
|
|
|
def test_inventory_dict_items(self):
|
|
inventory = [{"name": "silver dagger"}, {"name": "potion"}]
|
|
msgs = build_prompt({"inventory": inventory}, {}, "use")
|
|
assert "silver dagger" in msgs[0]["content"]
|
|
|
|
def test_quest_dict_items(self):
|
|
quests = [{"name": "The Warlord"}, {"name": "Lost in Translation"}]
|
|
msgs = build_prompt({"active_quests": quests}, {}, "help")
|
|
assert "The Warlord" in msgs[0]["content"]
|
|
|
|
|
|
# ── MetabolicRouter ──────────────────────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestMetabolicRouter:
|
|
"""Test MetabolicRouter routing, tier labelling, and T3 world-pause logic."""
|
|
|
|
def _make_router(self, mock_cascade=None):
|
|
"""Create a MetabolicRouter with a mocked CascadeRouter."""
|
|
if mock_cascade is None:
|
|
mock_cascade = MagicMock()
|
|
mock_cascade.complete = AsyncMock(
|
|
return_value={
|
|
"content": "Move north confirmed.",
|
|
"provider": "ollama-local",
|
|
"model": "qwen3:8b",
|
|
"latency_ms": 120.0,
|
|
}
|
|
)
|
|
return MetabolicRouter(cascade=mock_cascade)
|
|
|
|
async def test_route_returns_tier_in_result(self):
|
|
router = self._make_router()
|
|
result = await router.route("go north", state={})
|
|
assert "tier" in result
|
|
assert result["tier"] == ModelTier.T1_ROUTINE
|
|
|
|
async def test_t1_uses_t1_model(self):
|
|
mock_cascade = MagicMock()
|
|
mock_cascade.complete = AsyncMock(
|
|
return_value={
|
|
"content": "ok",
|
|
"provider": "ollama-local",
|
|
"model": "qwen3:8b",
|
|
"latency_ms": 100,
|
|
}
|
|
)
|
|
router = MetabolicRouter(cascade=mock_cascade)
|
|
await router.route("go north", state={})
|
|
call_kwargs = mock_cascade.complete.call_args
|
|
assert call_kwargs.kwargs["model"] == DEFAULT_TIER_MODELS[ModelTier.T1_ROUTINE]
|
|
|
|
async def test_t2_uses_t2_model(self):
|
|
mock_cascade = MagicMock()
|
|
mock_cascade.complete = AsyncMock(
|
|
return_value={
|
|
"content": "ok",
|
|
"provider": "ollama-local",
|
|
"model": "qwen3:14b",
|
|
"latency_ms": 300,
|
|
}
|
|
)
|
|
router = MetabolicRouter(cascade=mock_cascade)
|
|
await router.route("what should I say to the innkeeper", state={})
|
|
call_kwargs = mock_cascade.complete.call_args
|
|
assert call_kwargs.kwargs["model"] == DEFAULT_TIER_MODELS[ModelTier.T2_MEDIUM]
|
|
|
|
async def test_t3_uses_t3_model(self):
|
|
mock_cascade = MagicMock()
|
|
mock_cascade.complete = AsyncMock(
|
|
return_value={
|
|
"content": "ok",
|
|
"provider": "ollama-local",
|
|
"model": "qwen3:30b",
|
|
"latency_ms": 2000,
|
|
}
|
|
)
|
|
router = MetabolicRouter(cascade=mock_cascade)
|
|
await router.route("plan the optimal quest route", state={})
|
|
call_kwargs = mock_cascade.complete.call_args
|
|
assert call_kwargs.kwargs["model"] == DEFAULT_TIER_MODELS[ModelTier.T3_COMPLEX]
|
|
|
|
async def test_custom_tier_models_respected(self):
|
|
mock_cascade = MagicMock()
|
|
mock_cascade.complete = AsyncMock(
|
|
return_value={
|
|
"content": "ok",
|
|
"provider": "test",
|
|
"model": "custom-8b",
|
|
"latency_ms": 100,
|
|
}
|
|
)
|
|
custom = {ModelTier.T1_ROUTINE: "custom-8b"}
|
|
router = MetabolicRouter(cascade=mock_cascade, tier_models=custom)
|
|
await router.route("go north", state={})
|
|
call_kwargs = mock_cascade.complete.call_args
|
|
assert call_kwargs.kwargs["model"] == "custom-8b"
|
|
|
|
async def test_t3_pauses_world_before_inference(self):
|
|
mock_cascade = MagicMock()
|
|
mock_cascade.complete = AsyncMock(
|
|
return_value={
|
|
"content": "ok",
|
|
"provider": "ollama",
|
|
"model": "qwen3:30b",
|
|
"latency_ms": 1500,
|
|
}
|
|
)
|
|
router = MetabolicRouter(cascade=mock_cascade)
|
|
|
|
pause_calls = []
|
|
unpause_calls = []
|
|
|
|
mock_world = MagicMock()
|
|
|
|
def track_act(cmd):
|
|
if cmd.action == "pause":
|
|
pause_calls.append(cmd)
|
|
elif cmd.action == "unpause":
|
|
unpause_calls.append(cmd)
|
|
|
|
mock_world.act = track_act
|
|
router.set_world(mock_world)
|
|
|
|
await router.route("plan the quest", state={})
|
|
|
|
assert len(pause_calls) == 1, "world.pause() should be called once for T3"
|
|
assert len(unpause_calls) == 1, "world.unpause() should be called once for T3"
|
|
|
|
async def test_t3_unpauses_world_even_on_llm_error(self):
|
|
"""world.unpause() must be called even when the LLM raises."""
|
|
mock_cascade = MagicMock()
|
|
mock_cascade.complete = AsyncMock(side_effect=RuntimeError("LLM failed"))
|
|
router = MetabolicRouter(cascade=mock_cascade)
|
|
|
|
unpause_calls = []
|
|
mock_world = MagicMock()
|
|
mock_world.act = lambda cmd: unpause_calls.append(cmd) if cmd.action == "unpause" else None
|
|
router.set_world(mock_world)
|
|
|
|
with pytest.raises(RuntimeError, match="LLM failed"):
|
|
await router.route("plan the quest", state={})
|
|
|
|
assert len(unpause_calls) == 1, "world.unpause() must run even when LLM errors"
|
|
|
|
async def test_t1_does_not_pause_world(self):
|
|
mock_cascade = MagicMock()
|
|
mock_cascade.complete = AsyncMock(
|
|
return_value={
|
|
"content": "ok",
|
|
"provider": "ollama",
|
|
"model": "qwen3:8b",
|
|
"latency_ms": 120,
|
|
}
|
|
)
|
|
router = MetabolicRouter(cascade=mock_cascade)
|
|
|
|
pause_calls = []
|
|
mock_world = MagicMock()
|
|
mock_world.act = lambda cmd: pause_calls.append(cmd)
|
|
router.set_world(mock_world)
|
|
|
|
await router.route("go north", state={})
|
|
|
|
assert len(pause_calls) == 0, "world.pause() must NOT be called for T1"
|
|
|
|
async def test_t2_does_not_pause_world(self):
|
|
mock_cascade = MagicMock()
|
|
mock_cascade.complete = AsyncMock(
|
|
return_value={
|
|
"content": "ok",
|
|
"provider": "ollama",
|
|
"model": "qwen3:14b",
|
|
"latency_ms": 350,
|
|
}
|
|
)
|
|
router = MetabolicRouter(cascade=mock_cascade)
|
|
|
|
pause_calls = []
|
|
mock_world = MagicMock()
|
|
mock_world.act = lambda cmd: pause_calls.append(cmd)
|
|
router.set_world(mock_world)
|
|
|
|
await router.route("talk to the merchant", state={})
|
|
|
|
assert len(pause_calls) == 0, "world.pause() must NOT be called for T2"
|
|
|
|
async def test_broken_world_adapter_degrades_gracefully(self):
|
|
"""If world.act() raises, inference must still complete."""
|
|
mock_cascade = MagicMock()
|
|
mock_cascade.complete = AsyncMock(
|
|
return_value={
|
|
"content": "done",
|
|
"provider": "ollama",
|
|
"model": "qwen3:30b",
|
|
"latency_ms": 2000,
|
|
}
|
|
)
|
|
router = MetabolicRouter(cascade=mock_cascade)
|
|
|
|
mock_world = MagicMock()
|
|
mock_world.act = MagicMock(side_effect=RuntimeError("world broken"))
|
|
router.set_world(mock_world)
|
|
|
|
# Should not raise — degradation only logs a warning
|
|
result = await router.route("plan the quest", state={})
|
|
assert result["content"] == "done"
|
|
|
|
async def test_no_world_adapter_t3_still_works(self):
|
|
mock_cascade = MagicMock()
|
|
mock_cascade.complete = AsyncMock(
|
|
return_value={
|
|
"content": "plan done",
|
|
"provider": "ollama",
|
|
"model": "qwen3:30b",
|
|
"latency_ms": 2000,
|
|
}
|
|
)
|
|
router = MetabolicRouter(cascade=mock_cascade)
|
|
# No set_world() called
|
|
|
|
result = await router.route("plan the quest route", state={})
|
|
assert result["content"] == "plan done"
|
|
assert result["tier"] == ModelTier.T3_COMPLEX
|
|
|
|
async def test_classify_delegates_to_module_function(self):
|
|
router = MetabolicRouter(cascade=MagicMock())
|
|
assert router.classify("go north", {}) == classify_complexity("go north", {})
|
|
assert router.classify("plan the quest", {}) == classify_complexity("plan the quest", {})
|
|
|
|
async def test_ui_state_defaults_to_empty_dict(self):
|
|
"""Calling route without ui_state should not raise."""
|
|
mock_cascade = MagicMock()
|
|
mock_cascade.complete = AsyncMock(
|
|
return_value={
|
|
"content": "ok",
|
|
"provider": "ollama",
|
|
"model": "qwen3:8b",
|
|
"latency_ms": 100,
|
|
}
|
|
)
|
|
router = MetabolicRouter(cascade=mock_cascade)
|
|
# No ui_state argument
|
|
result = await router.route("go north", state={})
|
|
assert result["content"] == "ok"
|
|
|
|
async def test_temperature_and_max_tokens_forwarded(self):
|
|
mock_cascade = MagicMock()
|
|
mock_cascade.complete = AsyncMock(
|
|
return_value={
|
|
"content": "ok",
|
|
"provider": "ollama",
|
|
"model": "qwen3:14b",
|
|
"latency_ms": 200,
|
|
}
|
|
)
|
|
router = MetabolicRouter(cascade=mock_cascade)
|
|
await router.route("describe the scene", state={}, temperature=0.1, max_tokens=50)
|
|
call_kwargs = mock_cascade.complete.call_args.kwargs
|
|
assert call_kwargs["temperature"] == 0.1
|
|
assert call_kwargs["max_tokens"] == 50
|
|
|
|
|
|
class TestGetMetabolicRouter:
|
|
"""Test module-level singleton."""
|
|
|
|
def test_returns_metabolic_router_instance(self):
|
|
import infrastructure.router.metabolic as m_module
|
|
|
|
# Reset singleton for clean test
|
|
m_module._metabolic_router = None
|
|
router = get_metabolic_router()
|
|
assert isinstance(router, MetabolicRouter)
|
|
|
|
def test_singleton_returns_same_instance(self):
|
|
import infrastructure.router.metabolic as m_module
|
|
|
|
m_module._metabolic_router = None
|
|
r1 = get_metabolic_router()
|
|
r2 = get_metabolic_router()
|
|
assert r1 is r2
|