Compare commits

..

2 Commits

Author SHA1 Message Date
10d7cd7d0c test(#752): Add tests for error classification
Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Failing after 44s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 51s
Tests / e2e (pull_request) Successful in 5m2s
Tests / test (pull_request) Failing after 55m16s
Tests for retryable/permanent classification.
Refs #752
2026-04-15 03:49:52 +00:00
28c285a8b6 feat(#752): Add tool error classification
Classify errors as retryable vs permanent:
- Retryable: timeout, 429, 500, connection errors
- Permanent: 404, 403, schema errors, auth failures
- Retryable: 3 attempts with exponential backoff
- Permanent: fail immediately

Resolves #752
2026-04-15 03:49:31 +00:00
4 changed files with 288 additions and 233 deletions

View File

@@ -1,41 +0,0 @@
"""
Tests for cost estimator tool (#745).
"""
import pytest
from tools.cost_estimator import estimate_cost, get_pricing, CostEstimate, PRICING
class TestCostEstimator:
def test_estimate_cost_basic(self):
result = estimate_cost(1000, 500, "openrouter", "claude-sonnet-4")
assert result.input_tokens == 1000
assert result.output_tokens == 500
assert result.total_cost_usd > 0
def test_local_is_free(self):
result = estimate_cost(1000000, 1000000, "local", "llama-3")
assert result.total_cost_usd == 0.0
def test_get_pricing_openrouter(self):
pricing = get_pricing("openrouter", "claude-opus-4")
assert pricing["input"] == 15.0
assert pricing["output"] == 75.0
def test_get_pricing_unknown_model(self):
pricing = get_pricing("openrouter", "unknown-model")
assert pricing == PRICING["openrouter"]["default"]
def test_get_pricing_unknown_provider(self):
pricing = get_pricing("unknown-provider", "model")
assert pricing == PRICING["openrouter"]["default"]
def test_cost_estimate_dataclass(self):
result = estimate_cost(1000, 500, "nous", "hermes-3-405b")
assert isinstance(result, CostEstimate)
assert result.provider == "nous"
assert result.model == "hermes-3-405b"
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,55 @@
"""
Tests for error classification (#752).
"""
import pytest
from tools.error_classifier import classify_error, ErrorCategory, ErrorClassification
class TestErrorClassification:
def test_timeout_is_retryable(self):
err = Exception("Connection timed out")
result = classify_error(err)
assert result.category == ErrorCategory.RETRYABLE
assert result.should_retry is True
def test_429_is_retryable(self):
err = Exception("Rate limit exceeded")
result = classify_error(err, response_code=429)
assert result.category == ErrorCategory.RETRYABLE
assert result.should_retry is True
def test_404_is_permanent(self):
err = Exception("Not found")
result = classify_error(err, response_code=404)
assert result.category == ErrorCategory.PERMANENT
assert result.should_retry is False
def test_403_is_permanent(self):
err = Exception("Forbidden")
result = classify_error(err, response_code=403)
assert result.category == ErrorCategory.PERMANENT
assert result.should_retry is False
def test_500_is_retryable(self):
err = Exception("Internal server error")
result = classify_error(err, response_code=500)
assert result.category == ErrorCategory.RETRYABLE
assert result.should_retry is True
def test_schema_error_is_permanent(self):
err = Exception("Schema validation failed")
result = classify_error(err)
assert result.category == ErrorCategory.PERMANENT
assert result.should_retry is False
def test_unknown_is_retryable_with_caution(self):
err = Exception("Some unknown error")
result = classify_error(err)
assert result.category == ErrorCategory.UNKNOWN
assert result.should_retry is True
assert result.max_retries == 1
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -1,192 +0,0 @@
"""
Provider Cost Estimator — Estimate API costs from token counts.
Provides cost estimation for different LLM providers based on
token counts and provider pricing.
"""
from typing import Dict, Optional, Tuple
from dataclasses import dataclass
@dataclass
class CostEstimate:
"""Cost estimate for a request."""
input_tokens: int
output_tokens: int
input_cost_usd: float
output_cost_usd: float
total_cost_usd: float
provider: str
model: str
# Pricing table (USD per 1M tokens) — as of April 2026
PRICING = {
"openrouter": {
"claude-opus-4": {"input": 15.0, "output": 75.0},
"claude-sonnet-4": {"input": 3.0, "output": 15.0},
"claude-haiku-3.5": {"input": 0.80, "output": 4.0},
"gpt-4o": {"input": 2.50, "output": 10.0},
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
"gemini-2.5-pro": {"input": 1.25, "output": 10.0},
"gemini-2.5-flash": {"input": 0.15, "output": 0.60},
"llama-4-scout": {"input": 0.20, "output": 0.80},
"llama-4-maverick": {"input": 0.50, "output": 2.0},
"default": {"input": 1.0, "output": 3.0},
},
"nous": {
"hermes-3-405b": {"input": 5.0, "output": 5.0},
"mixtral-8x22b": {"input": 2.0, "output": 2.0},
"hermes-2-mixtral-8x7b": {"input": 0.90, "output": 0.90},
"default": {"input": 2.0, "output": 2.0},
},
"anthropic": {
"claude-opus-4": {"input": 15.0, "output": 75.0},
"claude-sonnet-4": {"input": 3.0, "output": 15.0},
"claude-haiku-3.5": {"input": 0.80, "output": 4.0},
"default": {"input": 3.0, "output": 15.0},
},
"local": {
# Local models are free (electricity only)
"default": {"input": 0.0, "output": 0.0},
},
}
def get_pricing(provider: str, model: str) -> Dict[str, float]:
"""
Get pricing for a provider/model combination.
Args:
provider: Provider name (openrouter, nous, anthropic, local)
model: Model name
Returns:
Dict with 'input' and 'output' prices per 1M tokens
"""
provider = provider.lower().strip()
model = model.lower().strip()
provider_pricing = PRICING.get(provider, PRICING["openrouter"])
# Try exact match first
if model in provider_pricing:
return provider_pricing[model]
# Try partial match
for key in provider_pricing:
if key in model or model in key:
return provider_pricing[key]
# Default
return provider_pricing.get("default", {"input": 1.0, "output": 3.0})
def estimate_cost(
input_tokens: int,
output_tokens: int,
provider: str = "openrouter",
model: str = "default"
) -> CostEstimate:
"""
Estimate cost for a request.
Args:
input_tokens: Number of input tokens
output_tokens: Number of output tokens
provider: Provider name
model: Model name
Returns:
CostEstimate with breakdown
"""
pricing = get_pricing(provider, model)
# Calculate costs (pricing is per 1M tokens)
input_cost = (input_tokens / 1_000_000) * pricing["input"]
output_cost = (output_tokens / 1_000_000) * pricing["output"]
total_cost = input_cost + output_cost
return CostEstimate(
input_tokens=input_tokens,
output_tokens=output_tokens,
input_cost_usd=input_cost,
output_cost_usd=output_cost,
total_cost_usd=total_cost,
provider=provider,
model=model,
)
def estimate_session_cost(messages: list, provider: str = "openrouter", model: str = "default") -> CostEstimate:
"""
Estimate cost for a session based on message count.
Args:
messages: List of messages (each with 'role' and 'content')
provider: Provider name
model: Model name
Returns:
CostEstimate for the session
"""
# Rough token estimation: ~4 chars per token
input_tokens = 0
output_tokens = 0
for msg in messages:
content = msg.get("content", "")
if isinstance(content, str):
tokens = len(content) // 4
if msg.get("role") == "user":
input_tokens += tokens
elif msg.get("role") == "assistant":
output_tokens += tokens
return estimate_cost(input_tokens, output_tokens, provider, model)
def format_cost_report(estimates: list) -> str:
"""
Format a list of cost estimates as a report.
Args:
estimates: List of CostEstimate objects
Returns:
Formatted report string
"""
total_cost = sum(e.total_cost_usd for e in estimates)
total_input = sum(e.input_tokens for e in estimates)
total_output = sum(e.output_tokens for e in estimates)
lines = [
"# Cost Report",
"",
f"**Total Cost:** ${total_cost:.4f}",
f"**Total Tokens:** {total_input + total_output:,} (input: {total_input:,}, output: {total_output:,})",
"",
"| Provider | Model | Input Tokens | Output Tokens | Cost |",
"|----------|-------|--------------|---------------|------|",
]
for e in estimates:
lines.append(f"| {e.provider} | {e.model} | {e.input_tokens:,} | {e.output_tokens:,} | ${e.total_cost_usd:.4f} |")
lines.append("")
lines.append(f"*Generated by cost_estimator.py*")
return "\n".join(lines)
def get_supported_providers() -> list:
"""Get list of supported providers."""
return list(PRICING.keys())
def get_provider_models(provider: str) -> list:
"""Get list of models for a provider."""
provider = provider.lower().strip()
provider_pricing = PRICING.get(provider, {})
return [k for k in provider_pricing.keys() if k != "default"]

233
tools/error_classifier.py Normal file
View File

@@ -0,0 +1,233 @@
"""
Tool Error Classification — Retryable vs Permanent.
Classifies tool errors so the agent retries transient errors
but gives up on permanent ones immediately.
"""
import logging
import re
import time
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Dict, Any
logger = logging.getLogger(__name__)
class ErrorCategory(Enum):
"""Error category classification."""
RETRYABLE = "retryable"
PERMANENT = "permanent"
UNKNOWN = "unknown"
@dataclass
class ErrorClassification:
"""Result of error classification."""
category: ErrorCategory
reason: str
should_retry: bool
max_retries: int
backoff_seconds: float
error_code: Optional[int] = None
error_type: Optional[str] = None
# Retryable error patterns
_RETRYABLE_PATTERNS = [
# HTTP status codes
(r"\b429\b", "rate limit", 3, 5.0),
(r"\b500\b", "server error", 3, 2.0),
(r"\b502\b", "bad gateway", 3, 2.0),
(r"\b503\b", "service unavailable", 3, 5.0),
(r"\b504\b", "gateway timeout", 3, 5.0),
# Timeout patterns
(r"timeout", "timeout", 3, 2.0),
(r"timed out", "timeout", 3, 2.0),
(r"TimeoutExpired", "timeout", 3, 2.0),
# Connection errors
(r"connection refused", "connection refused", 2, 5.0),
(r"connection reset", "connection reset", 2, 2.0),
(r"network unreachable", "network unreachable", 2, 10.0),
(r"DNS", "DNS error", 2, 5.0),
# Transient errors
(r"temporary", "temporary error", 2, 2.0),
(r"transient", "transient error", 2, 2.0),
(r"retry", "retryable", 2, 2.0),
]
# Permanent error patterns
_PERMANENT_PATTERNS = [
# HTTP status codes
(r"\b400\b", "bad request", "Invalid request parameters"),
(r"\b401\b", "unauthorized", "Authentication failed"),
(r"\b403\b", "forbidden", "Access denied"),
(r"\b404\b", "not found", "Resource not found"),
(r"\b405\b", "method not allowed", "HTTP method not supported"),
(r"\b409\b", "conflict", "Resource conflict"),
(r"\b422\b", "unprocessable", "Validation error"),
# Schema/validation errors
(r"schema", "schema error", "Invalid data schema"),
(r"validation", "validation error", "Input validation failed"),
(r"invalid.*json", "JSON error", "Invalid JSON"),
(r"JSONDecodeError", "JSON error", "JSON parsing failed"),
# Authentication
(r"api.?key", "API key error", "Invalid or missing API key"),
(r"token.*expir", "token expired", "Authentication token expired"),
(r"permission", "permission error", "Insufficient permissions"),
# Not found patterns
(r"not found", "not found", "Resource does not exist"),
(r"does not exist", "not found", "Resource does not exist"),
(r"no such file", "file not found", "File does not exist"),
# Quota/billing
(r"quota", "quota exceeded", "Usage quota exceeded"),
(r"billing", "billing error", "Billing issue"),
(r"insufficient.*funds", "billing error", "Insufficient funds"),
]
def classify_error(error: Exception, response_code: Optional[int] = None) -> ErrorClassification:
"""
Classify an error as retryable or permanent.
Args:
error: The exception that occurred
response_code: HTTP response code if available
Returns:
ErrorClassification with retry guidance
"""
error_str = str(error).lower()
error_type = type(error).__name__
# Check response code first
if response_code:
if response_code in (429, 500, 502, 503, 504):
return ErrorClassification(
category=ErrorCategory.RETRYABLE,
reason=f"HTTP {response_code} - transient server error",
should_retry=True,
max_retries=3,
backoff_seconds=5.0 if response_code == 429 else 2.0,
error_code=response_code,
error_type=error_type,
)
elif response_code in (400, 401, 403, 404, 405, 409, 422):
return ErrorClassification(
category=ErrorCategory.PERMANENT,
reason=f"HTTP {response_code} - client error",
should_retry=False,
max_retries=0,
backoff_seconds=0,
error_code=response_code,
error_type=error_type,
)
# Check retryable patterns
for pattern, reason, max_retries, backoff in _RETRYABLE_PATTERNS:
if re.search(pattern, error_str, re.IGNORECASE):
return ErrorClassification(
category=ErrorCategory.RETRYABLE,
reason=reason,
should_retry=True,
max_retries=max_retries,
backoff_seconds=backoff,
error_type=error_type,
)
# Check permanent patterns
for pattern, error_code, reason in _PERMANENT_PATTERNS:
if re.search(pattern, error_str, re.IGNORECASE):
return ErrorClassification(
category=ErrorCategory.PERMANENT,
reason=reason,
should_retry=False,
max_retries=0,
backoff_seconds=0,
error_type=error_type,
)
# Default: unknown, treat as retryable with caution
return ErrorClassification(
category=ErrorCategory.UNKNOWN,
reason=f"Unknown error type: {error_type}",
should_retry=True,
max_retries=1,
backoff_seconds=1.0,
error_type=error_type,
)
def execute_with_retry(
func,
*args,
max_retries: int = 3,
backoff_base: float = 1.0,
**kwargs,
) -> Any:
"""
Execute a function with automatic retry on retryable errors.
Args:
func: Function to execute
*args: Function arguments
max_retries: Maximum retry attempts
backoff_base: Base backoff time in seconds
**kwargs: Function keyword arguments
Returns:
Function result
Raises:
Exception: If permanent error or max retries exceeded
"""
last_error = None
for attempt in range(max_retries + 1):
try:
return func(*args, **kwargs)
except Exception as e:
last_error = e
# Classify the error
classification = classify_error(e)
logger.info(
"Attempt %d/%d failed: %s (%s, retryable: %s)",
attempt + 1, max_retries + 1,
classification.reason,
classification.category.value,
classification.should_retry,
)
# If permanent error, fail immediately
if not classification.should_retry:
logger.error("Permanent error: %s", classification.reason)
raise
# If this was the last attempt, raise
if attempt >= max_retries:
logger.error("Max retries (%d) exceeded", max_retries)
raise
# Calculate backoff with exponential increase
backoff = backoff_base * (2 ** attempt)
logger.info("Retrying in %.1fs...", backoff)
time.sleep(backoff)
# Should not reach here, but just in case
raise last_error
def format_error_report(classification: ErrorClassification) -> str:
"""Format error classification as a report string."""
icon = "🔄" if classification.should_retry else ""
return f"{icon} {classification.category.value}: {classification.reason}"