"""Tests for the three-tier metabolic LLM router (issue #966).""" from unittest.mock import AsyncMock, MagicMock import pytest from infrastructure.router.metabolic import ( DEFAULT_TIER_MODELS, MetabolicRouter, ModelTier, build_prompt, classify_complexity, get_metabolic_router, ) pytestmark = pytest.mark.unit # ── classify_complexity ────────────────────────────────────────────────────── class TestClassifyComplexity: """Verify tier classification for representative task / state pairs.""" # ── T1: Routine ───────────────────────────────────────────────────────── def test_simple_navigation_is_t1(self): assert classify_complexity("go north", {}) == ModelTier.T1_ROUTINE def test_single_action_is_t1(self): assert classify_complexity("open door", {}) == ModelTier.T1_ROUTINE def test_t1_with_extra_words_stays_t1(self): # 6 words, all T1 territory, no active context assert classify_complexity("go south and take it", {}) == ModelTier.T1_ROUTINE def test_t1_long_task_upgrades_to_t2(self): # More than 6 words → not T1 even with nav words assert ( classify_complexity("go north and then move east and pick up the sword", {}) != ModelTier.T1_ROUTINE ) def test_active_quest_upgrades_t1_to_t2(self): state = {"active_quests": ["Rescue the Mage"]} assert classify_complexity("go north", state) == ModelTier.T2_MEDIUM def test_dialogue_active_upgrades_t1_to_t2(self): state = {"dialogue_active": True} assert classify_complexity("yes", state) == ModelTier.T2_MEDIUM def test_combat_active_upgrades_t1_to_t2(self): state = {"combat_active": True} assert classify_complexity("attack", state) == ModelTier.T2_MEDIUM # ── T2: Medium ────────────────────────────────────────────────────────── def test_default_is_t2(self): assert classify_complexity("what do I have in my inventory", {}) == ModelTier.T2_MEDIUM def test_dialogue_response_is_t2(self): state = {"dialogue_active": True, "dialogue_npc": "Caius Cosades"} result = classify_complexity("I'm looking for Caius Cosades", state) assert result == ModelTier.T2_MEDIUM # ── T3: Complex ───────────────────────────────────────────────────────── def test_quest_planning_is_t3(self): assert classify_complexity("plan my quest route", {}) == ModelTier.T3_COMPLEX def test_strategy_keyword_is_t3(self): assert classify_complexity("what is the best strategy", {}) == ModelTier.T3_COMPLEX def test_stuck_keyword_is_t3(self): assert classify_complexity("I am stuck", {}) == ModelTier.T3_COMPLEX def test_stuck_state_is_t3(self): assert classify_complexity("help me", {"stuck": True}) == ModelTier.T3_COMPLEX def test_require_t3_flag_forces_t3(self): state = {"require_t3": True} assert classify_complexity("go north", state) == ModelTier.T3_COMPLEX def test_optimize_keyword_is_t3(self): assert classify_complexity("optimize my skill build", {}) == ModelTier.T3_COMPLEX def test_multi_word_t3_phrase(self): assert classify_complexity("how do i get past the guards", {}) == ModelTier.T3_COMPLEX def test_case_insensitive(self): assert classify_complexity("PLAN my route", {}) == ModelTier.T3_COMPLEX # ── build_prompt ───────────────────────────────────────────────────────────── class TestBuildPrompt: """Verify prompt structure and content assembly.""" def test_returns_two_messages(self): msgs = build_prompt({}, {}, "go north") assert len(msgs) == 2 assert msgs[0]["role"] == "system" assert msgs[1]["role"] == "user" def test_user_message_contains_task(self): msgs = build_prompt({}, {}, "pick up the sword") assert msgs[1]["content"] == "pick up the sword" def test_location_in_system(self): msgs = build_prompt({"location": "Balmora"}, {}, "look around") assert "Balmora" in msgs[0]["content"] def test_health_in_system(self): msgs = build_prompt({"health": 42}, {}, "rest") assert "42" in msgs[0]["content"] def test_inventory_in_system(self): msgs = build_prompt({"inventory": ["iron sword", "bread"]}, {}, "use item") assert "iron sword" in msgs[0]["content"] def test_inventory_truncated_to_10(self): inventory = [f"item{i}" for i in range(20)] msgs = build_prompt({"inventory": inventory}, {}, "check") # Only first 10 should appear in the system message assert "item10" not in msgs[0]["content"] def test_active_quests_in_system(self): msgs = build_prompt({"active_quests": ["Morrowind Main Quest"]}, {}, "help") assert "Morrowind Main Quest" in msgs[0]["content"] def test_stuck_indicator_in_system(self): msgs = build_prompt({"stuck": True}, {}, "what now") assert "STUCK" in msgs[0]["content"] def test_dialogue_npc_in_system(self): msgs = build_prompt({}, {"dialogue_active": True, "dialogue_npc": "Vivec"}, "hello") assert "Vivec" in msgs[0]["content"] def test_menu_open_in_system(self): msgs = build_prompt({}, {"menu_open": "inventory"}, "check items") assert "inventory" in msgs[0]["content"] def test_combat_active_in_system(self): msgs = build_prompt({}, {"combat_active": True}, "attack") assert "COMBAT" in msgs[0]["content"] def test_visual_context_in_system(self): msgs = build_prompt({}, {}, "where am I", visual_context="A dark dungeon corridor") assert "dungeon corridor" in msgs[0]["content"] def test_missing_optional_fields_omitted(self): msgs = build_prompt({}, {}, "move forward") system = msgs[0]["content"] assert "Health:" not in system assert "Inventory:" not in system assert "Active quests:" not in system def test_inventory_dict_items(self): inventory = [{"name": "silver dagger"}, {"name": "potion"}] msgs = build_prompt({"inventory": inventory}, {}, "use") assert "silver dagger" in msgs[0]["content"] def test_quest_dict_items(self): quests = [{"name": "The Warlord"}, {"name": "Lost in Translation"}] msgs = build_prompt({"active_quests": quests}, {}, "help") assert "The Warlord" in msgs[0]["content"] # ── MetabolicRouter ────────────────────────────────────────────────────────── @pytest.mark.asyncio class TestMetabolicRouter: """Test MetabolicRouter routing, tier labelling, and T3 world-pause logic.""" def _make_router(self, mock_cascade=None): """Create a MetabolicRouter with a mocked CascadeRouter.""" if mock_cascade is None: mock_cascade = MagicMock() mock_cascade.complete = AsyncMock( return_value={ "content": "Move north confirmed.", "provider": "ollama-local", "model": "qwen3:8b", "latency_ms": 120.0, } ) return MetabolicRouter(cascade=mock_cascade) async def test_route_returns_tier_in_result(self): router = self._make_router() result = await router.route("go north", state={}) assert "tier" in result assert result["tier"] == ModelTier.T1_ROUTINE async def test_t1_uses_t1_model(self): mock_cascade = MagicMock() mock_cascade.complete = AsyncMock( return_value={ "content": "ok", "provider": "ollama-local", "model": "qwen3:8b", "latency_ms": 100, } ) router = MetabolicRouter(cascade=mock_cascade) await router.route("go north", state={}) call_kwargs = mock_cascade.complete.call_args assert call_kwargs.kwargs["model"] == DEFAULT_TIER_MODELS[ModelTier.T1_ROUTINE] async def test_t2_uses_t2_model(self): mock_cascade = MagicMock() mock_cascade.complete = AsyncMock( return_value={ "content": "ok", "provider": "ollama-local", "model": "qwen3:14b", "latency_ms": 300, } ) router = MetabolicRouter(cascade=mock_cascade) await router.route("what should I say to the innkeeper", state={}) call_kwargs = mock_cascade.complete.call_args assert call_kwargs.kwargs["model"] == DEFAULT_TIER_MODELS[ModelTier.T2_MEDIUM] async def test_t3_uses_t3_model(self): mock_cascade = MagicMock() mock_cascade.complete = AsyncMock( return_value={ "content": "ok", "provider": "ollama-local", "model": "qwen3:30b", "latency_ms": 2000, } ) router = MetabolicRouter(cascade=mock_cascade) await router.route("plan the optimal quest route", state={}) call_kwargs = mock_cascade.complete.call_args assert call_kwargs.kwargs["model"] == DEFAULT_TIER_MODELS[ModelTier.T3_COMPLEX] async def test_custom_tier_models_respected(self): mock_cascade = MagicMock() mock_cascade.complete = AsyncMock( return_value={ "content": "ok", "provider": "test", "model": "custom-8b", "latency_ms": 100, } ) custom = {ModelTier.T1_ROUTINE: "custom-8b"} router = MetabolicRouter(cascade=mock_cascade, tier_models=custom) await router.route("go north", state={}) call_kwargs = mock_cascade.complete.call_args assert call_kwargs.kwargs["model"] == "custom-8b" async def test_t3_pauses_world_before_inference(self): mock_cascade = MagicMock() mock_cascade.complete = AsyncMock( return_value={ "content": "ok", "provider": "ollama", "model": "qwen3:30b", "latency_ms": 1500, } ) router = MetabolicRouter(cascade=mock_cascade) pause_calls = [] unpause_calls = [] mock_world = MagicMock() def track_act(cmd): if cmd.action == "pause": pause_calls.append(cmd) elif cmd.action == "unpause": unpause_calls.append(cmd) mock_world.act = track_act router.set_world(mock_world) await router.route("plan the quest", state={}) assert len(pause_calls) == 1, "world.pause() should be called once for T3" assert len(unpause_calls) == 1, "world.unpause() should be called once for T3" async def test_t3_unpauses_world_even_on_llm_error(self): """world.unpause() must be called even when the LLM raises.""" mock_cascade = MagicMock() mock_cascade.complete = AsyncMock(side_effect=RuntimeError("LLM failed")) router = MetabolicRouter(cascade=mock_cascade) unpause_calls = [] mock_world = MagicMock() mock_world.act = lambda cmd: unpause_calls.append(cmd) if cmd.action == "unpause" else None router.set_world(mock_world) with pytest.raises(RuntimeError, match="LLM failed"): await router.route("plan the quest", state={}) assert len(unpause_calls) == 1, "world.unpause() must run even when LLM errors" async def test_t1_does_not_pause_world(self): mock_cascade = MagicMock() mock_cascade.complete = AsyncMock( return_value={ "content": "ok", "provider": "ollama", "model": "qwen3:8b", "latency_ms": 120, } ) router = MetabolicRouter(cascade=mock_cascade) pause_calls = [] mock_world = MagicMock() mock_world.act = lambda cmd: pause_calls.append(cmd) router.set_world(mock_world) await router.route("go north", state={}) assert len(pause_calls) == 0, "world.pause() must NOT be called for T1" async def test_t2_does_not_pause_world(self): mock_cascade = MagicMock() mock_cascade.complete = AsyncMock( return_value={ "content": "ok", "provider": "ollama", "model": "qwen3:14b", "latency_ms": 350, } ) router = MetabolicRouter(cascade=mock_cascade) pause_calls = [] mock_world = MagicMock() mock_world.act = lambda cmd: pause_calls.append(cmd) router.set_world(mock_world) await router.route("talk to the merchant", state={}) assert len(pause_calls) == 0, "world.pause() must NOT be called for T2" async def test_broken_world_adapter_degrades_gracefully(self): """If world.act() raises, inference must still complete.""" mock_cascade = MagicMock() mock_cascade.complete = AsyncMock( return_value={ "content": "done", "provider": "ollama", "model": "qwen3:30b", "latency_ms": 2000, } ) router = MetabolicRouter(cascade=mock_cascade) mock_world = MagicMock() mock_world.act = MagicMock(side_effect=RuntimeError("world broken")) router.set_world(mock_world) # Should not raise — degradation only logs a warning result = await router.route("plan the quest", state={}) assert result["content"] == "done" async def test_no_world_adapter_t3_still_works(self): mock_cascade = MagicMock() mock_cascade.complete = AsyncMock( return_value={ "content": "plan done", "provider": "ollama", "model": "qwen3:30b", "latency_ms": 2000, } ) router = MetabolicRouter(cascade=mock_cascade) # No set_world() called result = await router.route("plan the quest route", state={}) assert result["content"] == "plan done" assert result["tier"] == ModelTier.T3_COMPLEX async def test_classify_delegates_to_module_function(self): router = MetabolicRouter(cascade=MagicMock()) assert router.classify("go north", {}) == classify_complexity("go north", {}) assert router.classify("plan the quest", {}) == classify_complexity("plan the quest", {}) async def test_ui_state_defaults_to_empty_dict(self): """Calling route without ui_state should not raise.""" mock_cascade = MagicMock() mock_cascade.complete = AsyncMock( return_value={ "content": "ok", "provider": "ollama", "model": "qwen3:8b", "latency_ms": 100, } ) router = MetabolicRouter(cascade=mock_cascade) # No ui_state argument result = await router.route("go north", state={}) assert result["content"] == "ok" async def test_temperature_and_max_tokens_forwarded(self): mock_cascade = MagicMock() mock_cascade.complete = AsyncMock( return_value={ "content": "ok", "provider": "ollama", "model": "qwen3:14b", "latency_ms": 200, } ) router = MetabolicRouter(cascade=mock_cascade) await router.route("describe the scene", state={}, temperature=0.1, max_tokens=50) call_kwargs = mock_cascade.complete.call_args.kwargs assert call_kwargs["temperature"] == 0.1 assert call_kwargs["max_tokens"] == 50 class TestGetMetabolicRouter: """Test module-level singleton.""" def test_returns_metabolic_router_instance(self): import infrastructure.router.metabolic as m_module # Reset singleton for clean test m_module._metabolic_router = None router = get_metabolic_router() assert isinstance(router, MetabolicRouter) def test_singleton_returns_same_instance(self): import infrastructure.router.metabolic as m_module m_module._metabolic_router = None r1 = get_metabolic_router() r2 = get_metabolic_router() assert r1 is r2