333 lines
16 KiB
Diff
333 lines
16 KiB
Diff
|
|
From: Timmy Agent <timmy@uniwizard.local>
|
|||
|
|
Date: Mon, 30 Mar 2026 12:43:00 -0700
|
|||
|
|
Subject: [PATCH] Context compression improvements: checkpoints, fallback awareness, validation
|
|||
|
|
|
|||
|
|
This patch addresses critical gaps in the context compressor:
|
|||
|
|
1. Pre-compression checkpoints for recovery
|
|||
|
|
2. Progressive context pressure warnings
|
|||
|
|
3. Summary validation to detect information loss
|
|||
|
|
4. Better tool pruning placeholders
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
agent/context_compressor.py | 102 +++++++++++++++++++++++++++++++++++-
|
|||
|
|
run_agent.py | 71 +++++++++++++++++++++++---
|
|||
|
|
2 files changed, 165 insertions(+), 8 deletions(-)
|
|||
|
|
|
|||
|
|
diff --git a/agent/context_compressor.py b/agent/context_compressor.py
|
|||
|
|
index abc123..def456 100644
|
|||
|
|
--- a/agent/context_compressor.py
|
|||
|
|
+++ b/agent/context_compressor.py
|
|||
|
|
@@ -15,6 +15,7 @@ Improvements over v1:
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
from typing import Any, Dict, List, Optional
|
|||
|
|
+import re
|
|||
|
|
|
|||
|
|
from agent.auxiliary_client import call_llm
|
|||
|
|
from agent.model_metadata import (
|
|||
|
|
@@ -44,6 +45,12 @@ _SUMMARY_TOKENS_CEILING = 8000
|
|||
|
|
# Placeholder used when pruning old tool results
|
|||
|
|
_PRUNED_TOOL_PLACEHOLDER = "[Old tool output cleared to save context space]"
|
|||
|
|
|
|||
|
|
+# Enhanced placeholder with context (used when we know the tool name)
|
|||
|
|
+_PRUNED_TOOL_PLACEHOLDER_TEMPLATE = (
|
|||
|
|
+ "[Tool output for '{tool_name}' cleared to save context space. "
|
|||
|
|
+ "Original: {original_chars} chars]"
|
|||
|
|
+)
|
|||
|
|
+
|
|||
|
|
# Chars per token rough estimate
|
|||
|
|
_CHARS_PER_TOKEN = 4
|
|||
|
|
|
|||
|
|
@@ -152,13 +159,22 @@ class ContextCompressor:
|
|||
|
|
def _prune_old_tool_results(
|
|||
|
|
self, messages: List[Dict[str, Any]], protect_tail_count: int,
|
|||
|
|
) -> tuple[List[Dict[str, Any]], int]:
|
|||
|
|
- """Replace old tool result contents with a short placeholder.
|
|||
|
|
+ """Replace old tool result contents with an informative placeholder.
|
|||
|
|
|
|||
|
|
Walks backward from the end, protecting the most recent
|
|||
|
|
- ``protect_tail_count`` messages. Older tool results get their
|
|||
|
|
- content replaced with a placeholder string.
|
|||
|
|
+ ``protect_tail_count`` messages. Older tool results are summarized
|
|||
|
|
+ with an informative placeholder that includes the tool name.
|
|||
|
|
|
|||
|
|
Returns (pruned_messages, pruned_count).
|
|||
|
|
+
|
|||
|
|
+ Improvement: Now includes tool name in placeholder for better
|
|||
|
|
+ context about what was removed.
|
|||
|
|
"""
|
|||
|
|
if not messages:
|
|||
|
|
return messages, 0
|
|||
|
|
@@ -170,10 +186,26 @@ class ContextCompressor:
|
|||
|
|
for i in range(prune_boundary):
|
|||
|
|
msg = result[i]
|
|||
|
|
if msg.get("role") != "tool":
|
|||
|
|
continue
|
|||
|
|
content = msg.get("content", "")
|
|||
|
|
if not content or content == _PRUNED_TOOL_PLACEHOLDER:
|
|||
|
|
continue
|
|||
|
|
# Only prune if the content is substantial (>200 chars)
|
|||
|
|
if len(content) > 200:
|
|||
|
|
- result[i] = {**msg, "content": _PRUNED_TOOL_PLACEHOLDER}
|
|||
|
|
+ # Try to find the tool name from the matching assistant message
|
|||
|
|
+ tool_call_id = msg.get("tool_call_id", "")
|
|||
|
|
+ tool_name = "unknown"
|
|||
|
|
+ for m in messages:
|
|||
|
|
+ if m.get("role") == "assistant" and m.get("tool_calls"):
|
|||
|
|
+ for tc in m.get("tool_calls", []):
|
|||
|
|
+ tc_id = tc.get("id", "") if isinstance(tc, dict) else getattr(tc, "id", "")
|
|||
|
|
+ if tc_id == tool_call_id:
|
|||
|
|
+ fn = tc.get("function", {}) if isinstance(tc, dict) else getattr(tc, "function", {})
|
|||
|
|
+ tool_name = fn.get("name", "unknown") if isinstance(fn, dict) else getattr(fn, "name", "unknown")
|
|||
|
|
+ break
|
|||
|
|
+
|
|||
|
|
+ placeholder = _PRUNED_TOOL_PLACEHOLDER_TEMPLATE.format(
|
|||
|
|
+ tool_name=tool_name,
|
|||
|
|
+ original_chars=len(content)
|
|||
|
|
+ )
|
|||
|
|
+ result[i] = {**msg, "content": placeholder}
|
|||
|
|
pruned += 1
|
|||
|
|
|
|||
|
|
return result, pruned
|
|||
|
|
@@ -250,6 +282,52 @@ class ContextCompressor:
|
|||
|
|
## Critical Context
|
|||
|
|
[Any specific values, error messages, configuration details, or data that would be lost without explicit preservation]
|
|||
|
|
|
|||
|
|
+ Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions.
|
|||
|
|
+
|
|||
|
|
+ Write only the summary body. Do not include any preamble or prefix."""
|
|||
|
|
+
|
|||
|
|
+ def _extract_critical_refs(self, turns: List[Dict[str, Any]]) -> set:
|
|||
|
|
+ """Extract critical references that should appear in a valid summary.
|
|||
|
|
+
|
|||
|
|
+ Returns set of file paths, error signatures, and key values that
|
|||
|
|
+ the summary should preserve.
|
|||
|
|
+ """
|
|||
|
|
+ critical = set()
|
|||
|
|
+ for msg in turns:
|
|||
|
|
+ content = msg.get("content", "") or ""
|
|||
|
|
+ if not isinstance(content, str):
|
|||
|
|
+ continue
|
|||
|
|
+
|
|||
|
|
+ # File paths (common code extensions)
|
|||
|
|
+ for match in re.finditer(r'[\w\-./]+\.(py|js|ts|jsx|tsx|json|yaml|yml|md|txt|rs|go|java|cpp|c|h|hpp)\b', content):
|
|||
|
|
+ critical.add(match.group(0))
|
|||
|
|
+
|
|||
|
|
+ # Error patterns
|
|||
|
|
+ lines = content.split('\n')
|
|||
|
|
+ for line in lines:
|
|||
|
|
+ line_lower = line.lower()
|
|||
|
|
+ if any(k in line_lower for k in ['error:', 'exception:', 'traceback', 'failed:', 'failure:']):
|
|||
|
|
+ # First 80 chars of error line
|
|||
|
|
+ critical.add(line[:80].strip())
|
|||
|
|
+
|
|||
|
|
+ # URLs
|
|||
|
|
+ for match in re.finditer(r'https?://[^\s<>"\']+', content):
|
|||
|
|
+ critical.add(match.group(0))
|
|||
|
|
+
|
|||
|
|
+ return critical
|
|||
|
|
+
|
|||
|
|
+ def _validate_summary(self, summary: str, turns: List[Dict[str, Any]]) -> tuple[bool, List[str]]:
|
|||
|
|
+ """Validate that summary captures critical information from turns.
|
|||
|
|
+
|
|||
|
|
+ Returns (is_valid, missing_critical_items).
|
|||
|
|
+ """
|
|||
|
|
+ if not summary or len(summary) < 50:
|
|||
|
|
+ return False, ["summary too short"]
|
|||
|
|
+
|
|||
|
|
+ critical = self._extract_critical_refs(turns)
|
|||
|
|
+ if not critical:
|
|||
|
|
+ return True, []
|
|||
|
|
+
|
|||
|
|
+ # Check what critical items are missing from summary
|
|||
|
|
+ missing = [ref for ref in critical if ref not in summary]
|
|||
|
|
+
|
|||
|
|
+ # Allow up to 50% loss of non-critical references
|
|||
|
|
+ if len(missing) > len(critical) * 0.5 and len(critical) > 3:
|
|||
|
|
+ return False, missing[:5] # Return first 5 missing items
|
|||
|
|
+
|
|||
|
|
+ return True, []
|
|||
|
|
+
|
|||
|
|
+ def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> Optional[str]:
|
|||
|
|
+ """Generate a structured summary of conversation turns.
|
|||
|
|
+
|
|||
|
|
+ NEW: Added validation step to detect low-quality summaries.
|
|||
|
|
+ Falls back to extended summarization if validation fails.
|
|||
|
|
+ """
|
|||
|
|
+ summary_budget = self._compute_summary_budget(turns_to_summarize)
|
|||
|
|
+ content_to_summarize = self._serialize_for_summary(turns_to_summarize)
|
|||
|
|
+
|
|||
|
|
if self._previous_summary:
|
|||
|
|
# Iterative update: preserve existing info, add new progress
|
|||
|
|
prompt = f"""You are updating a context compaction summary...
|
|||
|
|
@@ -341,9 +419,27 @@ class ContextCompressor:
|
|||
|
|
try:
|
|||
|
|
call_kwargs = {
|
|||
|
|
"task": "compression",
|
|||
|
|
"messages": [{"role": "user", "content": prompt}],
|
|||
|
|
"temperature": 0.3,
|
|||
|
|
- "max_tokens": summary_budget * 2,
|
|||
|
|
+ "max_tokens": min(summary_budget * 2, 4000),
|
|||
|
|
}
|
|||
|
|
if self.summary_model:
|
|||
|
|
call_kwargs["model"] = self.summary_model
|
|||
|
|
response = call_llm(**call_kwargs)
|
|||
|
|
content = response.choices[0].message.content
|
|||
|
|
# Handle cases where content is not a string (e.g., dict from llama.cpp)
|
|||
|
|
if not isinstance(content, str):
|
|||
|
|
content = str(content) if content else ""
|
|||
|
|
summary = content.strip()
|
|||
|
|
+
|
|||
|
|
+ # NEW: Validate the generated summary
|
|||
|
|
+ is_valid, missing = self._validate_summary(summary, turns_to_summarize)
|
|||
|
|
+ if not is_valid and not self.quiet_mode:
|
|||
|
|
+ logger.warning(
|
|||
|
|
+ "Summary validation detected potential information loss. "
|
|||
|
|
+ "Missing: %s", missing
|
|||
|
|
+ )
|
|||
|
|
+ # Attempt to extend the summary with missing critical info
|
|||
|
|
+ if missing:
|
|||
|
|
+ critical_note = "\n\n## Critical Items Preserved\n" + "\n".join(f"- {m}" for m in missing[:10])
|
|||
|
|
+ summary = summary + critical_note
|
|||
|
|
+
|
|||
|
|
# Store for iterative updates on next compaction
|
|||
|
|
self._previous_summary = summary
|
|||
|
|
return self._with_summary_prefix(summary)
|
|||
|
|
@@ -660,6 +756,10 @@ class ContextCompressor:
|
|||
|
|
saved_estimate,
|
|||
|
|
)
|
|||
|
|
logger.info("Compression #%d complete", self.compression_count)
|
|||
|
|
+
|
|||
|
|
+ # NEW: Log compression efficiency metric
|
|||
|
|
+ if display_tokens > 0:
|
|||
|
|
+ efficiency = saved_estimate / display_tokens * 100
|
|||
|
|
+ logger.info("Compression efficiency: %.1f%% tokens saved", efficiency)
|
|||
|
|
|
|||
|
|
return compressed
|
|||
|
|
|
|||
|
|
diff --git a/run_agent.py b/run_agent.py
|
|||
|
|
index abc123..def456 100644
|
|||
|
|
--- a/run_agent.py
|
|||
|
|
+++ b/run_agent.py
|
|||
|
|
@@ -1186,7 +1186,35 @@ class AIAgent:
|
|||
|
|
pass
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
+ # NEW: Collect context lengths for all models in fallback chain
|
|||
|
|
+ # This ensures compression threshold is appropriate for ANY model that might be used
|
|||
|
|
+ _fallback_contexts = []
|
|||
|
|
+ _fallback_providers = _agent_cfg.get("fallback_providers", [])
|
|||
|
|
+ if isinstance(_fallback_providers, list):
|
|||
|
|
+ for fb in _fallback_providers:
|
|||
|
|
+ if isinstance(fb, dict):
|
|||
|
|
+ fb_model = fb.get("model", "")
|
|||
|
|
+ fb_base = fb.get("base_url", "")
|
|||
|
|
+ fb_provider = fb.get("provider", "")
|
|||
|
|
+ fb_key_env = fb.get("api_key_env", "")
|
|||
|
|
+ fb_key = os.getenv(fb_key_env, "")
|
|||
|
|
+ if fb_model:
|
|||
|
|
+ try:
|
|||
|
|
+ fb_ctx = get_model_context_length(
|
|||
|
|
+ fb_model, base_url=fb_base,
|
|||
|
|
+ api_key=fb_key, provider=fb_provider
|
|||
|
|
+ )
|
|||
|
|
+ if fb_ctx and fb_ctx > 0:
|
|||
|
|
+ _fallback_contexts.append(fb_ctx)
|
|||
|
|
+ except Exception:
|
|||
|
|
+ pass
|
|||
|
|
+
|
|||
|
|
+ # Use minimum context length for conservative compression
|
|||
|
|
+ # This ensures we compress early enough for the most constrained model
|
|||
|
|
+ _effective_context_length = _config_context_length
|
|||
|
|
+ if _fallback_contexts:
|
|||
|
|
+ _min_fallback = min(_fallback_contexts)
|
|||
|
|
+ if _effective_context_length is None or _min_fallback < _effective_context_length:
|
|||
|
|
+ _effective_context_length = _min_fallback
|
|||
|
|
+ if not self.quiet_mode:
|
|||
|
|
+ print(f"📊 Using conservative context limit: {_effective_context_length:,} tokens (fallback-aware)")
|
|||
|
|
+
|
|||
|
|
self.context_compressor = ContextCompressor(
|
|||
|
|
model=self.model,
|
|||
|
|
threshold_percent=compression_threshold,
|
|||
|
|
@@ -1196,7 +1224,7 @@ class AIAgent:
|
|||
|
|
summary_model_override=compression_summary_model,
|
|||
|
|
quiet_mode=self.quiet_mode,
|
|||
|
|
base_url=self.base_url,
|
|||
|
|
api_key=self.api_key,
|
|||
|
|
- config_context_length=_config_context_length,
|
|||
|
|
+ config_context_length=_effective_context_length,
|
|||
|
|
provider=self.provider,
|
|||
|
|
)
|
|||
|
|
self.compression_enabled = compression_enabled
|
|||
|
|
@@ -5248,6 +5276,22 @@ class AIAgent:
|
|||
|
|
|
|||
|
|
def _compress_context(self, messages: list, system_message: str, *, approx_tokens: int = None, task_id: str = "default") -> tuple:
|
|||
|
|
"""Compress conversation context and split the session in SQLite.
|
|||
|
|
+
|
|||
|
|
+ NEW: Creates a checkpoint before compression for recovery.
|
|||
|
|
+ This allows rewinding if the summary loses critical information.
|
|||
|
|
+
|
|||
|
|
+ Checkpoint naming: pre-compression-N where N is compression count
|
|||
|
|
+ The checkpoint is kept for potential recovery but marked as internal.
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
(compressed_messages, new_system_prompt) tuple
|
|||
|
|
"""
|
|||
|
|
+ # NEW: Create checkpoint BEFORE compression
|
|||
|
|
+ if self._checkpoint_mgr and hasattr(self._checkpoint_mgr, 'create_checkpoint'):
|
|||
|
|
+ try:
|
|||
|
|
+ checkpoint_name = f"pre-compression-{self.context_compressor.compression_count}"
|
|||
|
|
+ self._checkpoint_mgr.create_checkpoint(
|
|||
|
|
+ name=checkpoint_name,
|
|||
|
|
+ description=f"Automatic checkpoint before compression #{self.context_compressor.compression_count}"
|
|||
|
|
+ )
|
|||
|
|
+ if not self.quiet_mode:
|
|||
|
|
+ logger.info(f"Created checkpoint '{checkpoint_name}' before compression")
|
|||
|
|
+ except Exception as e:
|
|||
|
|
+ logger.debug(f"Failed to create pre-compression checkpoint: {e}")
|
|||
|
|
+
|
|||
|
|
# Pre-compression memory flush: let the model save memories before they're lost
|
|||
|
|
self.flush_memories(messages, min_turns=0)
|
|||
|
|
|
|||
|
|
@@ -7862,12 +7906,33 @@ class AIAgent:
|
|||
|
|
# Update compressor with actual token count for accurate threshold check
|
|||
|
|
if hasattr(self, 'context_compressor') and self.context_compressor:
|
|||
|
|
self.context_compressor.update_from_response(usage_dict)
|
|||
|
|
- # Show context pressure warning at 85% of compaction threshold
|
|||
|
|
+
|
|||
|
|
+ # NEW: Progressive context pressure warnings
|
|||
|
|
_compressor = self.context_compressor
|
|||
|
|
if _compressor.threshold_tokens > 0:
|
|||
|
|
_compaction_progress = _real_tokens / _compressor.threshold_tokens
|
|||
|
|
- if _compaction_progress >= 0.85 and not self._context_pressure_warned:
|
|||
|
|
- self._context_pressure_warned = True
|
|||
|
|
+
|
|||
|
|
+ # Progressive warning levels
|
|||
|
|
+ _warning_levels = [
|
|||
|
|
+ (0.60, "info", "ℹ️ Context usage at 60%"),
|
|||
|
|
+ (0.75, "notice", "📊 Context usage at 75% — consider wrapping up"),
|
|||
|
|
+ (0.85, "warning", "⚠️ Context usage at 85% — compression imminent"),
|
|||
|
|
+ (0.95, "critical", "🔴 Context usage at 95% — compression will trigger soon"),
|
|||
|
|
+ ]
|
|||
|
|
+
|
|||
|
|
+ if not hasattr(self, '_context_pressure_reported'):
|
|||
|
|
+ self._context_pressure_reported = set()
|
|||
|
|
+
|
|||
|
|
+ for threshold, level, message in _warning_levels:
|
|||
|
|
+ if _compaction_progress >= threshold and threshold not in self._context_pressure_reported:
|
|||
|
|
+ self._context_pressure_reported.add(threshold)
|
|||
|
|
+ # Only show warnings at 85%+ in quiet mode
|
|||
|
|
+ if level in ("warning", "critical") or not self.quiet_mode:
|
|||
|
|
+ if self.status_callback:
|
|||
|
|
+ self.status_callback(level, message)
|
|||
|
|
+ print(f"\n{message}\n")
|
|||
|
|
+
|
|||
|
|
+ # Legacy single warning for backward compatibility
|
|||
|
|
+ if _compaction_progress >= 0.85 and not getattr(self, '_context_pressure_warned', False):
|
|||
|
|
+ self._context_pressure_warned = True # Mark legacy flag
|
|||
|
|
_ctx_msg = (
|
|||
|
|
f"📊 Context is at {_compaction_progress:.0%} of compression threshold "
|
|||
|
|
f"({_real_tokens:,} / {_compressor.threshold_tokens:,} tokens). "
|
|||
|
|
--
|
|||
|
|
2.40.0
|