381 lines
15 KiB
Python
381 lines
15 KiB
Python
"""Tests for the tiered model router (issue #882).
|
|
|
|
Covers:
|
|
- classify_tier() for Tier-1/2/3 routing
|
|
- TieredModelRouter.route() with mocked CascadeRouter + BudgetTracker
|
|
- Auto-escalation from Tier-1 on low-quality responses
|
|
- Cloud-tier budget guard
|
|
- Acceptance criteria from the issue:
|
|
- "Walk to the next room" → LOCAL_FAST
|
|
- "Plan the optimal path to become Hortator" → LOCAL_HEAVY
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from infrastructure.models.router import (
|
|
TieredModelRouter,
|
|
TierLabel,
|
|
_is_low_quality,
|
|
classify_tier,
|
|
get_tiered_router,
|
|
)
|
|
|
|
pytestmark = pytest.mark.unit
|
|
|
|
|
|
# ── classify_tier ─────────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestClassifyTier:
|
|
# ── Tier-1 (LOCAL_FAST) ────────────────────────────────────────────────
|
|
|
|
def test_simple_navigation_is_local_fast(self):
|
|
assert classify_tier("walk to the next room") == TierLabel.LOCAL_FAST
|
|
|
|
def test_go_north_is_local_fast(self):
|
|
assert classify_tier("go north") == TierLabel.LOCAL_FAST
|
|
|
|
def test_single_binary_choice_is_local_fast(self):
|
|
assert classify_tier("yes") == TierLabel.LOCAL_FAST
|
|
|
|
def test_open_door_is_local_fast(self):
|
|
assert classify_tier("open door") == TierLabel.LOCAL_FAST
|
|
|
|
def test_attack_is_local_fast(self):
|
|
assert classify_tier("attack", {}) == TierLabel.LOCAL_FAST
|
|
|
|
# ── Tier-2 (LOCAL_HEAVY) ───────────────────────────────────────────────
|
|
|
|
def test_quest_planning_is_local_heavy(self):
|
|
assert classify_tier("plan the optimal path to become Hortator") == TierLabel.LOCAL_HEAVY
|
|
|
|
def test_strategy_keyword_is_local_heavy(self):
|
|
assert classify_tier("what is the best strategy") == TierLabel.LOCAL_HEAVY
|
|
|
|
def test_stuck_state_escalates_to_local_heavy(self):
|
|
assert classify_tier("help me", {"stuck": True}) == TierLabel.LOCAL_HEAVY
|
|
|
|
def test_require_t2_flag_is_local_heavy(self):
|
|
assert classify_tier("go north", {"require_t2": True}) == TierLabel.LOCAL_HEAVY
|
|
|
|
def test_long_input_is_local_heavy(self):
|
|
long_task = "tell me about " + ("the dungeon " * 30)
|
|
assert classify_tier(long_task) == TierLabel.LOCAL_HEAVY
|
|
|
|
def test_active_quests_upgrades_to_local_heavy(self):
|
|
ctx = {"active_quests": ["Q1", "Q2", "Q3"]}
|
|
assert classify_tier("go north", ctx) == TierLabel.LOCAL_HEAVY
|
|
|
|
def test_dialogue_active_upgrades_to_local_heavy(self):
|
|
ctx = {"dialogue_active": True}
|
|
assert classify_tier("yes", ctx) == TierLabel.LOCAL_HEAVY
|
|
|
|
def test_analyze_is_local_heavy(self):
|
|
assert classify_tier("analyze the situation") == TierLabel.LOCAL_HEAVY
|
|
|
|
def test_optimize_is_local_heavy(self):
|
|
assert classify_tier("optimize my build") == TierLabel.LOCAL_HEAVY
|
|
|
|
def test_negotiate_is_local_heavy(self):
|
|
assert classify_tier("negotiate with the Camonna Tong") == TierLabel.LOCAL_HEAVY
|
|
|
|
def test_explain_is_local_heavy(self):
|
|
assert classify_tier("explain the faction system") == TierLabel.LOCAL_HEAVY
|
|
|
|
# ── Tier-3 (CLOUD_API) ─────────────────────────────────────────────────
|
|
|
|
def test_require_cloud_flag_is_cloud_api(self):
|
|
assert classify_tier("go north", {"require_cloud": True}) == TierLabel.CLOUD_API
|
|
|
|
def test_require_cloud_overrides_everything(self):
|
|
assert classify_tier("yes", {"require_cloud": True}) == TierLabel.CLOUD_API
|
|
|
|
# ── Edge cases ────────────────────────────────────────────────────────
|
|
|
|
def test_empty_task_defaults_to_local_heavy(self):
|
|
# Empty string → nothing classifies it as T1 or T3
|
|
assert classify_tier("") == TierLabel.LOCAL_HEAVY
|
|
|
|
def test_case_insensitive(self):
|
|
assert classify_tier("PLAN my route") == TierLabel.LOCAL_HEAVY
|
|
|
|
def test_combat_active_upgrades_t1_to_heavy(self):
|
|
ctx = {"combat_active": True}
|
|
# "attack" is T1 word, but combat context → should NOT be LOCAL_FAST
|
|
result = classify_tier("attack", ctx)
|
|
assert result != TierLabel.LOCAL_FAST
|
|
|
|
|
|
# ── _is_low_quality ───────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestIsLowQuality:
|
|
def test_empty_is_low_quality(self):
|
|
assert _is_low_quality("", TierLabel.LOCAL_FAST) is True
|
|
|
|
def test_whitespace_only_is_low_quality(self):
|
|
assert _is_low_quality(" ", TierLabel.LOCAL_FAST) is True
|
|
|
|
def test_very_short_is_low_quality(self):
|
|
assert _is_low_quality("ok", TierLabel.LOCAL_FAST) is True
|
|
|
|
def test_idontknow_is_low_quality(self):
|
|
assert _is_low_quality("I don't know how to help with that.", TierLabel.LOCAL_FAST) is True
|
|
|
|
def test_not_sure_is_low_quality(self):
|
|
assert _is_low_quality("I'm not sure about this.", TierLabel.LOCAL_FAST) is True
|
|
|
|
def test_as_an_ai_is_low_quality(self):
|
|
assert _is_low_quality("As an AI, I cannot...", TierLabel.LOCAL_FAST) is True
|
|
|
|
def test_good_response_is_not_low_quality(self):
|
|
response = "You move north into the Vivec Canton. The Ordinators watch your approach."
|
|
assert _is_low_quality(response, TierLabel.LOCAL_FAST) is False
|
|
|
|
def test_t1_short_response_triggers_escalation(self):
|
|
# Less than _ESCALATION_MIN_CHARS for T1
|
|
assert _is_low_quality("OK, done.", TierLabel.LOCAL_FAST) is True
|
|
|
|
def test_borderline_ok_for_t2_not_t1(self):
|
|
# Between _LOW_QUALITY_MIN_CHARS (20) and _ESCALATION_MIN_CHARS (60)
|
|
# → low quality for T1 (escalation threshold), but acceptable for T2/T3
|
|
response = "Done. The item is retrieved." # 28 chars: ≥20, <60
|
|
assert _is_low_quality(response, TierLabel.LOCAL_FAST) is True
|
|
assert _is_low_quality(response, TierLabel.LOCAL_HEAVY) is False
|
|
|
|
|
|
# ── TieredModelRouter ─────────────────────────────────────────────────────────
|
|
|
|
|
|
_GOOD_CONTENT = (
|
|
"You move north through the doorway into the next room. "
|
|
"The stone walls glisten with moisture."
|
|
) # 90 chars — well above the escalation threshold
|
|
|
|
|
|
def _make_cascade_mock(content=_GOOD_CONTENT, model="llama3.1:8b"):
|
|
mock = MagicMock()
|
|
mock.complete = AsyncMock(
|
|
return_value={
|
|
"content": content,
|
|
"provider": "ollama-local",
|
|
"model": model,
|
|
"latency_ms": 150.0,
|
|
}
|
|
)
|
|
return mock
|
|
|
|
|
|
def _make_budget_mock(allowed=True):
|
|
mock = MagicMock()
|
|
mock.cloud_allowed = MagicMock(return_value=allowed)
|
|
mock.record_spend = MagicMock(return_value=0.001)
|
|
return mock
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestTieredModelRouterRoute:
|
|
async def test_route_returns_tier_in_result(self):
|
|
router = TieredModelRouter(cascade=_make_cascade_mock())
|
|
result = await router.route("go north")
|
|
assert "tier" in result
|
|
assert result["tier"] == TierLabel.LOCAL_FAST
|
|
|
|
async def test_acceptance_walk_to_room_is_local_fast(self):
|
|
"""Acceptance: 'Walk to the next room' → LOCAL_FAST."""
|
|
router = TieredModelRouter(cascade=_make_cascade_mock())
|
|
result = await router.route("Walk to the next room")
|
|
assert result["tier"] == TierLabel.LOCAL_FAST
|
|
|
|
async def test_acceptance_plan_hortator_is_local_heavy(self):
|
|
"""Acceptance: 'Plan the optimal path to become Hortator' → LOCAL_HEAVY."""
|
|
router = TieredModelRouter(
|
|
cascade=_make_cascade_mock(model="hermes3:70b"),
|
|
)
|
|
result = await router.route("Plan the optimal path to become Hortator")
|
|
assert result["tier"] == TierLabel.LOCAL_HEAVY
|
|
|
|
async def test_t1_low_quality_escalates_to_t2(self):
|
|
"""Failed Tier-1 response auto-escalates to Tier-2."""
|
|
call_models = []
|
|
cascade = MagicMock()
|
|
|
|
async def complete_side_effect(messages, model, temperature, max_tokens):
|
|
call_models.append(model)
|
|
# First call (T1) returns a low-quality response
|
|
if len(call_models) == 1:
|
|
return {
|
|
"content": "I don't know.",
|
|
"provider": "ollama",
|
|
"model": model,
|
|
"latency_ms": 50,
|
|
}
|
|
# Second call (T2) returns a good response
|
|
return {
|
|
"content": "You move to the northern passage, passing through the Dunmer stronghold.",
|
|
"provider": "ollama",
|
|
"model": model,
|
|
"latency_ms": 800,
|
|
}
|
|
|
|
cascade.complete = complete_side_effect
|
|
|
|
router = TieredModelRouter(cascade=cascade, auto_escalate=True)
|
|
result = await router.route("go north")
|
|
|
|
assert len(call_models) == 2, "Should have called twice (T1 escalated to T2)"
|
|
assert result["tier"] == TierLabel.LOCAL_HEAVY
|
|
|
|
async def test_auto_escalate_false_no_escalation(self):
|
|
"""With auto_escalate=False, low-quality T1 response is returned as-is."""
|
|
call_count = {"n": 0}
|
|
cascade = MagicMock()
|
|
|
|
async def complete_side_effect(**kwargs):
|
|
call_count["n"] += 1
|
|
return {
|
|
"content": "I don't know.",
|
|
"provider": "ollama",
|
|
"model": "llama3.1:8b",
|
|
"latency_ms": 50,
|
|
}
|
|
|
|
cascade.complete = AsyncMock(side_effect=complete_side_effect)
|
|
router = TieredModelRouter(cascade=cascade, auto_escalate=False)
|
|
result = await router.route("go north")
|
|
assert call_count["n"] == 1
|
|
assert result["tier"] == TierLabel.LOCAL_FAST
|
|
|
|
async def test_t2_failure_escalates_to_cloud(self):
|
|
"""Tier-2 failure escalates to Cloud API (when budget allows)."""
|
|
cascade = MagicMock()
|
|
call_models = []
|
|
|
|
async def complete_side_effect(messages, model, temperature, max_tokens):
|
|
call_models.append(model)
|
|
if "hermes3" in model or "70b" in model.lower():
|
|
raise RuntimeError("Tier-2 model unavailable")
|
|
return {
|
|
"content": "Cloud response here.",
|
|
"provider": "anthropic",
|
|
"model": model,
|
|
"latency_ms": 1200,
|
|
}
|
|
|
|
cascade.complete = complete_side_effect
|
|
|
|
budget = _make_budget_mock(allowed=True)
|
|
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
|
|
result = await router.route("plan my route", context={"require_t2": True})
|
|
assert result["tier"] == TierLabel.CLOUD_API
|
|
|
|
async def test_cloud_blocked_by_budget_raises(self):
|
|
"""Cloud tier blocked when budget is exhausted."""
|
|
cascade = MagicMock()
|
|
cascade.complete = AsyncMock(side_effect=RuntimeError("T2 fail"))
|
|
|
|
budget = _make_budget_mock(allowed=False)
|
|
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
|
|
|
|
with pytest.raises(RuntimeError, match="budget limit"):
|
|
await router.route("plan my route", context={"require_t2": True})
|
|
|
|
async def test_explicit_cloud_tier_uses_cloud_model(self):
|
|
cascade = _make_cascade_mock(model="claude-haiku-4-5")
|
|
budget = _make_budget_mock(allowed=True)
|
|
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
|
|
result = await router.route("go north", context={"require_cloud": True})
|
|
assert result["tier"] == TierLabel.CLOUD_API
|
|
|
|
async def test_cloud_spend_recorded_with_usage(self):
|
|
"""Cloud spend is recorded when the response includes usage info."""
|
|
cascade = MagicMock()
|
|
cascade.complete = AsyncMock(
|
|
return_value={
|
|
"content": "Cloud answer.",
|
|
"provider": "anthropic",
|
|
"model": "claude-haiku-4-5",
|
|
"latency_ms": 900,
|
|
"usage": {"prompt_tokens": 50, "completion_tokens": 100},
|
|
}
|
|
)
|
|
budget = _make_budget_mock(allowed=True)
|
|
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
|
|
result = await router.route("go north", context={"require_cloud": True})
|
|
budget.record_spend.assert_called_once()
|
|
assert "cost_usd" in result
|
|
|
|
async def test_cloud_spend_not_recorded_without_usage(self):
|
|
"""Cloud spend is not recorded when usage info is absent."""
|
|
cascade = MagicMock()
|
|
cascade.complete = AsyncMock(
|
|
return_value={
|
|
"content": "Cloud answer.",
|
|
"provider": "anthropic",
|
|
"model": "claude-haiku-4-5",
|
|
"latency_ms": 900,
|
|
# no "usage" key
|
|
}
|
|
)
|
|
budget = _make_budget_mock(allowed=True)
|
|
router = TieredModelRouter(cascade=cascade, budget_tracker=budget)
|
|
result = await router.route("go north", context={"require_cloud": True})
|
|
budget.record_spend.assert_not_called()
|
|
assert "cost_usd" not in result
|
|
|
|
async def test_custom_tier_models_respected(self):
|
|
cascade = _make_cascade_mock()
|
|
router = TieredModelRouter(
|
|
cascade=cascade,
|
|
tier_models={TierLabel.LOCAL_FAST: "llama3.2:3b"},
|
|
)
|
|
await router.route("go north")
|
|
call_kwargs = cascade.complete.call_args
|
|
assert call_kwargs.kwargs["model"] == "llama3.2:3b"
|
|
|
|
async def test_messages_override_used_when_provided(self):
|
|
cascade = _make_cascade_mock()
|
|
router = TieredModelRouter(cascade=cascade)
|
|
custom_msgs = [{"role": "user", "content": "custom message"}]
|
|
await router.route("go north", messages=custom_msgs)
|
|
call_kwargs = cascade.complete.call_args
|
|
assert call_kwargs.kwargs["messages"] == custom_msgs
|
|
|
|
async def test_temperature_forwarded(self):
|
|
cascade = _make_cascade_mock()
|
|
router = TieredModelRouter(cascade=cascade)
|
|
await router.route("go north", temperature=0.7)
|
|
call_kwargs = cascade.complete.call_args
|
|
assert call_kwargs.kwargs["temperature"] == 0.7
|
|
|
|
async def test_max_tokens_forwarded(self):
|
|
cascade = _make_cascade_mock()
|
|
router = TieredModelRouter(cascade=cascade)
|
|
await router.route("go north", max_tokens=128)
|
|
call_kwargs = cascade.complete.call_args
|
|
assert call_kwargs.kwargs["max_tokens"] == 128
|
|
|
|
|
|
class TestTieredModelRouterClassify:
|
|
def test_classify_delegates_to_classify_tier(self):
|
|
router = TieredModelRouter(cascade=MagicMock())
|
|
assert router.classify("go north") == classify_tier("go north")
|
|
assert router.classify("plan the quest") == classify_tier("plan the quest")
|
|
|
|
|
|
class TestGetTieredRouterSingleton:
|
|
def test_returns_tiered_router_instance(self):
|
|
import infrastructure.models.router as rmod
|
|
rmod._tiered_router = None
|
|
router = get_tiered_router()
|
|
assert isinstance(router, TieredModelRouter)
|
|
|
|
def test_singleton_returns_same_instance(self):
|
|
import infrastructure.models.router as rmod
|
|
rmod._tiered_router = None
|
|
r1 = get_tiered_router()
|
|
r2 = get_tiered_router()
|
|
assert r1 is r2
|