Co-authored-by: Kimi Claw <kimi@timmytime.ai> Co-committed-by: Kimi Claw <kimi@timmytime.ai>
656 lines
27 KiB
Python
656 lines
27 KiB
Python
"""
|
|
Enhanced Task Classifier for Uniwizard
|
|
|
|
Classifies incoming prompts into task types and maps them to ranked backend preferences.
|
|
Integrates with the 7-backend fallback chain defined in config.yaml.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
|
|
|
|
|
class TaskType(Enum):
|
|
"""Classification categories for incoming prompts."""
|
|
CODE = "code"
|
|
REASONING = "reasoning"
|
|
RESEARCH = "research"
|
|
CREATIVE = "creative"
|
|
FAST_OPS = "fast_ops"
|
|
TOOL_USE = "tool_use"
|
|
UNKNOWN = "unknown"
|
|
|
|
|
|
class ComplexityLevel(Enum):
|
|
"""Complexity tiers for prompt analysis."""
|
|
LOW = "low"
|
|
MEDIUM = "medium"
|
|
HIGH = "high"
|
|
|
|
|
|
# Backend identifiers (match fallback_providers chain order)
|
|
BACKEND_ANTHROPIC = "anthropic"
|
|
BACKEND_OPENAI_CODEX = "openai-codex"
|
|
BACKEND_GEMINI = "gemini"
|
|
BACKEND_GROQ = "groq"
|
|
BACKEND_GROK = "grok"
|
|
BACKEND_KIMI = "kimi-coding"
|
|
BACKEND_OPENROUTER = "openrouter"
|
|
|
|
ALL_BACKENDS = [
|
|
BACKEND_ANTHROPIC,
|
|
BACKEND_OPENAI_CODEX,
|
|
BACKEND_GEMINI,
|
|
BACKEND_GROQ,
|
|
BACKEND_GROK,
|
|
BACKEND_KIMI,
|
|
BACKEND_OPENROUTER,
|
|
]
|
|
|
|
# Task-specific keyword mappings
|
|
CODE_KEYWORDS: Set[str] = {
|
|
"code", "coding", "program", "programming", "function", "class",
|
|
"implement", "implementation", "refactor", "debug", "debugging",
|
|
"error", "exception", "traceback", "stacktrace", "test", "tests",
|
|
"pytest", "unittest", "import", "module", "package", "library",
|
|
"api", "endpoint", "route", "middleware", "database", "query",
|
|
"sql", "orm", "migration", "deploy", "docker", "kubernetes",
|
|
"k8s", "ci/cd", "pipeline", "build", "compile", "syntax",
|
|
"lint", "format", "black", "flake8", "mypy", "type", "typing",
|
|
"async", "await", "callback", "promise", "thread", "process",
|
|
"concurrency", "parallel", "optimization", "optimize", "performance",
|
|
"memory", "leak", "bug", "fix", "patch", "commit", "git",
|
|
"repository", "repo", "clone", "fork", "merge", "conflict",
|
|
"branch", "pull request", "pr", "review", "crud", "rest",
|
|
"graphql", "json", "xml", "yaml", "toml", "csv", "parse",
|
|
"regex", "regular expression", "string", "bytes", "encoding",
|
|
"decoding", "serialize", "deserialize", "marshal", "unmarshal",
|
|
"encrypt", "decrypt", "hash", "checksum", "signature", "jwt",
|
|
"oauth", "authentication", "authorization", "auth", "login",
|
|
"logout", "session", "cookie", "token", "permission", "role",
|
|
"rbac", "acl", "security", "vulnerability", "cve", "exploit",
|
|
"sandbox", "isolate", "container", "vm", "virtual machine",
|
|
}
|
|
|
|
REASONING_KEYWORDS: Set[str] = {
|
|
"analyze", "analysis", "investigate", "investigation",
|
|
"compare", "comparison", "contrast", "evaluate", "evaluation",
|
|
"assess", "assessment", "reason", "reasoning", "logic",
|
|
"logical", "deduce", "deduction", "infer", "inference",
|
|
"synthesize", "synthesis", "critique", "criticism", "review",
|
|
"argument", "premise", "conclusion", "evidence", "proof",
|
|
"theorem", "axiom", "corollary", "lemma", "proposition",
|
|
"hypothesis", "theory", "model", "framework", "paradigm",
|
|
"philosophy", "ethical", "ethics", "moral", "morality",
|
|
"implication", "consequence", "trade-off", "tradeoff",
|
|
"pros and cons", "advantage", "disadvantage", "benefit",
|
|
"drawback", "risk", "mitigation", "strategy", "strategic",
|
|
"plan", "planning", "design", "architecture", "system",
|
|
"complex", "complicated", "nuanced", "subtle", "sophisticated",
|
|
"rigorous", "thorough", "comprehensive", "exhaustive",
|
|
"step by step", "chain of thought", "think through",
|
|
"work through", "figure out", "understand", "comprehend",
|
|
}
|
|
|
|
RESEARCH_KEYWORDS: Set[str] = {
|
|
"research", "find", "search", "look up", "lookup",
|
|
"investigate", "study", "explore", "discover",
|
|
"paper", "publication", "journal", "article", "study",
|
|
"arxiv", "scholar", "academic", "scientific", "literature",
|
|
"review", "survey", "meta-analysis", "bibliography",
|
|
"citation", "reference", "source", "primary source",
|
|
"secondary source", "peer review", "empirical", "experiment",
|
|
"experimental", "observational", "longitudinal", "cross-sectional",
|
|
"qualitative", "quantitative", "mixed methods", "case study",
|
|
"dataset", "data", "statistics", "statistical", "correlation",
|
|
"causation", "regression", "machine learning", "ml", "ai",
|
|
"neural network", "deep learning", "transformer", "llm",
|
|
"benchmark", "evaluation", "metric", "sota", "state of the art",
|
|
"survey", "poll", "interview", "focus group", "ethnography",
|
|
"field work", "archive", "archival", "repository", "collection",
|
|
"index", "catalog", "database", "librar", "museum", "histor",
|
|
"genealogy", "ancestry", "patent", "trademark", "copyright",
|
|
"legislation", "regulation", "policy", "compliance",
|
|
}
|
|
|
|
CREATIVE_KEYWORDS: Set[str] = {
|
|
"create", "creative", "creativity", "design", "designer",
|
|
"art", "artistic", "artist", "paint", "painting", "draw",
|
|
"drawing", "sketch", "illustration", "illustrator", "graphic",
|
|
"visual", "image", "photo", "photography", "photographer",
|
|
"video", "film", "movie", "animation", "animate", "motion",
|
|
"music", "musical", "song", "lyric", "compose", "composition",
|
|
"melody", "harmony", "rhythm", "beat", "sound", "audio",
|
|
"write", "writing", "writer", "author", "story", "storytelling",
|
|
"narrative", "plot", "character", "dialogue", "scene",
|
|
"novel", "fiction", "short story", "poem", "poetry", "poet",
|
|
"verse", "prose", "essay", "blog", "article", "content",
|
|
"copy", "copywriting", "marketing", "brand", "branding",
|
|
"slogan", "tagline", "headline", "title", "name", "naming",
|
|
"brainstorm", "ideate", "concept", "conceptualize", "imagine",
|
|
"imagination", "inspire", "inspiration", "muse", "vision",
|
|
"aesthetic", "style", "theme", "mood", "tone", "voice",
|
|
"unique", "original", "fresh", "novel", "innovative",
|
|
"unconventional", "experimental", "avant-garde", "edgy",
|
|
"humor", "funny", "comedy", "satire", "parody", "wit",
|
|
"romance", "romantic", "drama", "dramatic", "thriller",
|
|
"mystery", "horror", "sci-fi", "science fiction", "fantasy",
|
|
"adventure", "action", "documentary", "biopic", "memoir",
|
|
}
|
|
|
|
FAST_OPS_KEYWORDS: Set[str] = {
|
|
"quick", "fast", "brief", "short", "simple", "easy",
|
|
"status", "check", "list", "ls", "show", "display",
|
|
"get", "fetch", "retrieve", "read", "cat", "view",
|
|
"summary", "summarize", "tl;dr", "tldr", "overview",
|
|
"count", "number", "how many", "total", "sum", "average",
|
|
"min", "max", "sort", "filter", "grep", "search",
|
|
"find", "locate", "which", "where", "what is", "what's",
|
|
"who", "when", "yes/no", "confirm", "verify", "validate",
|
|
"ping", "health", "alive", "up", "running", "online",
|
|
"date", "time", "timezone", "clock", "timer", "alarm",
|
|
"remind", "reminder", "note", "jot", "save", "store",
|
|
"delete", "remove", "rm", "clean", "clear", "purge",
|
|
"start", "stop", "restart", "enable", "disable", "toggle",
|
|
"on", "off", "open", "close", "switch", "change", "set",
|
|
"update", "upgrade", "install", "uninstall", "download",
|
|
"upload", "sync", "backup", "restore", "export", "import",
|
|
"convert", "transform", "format", "parse", "extract",
|
|
"compress", "decompress", "zip", "unzip", "tar", "archive",
|
|
"copy", "cp", "move", "mv", "rename", "link", "symlink",
|
|
"permission", "chmod", "chown", "access", "ownership",
|
|
"hello", "hi", "hey", "greeting", "thanks", "thank you",
|
|
"bye", "goodbye", "help", "?", "how to", "how do i",
|
|
}
|
|
|
|
TOOL_USE_KEYWORDS: Set[str] = {
|
|
"tool", "tools", "use tool", "call tool", "invoke",
|
|
"run command", "execute", "terminal", "shell", "bash",
|
|
"zsh", "powershell", "cmd", "command line", "cli",
|
|
"file", "files", "directory", "folder", "path", "fs",
|
|
"read file", "write file", "edit file", "patch file",
|
|
"search files", "find files", "grep", "rg", "ack",
|
|
"browser", "web", "navigate", "click", "scroll",
|
|
"screenshot", "vision", "image", "analyze image",
|
|
"delegate", "subagent", "agent", "spawn", "task",
|
|
"mcp", "server", "mcporter", "protocol",
|
|
"process", "background", "kill", "signal", "pid",
|
|
"git", "commit", "push", "pull", "clone", "branch",
|
|
"docker", "container", "compose", "dockerfile",
|
|
"kubernetes", "kubectl", "k8s", "pod", "deployment",
|
|
"aws", "gcp", "azure", "cloud", "s3", "bucket",
|
|
"database", "db", "sql", "query", "migrate", "seed",
|
|
"api", "endpoint", "request", "response", "curl",
|
|
"http", "https", "rest", "graphql", "websocket",
|
|
"json", "xml", "yaml", "csv", "parse", "serialize",
|
|
"scrap", "crawl", "extract", "parse html", "xpath",
|
|
"schedule", "cron", "job", "task queue", "worker",
|
|
"notification", "alert", "webhook", "event", "trigger",
|
|
}
|
|
|
|
# URL pattern for detecting web/research tasks
|
|
_URL_PATTERN = re.compile(
|
|
r'https?://(?:[-\w.])+(?:[:\d]+)?(?:/(?:[\w/_.])*(?:\?(?:[\w&=%.])*)?(?:#(?:[\w.])*)?)?',
|
|
re.IGNORECASE
|
|
)
|
|
|
|
# Code block detection (count ``` blocks, not individual lines)
|
|
_CODE_BLOCK_PATTERN = re.compile(r'```[\w]*\n', re.MULTILINE)
|
|
|
|
|
|
def _count_code_blocks(text: str) -> int:
|
|
"""Count complete code blocks (opening ``` to closing ```)."""
|
|
# Count pairs of ``` - each pair is one code block
|
|
fence_count = text.count('```')
|
|
return fence_count // 2
|
|
_INLINE_CODE_PATTERN = re.compile(r'`[^`]+`')
|
|
|
|
# Complexity thresholds
|
|
COMPLEXITY_THRESHOLDS = {
|
|
"chars": {"low": 200, "medium": 800},
|
|
"words": {"low": 35, "medium": 150},
|
|
"lines": {"low": 3, "medium": 15},
|
|
"urls": {"low": 0, "medium": 2},
|
|
"code_blocks": {"low": 0, "medium": 1},
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class ClassificationResult:
|
|
"""Result of task classification."""
|
|
task_type: TaskType
|
|
preferred_backends: List[str]
|
|
complexity: ComplexityLevel
|
|
reason: str
|
|
confidence: float
|
|
features: Dict[str, Any]
|
|
|
|
|
|
class TaskClassifier:
|
|
"""
|
|
Enhanced task classifier for routing prompts to appropriate backends.
|
|
|
|
Maps task types to ranked backend preferences based on:
|
|
- Backend strengths (coding, reasoning, speed, context length, etc.)
|
|
- Message complexity (length, structure, keywords)
|
|
- Detected features (URLs, code blocks, specific terminology)
|
|
"""
|
|
|
|
# Backend preference rankings by task type
|
|
# Order matters: first is most preferred
|
|
TASK_BACKEND_MAP: Dict[TaskType, List[str]] = {
|
|
TaskType.CODE: [
|
|
BACKEND_OPENAI_CODEX, # Best for code generation
|
|
BACKEND_ANTHROPIC, # Excellent for code review, complex analysis
|
|
BACKEND_KIMI, # Long context for large codebases
|
|
BACKEND_GEMINI, # Good multimodal code understanding
|
|
BACKEND_GROQ, # Fast for simple code tasks
|
|
BACKEND_OPENROUTER, # Overflow option
|
|
BACKEND_GROK, # General knowledge backup
|
|
],
|
|
TaskType.REASONING: [
|
|
BACKEND_ANTHROPIC, # Deep reasoning champion
|
|
BACKEND_GEMINI, # Strong analytical capabilities
|
|
BACKEND_KIMI, # Long context for complex reasoning chains
|
|
BACKEND_GROK, # Broad knowledge for reasoning
|
|
BACKEND_OPENAI_CODEX, # Structured reasoning
|
|
BACKEND_OPENROUTER, # Overflow
|
|
BACKEND_GROQ, # Fast fallback
|
|
],
|
|
TaskType.RESEARCH: [
|
|
BACKEND_GEMINI, # Research and multimodal leader
|
|
BACKEND_KIMI, # 262K context for long documents
|
|
BACKEND_ANTHROPIC, # Deep analysis
|
|
BACKEND_GROK, # Broad knowledge
|
|
BACKEND_OPENROUTER, # Broadest model access
|
|
BACKEND_OPENAI_CODEX, # Structured research
|
|
BACKEND_GROQ, # Fast triage
|
|
],
|
|
TaskType.CREATIVE: [
|
|
BACKEND_GROK, # Creative writing and drafting
|
|
BACKEND_ANTHROPIC, # Nuanced creative work
|
|
BACKEND_GEMINI, # Multimodal creativity
|
|
BACKEND_OPENAI_CODEX, # Creative coding
|
|
BACKEND_KIMI, # Long-form creative
|
|
BACKEND_OPENROUTER, # Variety of creative models
|
|
BACKEND_GROQ, # Fast creative ops
|
|
],
|
|
TaskType.FAST_OPS: [
|
|
BACKEND_GROQ, # 284ms response time champion
|
|
BACKEND_OPENROUTER, # Fast mini models
|
|
BACKEND_GEMINI, # Flash models
|
|
BACKEND_GROK, # Fast for simple queries
|
|
BACKEND_ANTHROPIC, # If precision needed
|
|
BACKEND_OPENAI_CODEX, # Structured ops
|
|
BACKEND_KIMI, # Overflow
|
|
],
|
|
TaskType.TOOL_USE: [
|
|
BACKEND_ANTHROPIC, # Excellent tool use capabilities
|
|
BACKEND_OPENAI_CODEX, # Good tool integration
|
|
BACKEND_GEMINI, # Multimodal tool use
|
|
BACKEND_GROQ, # Fast tool chaining
|
|
BACKEND_KIMI, # Long context tool sessions
|
|
BACKEND_OPENROUTER, # Overflow
|
|
BACKEND_GROK, # General tool use
|
|
],
|
|
TaskType.UNKNOWN: [
|
|
BACKEND_ANTHROPIC, # Default to strongest general model
|
|
BACKEND_GEMINI, # Good all-rounder
|
|
BACKEND_OPENAI_CODEX, # Structured approach
|
|
BACKEND_KIMI, # Long context safety
|
|
BACKEND_GROK, # Broad knowledge
|
|
BACKEND_GROQ, # Fast fallback
|
|
BACKEND_OPENROUTER, # Ultimate overflow
|
|
],
|
|
}
|
|
|
|
def __init__(self):
|
|
"""Initialize the classifier with compiled patterns."""
|
|
self.url_pattern = _URL_PATTERN
|
|
self.code_block_pattern = _CODE_BLOCK_PATTERN
|
|
self.inline_code_pattern = _INLINE_CODE_PATTERN
|
|
|
|
def classify(
|
|
self,
|
|
prompt: str,
|
|
context: Optional[Dict[str, Any]] = None
|
|
) -> ClassificationResult:
|
|
"""
|
|
Classify a prompt and return routing recommendation.
|
|
|
|
Args:
|
|
prompt: The user message to classify
|
|
context: Optional context (previous messages, session state, etc.)
|
|
|
|
Returns:
|
|
ClassificationResult with task type, preferred backends, complexity, and reasoning
|
|
"""
|
|
text = (prompt or "").strip()
|
|
if not text:
|
|
return self._default_result("Empty prompt")
|
|
|
|
# Extract features
|
|
features = self._extract_features(text)
|
|
|
|
# Determine complexity
|
|
complexity = self._assess_complexity(features)
|
|
|
|
# Classify task type
|
|
task_type, task_confidence, task_reason = self._classify_task_type(text, features)
|
|
|
|
# Get preferred backends
|
|
preferred_backends = self._get_backends_for_task(task_type, complexity, features)
|
|
|
|
# Build reason string
|
|
reason = self._build_reason(task_type, complexity, task_reason, features)
|
|
|
|
return ClassificationResult(
|
|
task_type=task_type,
|
|
preferred_backends=preferred_backends,
|
|
complexity=complexity,
|
|
reason=reason,
|
|
confidence=task_confidence,
|
|
features=features,
|
|
)
|
|
|
|
def _extract_features(self, text: str) -> Dict[str, Any]:
|
|
"""Extract features from the prompt text."""
|
|
lowered = text.lower()
|
|
words = set(token.strip(".,:;!?()[]{}\"'`") for token in lowered.split())
|
|
|
|
# Count code blocks (complete ``` pairs)
|
|
code_blocks = _count_code_blocks(text)
|
|
inline_code = len(self.inline_code_pattern.findall(text))
|
|
|
|
# Count URLs
|
|
urls = self.url_pattern.findall(text)
|
|
|
|
# Count lines
|
|
lines = text.count('\n') + 1
|
|
|
|
return {
|
|
"char_count": len(text),
|
|
"word_count": len(text.split()),
|
|
"line_count": lines,
|
|
"url_count": len(urls),
|
|
"urls": urls,
|
|
"code_block_count": code_blocks,
|
|
"inline_code_count": inline_code,
|
|
"has_code": code_blocks > 0 or inline_code > 0,
|
|
"unique_words": words,
|
|
"lowercased_text": lowered,
|
|
}
|
|
|
|
def _assess_complexity(self, features: Dict[str, Any]) -> ComplexityLevel:
|
|
"""Assess the complexity level of the prompt."""
|
|
scores = {
|
|
"chars": features["char_count"],
|
|
"words": features["word_count"],
|
|
"lines": features["line_count"],
|
|
"urls": features["url_count"],
|
|
"code_blocks": features["code_block_count"],
|
|
}
|
|
|
|
# Count how many metrics exceed medium threshold
|
|
medium_count = 0
|
|
high_count = 0
|
|
|
|
for metric, value in scores.items():
|
|
thresholds = COMPLEXITY_THRESHOLDS.get(metric, {"low": 0, "medium": 0})
|
|
if value > thresholds["medium"]:
|
|
high_count += 1
|
|
elif value > thresholds["low"]:
|
|
medium_count += 1
|
|
|
|
# Determine complexity
|
|
if high_count >= 2 or scores["code_blocks"] > 2:
|
|
return ComplexityLevel.HIGH
|
|
elif medium_count >= 2 or high_count >= 1:
|
|
return ComplexityLevel.MEDIUM
|
|
else:
|
|
return ComplexityLevel.LOW
|
|
|
|
def _classify_task_type(
|
|
self,
|
|
text: str,
|
|
features: Dict[str, Any]
|
|
) -> Tuple[TaskType, float, str]:
|
|
"""
|
|
Classify the task type based on keywords and features.
|
|
|
|
Returns:
|
|
Tuple of (task_type, confidence, reason)
|
|
"""
|
|
words = features["unique_words"]
|
|
lowered = features["lowercased_text"]
|
|
|
|
# Score each task type
|
|
scores: Dict[TaskType, float] = {task: 0.0 for task in TaskType}
|
|
reasons: Dict[TaskType, str] = {}
|
|
|
|
# CODE scoring
|
|
code_matches = words & CODE_KEYWORDS
|
|
if features["has_code"]:
|
|
scores[TaskType.CODE] += 2.0
|
|
reasons[TaskType.CODE] = "Contains code blocks"
|
|
if code_matches:
|
|
scores[TaskType.CODE] += min(len(code_matches) * 0.5, 3.0)
|
|
if TaskType.CODE not in reasons:
|
|
reasons[TaskType.CODE] = f"Code keywords: {', '.join(list(code_matches)[:3])}"
|
|
|
|
# REASONING scoring
|
|
reasoning_matches = words & REASONING_KEYWORDS
|
|
if reasoning_matches:
|
|
scores[TaskType.REASONING] += min(len(reasoning_matches) * 0.4, 2.5)
|
|
reasons[TaskType.REASONING] = f"Reasoning keywords: {', '.join(list(reasoning_matches)[:3])}"
|
|
if any(phrase in lowered for phrase in ["step by step", "chain of thought", "think through"]):
|
|
scores[TaskType.REASONING] += 1.5
|
|
reasons[TaskType.REASONING] = "Explicit reasoning request"
|
|
|
|
# RESEARCH scoring
|
|
research_matches = words & RESEARCH_KEYWORDS
|
|
if features["url_count"] > 0:
|
|
scores[TaskType.RESEARCH] += 1.5
|
|
reasons[TaskType.RESEARCH] = f"Contains {features['url_count']} URL(s)"
|
|
if research_matches:
|
|
scores[TaskType.RESEARCH] += min(len(research_matches) * 0.4, 2.0)
|
|
if TaskType.RESEARCH not in reasons:
|
|
reasons[TaskType.RESEARCH] = f"Research keywords: {', '.join(list(research_matches)[:3])}"
|
|
|
|
# CREATIVE scoring
|
|
creative_matches = words & CREATIVE_KEYWORDS
|
|
if creative_matches:
|
|
scores[TaskType.CREATIVE] += min(len(creative_matches) * 0.4, 2.5)
|
|
reasons[TaskType.CREATIVE] = f"Creative keywords: {', '.join(list(creative_matches)[:3])}"
|
|
|
|
# FAST_OPS scoring (simple queries) - ONLY if no other strong signals
|
|
fast_ops_matches = words & FAST_OPS_KEYWORDS
|
|
is_very_short = features["word_count"] <= 5 and features["char_count"] < 50
|
|
|
|
# Only score fast_ops if it's very short OR has no other task indicators
|
|
other_scores_possible = bool(
|
|
(words & CODE_KEYWORDS) or
|
|
(words & REASONING_KEYWORDS) or
|
|
(words & RESEARCH_KEYWORDS) or
|
|
(words & CREATIVE_KEYWORDS) or
|
|
(words & TOOL_USE_KEYWORDS) or
|
|
features["has_code"]
|
|
)
|
|
|
|
if is_very_short and not other_scores_possible:
|
|
scores[TaskType.FAST_OPS] += 1.5
|
|
reasons[TaskType.FAST_OPS] = "Very short, simple query"
|
|
elif not other_scores_possible and fast_ops_matches and features["word_count"] < 30:
|
|
scores[TaskType.FAST_OPS] += min(len(fast_ops_matches) * 0.3, 1.0)
|
|
reasons[TaskType.FAST_OPS] = f"Simple query keywords: {', '.join(list(fast_ops_matches)[:3])}"
|
|
|
|
# TOOL_USE scoring
|
|
tool_matches = words & TOOL_USE_KEYWORDS
|
|
if tool_matches:
|
|
scores[TaskType.TOOL_USE] += min(len(tool_matches) * 0.4, 2.0)
|
|
reasons[TaskType.TOOL_USE] = f"Tool keywords: {', '.join(list(tool_matches)[:3])}"
|
|
if any(cmd in lowered for cmd in ["run ", "execute ", "call ", "use "]):
|
|
scores[TaskType.TOOL_USE] += 0.5
|
|
|
|
# Find highest scoring task type
|
|
best_task = TaskType.UNKNOWN
|
|
best_score = 0.0
|
|
|
|
for task, score in scores.items():
|
|
if score > best_score:
|
|
best_score = score
|
|
best_task = task
|
|
|
|
# Calculate confidence
|
|
confidence = min(best_score / 4.0, 1.0) if best_score > 0 else 0.0
|
|
reason = reasons.get(best_task, "No strong indicators")
|
|
|
|
return best_task, confidence, reason
|
|
|
|
def _get_backends_for_task(
|
|
self,
|
|
task_type: TaskType,
|
|
complexity: ComplexityLevel,
|
|
features: Dict[str, Any]
|
|
) -> List[str]:
|
|
"""Get ranked list of preferred backends for the task."""
|
|
base_backends = self.TASK_BACKEND_MAP.get(task_type, self.TASK_BACKEND_MAP[TaskType.UNKNOWN])
|
|
|
|
# Adjust for complexity
|
|
if complexity == ComplexityLevel.HIGH and task_type in (TaskType.RESEARCH, TaskType.CODE):
|
|
# For high complexity, prioritize long-context models
|
|
if BACKEND_KIMI in base_backends:
|
|
# Move kimi earlier for long context
|
|
base_backends = self._prioritize_backend(base_backends, BACKEND_KIMI, 2)
|
|
if BACKEND_GEMINI in base_backends:
|
|
base_backends = self._prioritize_backend(base_backends, BACKEND_GEMINI, 3)
|
|
|
|
elif complexity == ComplexityLevel.LOW and task_type == TaskType.FAST_OPS:
|
|
# For simple ops, ensure GROQ is first
|
|
base_backends = self._prioritize_backend(base_backends, BACKEND_GROQ, 0)
|
|
|
|
# Adjust for code presence
|
|
if features["has_code"] and task_type != TaskType.CODE:
|
|
# Boost OpenAI Codex if there's code but not explicitly a code task
|
|
base_backends = self._prioritize_backend(base_backends, BACKEND_OPENAI_CODEX, 2)
|
|
|
|
return list(base_backends)
|
|
|
|
def _prioritize_backend(
|
|
self,
|
|
backends: List[str],
|
|
target: str,
|
|
target_index: int
|
|
) -> List[str]:
|
|
"""Move a backend to a specific index in the list."""
|
|
if target not in backends:
|
|
return backends
|
|
|
|
new_backends = list(backends)
|
|
new_backends.remove(target)
|
|
new_backends.insert(min(target_index, len(new_backends)), target)
|
|
return new_backends
|
|
|
|
def _build_reason(
|
|
self,
|
|
task_type: TaskType,
|
|
complexity: ComplexityLevel,
|
|
task_reason: str,
|
|
features: Dict[str, Any]
|
|
) -> str:
|
|
"""Build a human-readable reason string."""
|
|
parts = [
|
|
f"Task: {task_type.value}",
|
|
f"Complexity: {complexity.value}",
|
|
]
|
|
|
|
if task_reason:
|
|
parts.append(f"Indicators: {task_reason}")
|
|
|
|
# Add feature summary
|
|
feature_parts = []
|
|
if features["has_code"]:
|
|
feature_parts.append(f"{features['code_block_count']} code block(s)")
|
|
if features["url_count"] > 0:
|
|
feature_parts.append(f"{features['url_count']} URL(s)")
|
|
if features["word_count"] > 100:
|
|
feature_parts.append(f"{features['word_count']} words")
|
|
|
|
if feature_parts:
|
|
parts.append(f"Features: {', '.join(feature_parts)}")
|
|
|
|
return "; ".join(parts)
|
|
|
|
def _default_result(self, reason: str) -> ClassificationResult:
|
|
"""Return a default result for edge cases."""
|
|
return ClassificationResult(
|
|
task_type=TaskType.UNKNOWN,
|
|
preferred_backends=list(self.TASK_BACKEND_MAP[TaskType.UNKNOWN]),
|
|
complexity=ComplexityLevel.LOW,
|
|
reason=reason,
|
|
confidence=0.0,
|
|
features={},
|
|
)
|
|
|
|
def to_dict(self, result: ClassificationResult) -> Dict[str, Any]:
|
|
"""Convert classification result to dictionary format."""
|
|
return {
|
|
"task_type": result.task_type.value,
|
|
"preferred_backends": result.preferred_backends,
|
|
"complexity": result.complexity.value,
|
|
"reason": result.reason,
|
|
"confidence": round(result.confidence, 2),
|
|
"features": {
|
|
k: v for k, v in result.features.items()
|
|
if k not in ("unique_words", "lowercased_text", "urls")
|
|
},
|
|
}
|
|
|
|
|
|
# Convenience function for direct usage
|
|
def classify_prompt(
|
|
prompt: str,
|
|
context: Optional[Dict[str, Any]] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Classify a prompt and return routing recommendation as a dictionary.
|
|
|
|
Args:
|
|
prompt: The user message to classify
|
|
context: Optional context (previous messages, session state, etc.)
|
|
|
|
Returns:
|
|
Dictionary with task_type, preferred_backends, complexity, reason, confidence
|
|
"""
|
|
classifier = TaskClassifier()
|
|
result = classifier.classify(prompt, context)
|
|
return classifier.to_dict(result)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Example usage and quick test
|
|
test_prompts = [
|
|
"Hello, how are you?",
|
|
"Implement a Python function to calculate fibonacci numbers",
|
|
"Analyze the architectural trade-offs between microservices and monoliths",
|
|
"Research the latest papers on transformer architectures",
|
|
"Write a creative story about AI",
|
|
"Check the status of the server and list running processes",
|
|
"Use the browser to navigate to https://example.com and take a screenshot",
|
|
"Refactor this large codebase: [2000 lines of code]",
|
|
]
|
|
|
|
classifier = TaskClassifier()
|
|
|
|
for prompt in test_prompts:
|
|
result = classifier.classify(prompt)
|
|
print(f"\nPrompt: {prompt[:60]}...")
|
|
print(f" Type: {result.task_type.value}")
|
|
print(f" Complexity: {result.complexity.value}")
|
|
print(f" Confidence: {result.confidence:.2f}")
|
|
print(f" Backends: {', '.join(result.preferred_backends[:3])}")
|
|
print(f" Reason: {result.reason}")
|