From 77b70943da2d67fd8fe5c39d547245dbaee61249 Mon Sep 17 00:00:00 2001 From: Alexander Whitestone Date: Mon, 23 Mar 2026 18:55:02 -0400 Subject: [PATCH] feat: Qwen3 two-model routing via task complexity classifier (#1065) - Add src/infrastructure/router/classifier.py with TaskComplexity enum and classify_task() heuristic for simple/complex routing - Add _get_model_for_complexity() to CascadeRouter for fallback chain lookup - Extend complete() with complexity_hint parameter and auto-classification - Add fallback_chains config to providers.yaml for qwen3:8b/14b routing - Fix: use explicit next() comprehension instead of get_model_with_capability to avoid silent fallback to provider default (0b28497) - 361 unit tests passing Fixes #1065 Co-Authored-By: Claude Sonnet 4.6 --- config/providers.yaml | 27 +++ src/config.py | 7 + src/infrastructure/router/__init__.py | 4 + src/infrastructure/router/cascade.py | 111 +++++++++- src/infrastructure/router/classifier.py | 166 +++++++++++++++ tests/infrastructure/test_router_cascade.py | 192 ++++++++++++++++++ .../infrastructure/test_router_classifier.py | 134 ++++++++++++ 7 files changed, 635 insertions(+), 6 deletions(-) create mode 100644 src/infrastructure/router/classifier.py create mode 100644 tests/infrastructure/test_router_classifier.py diff --git a/config/providers.yaml b/config/providers.yaml index 33fa0ca6..f629c022 100644 --- a/config/providers.yaml +++ b/config/providers.yaml @@ -25,6 +25,19 @@ providers: tier: local url: "http://localhost:11434" models: + # ── Dual-model routing: Qwen3-8B (fast) + Qwen3-14B (quality) ────────── + # Both models fit simultaneously: ~6.6 GB + ~10.5 GB = ~17 GB combined. + # Requires OLLAMA_MAX_LOADED_MODELS=2 (set in .env) to stay hot. + # Ref: issue #1065 — Qwen3-8B/14B dual-model routing strategy + - name: qwen3:8b + context_window: 32768 + capabilities: [text, tools, json, streaming, routine] + description: "Qwen3-8B Q6_K — fast router for routine tasks (~6.6 GB, 45-55 tok/s)" + - name: qwen3:14b + context_window: 40960 + capabilities: [text, tools, json, streaming, complex, reasoning] + description: "Qwen3-14B Q5_K_M — complex reasoning and planning (~10.5 GB, 20-28 tok/s)" + # Text + Tools models - name: qwen3:30b default: true @@ -187,6 +200,20 @@ fallback_chains: - dolphin3 # base Dolphin 3.0 8B (uncensored, no custom system prompt) - qwen3:30b # primary fallback — usually sufficient with a good system prompt + # ── Complexity-based routing chains (issue #1065) ─────────────────────── + # Routine tasks: prefer Qwen3-8B for low latency (~45-55 tok/s) + routine: + - qwen3:8b # Primary fast model + - llama3.1:8b-instruct # Fallback fast model + - llama3.2:3b # Smallest available + + # Complex tasks: prefer Qwen3-14B for quality (~20-28 tok/s) + complex: + - qwen3:14b # Primary quality model + - hermes4-14b # Native tool calling, hybrid reasoning + - qwen3:30b # Highest local quality + - qwen2.5:14b # Additional fallback + # ── Custom Models ─────────────────────────────────────────────────────────── # Register custom model weights for per-agent assignment. # Supports GGUF (Ollama), safetensors, and HuggingFace checkpoint dirs. diff --git a/src/config.py b/src/config.py index 5256582a..5e120e26 100644 --- a/src/config.py +++ b/src/config.py @@ -51,6 +51,13 @@ class Settings(BaseSettings): # Set to 0 to use model defaults. ollama_num_ctx: int = 32768 + # Maximum models loaded simultaneously in Ollama — override with OLLAMA_MAX_LOADED_MODELS + # Set to 2 so Qwen3-8B and Qwen3-14B can stay hot concurrently (~17 GB combined). + # Requires Ollama ≥ 0.1.33. Export this to the Ollama process environment: + # OLLAMA_MAX_LOADED_MODELS=2 ollama serve + # or add it to your systemd/launchd unit before starting the harness. + ollama_max_loaded_models: int = 2 + # Fallback model chains — override with FALLBACK_MODELS / VISION_FALLBACK_MODELS # as comma-separated strings, e.g. FALLBACK_MODELS="qwen3:8b,qwen2.5:14b" # Or edit config/providers.yaml → fallback_chains for the canonical source. diff --git a/src/infrastructure/router/__init__.py b/src/infrastructure/router/__init__.py index f7f7ac25..e00d4583 100644 --- a/src/infrastructure/router/__init__.py +++ b/src/infrastructure/router/__init__.py @@ -2,6 +2,7 @@ from .api import router from .cascade import CascadeRouter, Provider, ProviderStatus, get_router +from .classifier import TaskComplexity, classify_task from .history import HealthHistoryStore, get_history_store from .metabolic import ( DEFAULT_TIER_MODELS, @@ -27,4 +28,7 @@ __all__ = [ "classify_complexity", "build_prompt", "get_metabolic_router", + # Classifier + "TaskComplexity", + "classify_task", ] diff --git a/src/infrastructure/router/cascade.py b/src/infrastructure/router/cascade.py index be85939f..7789645b 100644 --- a/src/infrastructure/router/cascade.py +++ b/src/infrastructure/router/cascade.py @@ -593,6 +593,34 @@ class CascadeRouter: "is_fallback_model": is_fallback_model, } + def _get_model_for_complexity( + self, provider: Provider, complexity: "TaskComplexity" + ) -> str | None: + """Return the best model on *provider* for the given complexity tier. + + Checks fallback chains first (routine / complex), then falls back to + any model with the matching capability tag, then the provider default. + """ + from infrastructure.router.classifier import TaskComplexity + + chain_key = "routine" if complexity == TaskComplexity.SIMPLE else "complex" + + # Walk the capability fallback chain — first model present on this provider wins + for model_name in self.config.fallback_chains.get(chain_key, []): + if any(m["name"] == model_name for m in provider.models): + return model_name + + # Direct capability lookup — only return if a model explicitly has the tag + # (do not use get_model_with_capability here as it falls back to the default) + cap_model = next( + (m["name"] for m in provider.models if chain_key in m.get("capabilities", [])), + None, + ) + if cap_model: + return cap_model + + return None # Caller will use provider default + async def complete( self, messages: list[dict], @@ -600,6 +628,7 @@ class CascadeRouter: temperature: float = 0.7, max_tokens: int | None = None, cascade_tier: str | None = None, + complexity_hint: str | None = None, ) -> dict: """Complete a chat conversation with automatic failover. @@ -608,33 +637,103 @@ class CascadeRouter: - Falls back to vision-capable models when needed - Supports image URLs, paths, and base64 encoding + Complexity-based routing (issue #1065): + - ``complexity_hint="simple"`` → routes to Qwen3-8B (low-latency) + - ``complexity_hint="complex"`` → routes to Qwen3-14B (quality) + - ``complexity_hint=None`` (default) → auto-classifies from messages + Args: messages: List of message dicts with role and content - model: Preferred model (tries this first, then provider defaults) + model: Preferred model (tries this first; complexity routing is + skipped when an explicit model is given) temperature: Sampling temperature max_tokens: Maximum tokens to generate cascade_tier: If specified, filters providers by this tier. - "frontier_required": Uses only Anthropic provider for top-tier models. + complexity_hint: "simple", "complex", or None (auto-detect). Returns: - Dict with content, provider_used, and metrics + Dict with content, provider_used, model, latency_ms, + is_fallback_model, and complexity fields. Raises: RuntimeError: If all providers fail """ + from infrastructure.router.classifier import TaskComplexity, classify_task + content_type = self._detect_content_type(messages) if content_type != ContentType.TEXT: logger.debug("Detected %s content, selecting appropriate model", content_type.value) + # Resolve task complexity ───────────────────────────────────────────── + # Skip complexity routing when caller explicitly specifies a model. + complexity: TaskComplexity | None = None + if model is None: + if complexity_hint is not None: + try: + complexity = TaskComplexity(complexity_hint.lower()) + except ValueError: + logger.warning("Unknown complexity_hint %r, auto-classifying", complexity_hint) + complexity = classify_task(messages) + else: + complexity = classify_task(messages) + logger.debug("Task complexity: %s", complexity.value) + errors: list[str] = [] providers = self._filter_providers(cascade_tier) for provider in providers: - result = await self._try_single_provider( - provider, messages, model, temperature, max_tokens, content_type, errors + if not self._is_provider_available(provider): + continue + + # Metabolic protocol: skip cloud providers when quota is low + if provider.type in ("anthropic", "openai", "grok"): + if not self._quota_allows_cloud(provider): + logger.info( + "Metabolic protocol: skipping cloud provider %s (quota too low)", + provider.name, + ) + continue + + # Complexity-based model selection (only when no explicit model) ── + effective_model = model + if effective_model is None and complexity is not None: + effective_model = self._get_model_for_complexity(provider, complexity) + if effective_model: + logger.debug( + "Complexity routing [%s]: %s → %s", + complexity.value, + provider.name, + effective_model, + ) + + selected_model, is_fallback_model = self._select_model( + provider, effective_model, content_type ) - if result is not None: - return result + + try: + result = await self._attempt_with_retry( + provider, + messages, + selected_model, + temperature, + max_tokens, + content_type, + ) + except RuntimeError as exc: + errors.append(str(exc)) + self._record_failure(provider) + continue + + self._record_success(provider, result.get("latency_ms", 0)) + return { + "content": result["content"], + "provider": provider.name, + "model": result.get("model", selected_model or provider.get_default_model()), + "latency_ms": result.get("latency_ms", 0), + "is_fallback_model": is_fallback_model, + "complexity": complexity.value if complexity is not None else None, + } raise RuntimeError(f"All providers failed: {'; '.join(errors)}") diff --git a/src/infrastructure/router/classifier.py b/src/infrastructure/router/classifier.py new file mode 100644 index 00000000..26e2fdc2 --- /dev/null +++ b/src/infrastructure/router/classifier.py @@ -0,0 +1,166 @@ +"""Task complexity classifier for Qwen3 dual-model routing. + +Classifies incoming tasks as SIMPLE (route to Qwen3-8B for low-latency) +or COMPLEX (route to Qwen3-14B for quality-sensitive work). + +Classification is fully heuristic — no LLM inference required. +""" + +import re +from enum import Enum + + +class TaskComplexity(Enum): + """Task complexity tier for model routing.""" + + SIMPLE = "simple" # Qwen3-8B Q6_K: routine, latency-sensitive + COMPLEX = "complex" # Qwen3-14B Q5_K_M: quality-sensitive, multi-step + + +# Keywords strongly associated with complex tasks +_COMPLEX_KEYWORDS: frozenset[str] = frozenset( + [ + "plan", + "review", + "analyze", + "analyse", + "triage", + "refactor", + "design", + "architecture", + "implement", + "compare", + "debug", + "explain", + "prioritize", + "prioritise", + "strategy", + "optimize", + "optimise", + "evaluate", + "assess", + "brainstorm", + "outline", + "summarize", + "summarise", + "generate code", + "write a", + "write the", + "code review", + "pull request", + "multi-step", + "multi step", + "step by step", + "backlog prioriti", + "issue triage", + "root cause", + "how does", + "why does", + "what are the", + ] +) + +# Keywords strongly associated with simple/routine tasks +_SIMPLE_KEYWORDS: frozenset[str] = frozenset( + [ + "status", + "list ", + "show ", + "what is", + "how many", + "ping", + "run ", + "execute ", + "ls ", + "cat ", + "ps ", + "fetch ", + "count ", + "tail ", + "head ", + "grep ", + "find file", + "read file", + "get ", + "query ", + "check ", + "yes", + "no", + "ok", + "done", + "thanks", + ] +) + +# Content longer than this is treated as complex regardless of keywords +_COMPLEX_CHAR_THRESHOLD = 500 + +# Short content defaults to simple +_SIMPLE_CHAR_THRESHOLD = 150 + +# More than this many messages suggests an ongoing complex conversation +_COMPLEX_CONVERSATION_DEPTH = 6 + + +def classify_task(messages: list[dict]) -> TaskComplexity: + """Classify task complexity from a list of messages. + + Uses heuristic rules — no LLM call required. Errs toward COMPLEX + when uncertain so that quality is preserved. + + Args: + messages: List of message dicts with ``role`` and ``content`` keys. + + Returns: + TaskComplexity.SIMPLE or TaskComplexity.COMPLEX + """ + if not messages: + return TaskComplexity.SIMPLE + + # Concatenate all user-turn content for analysis + user_content = " ".join( + msg.get("content", "") + for msg in messages + if msg.get("role") in ("user", "human") + and isinstance(msg.get("content"), str) + ).lower().strip() + + if not user_content: + return TaskComplexity.SIMPLE + + # Complexity signals override everything ----------------------------------- + + # Explicit complex keywords + for kw in _COMPLEX_KEYWORDS: + if kw in user_content: + return TaskComplexity.COMPLEX + + # Numbered / multi-step instruction list: "1. do this 2. do that" + if re.search(r"\b\d+\.\s+\w", user_content): + return TaskComplexity.COMPLEX + + # Code blocks embedded in messages + if "```" in user_content: + return TaskComplexity.COMPLEX + + # Long content → complex reasoning likely required + if len(user_content) > _COMPLEX_CHAR_THRESHOLD: + return TaskComplexity.COMPLEX + + # Deep conversation → complex ongoing task + if len(messages) > _COMPLEX_CONVERSATION_DEPTH: + return TaskComplexity.COMPLEX + + # Simplicity signals ------------------------------------------------------- + + # Explicit simple keywords + for kw in _SIMPLE_KEYWORDS: + if kw in user_content: + return TaskComplexity.SIMPLE + + # Short single-sentence messages default to simple + if len(user_content) <= _SIMPLE_CHAR_THRESHOLD: + return TaskComplexity.SIMPLE + + # When uncertain, prefer quality (complex model) + return TaskComplexity.COMPLEX diff --git a/tests/infrastructure/test_router_cascade.py b/tests/infrastructure/test_router_cascade.py index 0d282ba6..5d2c7788 100644 --- a/tests/infrastructure/test_router_cascade.py +++ b/tests/infrastructure/test_router_cascade.py @@ -1512,3 +1512,195 @@ class TestTrySingleProvider: assert len(errors) == 1 assert "boom" in errors[0] assert provider.metrics.failed_requests == 1 + + +class TestComplexityRouting: + """Tests for Qwen3-8B / Qwen3-14B dual-model routing (issue #1065).""" + + def _make_dual_model_provider(self) -> Provider: + """Build an Ollama provider with both Qwen3 models registered.""" + return Provider( + name="ollama-local", + type="ollama", + enabled=True, + priority=1, + url="http://localhost:11434", + models=[ + { + "name": "qwen3:8b", + "capabilities": ["text", "tools", "json", "streaming", "routine"], + }, + { + "name": "qwen3:14b", + "default": True, + "capabilities": ["text", "tools", "json", "streaming", "complex", "reasoning"], + }, + ], + ) + + def test_get_model_for_complexity_simple_returns_8b(self): + """Simple tasks should select the model with 'routine' capability.""" + from infrastructure.router.classifier import TaskComplexity + + router = CascadeRouter(config_path=Path("/nonexistent")) + router.config.fallback_chains = { + "routine": ["qwen3:8b"], + "complex": ["qwen3:14b"], + } + provider = self._make_dual_model_provider() + + model = router._get_model_for_complexity(provider, TaskComplexity.SIMPLE) + assert model == "qwen3:8b" + + def test_get_model_for_complexity_complex_returns_14b(self): + """Complex tasks should select the model with 'complex' capability.""" + from infrastructure.router.classifier import TaskComplexity + + router = CascadeRouter(config_path=Path("/nonexistent")) + router.config.fallback_chains = { + "routine": ["qwen3:8b"], + "complex": ["qwen3:14b"], + } + provider = self._make_dual_model_provider() + + model = router._get_model_for_complexity(provider, TaskComplexity.COMPLEX) + assert model == "qwen3:14b" + + def test_get_model_for_complexity_returns_none_when_no_match(self): + """Returns None when provider has no matching model in chain.""" + from infrastructure.router.classifier import TaskComplexity + + router = CascadeRouter(config_path=Path("/nonexistent")) + router.config.fallback_chains = {} # empty chains + + provider = Provider( + name="test", + type="ollama", + enabled=True, + priority=1, + models=[{"name": "llama3.2:3b", "default": True, "capabilities": ["text"]}], + ) + + # No 'routine' or 'complex' model available + model = router._get_model_for_complexity(provider, TaskComplexity.SIMPLE) + assert model is None + + @pytest.mark.asyncio + async def test_complete_with_simple_hint_routes_to_8b(self): + """complexity_hint='simple' should use qwen3:8b.""" + router = CascadeRouter(config_path=Path("/nonexistent")) + router.config.fallback_chains = { + "routine": ["qwen3:8b"], + "complex": ["qwen3:14b"], + } + router.providers = [self._make_dual_model_provider()] + + with patch.object(router, "_call_ollama") as mock_call: + mock_call.return_value = {"content": "fast answer", "model": "qwen3:8b"} + result = await router.complete( + messages=[{"role": "user", "content": "list tasks"}], + complexity_hint="simple", + ) + + assert result["model"] == "qwen3:8b" + assert result["complexity"] == "simple" + + @pytest.mark.asyncio + async def test_complete_with_complex_hint_routes_to_14b(self): + """complexity_hint='complex' should use qwen3:14b.""" + router = CascadeRouter(config_path=Path("/nonexistent")) + router.config.fallback_chains = { + "routine": ["qwen3:8b"], + "complex": ["qwen3:14b"], + } + router.providers = [self._make_dual_model_provider()] + + with patch.object(router, "_call_ollama") as mock_call: + mock_call.return_value = {"content": "detailed answer", "model": "qwen3:14b"} + result = await router.complete( + messages=[{"role": "user", "content": "review this PR"}], + complexity_hint="complex", + ) + + assert result["model"] == "qwen3:14b" + assert result["complexity"] == "complex" + + @pytest.mark.asyncio + async def test_explicit_model_bypasses_complexity_routing(self): + """When model is explicitly provided, complexity routing is skipped.""" + router = CascadeRouter(config_path=Path("/nonexistent")) + router.config.fallback_chains = { + "routine": ["qwen3:8b"], + "complex": ["qwen3:14b"], + } + router.providers = [self._make_dual_model_provider()] + + with patch.object(router, "_call_ollama") as mock_call: + mock_call.return_value = {"content": "response", "model": "qwen3:14b"} + result = await router.complete( + messages=[{"role": "user", "content": "list tasks"}], + model="qwen3:14b", # explicit override + ) + + # Explicit model wins — complexity field is None + assert result["model"] == "qwen3:14b" + assert result["complexity"] is None + + @pytest.mark.asyncio + async def test_auto_classification_routes_simple_message(self): + """Short, simple messages should auto-classify as SIMPLE → 8B.""" + router = CascadeRouter(config_path=Path("/nonexistent")) + router.config.fallback_chains = { + "routine": ["qwen3:8b"], + "complex": ["qwen3:14b"], + } + router.providers = [self._make_dual_model_provider()] + + with patch.object(router, "_call_ollama") as mock_call: + mock_call.return_value = {"content": "ok", "model": "qwen3:8b"} + result = await router.complete( + messages=[{"role": "user", "content": "status"}], + # no complexity_hint — auto-classify + ) + + assert result["complexity"] == "simple" + assert result["model"] == "qwen3:8b" + + @pytest.mark.asyncio + async def test_auto_classification_routes_complex_message(self): + """Complex messages should auto-classify → 14B.""" + router = CascadeRouter(config_path=Path("/nonexistent")) + router.config.fallback_chains = { + "routine": ["qwen3:8b"], + "complex": ["qwen3:14b"], + } + router.providers = [self._make_dual_model_provider()] + + with patch.object(router, "_call_ollama") as mock_call: + mock_call.return_value = {"content": "deep analysis", "model": "qwen3:14b"} + result = await router.complete( + messages=[{"role": "user", "content": "analyze and prioritize the backlog"}], + ) + + assert result["complexity"] == "complex" + assert result["model"] == "qwen3:14b" + + @pytest.mark.asyncio + async def test_invalid_complexity_hint_falls_back_to_auto(self): + """Invalid complexity_hint should log a warning and auto-classify.""" + router = CascadeRouter(config_path=Path("/nonexistent")) + router.config.fallback_chains = { + "routine": ["qwen3:8b"], + "complex": ["qwen3:14b"], + } + router.providers = [self._make_dual_model_provider()] + + with patch.object(router, "_call_ollama") as mock_call: + mock_call.return_value = {"content": "ok", "model": "qwen3:8b"} + # Should not raise + result = await router.complete( + messages=[{"role": "user", "content": "status"}], + complexity_hint="INVALID_HINT", + ) + + assert result["complexity"] in ("simple", "complex") # auto-classified diff --git a/tests/infrastructure/test_router_classifier.py b/tests/infrastructure/test_router_classifier.py new file mode 100644 index 00000000..1e9d4df7 --- /dev/null +++ b/tests/infrastructure/test_router_classifier.py @@ -0,0 +1,134 @@ +"""Tests for Qwen3 dual-model task complexity classifier.""" + +import pytest + +from infrastructure.router.classifier import TaskComplexity, classify_task + + +class TestClassifyTask: + """Tests for classify_task heuristics.""" + + # ── Simple / routine tasks ────────────────────────────────────────────── + + def test_empty_messages_is_simple(self): + assert classify_task([]) == TaskComplexity.SIMPLE + + def test_no_user_content_is_simple(self): + messages = [{"role": "system", "content": "You are Timmy."}] + assert classify_task(messages) == TaskComplexity.SIMPLE + + def test_short_status_query_is_simple(self): + messages = [{"role": "user", "content": "status"}] + assert classify_task(messages) == TaskComplexity.SIMPLE + + def test_list_command_is_simple(self): + messages = [{"role": "user", "content": "list all tasks"}] + assert classify_task(messages) == TaskComplexity.SIMPLE + + def test_get_command_is_simple(self): + messages = [{"role": "user", "content": "get the latest log entry"}] + assert classify_task(messages) == TaskComplexity.SIMPLE + + def test_short_message_under_threshold_is_simple(self): + messages = [{"role": "user", "content": "run the build"}] + assert classify_task(messages) == TaskComplexity.SIMPLE + + def test_affirmation_is_simple(self): + messages = [{"role": "user", "content": "yes"}] + assert classify_task(messages) == TaskComplexity.SIMPLE + + # ── Complex / quality-sensitive tasks ────────────────────────────────── + + def test_plan_keyword_is_complex(self): + messages = [{"role": "user", "content": "plan the sprint"}] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_review_keyword_is_complex(self): + messages = [{"role": "user", "content": "review this code"}] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_analyze_keyword_is_complex(self): + messages = [{"role": "user", "content": "analyze performance"}] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_triage_keyword_is_complex(self): + messages = [{"role": "user", "content": "triage the open issues"}] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_refactor_keyword_is_complex(self): + messages = [{"role": "user", "content": "refactor the auth module"}] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_explain_keyword_is_complex(self): + messages = [{"role": "user", "content": "explain how the router works"}] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_prioritize_keyword_is_complex(self): + messages = [{"role": "user", "content": "prioritize the backlog"}] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_long_message_is_complex(self): + long_msg = "do something " * 50 # > 500 chars + messages = [{"role": "user", "content": long_msg}] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_numbered_list_is_complex(self): + messages = [ + { + "role": "user", + "content": "1. Read the file 2. Analyze it 3. Write a report", + } + ] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_code_block_is_complex(self): + messages = [ + {"role": "user", "content": "Here is the code:\n```python\nprint('hello')\n```"} + ] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_deep_conversation_is_complex(self): + messages = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + {"role": "user", "content": "ok"}, + {"role": "assistant", "content": "yes"}, + {"role": "user", "content": "ok"}, + {"role": "assistant", "content": "yes"}, + {"role": "user", "content": "now do the thing"}, + ] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_analyse_british_spelling_is_complex(self): + messages = [{"role": "user", "content": "analyse this dataset"}] + assert classify_task(messages) == TaskComplexity.COMPLEX + + def test_non_string_content_is_ignored(self): + """Non-string content should not crash the classifier.""" + messages = [{"role": "user", "content": ["part1", "part2"]}] + # Should not raise; result doesn't matter — just must not blow up + result = classify_task(messages) + assert isinstance(result, TaskComplexity) + + def test_system_message_not_counted_as_user(self): + """System message alone should not trigger complex keywords.""" + messages = [ + {"role": "system", "content": "analyze everything carefully"}, + {"role": "user", "content": "yes"}, + ] + # "analyze" is in system message (not user) — user says "yes" → simple + assert classify_task(messages) == TaskComplexity.SIMPLE + + +class TestTaskComplexityEnum: + """Tests for TaskComplexity enum values.""" + + def test_simple_value(self): + assert TaskComplexity.SIMPLE.value == "simple" + + def test_complex_value(self): + assert TaskComplexity.COMPLEX.value == "complex" + + def test_lookup_by_value(self): + assert TaskComplexity("simple") == TaskComplexity.SIMPLE + assert TaskComplexity("complex") == TaskComplexity.COMPLEX -- 2.43.0