diff --git a/tests/test_error_classifier.py b/tests/test_error_classifier.py new file mode 100644 index 000000000..c36eb5474 --- /dev/null +++ b/tests/test_error_classifier.py @@ -0,0 +1,55 @@ +""" +Tests for error classification (#752). +""" + +import pytest +from tools.error_classifier import classify_error, ErrorCategory, ErrorClassification + + +class TestErrorClassification: + def test_timeout_is_retryable(self): + err = Exception("Connection timed out") + result = classify_error(err) + assert result.category == ErrorCategory.RETRYABLE + assert result.should_retry is True + + def test_429_is_retryable(self): + err = Exception("Rate limit exceeded") + result = classify_error(err, response_code=429) + assert result.category == ErrorCategory.RETRYABLE + assert result.should_retry is True + + def test_404_is_permanent(self): + err = Exception("Not found") + result = classify_error(err, response_code=404) + assert result.category == ErrorCategory.PERMANENT + assert result.should_retry is False + + def test_403_is_permanent(self): + err = Exception("Forbidden") + result = classify_error(err, response_code=403) + assert result.category == ErrorCategory.PERMANENT + assert result.should_retry is False + + def test_500_is_retryable(self): + err = Exception("Internal server error") + result = classify_error(err, response_code=500) + assert result.category == ErrorCategory.RETRYABLE + assert result.should_retry is True + + def test_schema_error_is_permanent(self): + err = Exception("Schema validation failed") + result = classify_error(err) + assert result.category == ErrorCategory.PERMANENT + assert result.should_retry is False + + def test_unknown_is_retryable_with_caution(self): + err = Exception("Some unknown error") + result = classify_error(err) + assert result.category == ErrorCategory.UNKNOWN + assert result.should_retry is True + assert result.max_retries == 1 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tools/error_classifier.py b/tools/error_classifier.py new file mode 100644 index 000000000..d400fb2fa --- /dev/null +++ b/tools/error_classifier.py @@ -0,0 +1,233 @@ +""" +Tool Error Classification — Retryable vs Permanent. + +Classifies tool errors so the agent retries transient errors +but gives up on permanent ones immediately. +""" + +import logging +import re +import time +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Dict, Any + +logger = logging.getLogger(__name__) + + +class ErrorCategory(Enum): + """Error category classification.""" + RETRYABLE = "retryable" + PERMANENT = "permanent" + UNKNOWN = "unknown" + + +@dataclass +class ErrorClassification: + """Result of error classification.""" + category: ErrorCategory + reason: str + should_retry: bool + max_retries: int + backoff_seconds: float + error_code: Optional[int] = None + error_type: Optional[str] = None + + +# Retryable error patterns +_RETRYABLE_PATTERNS = [ + # HTTP status codes + (r"\b429\b", "rate limit", 3, 5.0), + (r"\b500\b", "server error", 3, 2.0), + (r"\b502\b", "bad gateway", 3, 2.0), + (r"\b503\b", "service unavailable", 3, 5.0), + (r"\b504\b", "gateway timeout", 3, 5.0), + + # Timeout patterns + (r"timeout", "timeout", 3, 2.0), + (r"timed out", "timeout", 3, 2.0), + (r"TimeoutExpired", "timeout", 3, 2.0), + + # Connection errors + (r"connection refused", "connection refused", 2, 5.0), + (r"connection reset", "connection reset", 2, 2.0), + (r"network unreachable", "network unreachable", 2, 10.0), + (r"DNS", "DNS error", 2, 5.0), + + # Transient errors + (r"temporary", "temporary error", 2, 2.0), + (r"transient", "transient error", 2, 2.0), + (r"retry", "retryable", 2, 2.0), +] + +# Permanent error patterns +_PERMANENT_PATTERNS = [ + # HTTP status codes + (r"\b400\b", "bad request", "Invalid request parameters"), + (r"\b401\b", "unauthorized", "Authentication failed"), + (r"\b403\b", "forbidden", "Access denied"), + (r"\b404\b", "not found", "Resource not found"), + (r"\b405\b", "method not allowed", "HTTP method not supported"), + (r"\b409\b", "conflict", "Resource conflict"), + (r"\b422\b", "unprocessable", "Validation error"), + + # Schema/validation errors + (r"schema", "schema error", "Invalid data schema"), + (r"validation", "validation error", "Input validation failed"), + (r"invalid.*json", "JSON error", "Invalid JSON"), + (r"JSONDecodeError", "JSON error", "JSON parsing failed"), + + # Authentication + (r"api.?key", "API key error", "Invalid or missing API key"), + (r"token.*expir", "token expired", "Authentication token expired"), + (r"permission", "permission error", "Insufficient permissions"), + + # Not found patterns + (r"not found", "not found", "Resource does not exist"), + (r"does not exist", "not found", "Resource does not exist"), + (r"no such file", "file not found", "File does not exist"), + + # Quota/billing + (r"quota", "quota exceeded", "Usage quota exceeded"), + (r"billing", "billing error", "Billing issue"), + (r"insufficient.*funds", "billing error", "Insufficient funds"), +] + + +def classify_error(error: Exception, response_code: Optional[int] = None) -> ErrorClassification: + """ + Classify an error as retryable or permanent. + + Args: + error: The exception that occurred + response_code: HTTP response code if available + + Returns: + ErrorClassification with retry guidance + """ + error_str = str(error).lower() + error_type = type(error).__name__ + + # Check response code first + if response_code: + if response_code in (429, 500, 502, 503, 504): + return ErrorClassification( + category=ErrorCategory.RETRYABLE, + reason=f"HTTP {response_code} - transient server error", + should_retry=True, + max_retries=3, + backoff_seconds=5.0 if response_code == 429 else 2.0, + error_code=response_code, + error_type=error_type, + ) + elif response_code in (400, 401, 403, 404, 405, 409, 422): + return ErrorClassification( + category=ErrorCategory.PERMANENT, + reason=f"HTTP {response_code} - client error", + should_retry=False, + max_retries=0, + backoff_seconds=0, + error_code=response_code, + error_type=error_type, + ) + + # Check retryable patterns + for pattern, reason, max_retries, backoff in _RETRYABLE_PATTERNS: + if re.search(pattern, error_str, re.IGNORECASE): + return ErrorClassification( + category=ErrorCategory.RETRYABLE, + reason=reason, + should_retry=True, + max_retries=max_retries, + backoff_seconds=backoff, + error_type=error_type, + ) + + # Check permanent patterns + for pattern, error_code, reason in _PERMANENT_PATTERNS: + if re.search(pattern, error_str, re.IGNORECASE): + return ErrorClassification( + category=ErrorCategory.PERMANENT, + reason=reason, + should_retry=False, + max_retries=0, + backoff_seconds=0, + error_type=error_type, + ) + + # Default: unknown, treat as retryable with caution + return ErrorClassification( + category=ErrorCategory.UNKNOWN, + reason=f"Unknown error type: {error_type}", + should_retry=True, + max_retries=1, + backoff_seconds=1.0, + error_type=error_type, + ) + + +def execute_with_retry( + func, + *args, + max_retries: int = 3, + backoff_base: float = 1.0, + **kwargs, +) -> Any: + """ + Execute a function with automatic retry on retryable errors. + + Args: + func: Function to execute + *args: Function arguments + max_retries: Maximum retry attempts + backoff_base: Base backoff time in seconds + **kwargs: Function keyword arguments + + Returns: + Function result + + Raises: + Exception: If permanent error or max retries exceeded + """ + last_error = None + + for attempt in range(max_retries + 1): + try: + return func(*args, **kwargs) + except Exception as e: + last_error = e + + # Classify the error + classification = classify_error(e) + + logger.info( + "Attempt %d/%d failed: %s (%s, retryable: %s)", + attempt + 1, max_retries + 1, + classification.reason, + classification.category.value, + classification.should_retry, + ) + + # If permanent error, fail immediately + if not classification.should_retry: + logger.error("Permanent error: %s", classification.reason) + raise + + # If this was the last attempt, raise + if attempt >= max_retries: + logger.error("Max retries (%d) exceeded", max_retries) + raise + + # Calculate backoff with exponential increase + backoff = backoff_base * (2 ** attempt) + logger.info("Retrying in %.1fs...", backoff) + time.sleep(backoff) + + # Should not reach here, but just in case + raise last_error + + +def format_error_report(classification: ErrorClassification) -> str: + """Format error classification as a report string.""" + icon = "🔄" if classification.should_retry else "❌" + return f"{icon} {classification.category.value}: {classification.reason}"