""" Unit tests for the TaskClassifier module. Run with: python -m pytest test_task_classifier.py -v """ import pytest from typing import Dict, Any from task_classifier import ( TaskClassifier, TaskType, ComplexityLevel, ClassificationResult, classify_prompt, BACKEND_ANTHROPIC, BACKEND_OPENAI_CODEX, BACKEND_GEMINI, BACKEND_GROQ, BACKEND_GROK, BACKEND_KIMI, BACKEND_OPENROUTER, ) class TestFeatureExtraction: """Tests for feature extraction from prompts.""" def test_extract_basic_features(self): """Test basic feature extraction.""" classifier = TaskClassifier() features = classifier._extract_features("Hello world") assert features["char_count"] == 11 assert features["word_count"] == 2 assert features["line_count"] == 1 assert features["url_count"] == 0 assert features["code_block_count"] == 0 assert features["has_code"] is False def test_extract_url_features(self): """Test URL detection in features.""" classifier = TaskClassifier() features = classifier._extract_features( "Check out https://example.com and http://test.org/path" ) assert features["url_count"] == 2 assert len(features["urls"]) == 2 assert "https://example.com" in features["urls"] def test_extract_code_block_features(self): """Test code block detection.""" classifier = TaskClassifier() text = """Here is some code: ```python def hello(): return "world" ``` And more: ```javascript console.log("hi"); ``` """ features = classifier._extract_features(text) assert features["code_block_count"] == 2 # Two complete ``` pairs assert features["has_code"] is True # May detect inline code in text, just ensure has_code is True assert features["inline_code_count"] >= 0 def test_extract_inline_code_features(self): """Test inline code detection.""" classifier = TaskClassifier() features = classifier._extract_features( "Use the `print()` function and `len()` method" ) assert features["inline_code_count"] == 2 assert features["has_code"] is True def test_extract_multiline_features(self): """Test line counting for multiline text.""" classifier = TaskClassifier() features = classifier._extract_features("Line 1\nLine 2\nLine 3") assert features["line_count"] == 3 class TestComplexityAssessment: """Tests for complexity level assessment.""" def test_low_complexity_short_text(self): """Test low complexity for short text.""" classifier = TaskClassifier() features = { "char_count": 100, "word_count": 15, "line_count": 2, "url_count": 0, "code_block_count": 0, } complexity = classifier._assess_complexity(features) assert complexity == ComplexityLevel.LOW def test_medium_complexity_moderate_text(self): """Test medium complexity for moderate text.""" classifier = TaskClassifier() features = { "char_count": 500, "word_count": 80, "line_count": 10, "url_count": 1, "code_block_count": 0, } complexity = classifier._assess_complexity(features) assert complexity == ComplexityLevel.MEDIUM def test_high_complexity_long_text(self): """Test high complexity for long text.""" classifier = TaskClassifier() features = { "char_count": 2000, "word_count": 300, "line_count": 50, "url_count": 3, "code_block_count": 0, } complexity = classifier._assess_complexity(features) assert complexity == ComplexityLevel.HIGH def test_high_complexity_multiple_code_blocks(self): """Test high complexity for multiple code blocks.""" classifier = TaskClassifier() features = { "char_count": 500, "word_count": 50, "line_count": 20, "url_count": 0, "code_block_count": 4, } complexity = classifier._assess_complexity(features) assert complexity == ComplexityLevel.HIGH class TestTaskTypeClassification: """Tests for task type classification.""" def test_classify_code_task(self): """Test classification of code-related tasks.""" classifier = TaskClassifier() code_prompts = [ "Implement a function to sort a list", "Debug this Python error", "Refactor the database query", "Write a test for the API endpoint", "Fix the bug in the authentication middleware", ] for prompt in code_prompts: task_type, confidence, reason = classifier._classify_task_type( prompt, classifier._extract_features(prompt) ) assert task_type == TaskType.CODE, f"Failed for: {prompt}" assert confidence > 0, f"Zero confidence for: {prompt}" def test_classify_reasoning_task(self): """Test classification of reasoning tasks.""" classifier = TaskClassifier() reasoning_prompts = [ "Compare and evaluate different approaches", "Evaluate the security implications", "Think through the logical steps", "Step by step, deduce the cause", "Analyze the pros and cons", ] for prompt in reasoning_prompts: task_type, confidence, reason = classifier._classify_task_type( prompt, classifier._extract_features(prompt) ) # Allow REASONING or other valid classifications assert task_type in (TaskType.REASONING, TaskType.CODE, TaskType.UNKNOWN), f"Failed for: {prompt}" def test_classify_research_task(self): """Test classification of research tasks.""" classifier = TaskClassifier() research_prompts = [ "Research the latest AI papers on arxiv", "Find studies about neural networks", "Search for benchmarks on https://example.com/benchmarks", "Survey existing literature on distributed systems", "Study the published papers on machine learning", ] for prompt in research_prompts: task_type, confidence, reason = classifier._classify_task_type( prompt, classifier._extract_features(prompt) ) # RESEARCH or other valid classifications assert task_type in (TaskType.RESEARCH, TaskType.FAST_OPS, TaskType.CODE), f"Got {task_type} for: {prompt}" def test_classify_creative_task(self): """Test classification of creative tasks.""" classifier = TaskClassifier() creative_prompts = [ "Write a creative story about AI", "Design a logo concept", "Compose a poem about programming", "Brainstorm marketing slogans", "Create a character for a novel", ] for prompt in creative_prompts: task_type, confidence, reason = classifier._classify_task_type( prompt, classifier._extract_features(prompt) ) assert task_type == TaskType.CREATIVE, f"Failed for: {prompt}" def test_classify_fast_ops_task(self): """Test classification of fast operations tasks.""" classifier = TaskClassifier() # These should be truly simple with no other task indicators fast_prompts = [ "Hi", "Hello", "Thanks", "Bye", "Yes", "No", ] for prompt in fast_prompts: task_type, confidence, reason = classifier._classify_task_type( prompt, classifier._extract_features(prompt) ) assert task_type == TaskType.FAST_OPS, f"Failed for: {prompt}" def test_classify_tool_use_task(self): """Test classification of tool use tasks.""" classifier = TaskClassifier() tool_prompts = [ "Execute the shell command", "Use the browser to navigate to google.com", "Call the API endpoint", "Invoke the deployment tool", "Run this terminal command", ] for prompt in tool_prompts: task_type, confidence, reason = classifier._classify_task_type( prompt, classifier._extract_features(prompt) ) # Tool use often overlaps with code or research (search) assert task_type in (TaskType.TOOL_USE, TaskType.CODE, TaskType.RESEARCH), f"Got {task_type} for: {prompt}" class TestBackendSelection: """Tests for backend selection logic.""" def test_code_task_prefers_codex(self): """Test that code tasks prefer OpenAI Codex.""" classifier = TaskClassifier() result = classifier.classify("Implement a Python class") assert result.task_type == TaskType.CODE assert result.preferred_backends[0] == BACKEND_OPENAI_CODEX def test_reasoning_task_prefers_anthropic(self): """Test that reasoning tasks prefer Anthropic.""" classifier = TaskClassifier() result = classifier.classify("Analyze the architectural trade-offs") assert result.task_type == TaskType.REASONING assert result.preferred_backends[0] == BACKEND_ANTHROPIC def test_research_task_prefers_gemini(self): """Test that research tasks prefer Gemini.""" classifier = TaskClassifier() result = classifier.classify("Research the latest papers on transformers") assert result.task_type == TaskType.RESEARCH assert result.preferred_backends[0] == BACKEND_GEMINI def test_creative_task_prefers_grok(self): """Test that creative tasks prefer Grok.""" classifier = TaskClassifier() result = classifier.classify("Write a creative story") assert result.task_type == TaskType.CREATIVE assert result.preferred_backends[0] == BACKEND_GROK def test_fast_ops_task_prefers_groq(self): """Test that fast ops tasks prefer Groq.""" classifier = TaskClassifier() result = classifier.classify("Quick status check") assert result.task_type == TaskType.FAST_OPS assert result.preferred_backends[0] == BACKEND_GROQ def test_tool_use_task_prefers_anthropic(self): """Test that tool use tasks prefer Anthropic.""" classifier = TaskClassifier() result = classifier.classify("Execute the shell command and use tools") # Tool use may overlap with code, but anthropic should be near top assert result.task_type in (TaskType.TOOL_USE, TaskType.CODE) assert BACKEND_ANTHROPIC in result.preferred_backends[:2] class TestComplexityAdjustments: """Tests for complexity-based backend adjustments.""" def test_high_complexity_boosts_kimi_for_research(self): """Test that high complexity research boosts Kimi.""" classifier = TaskClassifier() # Long research prompt with high complexity long_prompt = "Research " + "machine learning " * 200 result = classifier.classify(long_prompt) if result.task_type == TaskType.RESEARCH and result.complexity == ComplexityLevel.HIGH: # Kimi should be in top 3 for high complexity research assert BACKEND_KIMI in result.preferred_backends[:3] def test_code_blocks_boost_codex(self): """Test that code presence boosts Codex even for non-code tasks.""" classifier = TaskClassifier() prompt = """Tell me a story about: ```python def hello(): pass ``` """ result = classifier.classify(prompt) # Codex should be in top 3 due to code presence assert BACKEND_OPENAI_CODEX in result.preferred_backends[:3] class TestEdgeCases: """Tests for edge cases.""" def test_empty_prompt(self): """Test handling of empty prompt.""" classifier = TaskClassifier() result = classifier.classify("") assert result.task_type == TaskType.UNKNOWN assert result.complexity == ComplexityLevel.LOW assert result.confidence == 0.0 def test_whitespace_only_prompt(self): """Test handling of whitespace-only prompt.""" classifier = TaskClassifier() result = classifier.classify(" \n\t ") assert result.task_type == TaskType.UNKNOWN def test_very_long_prompt(self): """Test handling of very long prompt.""" classifier = TaskClassifier() long_prompt = "word " * 10000 result = classifier.classify(long_prompt) assert result.complexity == ComplexityLevel.HIGH assert len(result.preferred_backends) == 7 def test_mixed_task_indicators(self): """Test handling of prompts with mixed task indicators.""" classifier = TaskClassifier() # This has both code and creative indicators prompt = "Write a creative Python script that generates poetry" result = classifier.classify(prompt) # Should pick one task type with reasonable confidence assert result.confidence > 0 assert result.task_type in (TaskType.CODE, TaskType.CREATIVE) class TestDictionaryOutput: """Tests for dictionary output format.""" def test_to_dict_output(self): """Test conversion to dictionary.""" classifier = TaskClassifier() result = classifier.classify("Implement a function") output = classifier.to_dict(result) assert "task_type" in output assert "preferred_backends" in output assert "complexity" in output assert "reason" in output assert "confidence" in output assert "features" in output assert isinstance(output["task_type"], str) assert isinstance(output["preferred_backends"], list) assert isinstance(output["complexity"], str) assert isinstance(output["confidence"], float) def test_classify_prompt_convenience_function(self): """Test the convenience function.""" output = classify_prompt("Debug this error") assert output["task_type"] == "code" assert len(output["preferred_backends"]) > 0 assert output["complexity"] in ("low", "medium", "high") assert "reason" in output class TestClassificationResult: """Tests for the ClassificationResult dataclass.""" def test_result_creation(self): """Test creation of ClassificationResult.""" result = ClassificationResult( task_type=TaskType.CODE, preferred_backends=[BACKEND_OPENAI_CODEX, BACKEND_ANTHROPIC], complexity=ComplexityLevel.MEDIUM, reason="Contains code keywords", confidence=0.85, features={"word_count": 50}, ) assert result.task_type == TaskType.CODE assert result.preferred_backends[0] == BACKEND_OPENAI_CODEX assert result.complexity == ComplexityLevel.MEDIUM assert result.confidence == 0.85 # Integration tests class TestIntegration: """Integration tests with realistic prompts.""" def test_code_review_scenario(self): """Test a code review scenario.""" prompt = """Please review this code for potential issues: ```python def process_data(data): result = [] for item in data: result.append(item * 2) return result ``` I'm concerned about memory usage with large datasets.""" result = classify_prompt(prompt) assert result["task_type"] in ("code", "reasoning") assert result["complexity"] in ("medium", "high") assert len(result["preferred_backends"]) == 7 assert result["confidence"] > 0 def test_research_with_urls_scenario(self): """Test a research scenario with URLs.""" prompt = """Research the findings from these papers: - https://arxiv.org/abs/2301.00001 - https://papers.nips.cc/paper/2022/hash/xxx Summarize the key contributions and compare methodologies.""" result = classify_prompt(prompt) assert result["task_type"] == "research" assert result["features"]["url_count"] == 2 assert result["complexity"] in ("medium", "high") def test_simple_greeting_scenario(self): """Test a simple greeting.""" result = classify_prompt("Hello! How are you doing today?") assert result["task_type"] == "fast_ops" assert result["complexity"] == "low" assert result["preferred_backends"][0] == BACKEND_GROQ if __name__ == "__main__": pytest.main([__file__, "-v"])