From 6c5f55230beac452022f3df072d3a2a1c89c34b2 Mon Sep 17 00:00:00 2001 From: Alexander Whitestone Date: Mon, 23 Mar 2026 14:41:42 -0400 Subject: [PATCH 1/2] WIP: Claude Code progress on #1065 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Automated salvage commit — agent session ended (exit 124). Work in progress, may need continuation. --- config/providers.yaml | 27 +++ src/config.py | 7 + src/infrastructure/router/__init__.py | 3 + src/infrastructure/router/cascade.py | 71 ++++++- src/infrastructure/router/classifier.py | 166 +++++++++++++++ tests/infrastructure/test_router_cascade.py | 192 ++++++++++++++++++ .../infrastructure/test_router_classifier.py | 134 ++++++++++++ 7 files changed, 596 insertions(+), 4 deletions(-) create mode 100644 src/infrastructure/router/classifier.py create mode 100644 tests/infrastructure/test_router_classifier.py diff --git a/config/providers.yaml b/config/providers.yaml index 33fa0ca6..f629c022 100644 --- a/config/providers.yaml +++ b/config/providers.yaml @@ -25,6 +25,19 @@ providers: tier: local url: "http://localhost:11434" models: + # ── Dual-model routing: Qwen3-8B (fast) + Qwen3-14B (quality) ────────── + # Both models fit simultaneously: ~6.6 GB + ~10.5 GB = ~17 GB combined. + # Requires OLLAMA_MAX_LOADED_MODELS=2 (set in .env) to stay hot. + # Ref: issue #1065 — Qwen3-8B/14B dual-model routing strategy + - name: qwen3:8b + context_window: 32768 + capabilities: [text, tools, json, streaming, routine] + description: "Qwen3-8B Q6_K — fast router for routine tasks (~6.6 GB, 45-55 tok/s)" + - name: qwen3:14b + context_window: 40960 + capabilities: [text, tools, json, streaming, complex, reasoning] + description: "Qwen3-14B Q5_K_M — complex reasoning and planning (~10.5 GB, 20-28 tok/s)" + # Text + Tools models - name: qwen3:30b default: true @@ -187,6 +200,20 @@ fallback_chains: - dolphin3 # base Dolphin 3.0 8B (uncensored, no custom system prompt) - qwen3:30b # primary fallback — usually sufficient with a good system prompt + # ── Complexity-based routing chains (issue #1065) ─────────────────────── + # Routine tasks: prefer Qwen3-8B for low latency (~45-55 tok/s) + routine: + - qwen3:8b # Primary fast model + - llama3.1:8b-instruct # Fallback fast model + - llama3.2:3b # Smallest available + + # Complex tasks: prefer Qwen3-14B for quality (~20-28 tok/s) + complex: + - qwen3:14b # Primary quality model + - hermes4-14b # Native tool calling, hybrid reasoning + - qwen3:30b # Highest local quality + - qwen2.5:14b # Additional fallback + # ── Custom Models ─────────────────────────────────────────────────────────── # Register custom model weights for per-agent assignment. # Supports GGUF (Ollama), safetensors, and HuggingFace checkpoint dirs. diff --git a/src/config.py b/src/config.py index 712e5750..36788e42 100644 --- a/src/config.py +++ b/src/config.py @@ -41,6 +41,13 @@ class Settings(BaseSettings): # 4096 keeps memory at ~19GB. Set to 0 to use model defaults. ollama_num_ctx: int = 4096 + # Maximum models loaded simultaneously in Ollama — override with OLLAMA_MAX_LOADED_MODELS + # Set to 2 so Qwen3-8B and Qwen3-14B can stay hot concurrently (~17 GB combined). + # Requires Ollama ≥ 0.1.33. Export this to the Ollama process environment: + # OLLAMA_MAX_LOADED_MODELS=2 ollama serve + # or add it to your systemd/launchd unit before starting the harness. + ollama_max_loaded_models: int = 2 + # Fallback model chains — override with FALLBACK_MODELS / VISION_FALLBACK_MODELS # as comma-separated strings, e.g. FALLBACK_MODELS="qwen3:30b,llama3.1" # Or edit config/providers.yaml → fallback_chains for the canonical source. diff --git a/src/infrastructure/router/__init__.py b/src/infrastructure/router/__init__.py index dfe39c12..326e16e1 100644 --- a/src/infrastructure/router/__init__.py +++ b/src/infrastructure/router/__init__.py @@ -2,6 +2,7 @@ from .api import router from .cascade import CascadeRouter, Provider, ProviderStatus, get_router +from .classifier import TaskComplexity, classify_task from .history import HealthHistoryStore, get_history_store __all__ = [ @@ -12,4 +13,6 @@ __all__ = [ "router", "HealthHistoryStore", "get_history_store", + "TaskComplexity", + "classify_task", ] diff --git a/src/infrastructure/router/cascade.py b/src/infrastructure/router/cascade.py index 84f07e90..2de133d6 100644 --- a/src/infrastructure/router/cascade.py +++ b/src/infrastructure/router/cascade.py @@ -528,6 +528,30 @@ class CascadeRouter: return True + def _get_model_for_complexity( + self, provider: Provider, complexity: "TaskComplexity" + ) -> str | None: + """Return the best model on *provider* for the given complexity tier. + + Checks fallback chains first (routine / complex), then falls back to + any model with the matching capability tag, then the provider default. + """ + from infrastructure.router.classifier import TaskComplexity + + chain_key = "routine" if complexity == TaskComplexity.SIMPLE else "complex" + + # Walk the capability fallback chain — first model present on this provider wins + for model_name in self.config.fallback_chains.get(chain_key, []): + if any(m["name"] == model_name for m in provider.models): + return model_name + + # Direct capability lookup as a secondary pass + cap_model = provider.get_model_with_capability(chain_key) + if cap_model: + return cap_model + + return None # Caller will use provider default + async def complete( self, messages: list[dict], @@ -535,6 +559,7 @@ class CascadeRouter: temperature: float = 0.7, max_tokens: int | None = None, cascade_tier: str | None = None, + complexity_hint: str | None = None, ) -> dict: """Complete a chat conversation with automatic failover. @@ -543,24 +568,48 @@ class CascadeRouter: - Falls back to vision-capable models when needed - Supports image URLs, paths, and base64 encoding + Complexity-based routing (issue #1065): + - ``complexity_hint="simple"`` → routes to Qwen3-8B (low-latency) + - ``complexity_hint="complex"`` → routes to Qwen3-14B (quality) + - ``complexity_hint=None`` (default) → auto-classifies from messages + Args: messages: List of message dicts with role and content - model: Preferred model (tries this first, then provider defaults) + model: Preferred model (tries this first; complexity routing is + skipped when an explicit model is given) temperature: Sampling temperature max_tokens: Maximum tokens to generate cascade_tier: If specified, filters providers by this tier. - "frontier_required": Uses only Anthropic provider for top-tier models. + complexity_hint: "simple", "complex", or None (auto-detect). Returns: - Dict with content, provider_used, and metrics + Dict with content, provider_used, model, latency_ms, + is_fallback_model, and complexity fields. Raises: RuntimeError: If all providers fail """ + from infrastructure.router.classifier import TaskComplexity, classify_task + content_type = self._detect_content_type(messages) if content_type != ContentType.TEXT: logger.debug("Detected %s content, selecting appropriate model", content_type.value) + # Resolve task complexity ───────────────────────────────────────────── + # Skip complexity routing when caller explicitly specifies a model. + complexity: TaskComplexity | None = None + if model is None: + if complexity_hint is not None: + try: + complexity = TaskComplexity(complexity_hint.lower()) + except ValueError: + logger.warning("Unknown complexity_hint %r, auto-classifying", complexity_hint) + complexity = classify_task(messages) + else: + complexity = classify_task(messages) + logger.debug("Task complexity: %s", complexity.value) + errors = [] providers = self.providers @@ -573,7 +622,6 @@ class CascadeRouter: if not providers: raise RuntimeError(f"No providers found for tier: {cascade_tier}") - for provider in providers: if not self._is_provider_available(provider): continue @@ -587,7 +635,21 @@ class CascadeRouter: ) continue - selected_model, is_fallback_model = self._select_model(provider, model, content_type) + # Complexity-based model selection (only when no explicit model) ── + effective_model = model + if effective_model is None and complexity is not None: + effective_model = self._get_model_for_complexity(provider, complexity) + if effective_model: + logger.debug( + "Complexity routing [%s]: %s → %s", + complexity.value, + provider.name, + effective_model, + ) + + selected_model, is_fallback_model = self._select_model( + provider, effective_model, content_type + ) try: result = await self._attempt_with_retry( @@ -610,6 +672,7 @@ class CascadeRouter: "model": result.get("model", selected_model or provider.get_default_model()), "latency_ms": result.get("latency_ms", 0), "is_fallback_model": is_fallback_model, + "complexity": complexity.value if complexity is not None else None, } raise RuntimeError(f"All providers failed: {'; '.join(errors)}") diff --git a/src/infrastructure/router/classifier.py b/src/infrastructure/router/classifier.py new file mode 100644 index 00000000..26e2fdc2 --- /dev/null +++ b/src/infrastructure/router/classifier.py @@ -0,0 +1,166 @@ +"""Task complexity classifier for Qwen3 dual-model routing. + +Classifies incoming tasks as SIMPLE (route to Qwen3-8B for low-latency) +or COMPLEX (route to Qwen3-14B for quality-sensitive work). + +Classification is fully heuristic — no LLM inference required. +""" + +import re +from enum import Enum + + +class TaskComplexity(Enum): + """Task complexity tier for model routing.""" + + SIMPLE = "simple" # Qwen3-8B Q6_K: routine, latency-sensitive + COMPLEX = "complex" # Qwen3-14B Q5_K_M: quality-sensitive, multi-step + + +# Keywords strongly associated with complex tasks +_COMPLEX_KEYWORDS: frozenset[str] = frozenset( + [ + "plan", + "review", + "analyze", + "analyse", + "triage", + "refactor", + "design", + "architecture", + "implement", + "compare", + "debug", + "explain", + "prioritize", + "prioritise", + "strategy", + "optimize", + "optimise", + "evaluate", + "assess", + "brainstorm", + "outline", + "summarize", + "summarise", + "generate code", + "write a", + "write the", + "code review", + "pull request", + "multi-step", + "multi step", + "step by step", + "backlog prioriti", + "issue triage", + "root cause", + "how does", + "why does", + "what are the", + ] +) + +# Keywords strongly associated with simple/routine tasks +_SIMPLE_KEYWORDS: frozenset[str] = frozenset( + [ + "status", + "list ", + "show ", + "what is", + "how many", + "ping", + "run ", + "execute ", + "ls ", + "cat ", + "ps ", + "fetch ", + "count ", + "tail ", + "head ", + "grep ", + "find file", + "read file", + "get ", + "query ", + "check ", + "yes", + "no", + "ok", + "done", + "thanks", + ] +) + +# Content longer than this is treated as complex regardless of keywords +_COMPLEX_CHAR_THRESHOLD = 500 + +# Short content defaults to simple +_SIMPLE_CHAR_THRESHOLD = 150 + +# More than this many messages suggests an ongoing complex conversation +_COMPLEX_CONVERSATION_DEPTH = 6 + + +def classify_task(messages: list[dict]) -> TaskComplexity: + """Classify task complexity from a list of messages. + + Uses heuristic rules — no LLM call required. Errs toward COMPLEX + when uncertain so that quality is preserved. + + Args: + messages: List of message dicts with ``role`` and ``content`` keys. + + Returns: + TaskComplexity.SIMPLE or TaskComplexity.COMPLEX + """ + if not messages: + return TaskComplexity.SIMPLE + + # Concatenate all user-turn content for analysis + user_content = " ".join( + msg.get("content", "") + for msg in messages + if msg.get("role") in ("user", "human") + and isinstance(msg.get("content"), str) + ).lower().strip() + + if not user_content: + return TaskComplexity.SIMPLE + + # Complexity signals override everything ----------------------------------- + + # Explicit complex keywords + for kw in _COMPLEX_KEYWORDS: + if kw in user_content: + return TaskComplexity.COMPLEX + + # Numbered / multi-step instruction list: "1. do this 2. do that" + if re.search(r"\b\d+\.\s+\w", user_content): + return TaskComplexity.COMPLEX + + # Code blocks embedded in messages + if "```" in user_content: + return TaskComplexity.COMPLEX + + # Long content → complex reasoning likely required + if len(user_content) > _COMPLEX_CHAR_THRESHOLD: + return TaskComplexity.COMPLEX + + # Deep conversation → complex ongoing task + if len(messages) > _COMPLEX_CONVERSATION_DEPTH: + return TaskComplexity.COMPLEX + + # Simplicity signals ------------------------------------------------------- + + # Explicit simple keywords + for kw in _SIMPLE_KEYWORDS: + if kw in user_content: + return TaskComplexity.SIMPLE + + # Short single-sentence messages default to simple + if len(user_content) <= _SIMPLE_CHAR_THRESHOLD: + return TaskComplexity.SIMPLE + + # When uncertain, prefer quality (complex model) + return TaskComplexity.COMPLEX diff --git a/tests/infrastructure/test_router_cascade.py b/tests/infrastructure/test_router_cascade.py index ca881c6a..cef48e32 100644 --- a/tests/infrastructure/test_router_cascade.py +++ b/tests/infrastructure/test_router_cascade.py @@ -968,3 +968,195 @@ class TestCascadeRouterReload: assert router.providers[0].name == "low-priority" assert router.providers[1].name == "high-priority" + + +class TestComplexityRouting: + """Tests for Qwen3-8B / Qwen3-14B dual-model routing (issue #1065).""" + + def _make_dual_model_provider(self) -> Provider: + """Build an Ollama provider with both Qwen3 models registered.""" + return Provider( + name="ollama-local", + type="ollama", + enabled=True, + priority=1, + url="http://localhost:11434", + models=[ + { + "name": "qwen3:8b", + "capabilities": ["text", "tools", "json", "streaming", "routine"], + }, + { + "name": "qwen3:14b", + "default": True, + "capabilities": ["text", "tools", "json", "streaming", "complex", "reasoning"], + }, + ], + ) + + def test_get_model_for_complexity_simple_returns_8b(self): + """Simple tasks should select the model with 'routine' capability.""" + from infrastructure.router.classifier import TaskComplexity + + router = CascadeRouter(config_path=Path("/nonexistent")) + router.config.fallback_chains = { + "routine": ["qwen3:8b"], + "complex": ["qwen3:14b"], + } + provider = self._make_dual_model_provider() + + model = router._get_model_for_complexity(provider, TaskComplexity.SIMPLE) + assert model == "qwen3:8b" + + def test_get_model_for_complexity_complex_returns_14b(self): + """Complex tasks should select the model with 'complex' capability.""" + from infrastructure.router.classifier import TaskComplexity + + router = CascadeRouter(config_path=Path("/nonexistent")) + router.config.fallback_chains = { + "routine": ["qwen3:8b"], + "complex": ["qwen3:14b"], + } + provider = self._make_dual_model_provider() + + model = router._get_model_for_complexity(provider, TaskComplexity.COMPLEX) + assert model == "qwen3:14b" + + def test_get_model_for_complexity_returns_none_when_no_match(self): + """Returns None when provider has no matching model in chain.""" + from infrastructure.router.classifier import TaskComplexity + + router = CascadeRouter(config_path=Path("/nonexistent")) + router.config.fallback_chains = {} # empty chains + + provider = Provider( + name="test", + type="ollama", + enabled=True, + priority=1, + models=[{"name": "llama3.2:3b", "default": True, "capabilities": ["text"]}], + ) + + # No 'routine' or 'complex' model available + model = router._get_model_for_complexity(provider, TaskComplexity.SIMPLE) + assert model is None + + @pytest.mark.asyncio + async def test_complete_with_simple_hint_routes_to_8b(self): + """complexity_hint='simple' should use qwen3:8b.""" + router = CascadeRouter(config_path=Path("/nonexistent")) + router.config.fallback_chains = { + "routine": ["qwen3:8b"], + "complex": ["qwen3:14b"], + } + router.providers = [self._make_dual_model_provider()] + + with patch.object(router, "_call_ollama") as mock_call: + mock_call.return_value = {"content": "fast answer", "model": "qwen3:8b"} + result = await router.complete( + messages=[{"role": "user", "content": "list tasks"}], + complexity_hint="simple", + ) + + assert result["model"] == "qwen3:8b" + assert result["complexity"] == "simple" + + @pytest.mark.asyncio + async def test_complete_with_complex_hint_routes_to_14b(self): + """complexity_hint='complex' should use qwen3:14b.""" + router = CascadeRouter(config_path=Path("/nonexistent")) + router.config.fallback_chains = { + "routine": ["qwen3:8b"], + "complex": ["qwen3:14b"], + } + router.providers = [self._make_dual_model_provider()] + + with patch.object(router, "_call_ollama") as mock_call: + mock_call.return_value = {"content": "detailed answer", "model": "qwen3:14b"} + result = await router.complete( + messages=[{"role": "user", "content": "review this PR"}], + complexity_hint="complex", + ) + + assert result["model"] == "qwen3:14b" + assert result["complexity"] == "complex" + + @pytest.mark.asyncio + async def test_explicit_model_bypasses_complexity_routing(self): + """When model is explicitly provided, complexity routing is skipped.""" + router = CascadeRouter(config_path=Path("/nonexistent")) + router.config.fallback_chains = { + "routine": ["qwen3:8b"], + "complex": ["qwen3:14b"], + } + router.providers = [self._make_dual_model_provider()] + + with patch.object(router, "_call_ollama") as mock_call: + mock_call.return_value = {"content": "response", "model": "qwen3:14b"} + result = await router.complete( + messages=[{"role": "user", "content": "list tasks"}], + model="qwen3:14b", # explicit override + ) + + # Explicit model wins — complexity field is None + assert result["model"] == "qwen3:14b" + assert result["complexity"] is None + + @pytest.mark.asyncio + async def test_auto_classification_routes_simple_message(self): + """Short, simple messages should auto-classify as SIMPLE → 8B.""" + router = CascadeRouter(config_path=Path("/nonexistent")) + router.config.fallback_chains = { + "routine": ["qwen3:8b"], + "complex": ["qwen3:14b"], + } + router.providers = [self._make_dual_model_provider()] + + with patch.object(router, "_call_ollama") as mock_call: + mock_call.return_value = {"content": "ok", "model": "qwen3:8b"} + result = await router.complete( + messages=[{"role": "user", "content": "status"}], + # no complexity_hint — auto-classify + ) + + assert result["complexity"] == "simple" + assert result["model"] == "qwen3:8b" + + @pytest.mark.asyncio + async def test_auto_classification_routes_complex_message(self): + """Complex messages should auto-classify → 14B.""" + router = CascadeRouter(config_path=Path("/nonexistent")) + router.config.fallback_chains = { + "routine": ["qwen3:8b"], + "complex": ["qwen3:14b"], + } + router.providers = [self._make_dual_model_provider()] + + with patch.object(router, "_call_ollama") as mock_call: + mock_call.return_value = {"content": "deep analysis", "model": "qwen3:14b"} + result = await router.complete( + messages=[{"role": "user", "content": "analyze and prioritize the backlog"}], + ) + + assert result["complexity"] == "complex" + assert result["model"] == "qwen3:14b" + + @pytest.mark.asyncio + async def test_invalid_complexity_hint_falls_back_to_auto(self): + """Invalid complexity_hint should log a warning and auto-classify.""" + router = CascadeRouter(config_path=Path("/nonexistent")) + router.config.fallback_chains = { + "routine": ["qwen3:8b"], + "complex": ["qwen3:14b"], + } + router.providers = [self._make_dual_model_provider()] + + with patch.object(router, "_call_ollama") as mock_call: + mock_call.return_value = {"content": "ok", "model": "qwen3:8b"} + # Should not raise + result = await router.complete( + messages=[{"role": "user", "content": "status"}], + complexity_hint="INVALID_HINT", + ) + + assert result["complexity"] in ("simple", "complex") # auto-classified diff --git a/tests/infrastructure/test_router_classifier.py b/tests/infrastructure/test_router_classifier.py new file mode 100644 index 00000000..1e9d4df7 --- /dev/null +++ b/tests/infrastructure/test_router_classifier.py @@ -0,0 +1,134 @@ +"""Tests for Qwen3 dual-model task complexity classifier.""" + +import pytest + +from infrastructure.router.classifier import TaskComplexity, classify_task + + +class TestClassifyTask: + """Tests for classify_task heuristics.""" + + # ── Simple / routine tasks ────────────────────────────────────────────── + + def test_empty_messages_is_simple(self): + assert classify_task([]) == TaskComplexity.SIMPLE + + def test_no_user_content_is_simple(self): + messages = [{"role": "system", "content": "You are Timmy."}] + assert classify_task(messages) == TaskComplexity.SIMPLE + + def test_short_status_query_is_simple(self): + messages = [{"role": "user", "content": "status"}] + assert classify_task(messages) == TaskComplexity.SIMPLE + + def test_list_command_is_simple(self): + messages = [{"role": "user", "content": "list all tasks"}] + assert classify_task(messages) == TaskComplexity.SIMPLE + + def test_get_command_is_simple(self): + messages = [{"role": "user", "content": "get the latest log entry"}] + assert classify_task(messages) == TaskComplexity.SIMPLE + + def test_short_message_under_threshold_is_simple(self): + messages = [{"role": "user", "content": "run the build"}] + assert classify_task(messages) == TaskComplexity.SIMPLE + + def test_affirmation_is_simple(self): + messages = [{"role": "user", "content": "yes"}] + assert classify_task(messages) == TaskComplexity.SIMPLE + + # ── Complex / quality-sensitive tasks ────────────────────────────────── + + def test_plan_keyword_is_complex(self): + messages = [{"role": "user", "content": "plan the sprint"}] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_review_keyword_is_complex(self): + messages = [{"role": "user", "content": "review this code"}] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_analyze_keyword_is_complex(self): + messages = [{"role": "user", "content": "analyze performance"}] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_triage_keyword_is_complex(self): + messages = [{"role": "user", "content": "triage the open issues"}] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_refactor_keyword_is_complex(self): + messages = [{"role": "user", "content": "refactor the auth module"}] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_explain_keyword_is_complex(self): + messages = [{"role": "user", "content": "explain how the router works"}] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_prioritize_keyword_is_complex(self): + messages = [{"role": "user", "content": "prioritize the backlog"}] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_long_message_is_complex(self): + long_msg = "do something " * 50 # > 500 chars + messages = [{"role": "user", "content": long_msg}] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_numbered_list_is_complex(self): + messages = [ + { + "role": "user", + "content": "1. Read the file 2. Analyze it 3. Write a report", + } + ] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_code_block_is_complex(self): + messages = [ + {"role": "user", "content": "Here is the code:\n```python\nprint('hello')\n```"} + ] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_deep_conversation_is_complex(self): + messages = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + {"role": "user", "content": "ok"}, + {"role": "assistant", "content": "yes"}, + {"role": "user", "content": "ok"}, + {"role": "assistant", "content": "yes"}, + {"role": "user", "content": "now do the thing"}, + ] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_analyse_british_spelling_is_complex(self): + messages = [{"role": "user", "content": "analyse this dataset"}] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_non_string_content_is_ignored(self): + """Non-string content should not crash the classifier.""" + messages = [{"role": "user", "content": ["part1", "part2"]}] + # Should not raise; result doesn't matter — just must not blow up + result = classify_task(messages) + assert isinstance(result, TaskComplexity) + + def test_system_message_not_counted_as_user(self): + """System message alone should not trigger complex keywords.""" + messages = [ + {"role": "system", "content": "analyze everything carefully"}, + {"role": "user", "content": "yes"}, + ] + # "analyze" is in system message (not user) — user says "yes" → simple + assert classify_task(messages) == TaskComplexity.SIMPLE + + +class TestTaskComplexityEnum: + """Tests for TaskComplexity enum values.""" + + def test_simple_value(self): + assert TaskComplexity.SIMPLE.value == "simple" + + def test_complex_value(self): + assert TaskComplexity.COMPLEX.value == "complex" + + def test_lookup_by_value(self): + assert TaskComplexity("simple") == TaskComplexity.SIMPLE + assert TaskComplexity("complex") == TaskComplexity.COMPLEX -- 2.43.0 From 0b284972cbfad4296ec729e35cc99ad964e509b8 Mon Sep 17 00:00:00 2001 From: Alexander Whitestone Date: Mon, 23 Mar 2026 15:30:23 -0400 Subject: [PATCH 2/2] fix: correct complexity routing to not fall back to default model `_get_model_for_complexity` was calling `get_model_with_capability`, which silently falls back to the provider default when no model has the requested capability tag. This caused the method to return a generic model instead of None when neither the fallback chain nor any explicit capability tag matched, misleading callers into skipping the provider default logic. Replace the call with an explicit next() comprehension that returns None when no model explicitly carries the 'routine' or 'complex' capability. Refs #1065 Co-Authored-By: Claude Sonnet 4.6 --- src/infrastructure/router/cascade.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/infrastructure/router/cascade.py b/src/infrastructure/router/cascade.py index 2de133d6..6f158b00 100644 --- a/src/infrastructure/router/cascade.py +++ b/src/infrastructure/router/cascade.py @@ -545,8 +545,12 @@ class CascadeRouter: if any(m["name"] == model_name for m in provider.models): return model_name - # Direct capability lookup as a secondary pass - cap_model = provider.get_model_with_capability(chain_key) + # Direct capability lookup — only return if a model explicitly has the tag + # (do not use get_model_with_capability here as it falls back to the default) + cap_model = next( + (m["name"] for m in provider.models if chain_key in m.get("capabilities", [])), + None, + ) if cap_model: return cap_model -- 2.43.0