Compare commits

..

2 Commits

Author SHA1 Message Date
10d7cd7d0c test(#752): Add tests for error classification
Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Failing after 44s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 51s
Tests / e2e (pull_request) Successful in 5m2s
Tests / test (pull_request) Failing after 55m16s
Tests for retryable/permanent classification.
Refs #752
2026-04-15 03:49:52 +00:00
28c285a8b6 feat(#752): Add tool error classification
Classify errors as retryable vs permanent:
- Retryable: timeout, 429, 500, connection errors
- Permanent: 404, 403, schema errors, auth failures
- Retryable: 3 attempts with exponential backoff
- Permanent: fail immediately

Resolves #752
2026-04-15 03:49:31 +00:00
4 changed files with 288 additions and 305 deletions

View File

@@ -1,221 +0,0 @@
"""
Session Compaction with Fact Extraction — #748
Before compressing a long conversation, extracts durable facts
(user preferences, corrections, project details) and saves them
to the fact store. Then compresses the conversation.
This ensures key information survives context limits.
Usage:
from agent.session_compaction import compact_session
# In the conversation loop, when context is near limit:
compact_session(messages, fact_store)
"""
import json
import re
from typing import Any, Dict, List, Optional, Tuple
# ---------------------------------------------------------------------------
# Fact Extraction Patterns
# ---------------------------------------------------------------------------
# Patterns that indicate durable facts worth preserving
_FACT_PATTERNS = [
# User preferences
(r"(?:i prefer|i like|i always|my preference is|remember that i)\s+(.+?)(?:\.|$)", "user_pref"),
(r"(?:call me|my name is|i\'m)\s+([A-Z][a-z]+)", "user_name"),
(r"(?:don\'t|do not|never)\s+(?:use|do|show|tell)\s+(.+?)(?:\.|$)", "user_constraint"),
# Corrections
(r"(?:actually|no,?|correction:?)\s+(.+?)(?:\.|$)", "correction"),
(r"(?:that\'s wrong|not correct|i meant)\s+(.+?)(?:\.|$)", "correction"),
# Project facts
(r"(?:the project|this repo|the codebase)\s+(?:is|has|uses|runs)\s+(.+?)(?:\.|$)", "project_fact"),
(r"(?:we use|our stack is|deployed on)\s+(.+?)(?:\.|$)", "project_fact"),
# Technical facts
(r"(?:the server|the service|the endpoint)\s+(?:is|runs on|listens on)\s+(.+?)(?:\.|$)", "technical"),
(r"(?:port|url|address|host)\s*(?::|is|=)\s*(.+?)(?:\.|$)", "technical"),
]
def extract_facts_from_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Scan conversation messages for durable facts.
Returns list of fact dicts suitable for fact_store.
"""
facts = []
seen = set() # Deduplicate
for msg in messages:
if msg.get("role") != "user":
continue
content = msg.get("content", "")
if not isinstance(content, str) or len(content) < 10:
continue
for pattern, category in _FACT_PATTERNS:
matches = re.findall(pattern, content, re.IGNORECASE)
for match in matches:
if isinstance(match, tuple):
match = match[0] if match else ""
fact_text = match.strip()
if len(fact_text) < 5 or len(fact_text) > 200:
continue
# Deduplicate
dedup_key = f"{category}:{fact_text.lower()}"
if dedup_key in seen:
continue
seen.add(dedup_key)
facts.append({
"content": fact_text,
"category": category,
"source": "session_compaction",
"trust": 0.7, # Medium trust — extracted, not explicitly stated
})
return facts
def extract_preferences(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Extract user preferences specifically."""
prefs = []
pref_patterns = [
r"(?:i prefer|i like|i want|use|always)\s+(.+?)(?:\.|$)",
r"(?:my (?:preferred|favorite|default))\s+(?:is|are)\s+(.+?)(?:\.|$)",
r"(?:set|configure|make)\s+(?:it to|the default to)\s+(.+?)(?:\.|$)",
]
for msg in messages:
if msg.get("role") != "user":
continue
content = msg.get("content", "")
if not isinstance(content, str):
continue
for pattern in pref_patterns:
matches = re.findall(pattern, content, re.IGNORECASE)
for match in matches:
if isinstance(match, str) and len(match) > 5 and len(match) < 200:
prefs.append({
"content": match.strip(),
"category": "user_pref",
"source": "session_compaction",
"trust": 0.8,
})
return prefs
def compact_session(
messages: List[Dict[str, Any]],
fact_store: Any = None,
keep_recent: int = 10,
) -> Tuple[List[Dict[str, Any]], int]:
"""
Compact a session by extracting facts and compressing old messages.
Args:
messages: Full conversation history
fact_store: Optional fact_store instance for saving facts
keep_recent: Number of recent messages to keep uncompressed
Returns:
Tuple of (compacted_messages, facts_extracted)
"""
if len(messages) <= keep_recent * 2:
return messages, 0
# Split into old (to compress) and recent (to keep)
split_point = len(messages) - keep_recent
old_messages = messages[:split_point]
recent_messages = messages[split_point:]
# Extract facts from old messages
facts = extract_facts_from_messages(old_messages)
prefs = extract_preferences(old_messages)
all_facts = facts + prefs
# Save facts to store if available
saved_count = 0
if fact_store and all_facts:
for fact in all_facts:
try:
if hasattr(fact_store, 'store'):
fact_store.store(
content=fact["content"],
category=fact["category"],
tags=["session_compaction"],
)
saved_count += 1
elif hasattr(fact_store, 'add'):
fact_store.add(fact["content"])
saved_count += 1
except Exception:
pass # Don't let fact saving block compaction
# Create summary of old messages
summary_parts = []
if saved_count > 0:
summary_parts.append(f"[Session compacted: {saved_count} facts extracted and saved]")
# Count message types
user_msgs = sum(1 for m in old_messages if m.get("role") == "user")
asst_msgs = sum(1 for m in old_messages if m.get("role") == "assistant")
summary_parts.append(f"[Previous conversation: {user_msgs} user messages, {asst_msgs} assistant responses]")
summary = " ".join(summary_parts)
# Build compacted messages
compacted = []
# Add summary as system message
if summary:
compacted.append({
"role": "system",
"content": summary,
"_compacted": True,
})
# Add extracted facts as system context
if all_facts:
facts_text = "Known facts from previous conversation:\n"
for fact in all_facts[:20]: # Limit to 20 facts
facts_text += f"- [{fact['category']}] {fact['content']}\n"
compacted.append({
"role": "system",
"content": facts_text,
"_extracted_facts": True,
})
# Add recent messages
compacted.extend(recent_messages)
return compacted, saved_count
def should_compact(messages: List[Dict[str, Any]], max_tokens: int = 80000) -> bool:
"""
Determine if compaction is needed based on message count/length.
Simple heuristic: compact if we have many messages or very long content.
"""
if len(messages) < 50:
return False
# Estimate token count (rough: 4 chars per token)
total_chars = sum(len(str(m.get("content", ""))) for m in messages)
estimated_tokens = total_chars // 4
return estimated_tokens > max_tokens * 0.8 # Compact at 80% of limit

View File

@@ -0,0 +1,55 @@
"""
Tests for error classification (#752).
"""
import pytest
from tools.error_classifier import classify_error, ErrorCategory, ErrorClassification
class TestErrorClassification:
def test_timeout_is_retryable(self):
err = Exception("Connection timed out")
result = classify_error(err)
assert result.category == ErrorCategory.RETRYABLE
assert result.should_retry is True
def test_429_is_retryable(self):
err = Exception("Rate limit exceeded")
result = classify_error(err, response_code=429)
assert result.category == ErrorCategory.RETRYABLE
assert result.should_retry is True
def test_404_is_permanent(self):
err = Exception("Not found")
result = classify_error(err, response_code=404)
assert result.category == ErrorCategory.PERMANENT
assert result.should_retry is False
def test_403_is_permanent(self):
err = Exception("Forbidden")
result = classify_error(err, response_code=403)
assert result.category == ErrorCategory.PERMANENT
assert result.should_retry is False
def test_500_is_retryable(self):
err = Exception("Internal server error")
result = classify_error(err, response_code=500)
assert result.category == ErrorCategory.RETRYABLE
assert result.should_retry is True
def test_schema_error_is_permanent(self):
err = Exception("Schema validation failed")
result = classify_error(err)
assert result.category == ErrorCategory.PERMANENT
assert result.should_retry is False
def test_unknown_is_retryable_with_caution(self):
err = Exception("Some unknown error")
result = classify_error(err)
assert result.category == ErrorCategory.UNKNOWN
assert result.should_retry is True
assert result.max_retries == 1
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -1,84 +0,0 @@
"""Tests for session compaction with fact extraction (#748)."""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from agent.session_compaction import (
extract_facts_from_messages,
extract_preferences,
compact_session,
should_compact,
)
def test_extract_preferences():
msgs = [
{"role": "user", "content": "I prefer using Python for this"},
{"role": "assistant", "content": "OK"},
{"role": "user", "content": "Always use tabs, not spaces"},
]
prefs = extract_preferences(msgs)
assert len(prefs) >= 1
def test_extract_facts():
msgs = [
{"role": "user", "content": "The server runs on port 8080"},
{"role": "user", "content": "Actually, the port is 8081"},
{"role": "user", "content": "Hello"}, # Too short, should be skipped
]
facts = extract_facts_from_messages(msgs)
assert len(facts) >= 1
assert any("technical" in f["category"] for f in facts)
def test_extract_deduplicates():
msgs = [
{"role": "user", "content": "I prefer Python"},
{"role": "user", "content": "I prefer Python"},
]
facts = extract_facts_from_messages(msgs)
assert len(facts) == 1
def test_compact_session():
messages = []
for i in range(30):
messages.append({"role": "user", "content": f"Message {i}: I prefer Python for server {i}"})
messages.append({"role": "assistant", "content": f"Response {i}"})
compacted, count = compact_session(messages, keep_recent=10)
assert len(compacted) < len(messages)
assert count >= 0
def test_compact_keeps_recent():
messages = []
for i in range(30):
messages.append({"role": "user", "content": f"Message {i}"})
messages.append({"role": "assistant", "content": f"Response {i}"})
compacted, _ = compact_session(messages, keep_recent=10)
# Should have summary + facts + 10 recent
assert len(compacted) >= 10
def test_should_compact_short():
messages = [{"role": "user", "content": "hi"} for _ in range(10)]
assert not should_compact(messages)
def test_should_compact_long():
messages = [{"role": "user", "content": "x" * 1000} for _ in range(100)]
assert should_compact(messages)
if __name__ == "__main__":
tests = [test_extract_preferences, test_extract_facts, test_extract_deduplicates,
test_compact_session, test_compact_keeps_recent, test_should_compact_short, test_should_compact_long]
for t in tests:
print(f"Running {t.__name__}...")
t()
print(" PASS")
print("\nAll tests passed.")

233
tools/error_classifier.py Normal file
View File

@@ -0,0 +1,233 @@
"""
Tool Error Classification — Retryable vs Permanent.
Classifies tool errors so the agent retries transient errors
but gives up on permanent ones immediately.
"""
import logging
import re
import time
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Dict, Any
logger = logging.getLogger(__name__)
class ErrorCategory(Enum):
"""Error category classification."""
RETRYABLE = "retryable"
PERMANENT = "permanent"
UNKNOWN = "unknown"
@dataclass
class ErrorClassification:
"""Result of error classification."""
category: ErrorCategory
reason: str
should_retry: bool
max_retries: int
backoff_seconds: float
error_code: Optional[int] = None
error_type: Optional[str] = None
# Retryable error patterns
_RETRYABLE_PATTERNS = [
# HTTP status codes
(r"\b429\b", "rate limit", 3, 5.0),
(r"\b500\b", "server error", 3, 2.0),
(r"\b502\b", "bad gateway", 3, 2.0),
(r"\b503\b", "service unavailable", 3, 5.0),
(r"\b504\b", "gateway timeout", 3, 5.0),
# Timeout patterns
(r"timeout", "timeout", 3, 2.0),
(r"timed out", "timeout", 3, 2.0),
(r"TimeoutExpired", "timeout", 3, 2.0),
# Connection errors
(r"connection refused", "connection refused", 2, 5.0),
(r"connection reset", "connection reset", 2, 2.0),
(r"network unreachable", "network unreachable", 2, 10.0),
(r"DNS", "DNS error", 2, 5.0),
# Transient errors
(r"temporary", "temporary error", 2, 2.0),
(r"transient", "transient error", 2, 2.0),
(r"retry", "retryable", 2, 2.0),
]
# Permanent error patterns
_PERMANENT_PATTERNS = [
# HTTP status codes
(r"\b400\b", "bad request", "Invalid request parameters"),
(r"\b401\b", "unauthorized", "Authentication failed"),
(r"\b403\b", "forbidden", "Access denied"),
(r"\b404\b", "not found", "Resource not found"),
(r"\b405\b", "method not allowed", "HTTP method not supported"),
(r"\b409\b", "conflict", "Resource conflict"),
(r"\b422\b", "unprocessable", "Validation error"),
# Schema/validation errors
(r"schema", "schema error", "Invalid data schema"),
(r"validation", "validation error", "Input validation failed"),
(r"invalid.*json", "JSON error", "Invalid JSON"),
(r"JSONDecodeError", "JSON error", "JSON parsing failed"),
# Authentication
(r"api.?key", "API key error", "Invalid or missing API key"),
(r"token.*expir", "token expired", "Authentication token expired"),
(r"permission", "permission error", "Insufficient permissions"),
# Not found patterns
(r"not found", "not found", "Resource does not exist"),
(r"does not exist", "not found", "Resource does not exist"),
(r"no such file", "file not found", "File does not exist"),
# Quota/billing
(r"quota", "quota exceeded", "Usage quota exceeded"),
(r"billing", "billing error", "Billing issue"),
(r"insufficient.*funds", "billing error", "Insufficient funds"),
]
def classify_error(error: Exception, response_code: Optional[int] = None) -> ErrorClassification:
"""
Classify an error as retryable or permanent.
Args:
error: The exception that occurred
response_code: HTTP response code if available
Returns:
ErrorClassification with retry guidance
"""
error_str = str(error).lower()
error_type = type(error).__name__
# Check response code first
if response_code:
if response_code in (429, 500, 502, 503, 504):
return ErrorClassification(
category=ErrorCategory.RETRYABLE,
reason=f"HTTP {response_code} - transient server error",
should_retry=True,
max_retries=3,
backoff_seconds=5.0 if response_code == 429 else 2.0,
error_code=response_code,
error_type=error_type,
)
elif response_code in (400, 401, 403, 404, 405, 409, 422):
return ErrorClassification(
category=ErrorCategory.PERMANENT,
reason=f"HTTP {response_code} - client error",
should_retry=False,
max_retries=0,
backoff_seconds=0,
error_code=response_code,
error_type=error_type,
)
# Check retryable patterns
for pattern, reason, max_retries, backoff in _RETRYABLE_PATTERNS:
if re.search(pattern, error_str, re.IGNORECASE):
return ErrorClassification(
category=ErrorCategory.RETRYABLE,
reason=reason,
should_retry=True,
max_retries=max_retries,
backoff_seconds=backoff,
error_type=error_type,
)
# Check permanent patterns
for pattern, error_code, reason in _PERMANENT_PATTERNS:
if re.search(pattern, error_str, re.IGNORECASE):
return ErrorClassification(
category=ErrorCategory.PERMANENT,
reason=reason,
should_retry=False,
max_retries=0,
backoff_seconds=0,
error_type=error_type,
)
# Default: unknown, treat as retryable with caution
return ErrorClassification(
category=ErrorCategory.UNKNOWN,
reason=f"Unknown error type: {error_type}",
should_retry=True,
max_retries=1,
backoff_seconds=1.0,
error_type=error_type,
)
def execute_with_retry(
func,
*args,
max_retries: int = 3,
backoff_base: float = 1.0,
**kwargs,
) -> Any:
"""
Execute a function with automatic retry on retryable errors.
Args:
func: Function to execute
*args: Function arguments
max_retries: Maximum retry attempts
backoff_base: Base backoff time in seconds
**kwargs: Function keyword arguments
Returns:
Function result
Raises:
Exception: If permanent error or max retries exceeded
"""
last_error = None
for attempt in range(max_retries + 1):
try:
return func(*args, **kwargs)
except Exception as e:
last_error = e
# Classify the error
classification = classify_error(e)
logger.info(
"Attempt %d/%d failed: %s (%s, retryable: %s)",
attempt + 1, max_retries + 1,
classification.reason,
classification.category.value,
classification.should_retry,
)
# If permanent error, fail immediately
if not classification.should_retry:
logger.error("Permanent error: %s", classification.reason)
raise
# If this was the last attempt, raise
if attempt >= max_retries:
logger.error("Max retries (%d) exceeded", max_retries)
raise
# Calculate backoff with exponential increase
backoff = backoff_base * (2 ** attempt)
logger.info("Retrying in %.1fs...", backoff)
time.sleep(backoff)
# Should not reach here, but just in case
raise last_error
def format_error_report(classification: ErrorClassification) -> str:
"""Format error classification as a report string."""
icon = "🔄" if classification.should_retry else ""
return f"{icon} {classification.category.value}: {classification.reason}"