Files
timmy-home/uniwizard/context_compressor.patch

333 lines
16 KiB
Diff
Raw Normal View History

From: Timmy Agent <timmy@uniwizard.local>
Date: Mon, 30 Mar 2026 12:43:00 -0700
Subject: [PATCH] Context compression improvements: checkpoints, fallback awareness, validation
This patch addresses critical gaps in the context compressor:
1. Pre-compression checkpoints for recovery
2. Progressive context pressure warnings
3. Summary validation to detect information loss
4. Better tool pruning placeholders
---
agent/context_compressor.py | 102 +++++++++++++++++++++++++++++++++++-
run_agent.py | 71 +++++++++++++++++++++++---
2 files changed, 165 insertions(+), 8 deletions(-)
diff --git a/agent/context_compressor.py b/agent/context_compressor.py
index abc123..def456 100644
--- a/agent/context_compressor.py
+++ b/agent/context_compressor.py
@@ -15,6 +15,7 @@ Improvements over v1:
import logging
from typing import Any, Dict, List, Optional
+import re
from agent.auxiliary_client import call_llm
from agent.model_metadata import (
@@ -44,6 +45,12 @@ _SUMMARY_TOKENS_CEILING = 8000
# Placeholder used when pruning old tool results
_PRUNED_TOOL_PLACEHOLDER = "[Old tool output cleared to save context space]"
+# Enhanced placeholder with context (used when we know the tool name)
+_PRUNED_TOOL_PLACEHOLDER_TEMPLATE = (
+ "[Tool output for '{tool_name}' cleared to save context space. "
+ "Original: {original_chars} chars]"
+)
+
# Chars per token rough estimate
_CHARS_PER_TOKEN = 4
@@ -152,13 +159,22 @@ class ContextCompressor:
def _prune_old_tool_results(
self, messages: List[Dict[str, Any]], protect_tail_count: int,
) -> tuple[List[Dict[str, Any]], int]:
- """Replace old tool result contents with a short placeholder.
+ """Replace old tool result contents with an informative placeholder.
Walks backward from the end, protecting the most recent
- ``protect_tail_count`` messages. Older tool results get their
- content replaced with a placeholder string.
+ ``protect_tail_count`` messages. Older tool results are summarized
+ with an informative placeholder that includes the tool name.
Returns (pruned_messages, pruned_count).
+
+ Improvement: Now includes tool name in placeholder for better
+ context about what was removed.
"""
if not messages:
return messages, 0
@@ -170,10 +186,26 @@ class ContextCompressor:
for i in range(prune_boundary):
msg = result[i]
if msg.get("role") != "tool":
continue
content = msg.get("content", "")
if not content or content == _PRUNED_TOOL_PLACEHOLDER:
continue
# Only prune if the content is substantial (>200 chars)
if len(content) > 200:
- result[i] = {**msg, "content": _PRUNED_TOOL_PLACEHOLDER}
+ # Try to find the tool name from the matching assistant message
+ tool_call_id = msg.get("tool_call_id", "")
+ tool_name = "unknown"
+ for m in messages:
+ if m.get("role") == "assistant" and m.get("tool_calls"):
+ for tc in m.get("tool_calls", []):
+ tc_id = tc.get("id", "") if isinstance(tc, dict) else getattr(tc, "id", "")
+ if tc_id == tool_call_id:
+ fn = tc.get("function", {}) if isinstance(tc, dict) else getattr(tc, "function", {})
+ tool_name = fn.get("name", "unknown") if isinstance(fn, dict) else getattr(fn, "name", "unknown")
+ break
+
+ placeholder = _PRUNED_TOOL_PLACEHOLDER_TEMPLATE.format(
+ tool_name=tool_name,
+ original_chars=len(content)
+ )
+ result[i] = {**msg, "content": placeholder}
pruned += 1
return result, pruned
@@ -250,6 +282,52 @@ class ContextCompressor:
## Critical Context
[Any specific values, error messages, configuration details, or data that would be lost without explicit preservation]
+ Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions.
+
+ Write only the summary body. Do not include any preamble or prefix."""
+
+ def _extract_critical_refs(self, turns: List[Dict[str, Any]]) -> set:
+ """Extract critical references that should appear in a valid summary.
+
+ Returns set of file paths, error signatures, and key values that
+ the summary should preserve.
+ """
+ critical = set()
+ for msg in turns:
+ content = msg.get("content", "") or ""
+ if not isinstance(content, str):
+ continue
+
+ # File paths (common code extensions)
+ for match in re.finditer(r'[\w\-./]+\.(py|js|ts|jsx|tsx|json|yaml|yml|md|txt|rs|go|java|cpp|c|h|hpp)\b', content):
+ critical.add(match.group(0))
+
+ # Error patterns
+ lines = content.split('\n')
+ for line in lines:
+ line_lower = line.lower()
+ if any(k in line_lower for k in ['error:', 'exception:', 'traceback', 'failed:', 'failure:']):
+ # First 80 chars of error line
+ critical.add(line[:80].strip())
+
+ # URLs
+ for match in re.finditer(r'https?://[^\s<>"\']+', content):
+ critical.add(match.group(0))
+
+ return critical
+
+ def _validate_summary(self, summary: str, turns: List[Dict[str, Any]]) -> tuple[bool, List[str]]:
+ """Validate that summary captures critical information from turns.
+
+ Returns (is_valid, missing_critical_items).
+ """
+ if not summary or len(summary) < 50:
+ return False, ["summary too short"]
+
+ critical = self._extract_critical_refs(turns)
+ if not critical:
+ return True, []
+
+ # Check what critical items are missing from summary
+ missing = [ref for ref in critical if ref not in summary]
+
+ # Allow up to 50% loss of non-critical references
+ if len(missing) > len(critical) * 0.5 and len(critical) > 3:
+ return False, missing[:5] # Return first 5 missing items
+
+ return True, []
+
+ def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> Optional[str]:
+ """Generate a structured summary of conversation turns.
+
+ NEW: Added validation step to detect low-quality summaries.
+ Falls back to extended summarization if validation fails.
+ """
+ summary_budget = self._compute_summary_budget(turns_to_summarize)
+ content_to_summarize = self._serialize_for_summary(turns_to_summarize)
+
if self._previous_summary:
# Iterative update: preserve existing info, add new progress
prompt = f"""You are updating a context compaction summary...
@@ -341,9 +419,27 @@ class ContextCompressor:
try:
call_kwargs = {
"task": "compression",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.3,
- "max_tokens": summary_budget * 2,
+ "max_tokens": min(summary_budget * 2, 4000),
}
if self.summary_model:
call_kwargs["model"] = self.summary_model
response = call_llm(**call_kwargs)
content = response.choices[0].message.content
# Handle cases where content is not a string (e.g., dict from llama.cpp)
if not isinstance(content, str):
content = str(content) if content else ""
summary = content.strip()
+
+ # NEW: Validate the generated summary
+ is_valid, missing = self._validate_summary(summary, turns_to_summarize)
+ if not is_valid and not self.quiet_mode:
+ logger.warning(
+ "Summary validation detected potential information loss. "
+ "Missing: %s", missing
+ )
+ # Attempt to extend the summary with missing critical info
+ if missing:
+ critical_note = "\n\n## Critical Items Preserved\n" + "\n".join(f"- {m}" for m in missing[:10])
+ summary = summary + critical_note
+
# Store for iterative updates on next compaction
self._previous_summary = summary
return self._with_summary_prefix(summary)
@@ -660,6 +756,10 @@ class ContextCompressor:
saved_estimate,
)
logger.info("Compression #%d complete", self.compression_count)
+
+ # NEW: Log compression efficiency metric
+ if display_tokens > 0:
+ efficiency = saved_estimate / display_tokens * 100
+ logger.info("Compression efficiency: %.1f%% tokens saved", efficiency)
return compressed
diff --git a/run_agent.py b/run_agent.py
index abc123..def456 100644
--- a/run_agent.py
+++ b/run_agent.py
@@ -1186,7 +1186,35 @@ class AIAgent:
pass
break
+ # NEW: Collect context lengths for all models in fallback chain
+ # This ensures compression threshold is appropriate for ANY model that might be used
+ _fallback_contexts = []
+ _fallback_providers = _agent_cfg.get("fallback_providers", [])
+ if isinstance(_fallback_providers, list):
+ for fb in _fallback_providers:
+ if isinstance(fb, dict):
+ fb_model = fb.get("model", "")
+ fb_base = fb.get("base_url", "")
+ fb_provider = fb.get("provider", "")
+ fb_key_env = fb.get("api_key_env", "")
+ fb_key = os.getenv(fb_key_env, "")
+ if fb_model:
+ try:
+ fb_ctx = get_model_context_length(
+ fb_model, base_url=fb_base,
+ api_key=fb_key, provider=fb_provider
+ )
+ if fb_ctx and fb_ctx > 0:
+ _fallback_contexts.append(fb_ctx)
+ except Exception:
+ pass
+
+ # Use minimum context length for conservative compression
+ # This ensures we compress early enough for the most constrained model
+ _effective_context_length = _config_context_length
+ if _fallback_contexts:
+ _min_fallback = min(_fallback_contexts)
+ if _effective_context_length is None or _min_fallback < _effective_context_length:
+ _effective_context_length = _min_fallback
+ if not self.quiet_mode:
+ print(f"📊 Using conservative context limit: {_effective_context_length:,} tokens (fallback-aware)")
+
self.context_compressor = ContextCompressor(
model=self.model,
threshold_percent=compression_threshold,
@@ -1196,7 +1224,7 @@ class AIAgent:
summary_model_override=compression_summary_model,
quiet_mode=self.quiet_mode,
base_url=self.base_url,
api_key=self.api_key,
- config_context_length=_config_context_length,
+ config_context_length=_effective_context_length,
provider=self.provider,
)
self.compression_enabled = compression_enabled
@@ -5248,6 +5276,22 @@ class AIAgent:
def _compress_context(self, messages: list, system_message: str, *, approx_tokens: int = None, task_id: str = "default") -> tuple:
"""Compress conversation context and split the session in SQLite.
+
+ NEW: Creates a checkpoint before compression for recovery.
+ This allows rewinding if the summary loses critical information.
+
+ Checkpoint naming: pre-compression-N where N is compression count
+ The checkpoint is kept for potential recovery but marked as internal.
Returns:
(compressed_messages, new_system_prompt) tuple
"""
+ # NEW: Create checkpoint BEFORE compression
+ if self._checkpoint_mgr and hasattr(self._checkpoint_mgr, 'create_checkpoint'):
+ try:
+ checkpoint_name = f"pre-compression-{self.context_compressor.compression_count}"
+ self._checkpoint_mgr.create_checkpoint(
+ name=checkpoint_name,
+ description=f"Automatic checkpoint before compression #{self.context_compressor.compression_count}"
+ )
+ if not self.quiet_mode:
+ logger.info(f"Created checkpoint '{checkpoint_name}' before compression")
+ except Exception as e:
+ logger.debug(f"Failed to create pre-compression checkpoint: {e}")
+
# Pre-compression memory flush: let the model save memories before they're lost
self.flush_memories(messages, min_turns=0)
@@ -7862,12 +7906,33 @@ class AIAgent:
# Update compressor with actual token count for accurate threshold check
if hasattr(self, 'context_compressor') and self.context_compressor:
self.context_compressor.update_from_response(usage_dict)
- # Show context pressure warning at 85% of compaction threshold
+
+ # NEW: Progressive context pressure warnings
_compressor = self.context_compressor
if _compressor.threshold_tokens > 0:
_compaction_progress = _real_tokens / _compressor.threshold_tokens
- if _compaction_progress >= 0.85 and not self._context_pressure_warned:
- self._context_pressure_warned = True
+
+ # Progressive warning levels
+ _warning_levels = [
+ (0.60, "info", " Context usage at 60%"),
+ (0.75, "notice", "📊 Context usage at 75% — consider wrapping up"),
+ (0.85, "warning", "⚠️ Context usage at 85% — compression imminent"),
+ (0.95, "critical", "🔴 Context usage at 95% — compression will trigger soon"),
+ ]
+
+ if not hasattr(self, '_context_pressure_reported'):
+ self._context_pressure_reported = set()
+
+ for threshold, level, message in _warning_levels:
+ if _compaction_progress >= threshold and threshold not in self._context_pressure_reported:
+ self._context_pressure_reported.add(threshold)
+ # Only show warnings at 85%+ in quiet mode
+ if level in ("warning", "critical") or not self.quiet_mode:
+ if self.status_callback:
+ self.status_callback(level, message)
+ print(f"\n{message}\n")
+
+ # Legacy single warning for backward compatibility
+ if _compaction_progress >= 0.85 and not getattr(self, '_context_pressure_warned', False):
+ self._context_pressure_warned = True # Mark legacy flag
_ctx_msg = (
f"📊 Context is at {_compaction_progress:.0%} of compression threshold "
f"({_real_tokens:,} / {_compressor.threshold_tokens:,} tokens). "
--
2.40.0