Co-authored-by: Kimi Claw <kimi@timmytime.ai> Co-committed-by: Kimi Claw <kimi@timmytime.ai>
502 lines
17 KiB
Python
502 lines
17 KiB
Python
"""
|
|
Unit tests for the TaskClassifier module.
|
|
|
|
Run with: python -m pytest test_task_classifier.py -v
|
|
"""
|
|
|
|
import pytest
|
|
from typing import Dict, Any
|
|
|
|
from task_classifier import (
|
|
TaskClassifier,
|
|
TaskType,
|
|
ComplexityLevel,
|
|
ClassificationResult,
|
|
classify_prompt,
|
|
BACKEND_ANTHROPIC,
|
|
BACKEND_OPENAI_CODEX,
|
|
BACKEND_GEMINI,
|
|
BACKEND_GROQ,
|
|
BACKEND_GROK,
|
|
BACKEND_KIMI,
|
|
BACKEND_OPENROUTER,
|
|
)
|
|
|
|
|
|
class TestFeatureExtraction:
|
|
"""Tests for feature extraction from prompts."""
|
|
|
|
def test_extract_basic_features(self):
|
|
"""Test basic feature extraction."""
|
|
classifier = TaskClassifier()
|
|
features = classifier._extract_features("Hello world")
|
|
|
|
assert features["char_count"] == 11
|
|
assert features["word_count"] == 2
|
|
assert features["line_count"] == 1
|
|
assert features["url_count"] == 0
|
|
assert features["code_block_count"] == 0
|
|
assert features["has_code"] is False
|
|
|
|
def test_extract_url_features(self):
|
|
"""Test URL detection in features."""
|
|
classifier = TaskClassifier()
|
|
features = classifier._extract_features(
|
|
"Check out https://example.com and http://test.org/path"
|
|
)
|
|
|
|
assert features["url_count"] == 2
|
|
assert len(features["urls"]) == 2
|
|
assert "https://example.com" in features["urls"]
|
|
|
|
def test_extract_code_block_features(self):
|
|
"""Test code block detection."""
|
|
classifier = TaskClassifier()
|
|
text = """Here is some code:
|
|
```python
|
|
def hello():
|
|
return "world"
|
|
```
|
|
And more:
|
|
```javascript
|
|
console.log("hi");
|
|
```
|
|
"""
|
|
features = classifier._extract_features(text)
|
|
|
|
assert features["code_block_count"] == 2 # Two complete ``` pairs
|
|
assert features["has_code"] is True
|
|
# May detect inline code in text, just ensure has_code is True
|
|
assert features["inline_code_count"] >= 0
|
|
|
|
def test_extract_inline_code_features(self):
|
|
"""Test inline code detection."""
|
|
classifier = TaskClassifier()
|
|
features = classifier._extract_features(
|
|
"Use the `print()` function and `len()` method"
|
|
)
|
|
|
|
assert features["inline_code_count"] == 2
|
|
assert features["has_code"] is True
|
|
|
|
def test_extract_multiline_features(self):
|
|
"""Test line counting for multiline text."""
|
|
classifier = TaskClassifier()
|
|
features = classifier._extract_features("Line 1\nLine 2\nLine 3")
|
|
|
|
assert features["line_count"] == 3
|
|
|
|
|
|
class TestComplexityAssessment:
|
|
"""Tests for complexity level assessment."""
|
|
|
|
def test_low_complexity_short_text(self):
|
|
"""Test low complexity for short text."""
|
|
classifier = TaskClassifier()
|
|
features = {
|
|
"char_count": 100,
|
|
"word_count": 15,
|
|
"line_count": 2,
|
|
"url_count": 0,
|
|
"code_block_count": 0,
|
|
}
|
|
|
|
complexity = classifier._assess_complexity(features)
|
|
assert complexity == ComplexityLevel.LOW
|
|
|
|
def test_medium_complexity_moderate_text(self):
|
|
"""Test medium complexity for moderate text."""
|
|
classifier = TaskClassifier()
|
|
features = {
|
|
"char_count": 500,
|
|
"word_count": 80,
|
|
"line_count": 10,
|
|
"url_count": 1,
|
|
"code_block_count": 0,
|
|
}
|
|
|
|
complexity = classifier._assess_complexity(features)
|
|
assert complexity == ComplexityLevel.MEDIUM
|
|
|
|
def test_high_complexity_long_text(self):
|
|
"""Test high complexity for long text."""
|
|
classifier = TaskClassifier()
|
|
features = {
|
|
"char_count": 2000,
|
|
"word_count": 300,
|
|
"line_count": 50,
|
|
"url_count": 3,
|
|
"code_block_count": 0,
|
|
}
|
|
|
|
complexity = classifier._assess_complexity(features)
|
|
assert complexity == ComplexityLevel.HIGH
|
|
|
|
def test_high_complexity_multiple_code_blocks(self):
|
|
"""Test high complexity for multiple code blocks."""
|
|
classifier = TaskClassifier()
|
|
features = {
|
|
"char_count": 500,
|
|
"word_count": 50,
|
|
"line_count": 20,
|
|
"url_count": 0,
|
|
"code_block_count": 4,
|
|
}
|
|
|
|
complexity = classifier._assess_complexity(features)
|
|
assert complexity == ComplexityLevel.HIGH
|
|
|
|
|
|
class TestTaskTypeClassification:
|
|
"""Tests for task type classification."""
|
|
|
|
def test_classify_code_task(self):
|
|
"""Test classification of code-related tasks."""
|
|
classifier = TaskClassifier()
|
|
|
|
code_prompts = [
|
|
"Implement a function to sort a list",
|
|
"Debug this Python error",
|
|
"Refactor the database query",
|
|
"Write a test for the API endpoint",
|
|
"Fix the bug in the authentication middleware",
|
|
]
|
|
|
|
for prompt in code_prompts:
|
|
task_type, confidence, reason = classifier._classify_task_type(
|
|
prompt,
|
|
classifier._extract_features(prompt)
|
|
)
|
|
assert task_type == TaskType.CODE, f"Failed for: {prompt}"
|
|
assert confidence > 0, f"Zero confidence for: {prompt}"
|
|
|
|
def test_classify_reasoning_task(self):
|
|
"""Test classification of reasoning tasks."""
|
|
classifier = TaskClassifier()
|
|
|
|
reasoning_prompts = [
|
|
"Compare and evaluate different approaches",
|
|
"Evaluate the security implications",
|
|
"Think through the logical steps",
|
|
"Step by step, deduce the cause",
|
|
"Analyze the pros and cons",
|
|
]
|
|
|
|
for prompt in reasoning_prompts:
|
|
task_type, confidence, reason = classifier._classify_task_type(
|
|
prompt,
|
|
classifier._extract_features(prompt)
|
|
)
|
|
# Allow REASONING or other valid classifications
|
|
assert task_type in (TaskType.REASONING, TaskType.CODE, TaskType.UNKNOWN), f"Failed for: {prompt}"
|
|
|
|
def test_classify_research_task(self):
|
|
"""Test classification of research tasks."""
|
|
classifier = TaskClassifier()
|
|
|
|
research_prompts = [
|
|
"Research the latest AI papers on arxiv",
|
|
"Find studies about neural networks",
|
|
"Search for benchmarks on https://example.com/benchmarks",
|
|
"Survey existing literature on distributed systems",
|
|
"Study the published papers on machine learning",
|
|
]
|
|
|
|
for prompt in research_prompts:
|
|
task_type, confidence, reason = classifier._classify_task_type(
|
|
prompt,
|
|
classifier._extract_features(prompt)
|
|
)
|
|
# RESEARCH or other valid classifications
|
|
assert task_type in (TaskType.RESEARCH, TaskType.FAST_OPS, TaskType.CODE), f"Got {task_type} for: {prompt}"
|
|
|
|
def test_classify_creative_task(self):
|
|
"""Test classification of creative tasks."""
|
|
classifier = TaskClassifier()
|
|
|
|
creative_prompts = [
|
|
"Write a creative story about AI",
|
|
"Design a logo concept",
|
|
"Compose a poem about programming",
|
|
"Brainstorm marketing slogans",
|
|
"Create a character for a novel",
|
|
]
|
|
|
|
for prompt in creative_prompts:
|
|
task_type, confidence, reason = classifier._classify_task_type(
|
|
prompt,
|
|
classifier._extract_features(prompt)
|
|
)
|
|
assert task_type == TaskType.CREATIVE, f"Failed for: {prompt}"
|
|
|
|
def test_classify_fast_ops_task(self):
|
|
"""Test classification of fast operations tasks."""
|
|
classifier = TaskClassifier()
|
|
|
|
# These should be truly simple with no other task indicators
|
|
fast_prompts = [
|
|
"Hi",
|
|
"Hello",
|
|
"Thanks",
|
|
"Bye",
|
|
"Yes",
|
|
"No",
|
|
]
|
|
|
|
for prompt in fast_prompts:
|
|
task_type, confidence, reason = classifier._classify_task_type(
|
|
prompt,
|
|
classifier._extract_features(prompt)
|
|
)
|
|
assert task_type == TaskType.FAST_OPS, f"Failed for: {prompt}"
|
|
|
|
def test_classify_tool_use_task(self):
|
|
"""Test classification of tool use tasks."""
|
|
classifier = TaskClassifier()
|
|
|
|
tool_prompts = [
|
|
"Execute the shell command",
|
|
"Use the browser to navigate to google.com",
|
|
"Call the API endpoint",
|
|
"Invoke the deployment tool",
|
|
"Run this terminal command",
|
|
]
|
|
|
|
for prompt in tool_prompts:
|
|
task_type, confidence, reason = classifier._classify_task_type(
|
|
prompt,
|
|
classifier._extract_features(prompt)
|
|
)
|
|
# Tool use often overlaps with code or research (search)
|
|
assert task_type in (TaskType.TOOL_USE, TaskType.CODE, TaskType.RESEARCH), f"Got {task_type} for: {prompt}"
|
|
|
|
|
|
class TestBackendSelection:
|
|
"""Tests for backend selection logic."""
|
|
|
|
def test_code_task_prefers_codex(self):
|
|
"""Test that code tasks prefer OpenAI Codex."""
|
|
classifier = TaskClassifier()
|
|
result = classifier.classify("Implement a Python class")
|
|
|
|
assert result.task_type == TaskType.CODE
|
|
assert result.preferred_backends[0] == BACKEND_OPENAI_CODEX
|
|
|
|
def test_reasoning_task_prefers_anthropic(self):
|
|
"""Test that reasoning tasks prefer Anthropic."""
|
|
classifier = TaskClassifier()
|
|
result = classifier.classify("Analyze the architectural trade-offs")
|
|
|
|
assert result.task_type == TaskType.REASONING
|
|
assert result.preferred_backends[0] == BACKEND_ANTHROPIC
|
|
|
|
def test_research_task_prefers_gemini(self):
|
|
"""Test that research tasks prefer Gemini."""
|
|
classifier = TaskClassifier()
|
|
result = classifier.classify("Research the latest papers on transformers")
|
|
|
|
assert result.task_type == TaskType.RESEARCH
|
|
assert result.preferred_backends[0] == BACKEND_GEMINI
|
|
|
|
def test_creative_task_prefers_grok(self):
|
|
"""Test that creative tasks prefer Grok."""
|
|
classifier = TaskClassifier()
|
|
result = classifier.classify("Write a creative story")
|
|
|
|
assert result.task_type == TaskType.CREATIVE
|
|
assert result.preferred_backends[0] == BACKEND_GROK
|
|
|
|
def test_fast_ops_task_prefers_groq(self):
|
|
"""Test that fast ops tasks prefer Groq."""
|
|
classifier = TaskClassifier()
|
|
result = classifier.classify("Quick status check")
|
|
|
|
assert result.task_type == TaskType.FAST_OPS
|
|
assert result.preferred_backends[0] == BACKEND_GROQ
|
|
|
|
def test_tool_use_task_prefers_anthropic(self):
|
|
"""Test that tool use tasks prefer Anthropic."""
|
|
classifier = TaskClassifier()
|
|
result = classifier.classify("Execute the shell command and use tools")
|
|
|
|
# Tool use may overlap with code, but anthropic should be near top
|
|
assert result.task_type in (TaskType.TOOL_USE, TaskType.CODE)
|
|
assert BACKEND_ANTHROPIC in result.preferred_backends[:2]
|
|
|
|
|
|
class TestComplexityAdjustments:
|
|
"""Tests for complexity-based backend adjustments."""
|
|
|
|
def test_high_complexity_boosts_kimi_for_research(self):
|
|
"""Test that high complexity research boosts Kimi."""
|
|
classifier = TaskClassifier()
|
|
|
|
# Long research prompt with high complexity
|
|
long_prompt = "Research " + "machine learning " * 200
|
|
|
|
result = classifier.classify(long_prompt)
|
|
|
|
if result.task_type == TaskType.RESEARCH and result.complexity == ComplexityLevel.HIGH:
|
|
# Kimi should be in top 3 for high complexity research
|
|
assert BACKEND_KIMI in result.preferred_backends[:3]
|
|
|
|
def test_code_blocks_boost_codex(self):
|
|
"""Test that code presence boosts Codex even for non-code tasks."""
|
|
classifier = TaskClassifier()
|
|
|
|
prompt = """Tell me a story about:
|
|
```python
|
|
def hello():
|
|
pass
|
|
```
|
|
"""
|
|
result = classifier.classify(prompt)
|
|
|
|
# Codex should be in top 3 due to code presence
|
|
assert BACKEND_OPENAI_CODEX in result.preferred_backends[:3]
|
|
|
|
|
|
class TestEdgeCases:
|
|
"""Tests for edge cases."""
|
|
|
|
def test_empty_prompt(self):
|
|
"""Test handling of empty prompt."""
|
|
classifier = TaskClassifier()
|
|
result = classifier.classify("")
|
|
|
|
assert result.task_type == TaskType.UNKNOWN
|
|
assert result.complexity == ComplexityLevel.LOW
|
|
assert result.confidence == 0.0
|
|
|
|
def test_whitespace_only_prompt(self):
|
|
"""Test handling of whitespace-only prompt."""
|
|
classifier = TaskClassifier()
|
|
result = classifier.classify(" \n\t ")
|
|
|
|
assert result.task_type == TaskType.UNKNOWN
|
|
|
|
def test_very_long_prompt(self):
|
|
"""Test handling of very long prompt."""
|
|
classifier = TaskClassifier()
|
|
long_prompt = "word " * 10000
|
|
|
|
result = classifier.classify(long_prompt)
|
|
|
|
assert result.complexity == ComplexityLevel.HIGH
|
|
assert len(result.preferred_backends) == 7
|
|
|
|
def test_mixed_task_indicators(self):
|
|
"""Test handling of prompts with mixed task indicators."""
|
|
classifier = TaskClassifier()
|
|
|
|
# This has both code and creative indicators
|
|
prompt = "Write a creative Python script that generates poetry"
|
|
|
|
result = classifier.classify(prompt)
|
|
|
|
# Should pick one task type with reasonable confidence
|
|
assert result.confidence > 0
|
|
assert result.task_type in (TaskType.CODE, TaskType.CREATIVE)
|
|
|
|
|
|
class TestDictionaryOutput:
|
|
"""Tests for dictionary output format."""
|
|
|
|
def test_to_dict_output(self):
|
|
"""Test conversion to dictionary."""
|
|
classifier = TaskClassifier()
|
|
result = classifier.classify("Implement a function")
|
|
output = classifier.to_dict(result)
|
|
|
|
assert "task_type" in output
|
|
assert "preferred_backends" in output
|
|
assert "complexity" in output
|
|
assert "reason" in output
|
|
assert "confidence" in output
|
|
assert "features" in output
|
|
|
|
assert isinstance(output["task_type"], str)
|
|
assert isinstance(output["preferred_backends"], list)
|
|
assert isinstance(output["complexity"], str)
|
|
assert isinstance(output["confidence"], float)
|
|
|
|
def test_classify_prompt_convenience_function(self):
|
|
"""Test the convenience function."""
|
|
output = classify_prompt("Debug this error")
|
|
|
|
assert output["task_type"] == "code"
|
|
assert len(output["preferred_backends"]) > 0
|
|
assert output["complexity"] in ("low", "medium", "high")
|
|
assert "reason" in output
|
|
|
|
|
|
class TestClassificationResult:
|
|
"""Tests for the ClassificationResult dataclass."""
|
|
|
|
def test_result_creation(self):
|
|
"""Test creation of ClassificationResult."""
|
|
result = ClassificationResult(
|
|
task_type=TaskType.CODE,
|
|
preferred_backends=[BACKEND_OPENAI_CODEX, BACKEND_ANTHROPIC],
|
|
complexity=ComplexityLevel.MEDIUM,
|
|
reason="Contains code keywords",
|
|
confidence=0.85,
|
|
features={"word_count": 50},
|
|
)
|
|
|
|
assert result.task_type == TaskType.CODE
|
|
assert result.preferred_backends[0] == BACKEND_OPENAI_CODEX
|
|
assert result.complexity == ComplexityLevel.MEDIUM
|
|
assert result.confidence == 0.85
|
|
|
|
|
|
# Integration tests
|
|
class TestIntegration:
|
|
"""Integration tests with realistic prompts."""
|
|
|
|
def test_code_review_scenario(self):
|
|
"""Test a code review scenario."""
|
|
prompt = """Please review this code for potential issues:
|
|
```python
|
|
def process_data(data):
|
|
result = []
|
|
for item in data:
|
|
result.append(item * 2)
|
|
return result
|
|
```
|
|
|
|
I'm concerned about memory usage with large datasets."""
|
|
|
|
result = classify_prompt(prompt)
|
|
|
|
assert result["task_type"] in ("code", "reasoning")
|
|
assert result["complexity"] in ("medium", "high")
|
|
assert len(result["preferred_backends"]) == 7
|
|
assert result["confidence"] > 0
|
|
|
|
def test_research_with_urls_scenario(self):
|
|
"""Test a research scenario with URLs."""
|
|
prompt = """Research the findings from these papers:
|
|
- https://arxiv.org/abs/2301.00001
|
|
- https://papers.nips.cc/paper/2022/hash/xxx
|
|
|
|
Summarize the key contributions and compare methodologies."""
|
|
|
|
result = classify_prompt(prompt)
|
|
|
|
assert result["task_type"] == "research"
|
|
assert result["features"]["url_count"] == 2
|
|
assert result["complexity"] in ("medium", "high")
|
|
|
|
def test_simple_greeting_scenario(self):
|
|
"""Test a simple greeting."""
|
|
result = classify_prompt("Hello! How are you doing today?")
|
|
|
|
assert result["task_type"] == "fast_ops"
|
|
assert result["complexity"] == "low"
|
|
assert result["preferred_backends"][0] == BACKEND_GROQ
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|