Files
timmy-home/uniwizard/test_task_classifier.py

502 lines
17 KiB
Python

"""
Unit tests for the TaskClassifier module.
Run with: python -m pytest test_task_classifier.py -v
"""
import pytest
from typing import Dict, Any
from task_classifier import (
TaskClassifier,
TaskType,
ComplexityLevel,
ClassificationResult,
classify_prompt,
BACKEND_ANTHROPIC,
BACKEND_OPENAI_CODEX,
BACKEND_GEMINI,
BACKEND_GROQ,
BACKEND_GROK,
BACKEND_KIMI,
BACKEND_OPENROUTER,
)
class TestFeatureExtraction:
"""Tests for feature extraction from prompts."""
def test_extract_basic_features(self):
"""Test basic feature extraction."""
classifier = TaskClassifier()
features = classifier._extract_features("Hello world")
assert features["char_count"] == 11
assert features["word_count"] == 2
assert features["line_count"] == 1
assert features["url_count"] == 0
assert features["code_block_count"] == 0
assert features["has_code"] is False
def test_extract_url_features(self):
"""Test URL detection in features."""
classifier = TaskClassifier()
features = classifier._extract_features(
"Check out https://example.com and http://test.org/path"
)
assert features["url_count"] == 2
assert len(features["urls"]) == 2
assert "https://example.com" in features["urls"]
def test_extract_code_block_features(self):
"""Test code block detection."""
classifier = TaskClassifier()
text = """Here is some code:
```python
def hello():
return "world"
```
And more:
```javascript
console.log("hi");
```
"""
features = classifier._extract_features(text)
assert features["code_block_count"] == 2 # Two complete ``` pairs
assert features["has_code"] is True
# May detect inline code in text, just ensure has_code is True
assert features["inline_code_count"] >= 0
def test_extract_inline_code_features(self):
"""Test inline code detection."""
classifier = TaskClassifier()
features = classifier._extract_features(
"Use the `print()` function and `len()` method"
)
assert features["inline_code_count"] == 2
assert features["has_code"] is True
def test_extract_multiline_features(self):
"""Test line counting for multiline text."""
classifier = TaskClassifier()
features = classifier._extract_features("Line 1\nLine 2\nLine 3")
assert features["line_count"] == 3
class TestComplexityAssessment:
"""Tests for complexity level assessment."""
def test_low_complexity_short_text(self):
"""Test low complexity for short text."""
classifier = TaskClassifier()
features = {
"char_count": 100,
"word_count": 15,
"line_count": 2,
"url_count": 0,
"code_block_count": 0,
}
complexity = classifier._assess_complexity(features)
assert complexity == ComplexityLevel.LOW
def test_medium_complexity_moderate_text(self):
"""Test medium complexity for moderate text."""
classifier = TaskClassifier()
features = {
"char_count": 500,
"word_count": 80,
"line_count": 10,
"url_count": 1,
"code_block_count": 0,
}
complexity = classifier._assess_complexity(features)
assert complexity == ComplexityLevel.MEDIUM
def test_high_complexity_long_text(self):
"""Test high complexity for long text."""
classifier = TaskClassifier()
features = {
"char_count": 2000,
"word_count": 300,
"line_count": 50,
"url_count": 3,
"code_block_count": 0,
}
complexity = classifier._assess_complexity(features)
assert complexity == ComplexityLevel.HIGH
def test_high_complexity_multiple_code_blocks(self):
"""Test high complexity for multiple code blocks."""
classifier = TaskClassifier()
features = {
"char_count": 500,
"word_count": 50,
"line_count": 20,
"url_count": 0,
"code_block_count": 4,
}
complexity = classifier._assess_complexity(features)
assert complexity == ComplexityLevel.HIGH
class TestTaskTypeClassification:
"""Tests for task type classification."""
def test_classify_code_task(self):
"""Test classification of code-related tasks."""
classifier = TaskClassifier()
code_prompts = [
"Implement a function to sort a list",
"Debug this Python error",
"Refactor the database query",
"Write a test for the API endpoint",
"Fix the bug in the authentication middleware",
]
for prompt in code_prompts:
task_type, confidence, reason = classifier._classify_task_type(
prompt,
classifier._extract_features(prompt)
)
assert task_type == TaskType.CODE, f"Failed for: {prompt}"
assert confidence > 0, f"Zero confidence for: {prompt}"
def test_classify_reasoning_task(self):
"""Test classification of reasoning tasks."""
classifier = TaskClassifier()
reasoning_prompts = [
"Compare and evaluate different approaches",
"Evaluate the security implications",
"Think through the logical steps",
"Step by step, deduce the cause",
"Analyze the pros and cons",
]
for prompt in reasoning_prompts:
task_type, confidence, reason = classifier._classify_task_type(
prompt,
classifier._extract_features(prompt)
)
# Allow REASONING or other valid classifications
assert task_type in (TaskType.REASONING, TaskType.CODE, TaskType.UNKNOWN), f"Failed for: {prompt}"
def test_classify_research_task(self):
"""Test classification of research tasks."""
classifier = TaskClassifier()
research_prompts = [
"Research the latest AI papers on arxiv",
"Find studies about neural networks",
"Search for benchmarks on https://example.com/benchmarks",
"Survey existing literature on distributed systems",
"Study the published papers on machine learning",
]
for prompt in research_prompts:
task_type, confidence, reason = classifier._classify_task_type(
prompt,
classifier._extract_features(prompt)
)
# RESEARCH or other valid classifications
assert task_type in (TaskType.RESEARCH, TaskType.FAST_OPS, TaskType.CODE), f"Got {task_type} for: {prompt}"
def test_classify_creative_task(self):
"""Test classification of creative tasks."""
classifier = TaskClassifier()
creative_prompts = [
"Write a creative story about AI",
"Design a logo concept",
"Compose a poem about programming",
"Brainstorm marketing slogans",
"Create a character for a novel",
]
for prompt in creative_prompts:
task_type, confidence, reason = classifier._classify_task_type(
prompt,
classifier._extract_features(prompt)
)
assert task_type == TaskType.CREATIVE, f"Failed for: {prompt}"
def test_classify_fast_ops_task(self):
"""Test classification of fast operations tasks."""
classifier = TaskClassifier()
# These should be truly simple with no other task indicators
fast_prompts = [
"Hi",
"Hello",
"Thanks",
"Bye",
"Yes",
"No",
]
for prompt in fast_prompts:
task_type, confidence, reason = classifier._classify_task_type(
prompt,
classifier._extract_features(prompt)
)
assert task_type == TaskType.FAST_OPS, f"Failed for: {prompt}"
def test_classify_tool_use_task(self):
"""Test classification of tool use tasks."""
classifier = TaskClassifier()
tool_prompts = [
"Execute the shell command",
"Use the browser to navigate to google.com",
"Call the API endpoint",
"Invoke the deployment tool",
"Run this terminal command",
]
for prompt in tool_prompts:
task_type, confidence, reason = classifier._classify_task_type(
prompt,
classifier._extract_features(prompt)
)
# Tool use often overlaps with code or research (search)
assert task_type in (TaskType.TOOL_USE, TaskType.CODE, TaskType.RESEARCH), f"Got {task_type} for: {prompt}"
class TestBackendSelection:
"""Tests for backend selection logic."""
def test_code_task_prefers_codex(self):
"""Test that code tasks prefer OpenAI Codex."""
classifier = TaskClassifier()
result = classifier.classify("Implement a Python class")
assert result.task_type == TaskType.CODE
assert result.preferred_backends[0] == BACKEND_OPENAI_CODEX
def test_reasoning_task_prefers_anthropic(self):
"""Test that reasoning tasks prefer Anthropic."""
classifier = TaskClassifier()
result = classifier.classify("Analyze the architectural trade-offs")
assert result.task_type == TaskType.REASONING
assert result.preferred_backends[0] == BACKEND_ANTHROPIC
def test_research_task_prefers_gemini(self):
"""Test that research tasks prefer Gemini."""
classifier = TaskClassifier()
result = classifier.classify("Research the latest papers on transformers")
assert result.task_type == TaskType.RESEARCH
assert result.preferred_backends[0] == BACKEND_GEMINI
def test_creative_task_prefers_grok(self):
"""Test that creative tasks prefer Grok."""
classifier = TaskClassifier()
result = classifier.classify("Write a creative story")
assert result.task_type == TaskType.CREATIVE
assert result.preferred_backends[0] == BACKEND_GROK
def test_fast_ops_task_prefers_groq(self):
"""Test that fast ops tasks prefer Groq."""
classifier = TaskClassifier()
result = classifier.classify("Quick status check")
assert result.task_type == TaskType.FAST_OPS
assert result.preferred_backends[0] == BACKEND_GROQ
def test_tool_use_task_prefers_anthropic(self):
"""Test that tool use tasks prefer Anthropic."""
classifier = TaskClassifier()
result = classifier.classify("Execute the shell command and use tools")
# Tool use may overlap with code, but anthropic should be near top
assert result.task_type in (TaskType.TOOL_USE, TaskType.CODE)
assert BACKEND_ANTHROPIC in result.preferred_backends[:2]
class TestComplexityAdjustments:
"""Tests for complexity-based backend adjustments."""
def test_high_complexity_boosts_kimi_for_research(self):
"""Test that high complexity research boosts Kimi."""
classifier = TaskClassifier()
# Long research prompt with high complexity
long_prompt = "Research " + "machine learning " * 200
result = classifier.classify(long_prompt)
if result.task_type == TaskType.RESEARCH and result.complexity == ComplexityLevel.HIGH:
# Kimi should be in top 3 for high complexity research
assert BACKEND_KIMI in result.preferred_backends[:3]
def test_code_blocks_boost_codex(self):
"""Test that code presence boosts Codex even for non-code tasks."""
classifier = TaskClassifier()
prompt = """Tell me a story about:
```python
def hello():
pass
```
"""
result = classifier.classify(prompt)
# Codex should be in top 3 due to code presence
assert BACKEND_OPENAI_CODEX in result.preferred_backends[:3]
class TestEdgeCases:
"""Tests for edge cases."""
def test_empty_prompt(self):
"""Test handling of empty prompt."""
classifier = TaskClassifier()
result = classifier.classify("")
assert result.task_type == TaskType.UNKNOWN
assert result.complexity == ComplexityLevel.LOW
assert result.confidence == 0.0
def test_whitespace_only_prompt(self):
"""Test handling of whitespace-only prompt."""
classifier = TaskClassifier()
result = classifier.classify(" \n\t ")
assert result.task_type == TaskType.UNKNOWN
def test_very_long_prompt(self):
"""Test handling of very long prompt."""
classifier = TaskClassifier()
long_prompt = "word " * 10000
result = classifier.classify(long_prompt)
assert result.complexity == ComplexityLevel.HIGH
assert len(result.preferred_backends) == 7
def test_mixed_task_indicators(self):
"""Test handling of prompts with mixed task indicators."""
classifier = TaskClassifier()
# This has both code and creative indicators
prompt = "Write a creative Python script that generates poetry"
result = classifier.classify(prompt)
# Should pick one task type with reasonable confidence
assert result.confidence > 0
assert result.task_type in (TaskType.CODE, TaskType.CREATIVE)
class TestDictionaryOutput:
"""Tests for dictionary output format."""
def test_to_dict_output(self):
"""Test conversion to dictionary."""
classifier = TaskClassifier()
result = classifier.classify("Implement a function")
output = classifier.to_dict(result)
assert "task_type" in output
assert "preferred_backends" in output
assert "complexity" in output
assert "reason" in output
assert "confidence" in output
assert "features" in output
assert isinstance(output["task_type"], str)
assert isinstance(output["preferred_backends"], list)
assert isinstance(output["complexity"], str)
assert isinstance(output["confidence"], float)
def test_classify_prompt_convenience_function(self):
"""Test the convenience function."""
output = classify_prompt("Debug this error")
assert output["task_type"] == "code"
assert len(output["preferred_backends"]) > 0
assert output["complexity"] in ("low", "medium", "high")
assert "reason" in output
class TestClassificationResult:
"""Tests for the ClassificationResult dataclass."""
def test_result_creation(self):
"""Test creation of ClassificationResult."""
result = ClassificationResult(
task_type=TaskType.CODE,
preferred_backends=[BACKEND_OPENAI_CODEX, BACKEND_ANTHROPIC],
complexity=ComplexityLevel.MEDIUM,
reason="Contains code keywords",
confidence=0.85,
features={"word_count": 50},
)
assert result.task_type == TaskType.CODE
assert result.preferred_backends[0] == BACKEND_OPENAI_CODEX
assert result.complexity == ComplexityLevel.MEDIUM
assert result.confidence == 0.85
# Integration tests
class TestIntegration:
"""Integration tests with realistic prompts."""
def test_code_review_scenario(self):
"""Test a code review scenario."""
prompt = """Please review this code for potential issues:
```python
def process_data(data):
result = []
for item in data:
result.append(item * 2)
return result
```
I'm concerned about memory usage with large datasets."""
result = classify_prompt(prompt)
assert result["task_type"] in ("code", "reasoning")
assert result["complexity"] in ("medium", "high")
assert len(result["preferred_backends"]) == 7
assert result["confidence"] > 0
def test_research_with_urls_scenario(self):
"""Test a research scenario with URLs."""
prompt = """Research the findings from these papers:
- https://arxiv.org/abs/2301.00001
- https://papers.nips.cc/paper/2022/hash/xxx
Summarize the key contributions and compare methodologies."""
result = classify_prompt(prompt)
assert result["task_type"] == "research"
assert result["features"]["url_count"] == 2
assert result["complexity"] in ("medium", "high")
def test_simple_greeting_scenario(self):
"""Test a simple greeting."""
result = classify_prompt("Hello! How are you doing today?")
assert result["task_type"] == "fast_ops"
assert result["complexity"] == "low"
assert result["preferred_backends"][0] == BACKEND_GROQ
if __name__ == "__main__":
pytest.main([__file__, "-v"])