Compare commits

..

2 Commits

Author SHA1 Message Date
10d7cd7d0c test(#752): Add tests for error classification
Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Failing after 44s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 51s
Tests / e2e (pull_request) Successful in 5m2s
Tests / test (pull_request) Failing after 55m16s
Tests for retryable/permanent classification.
Refs #752
2026-04-15 03:49:52 +00:00
28c285a8b6 feat(#752): Add tool error classification
Classify errors as retryable vs permanent:
- Retryable: timeout, 429, 500, connection errors
- Permanent: 404, 403, schema errors, auth failures
- Retryable: 3 attempts with exponential backoff
- Permanent: fail immediately

Resolves #752
2026-04-15 03:49:31 +00:00
4 changed files with 288 additions and 328 deletions

View File

@@ -1,223 +0,0 @@
"""
Session Model Metadata — Persist model context info per session
When a session switches models mid-conversation, context length and
token budget need to be updated to prevent silent truncation.
Issue: #741
"""
import json
import logging
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Any, Dict, Optional
logger = logging.getLogger(__name__)
HERMES_HOME = Path.home() / ".hermes"
# Common model context lengths (tokens)
KNOWN_CONTEXT_LENGTHS = {
# Anthropic
"claude-opus-4-6": 200000,
"claude-sonnet-4": 200000,
"claude-3.5-sonnet": 200000,
"claude-3-haiku": 200000,
# OpenAI
"gpt-4o": 128000,
"gpt-4-turbo": 128000,
"gpt-4": 8192,
"gpt-3.5-turbo": 16385,
# Nous / open models
"hermes-3-llama-3.1-405b": 131072,
"hermes-3-llama-3.1-70b": 131072,
"deepseek-r1": 131072,
"deepseek-v3": 131072,
# Local
"llama-3.1-8b": 131072,
"llama-3.1-70b": 131072,
"qwen-2.5-72b": 131072,
# Xiaomi
"mimo-v2-pro": 131072,
"mimo-v2-flash": 131072,
# Defaults
"default": 4096,
}
# Reserve tokens for system prompt, response, and overhead
TOKEN_RESERVE = 2000
@dataclass
class ModelMetadata:
"""Metadata for a model in a session."""
model: str
provider: str
context_length: int
available_for_input: int # context_length - reserve
current_tokens_used: int = 0
@property
def remaining_tokens(self) -> int:
"""Tokens remaining for new input."""
return max(0, self.available_for_input - self.current_tokens_used)
@property
def utilization_pct(self) -> float:
"""Percentage of context used."""
if self.available_for_input == 0:
return 0.0
return (self.current_tokens_used / self.available_for_input) * 100
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
def get_context_length(model: str) -> int:
"""Get context length for a model."""
model_lower = model.lower()
# Check exact match
if model_lower in KNOWN_CONTEXT_LENGTHS:
return KNOWN_CONTEXT_LENGTHS[model_lower]
# Check partial match
for key, length in KNOWN_CONTEXT_LENGTHS.items():
if key in model_lower:
return length
return KNOWN_CONTEXT_LENGTHS["default"]
def create_metadata(model: str, provider: str = "", current_tokens: int = 0) -> ModelMetadata:
"""Create model metadata."""
context_length = get_context_length(model)
available = max(0, context_length - TOKEN_RESERVE)
return ModelMetadata(
model=model,
provider=provider,
context_length=context_length,
available_for_input=available,
current_tokens_used=current_tokens
)
def check_model_switch(
old_model: str,
new_model: str,
current_tokens: int
) -> Dict[str, Any]:
"""
Check impact of switching models mid-session.
Returns:
Dict with switch analysis including warnings
"""
old_ctx = get_context_length(old_model)
new_ctx = get_context_length(new_model)
old_available = old_ctx - TOKEN_RESERVE
new_available = new_ctx - TOKEN_RESERVE
result = {
"old_model": old_model,
"new_model": new_model,
"old_context": old_ctx,
"new_context": new_ctx,
"current_tokens": current_tokens,
"fits_in_new": current_tokens <= new_available,
"truncation_needed": max(0, current_tokens - new_available),
"warning": None,
}
if not result["fits_in_new"]:
result["warning"] = (
f"Switching to {new_model} ({new_ctx:,} ctx) with {current_tokens:,} tokens "
f"will truncate {result['truncation_needed']:,} tokens of history. "
f"Consider starting a new session."
)
if new_ctx < old_ctx:
reduction = old_ctx - new_ctx
result["warning"] = (
f"New model has {reduction:,} fewer tokens of context. "
f"({old_ctx:,} -> {new_ctx:,})"
)
return result
class SessionModelTracker:
"""Track model metadata for a session."""
def __init__(self, session_id: str):
self.session_id = session_id
self.metadata: Optional[ModelMetadata] = None
self.history: list = [] # Model switch history
def set_model(self, model: str, provider: str = "", tokens_used: int = 0):
"""Set the current model for the session."""
old_model = self.metadata.model if self.metadata else None
self.metadata = create_metadata(model, provider, tokens_used)
# Record switch in history
if old_model and old_model != model:
self.history.append({
"from": old_model,
"to": model,
"tokens_at_switch": tokens_used,
"context_length": self.metadata.context_length
})
logger.info(
"Session %s: model=%s context=%d available=%d",
self.session_id[:12], model,
self.metadata.context_length,
self.metadata.available_for_input
)
def update_tokens(self, tokens: int):
"""Update current token usage."""
if self.metadata:
self.metadata.current_tokens_used = tokens
def get_remaining(self) -> int:
"""Get remaining tokens."""
if not self.metadata:
return 0
return self.metadata.remaining_tokens
def can_fit(self, additional_tokens: int) -> bool:
"""Check if additional tokens fit in context."""
if not self.metadata:
return False
return self.metadata.remaining_tokens >= additional_tokens
def get_warning(self) -> Optional[str]:
"""Get warning if context is running low."""
if not self.metadata:
return None
util = self.metadata.utilization_pct
if util > 90:
return f"Context {util:.0f}% full. Consider compression or new session."
if util > 75:
return f"Context {util:.0f}% full."
return None
def to_dict(self) -> Dict[str, Any]:
"""Export state."""
return {
"session_id": self.session_id,
"metadata": self.metadata.to_dict() if self.metadata else None,
"history": self.history
}

View File

@@ -0,0 +1,55 @@
"""
Tests for error classification (#752).
"""
import pytest
from tools.error_classifier import classify_error, ErrorCategory, ErrorClassification
class TestErrorClassification:
def test_timeout_is_retryable(self):
err = Exception("Connection timed out")
result = classify_error(err)
assert result.category == ErrorCategory.RETRYABLE
assert result.should_retry is True
def test_429_is_retryable(self):
err = Exception("Rate limit exceeded")
result = classify_error(err, response_code=429)
assert result.category == ErrorCategory.RETRYABLE
assert result.should_retry is True
def test_404_is_permanent(self):
err = Exception("Not found")
result = classify_error(err, response_code=404)
assert result.category == ErrorCategory.PERMANENT
assert result.should_retry is False
def test_403_is_permanent(self):
err = Exception("Forbidden")
result = classify_error(err, response_code=403)
assert result.category == ErrorCategory.PERMANENT
assert result.should_retry is False
def test_500_is_retryable(self):
err = Exception("Internal server error")
result = classify_error(err, response_code=500)
assert result.category == ErrorCategory.RETRYABLE
assert result.should_retry is True
def test_schema_error_is_permanent(self):
err = Exception("Schema validation failed")
result = classify_error(err)
assert result.category == ErrorCategory.PERMANENT
assert result.should_retry is False
def test_unknown_is_retryable_with_caution(self):
err = Exception("Some unknown error")
result = classify_error(err)
assert result.category == ErrorCategory.UNKNOWN
assert result.should_retry is True
assert result.max_retries == 1
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -1,105 +0,0 @@
"""
Tests for session model metadata
Issue: #741
"""
import unittest
from agent.session_model_metadata import (
get_context_length,
create_metadata,
check_model_switch,
SessionModelTracker,
)
class TestContextLength(unittest.TestCase):
def test_known_model(self):
ctx = get_context_length("claude-opus-4-6")
self.assertEqual(ctx, 200000)
def test_partial_match(self):
ctx = get_context_length("anthropic/claude-sonnet-4")
self.assertEqual(ctx, 200000)
def test_unknown_model(self):
ctx = get_context_length("unknown-model-xyz")
self.assertEqual(ctx, 4096)
class TestModelMetadata(unittest.TestCase):
def test_create(self):
meta = create_metadata("gpt-4o", "openai", 1000)
self.assertEqual(meta.context_length, 128000)
self.assertEqual(meta.current_tokens_used, 1000)
self.assertGreater(meta.remaining_tokens, 0)
def test_utilization(self):
meta = create_metadata("gpt-4o", "openai", 64000)
self.assertAlmostEqual(meta.utilization_pct, 50.0, delta=1)
class TestModelSwitch(unittest.TestCase):
def test_safe_switch(self):
result = check_model_switch("gpt-3.5-turbo", "gpt-4o", 5000)
self.assertTrue(result["fits_in_new"])
self.assertIsNone(result["warning"])
def test_truncation_warning(self):
result = check_model_switch("gpt-4o", "gpt-3.5-turbo", 20000)
self.assertFalse(result["fits_in_new"])
self.assertIsNotNone(result["warning"])
self.assertIn("truncate", result["warning"].lower())
def test_downgrade_warning(self):
result = check_model_switch("claude-opus-4-6", "gpt-4", 5000)
self.assertIsNotNone(result["warning"])
class TestSessionModelTracker(unittest.TestCase):
def test_set_model(self):
tracker = SessionModelTracker("test")
tracker.set_model("gpt-4o", "openai")
self.assertEqual(tracker.metadata.model, "gpt-4o")
def test_update_tokens(self):
tracker = SessionModelTracker("test")
tracker.set_model("gpt-4o")
tracker.update_tokens(5000)
self.assertEqual(tracker.metadata.current_tokens_used, 5000)
def test_remaining(self):
tracker = SessionModelTracker("test")
tracker.set_model("gpt-4o")
tracker.update_tokens(10000)
self.assertGreater(tracker.get_remaining(), 0)
def test_can_fit(self):
tracker = SessionModelTracker("test")
tracker.set_model("gpt-4o")
tracker.update_tokens(10000)
self.assertTrue(tracker.can_fit(5000))
self.assertFalse(tracker.can_fit(200000))
def test_warning_low_context(self):
tracker = SessionModelTracker("test")
tracker.set_model("gpt-4o")
tracker.update_tokens(115000) # ~90% used
warning = tracker.get_warning()
self.assertIsNotNone(warning)
def test_model_switch_history(self):
tracker = SessionModelTracker("test")
tracker.set_model("gpt-4o", "openai")
tracker.update_tokens(5000)
tracker.set_model("claude-opus-4-6", "anthropic")
self.assertEqual(len(tracker.history), 1)
self.assertEqual(tracker.history[0]["from"], "gpt-4o")
if __name__ == "__main__":
unittest.main()

233
tools/error_classifier.py Normal file
View File

@@ -0,0 +1,233 @@
"""
Tool Error Classification — Retryable vs Permanent.
Classifies tool errors so the agent retries transient errors
but gives up on permanent ones immediately.
"""
import logging
import re
import time
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Dict, Any
logger = logging.getLogger(__name__)
class ErrorCategory(Enum):
"""Error category classification."""
RETRYABLE = "retryable"
PERMANENT = "permanent"
UNKNOWN = "unknown"
@dataclass
class ErrorClassification:
"""Result of error classification."""
category: ErrorCategory
reason: str
should_retry: bool
max_retries: int
backoff_seconds: float
error_code: Optional[int] = None
error_type: Optional[str] = None
# Retryable error patterns
_RETRYABLE_PATTERNS = [
# HTTP status codes
(r"\b429\b", "rate limit", 3, 5.0),
(r"\b500\b", "server error", 3, 2.0),
(r"\b502\b", "bad gateway", 3, 2.0),
(r"\b503\b", "service unavailable", 3, 5.0),
(r"\b504\b", "gateway timeout", 3, 5.0),
# Timeout patterns
(r"timeout", "timeout", 3, 2.0),
(r"timed out", "timeout", 3, 2.0),
(r"TimeoutExpired", "timeout", 3, 2.0),
# Connection errors
(r"connection refused", "connection refused", 2, 5.0),
(r"connection reset", "connection reset", 2, 2.0),
(r"network unreachable", "network unreachable", 2, 10.0),
(r"DNS", "DNS error", 2, 5.0),
# Transient errors
(r"temporary", "temporary error", 2, 2.0),
(r"transient", "transient error", 2, 2.0),
(r"retry", "retryable", 2, 2.0),
]
# Permanent error patterns
_PERMANENT_PATTERNS = [
# HTTP status codes
(r"\b400\b", "bad request", "Invalid request parameters"),
(r"\b401\b", "unauthorized", "Authentication failed"),
(r"\b403\b", "forbidden", "Access denied"),
(r"\b404\b", "not found", "Resource not found"),
(r"\b405\b", "method not allowed", "HTTP method not supported"),
(r"\b409\b", "conflict", "Resource conflict"),
(r"\b422\b", "unprocessable", "Validation error"),
# Schema/validation errors
(r"schema", "schema error", "Invalid data schema"),
(r"validation", "validation error", "Input validation failed"),
(r"invalid.*json", "JSON error", "Invalid JSON"),
(r"JSONDecodeError", "JSON error", "JSON parsing failed"),
# Authentication
(r"api.?key", "API key error", "Invalid or missing API key"),
(r"token.*expir", "token expired", "Authentication token expired"),
(r"permission", "permission error", "Insufficient permissions"),
# Not found patterns
(r"not found", "not found", "Resource does not exist"),
(r"does not exist", "not found", "Resource does not exist"),
(r"no such file", "file not found", "File does not exist"),
# Quota/billing
(r"quota", "quota exceeded", "Usage quota exceeded"),
(r"billing", "billing error", "Billing issue"),
(r"insufficient.*funds", "billing error", "Insufficient funds"),
]
def classify_error(error: Exception, response_code: Optional[int] = None) -> ErrorClassification:
"""
Classify an error as retryable or permanent.
Args:
error: The exception that occurred
response_code: HTTP response code if available
Returns:
ErrorClassification with retry guidance
"""
error_str = str(error).lower()
error_type = type(error).__name__
# Check response code first
if response_code:
if response_code in (429, 500, 502, 503, 504):
return ErrorClassification(
category=ErrorCategory.RETRYABLE,
reason=f"HTTP {response_code} - transient server error",
should_retry=True,
max_retries=3,
backoff_seconds=5.0 if response_code == 429 else 2.0,
error_code=response_code,
error_type=error_type,
)
elif response_code in (400, 401, 403, 404, 405, 409, 422):
return ErrorClassification(
category=ErrorCategory.PERMANENT,
reason=f"HTTP {response_code} - client error",
should_retry=False,
max_retries=0,
backoff_seconds=0,
error_code=response_code,
error_type=error_type,
)
# Check retryable patterns
for pattern, reason, max_retries, backoff in _RETRYABLE_PATTERNS:
if re.search(pattern, error_str, re.IGNORECASE):
return ErrorClassification(
category=ErrorCategory.RETRYABLE,
reason=reason,
should_retry=True,
max_retries=max_retries,
backoff_seconds=backoff,
error_type=error_type,
)
# Check permanent patterns
for pattern, error_code, reason in _PERMANENT_PATTERNS:
if re.search(pattern, error_str, re.IGNORECASE):
return ErrorClassification(
category=ErrorCategory.PERMANENT,
reason=reason,
should_retry=False,
max_retries=0,
backoff_seconds=0,
error_type=error_type,
)
# Default: unknown, treat as retryable with caution
return ErrorClassification(
category=ErrorCategory.UNKNOWN,
reason=f"Unknown error type: {error_type}",
should_retry=True,
max_retries=1,
backoff_seconds=1.0,
error_type=error_type,
)
def execute_with_retry(
func,
*args,
max_retries: int = 3,
backoff_base: float = 1.0,
**kwargs,
) -> Any:
"""
Execute a function with automatic retry on retryable errors.
Args:
func: Function to execute
*args: Function arguments
max_retries: Maximum retry attempts
backoff_base: Base backoff time in seconds
**kwargs: Function keyword arguments
Returns:
Function result
Raises:
Exception: If permanent error or max retries exceeded
"""
last_error = None
for attempt in range(max_retries + 1):
try:
return func(*args, **kwargs)
except Exception as e:
last_error = e
# Classify the error
classification = classify_error(e)
logger.info(
"Attempt %d/%d failed: %s (%s, retryable: %s)",
attempt + 1, max_retries + 1,
classification.reason,
classification.category.value,
classification.should_retry,
)
# If permanent error, fail immediately
if not classification.should_retry:
logger.error("Permanent error: %s", classification.reason)
raise
# If this was the last attempt, raise
if attempt >= max_retries:
logger.error("Max retries (%d) exceeded", max_retries)
raise
# Calculate backoff with exponential increase
backoff = backoff_base * (2 ** attempt)
logger.info("Retrying in %.1fs...", backoff)
time.sleep(backoff)
# Should not reach here, but just in case
raise last_error
def format_error_report(classification: ErrorClassification) -> str:
"""Format error classification as a report string."""
icon = "🔄" if classification.should_retry else ""
return f"{icon} {classification.category.value}: {classification.reason}"