Co-authored-by: Kimi Claw <kimi@timmytime.ai> Co-committed-by: Kimi Claw <kimi@timmytime.ai>
333 lines
16 KiB
Diff
333 lines
16 KiB
Diff
From: Timmy Agent <timmy@uniwizard.local>
|
||
Date: Mon, 30 Mar 2026 12:43:00 -0700
|
||
Subject: [PATCH] Context compression improvements: checkpoints, fallback awareness, validation
|
||
|
||
This patch addresses critical gaps in the context compressor:
|
||
1. Pre-compression checkpoints for recovery
|
||
2. Progressive context pressure warnings
|
||
3. Summary validation to detect information loss
|
||
4. Better tool pruning placeholders
|
||
|
||
---
|
||
agent/context_compressor.py | 102 +++++++++++++++++++++++++++++++++++-
|
||
run_agent.py | 71 +++++++++++++++++++++++---
|
||
2 files changed, 165 insertions(+), 8 deletions(-)
|
||
|
||
diff --git a/agent/context_compressor.py b/agent/context_compressor.py
|
||
index abc123..def456 100644
|
||
--- a/agent/context_compressor.py
|
||
+++ b/agent/context_compressor.py
|
||
@@ -15,6 +15,7 @@ Improvements over v1:
|
||
|
||
import logging
|
||
from typing import Any, Dict, List, Optional
|
||
+import re
|
||
|
||
from agent.auxiliary_client import call_llm
|
||
from agent.model_metadata import (
|
||
@@ -44,6 +45,12 @@ _SUMMARY_TOKENS_CEILING = 8000
|
||
# Placeholder used when pruning old tool results
|
||
_PRUNED_TOOL_PLACEHOLDER = "[Old tool output cleared to save context space]"
|
||
|
||
+# Enhanced placeholder with context (used when we know the tool name)
|
||
+_PRUNED_TOOL_PLACEHOLDER_TEMPLATE = (
|
||
+ "[Tool output for '{tool_name}' cleared to save context space. "
|
||
+ "Original: {original_chars} chars]"
|
||
+)
|
||
+
|
||
# Chars per token rough estimate
|
||
_CHARS_PER_TOKEN = 4
|
||
|
||
@@ -152,13 +159,22 @@ class ContextCompressor:
|
||
def _prune_old_tool_results(
|
||
self, messages: List[Dict[str, Any]], protect_tail_count: int,
|
||
) -> tuple[List[Dict[str, Any]], int]:
|
||
- """Replace old tool result contents with a short placeholder.
|
||
+ """Replace old tool result contents with an informative placeholder.
|
||
|
||
Walks backward from the end, protecting the most recent
|
||
- ``protect_tail_count`` messages. Older tool results get their
|
||
- content replaced with a placeholder string.
|
||
+ ``protect_tail_count`` messages. Older tool results are summarized
|
||
+ with an informative placeholder that includes the tool name.
|
||
|
||
Returns (pruned_messages, pruned_count).
|
||
+
|
||
+ Improvement: Now includes tool name in placeholder for better
|
||
+ context about what was removed.
|
||
"""
|
||
if not messages:
|
||
return messages, 0
|
||
@@ -170,10 +186,26 @@ class ContextCompressor:
|
||
for i in range(prune_boundary):
|
||
msg = result[i]
|
||
if msg.get("role") != "tool":
|
||
continue
|
||
content = msg.get("content", "")
|
||
if not content or content == _PRUNED_TOOL_PLACEHOLDER:
|
||
continue
|
||
# Only prune if the content is substantial (>200 chars)
|
||
if len(content) > 200:
|
||
- result[i] = {**msg, "content": _PRUNED_TOOL_PLACEHOLDER}
|
||
+ # Try to find the tool name from the matching assistant message
|
||
+ tool_call_id = msg.get("tool_call_id", "")
|
||
+ tool_name = "unknown"
|
||
+ for m in messages:
|
||
+ if m.get("role") == "assistant" and m.get("tool_calls"):
|
||
+ for tc in m.get("tool_calls", []):
|
||
+ tc_id = tc.get("id", "") if isinstance(tc, dict) else getattr(tc, "id", "")
|
||
+ if tc_id == tool_call_id:
|
||
+ fn = tc.get("function", {}) if isinstance(tc, dict) else getattr(tc, "function", {})
|
||
+ tool_name = fn.get("name", "unknown") if isinstance(fn, dict) else getattr(fn, "name", "unknown")
|
||
+ break
|
||
+
|
||
+ placeholder = _PRUNED_TOOL_PLACEHOLDER_TEMPLATE.format(
|
||
+ tool_name=tool_name,
|
||
+ original_chars=len(content)
|
||
+ )
|
||
+ result[i] = {**msg, "content": placeholder}
|
||
pruned += 1
|
||
|
||
return result, pruned
|
||
@@ -250,6 +282,52 @@ class ContextCompressor:
|
||
## Critical Context
|
||
[Any specific values, error messages, configuration details, or data that would be lost without explicit preservation]
|
||
|
||
+ Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions.
|
||
+
|
||
+ Write only the summary body. Do not include any preamble or prefix."""
|
||
+
|
||
+ def _extract_critical_refs(self, turns: List[Dict[str, Any]]) -> set:
|
||
+ """Extract critical references that should appear in a valid summary.
|
||
+
|
||
+ Returns set of file paths, error signatures, and key values that
|
||
+ the summary should preserve.
|
||
+ """
|
||
+ critical = set()
|
||
+ for msg in turns:
|
||
+ content = msg.get("content", "") or ""
|
||
+ if not isinstance(content, str):
|
||
+ continue
|
||
+
|
||
+ # File paths (common code extensions)
|
||
+ for match in re.finditer(r'[\w\-./]+\.(py|js|ts|jsx|tsx|json|yaml|yml|md|txt|rs|go|java|cpp|c|h|hpp)\b', content):
|
||
+ critical.add(match.group(0))
|
||
+
|
||
+ # Error patterns
|
||
+ lines = content.split('\n')
|
||
+ for line in lines:
|
||
+ line_lower = line.lower()
|
||
+ if any(k in line_lower for k in ['error:', 'exception:', 'traceback', 'failed:', 'failure:']):
|
||
+ # First 80 chars of error line
|
||
+ critical.add(line[:80].strip())
|
||
+
|
||
+ # URLs
|
||
+ for match in re.finditer(r'https?://[^\s<>"\']+', content):
|
||
+ critical.add(match.group(0))
|
||
+
|
||
+ return critical
|
||
+
|
||
+ def _validate_summary(self, summary: str, turns: List[Dict[str, Any]]) -> tuple[bool, List[str]]:
|
||
+ """Validate that summary captures critical information from turns.
|
||
+
|
||
+ Returns (is_valid, missing_critical_items).
|
||
+ """
|
||
+ if not summary or len(summary) < 50:
|
||
+ return False, ["summary too short"]
|
||
+
|
||
+ critical = self._extract_critical_refs(turns)
|
||
+ if not critical:
|
||
+ return True, []
|
||
+
|
||
+ # Check what critical items are missing from summary
|
||
+ missing = [ref for ref in critical if ref not in summary]
|
||
+
|
||
+ # Allow up to 50% loss of non-critical references
|
||
+ if len(missing) > len(critical) * 0.5 and len(critical) > 3:
|
||
+ return False, missing[:5] # Return first 5 missing items
|
||
+
|
||
+ return True, []
|
||
+
|
||
+ def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> Optional[str]:
|
||
+ """Generate a structured summary of conversation turns.
|
||
+
|
||
+ NEW: Added validation step to detect low-quality summaries.
|
||
+ Falls back to extended summarization if validation fails.
|
||
+ """
|
||
+ summary_budget = self._compute_summary_budget(turns_to_summarize)
|
||
+ content_to_summarize = self._serialize_for_summary(turns_to_summarize)
|
||
+
|
||
if self._previous_summary:
|
||
# Iterative update: preserve existing info, add new progress
|
||
prompt = f"""You are updating a context compaction summary...
|
||
@@ -341,9 +419,27 @@ class ContextCompressor:
|
||
try:
|
||
call_kwargs = {
|
||
"task": "compression",
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"temperature": 0.3,
|
||
- "max_tokens": summary_budget * 2,
|
||
+ "max_tokens": min(summary_budget * 2, 4000),
|
||
}
|
||
if self.summary_model:
|
||
call_kwargs["model"] = self.summary_model
|
||
response = call_llm(**call_kwargs)
|
||
content = response.choices[0].message.content
|
||
# Handle cases where content is not a string (e.g., dict from llama.cpp)
|
||
if not isinstance(content, str):
|
||
content = str(content) if content else ""
|
||
summary = content.strip()
|
||
+
|
||
+ # NEW: Validate the generated summary
|
||
+ is_valid, missing = self._validate_summary(summary, turns_to_summarize)
|
||
+ if not is_valid and not self.quiet_mode:
|
||
+ logger.warning(
|
||
+ "Summary validation detected potential information loss. "
|
||
+ "Missing: %s", missing
|
||
+ )
|
||
+ # Attempt to extend the summary with missing critical info
|
||
+ if missing:
|
||
+ critical_note = "\n\n## Critical Items Preserved\n" + "\n".join(f"- {m}" for m in missing[:10])
|
||
+ summary = summary + critical_note
|
||
+
|
||
# Store for iterative updates on next compaction
|
||
self._previous_summary = summary
|
||
return self._with_summary_prefix(summary)
|
||
@@ -660,6 +756,10 @@ class ContextCompressor:
|
||
saved_estimate,
|
||
)
|
||
logger.info("Compression #%d complete", self.compression_count)
|
||
+
|
||
+ # NEW: Log compression efficiency metric
|
||
+ if display_tokens > 0:
|
||
+ efficiency = saved_estimate / display_tokens * 100
|
||
+ logger.info("Compression efficiency: %.1f%% tokens saved", efficiency)
|
||
|
||
return compressed
|
||
|
||
diff --git a/run_agent.py b/run_agent.py
|
||
index abc123..def456 100644
|
||
--- a/run_agent.py
|
||
+++ b/run_agent.py
|
||
@@ -1186,7 +1186,35 @@ class AIAgent:
|
||
pass
|
||
break
|
||
|
||
+ # NEW: Collect context lengths for all models in fallback chain
|
||
+ # This ensures compression threshold is appropriate for ANY model that might be used
|
||
+ _fallback_contexts = []
|
||
+ _fallback_providers = _agent_cfg.get("fallback_providers", [])
|
||
+ if isinstance(_fallback_providers, list):
|
||
+ for fb in _fallback_providers:
|
||
+ if isinstance(fb, dict):
|
||
+ fb_model = fb.get("model", "")
|
||
+ fb_base = fb.get("base_url", "")
|
||
+ fb_provider = fb.get("provider", "")
|
||
+ fb_key_env = fb.get("api_key_env", "")
|
||
+ fb_key = os.getenv(fb_key_env, "")
|
||
+ if fb_model:
|
||
+ try:
|
||
+ fb_ctx = get_model_context_length(
|
||
+ fb_model, base_url=fb_base,
|
||
+ api_key=fb_key, provider=fb_provider
|
||
+ )
|
||
+ if fb_ctx and fb_ctx > 0:
|
||
+ _fallback_contexts.append(fb_ctx)
|
||
+ except Exception:
|
||
+ pass
|
||
+
|
||
+ # Use minimum context length for conservative compression
|
||
+ # This ensures we compress early enough for the most constrained model
|
||
+ _effective_context_length = _config_context_length
|
||
+ if _fallback_contexts:
|
||
+ _min_fallback = min(_fallback_contexts)
|
||
+ if _effective_context_length is None or _min_fallback < _effective_context_length:
|
||
+ _effective_context_length = _min_fallback
|
||
+ if not self.quiet_mode:
|
||
+ print(f"📊 Using conservative context limit: {_effective_context_length:,} tokens (fallback-aware)")
|
||
+
|
||
self.context_compressor = ContextCompressor(
|
||
model=self.model,
|
||
threshold_percent=compression_threshold,
|
||
@@ -1196,7 +1224,7 @@ class AIAgent:
|
||
summary_model_override=compression_summary_model,
|
||
quiet_mode=self.quiet_mode,
|
||
base_url=self.base_url,
|
||
api_key=self.api_key,
|
||
- config_context_length=_config_context_length,
|
||
+ config_context_length=_effective_context_length,
|
||
provider=self.provider,
|
||
)
|
||
self.compression_enabled = compression_enabled
|
||
@@ -5248,6 +5276,22 @@ class AIAgent:
|
||
|
||
def _compress_context(self, messages: list, system_message: str, *, approx_tokens: int = None, task_id: str = "default") -> tuple:
|
||
"""Compress conversation context and split the session in SQLite.
|
||
+
|
||
+ NEW: Creates a checkpoint before compression for recovery.
|
||
+ This allows rewinding if the summary loses critical information.
|
||
+
|
||
+ Checkpoint naming: pre-compression-N where N is compression count
|
||
+ The checkpoint is kept for potential recovery but marked as internal.
|
||
|
||
Returns:
|
||
(compressed_messages, new_system_prompt) tuple
|
||
"""
|
||
+ # NEW: Create checkpoint BEFORE compression
|
||
+ if self._checkpoint_mgr and hasattr(self._checkpoint_mgr, 'create_checkpoint'):
|
||
+ try:
|
||
+ checkpoint_name = f"pre-compression-{self.context_compressor.compression_count}"
|
||
+ self._checkpoint_mgr.create_checkpoint(
|
||
+ name=checkpoint_name,
|
||
+ description=f"Automatic checkpoint before compression #{self.context_compressor.compression_count}"
|
||
+ )
|
||
+ if not self.quiet_mode:
|
||
+ logger.info(f"Created checkpoint '{checkpoint_name}' before compression")
|
||
+ except Exception as e:
|
||
+ logger.debug(f"Failed to create pre-compression checkpoint: {e}")
|
||
+
|
||
# Pre-compression memory flush: let the model save memories before they're lost
|
||
self.flush_memories(messages, min_turns=0)
|
||
|
||
@@ -7862,12 +7906,33 @@ class AIAgent:
|
||
# Update compressor with actual token count for accurate threshold check
|
||
if hasattr(self, 'context_compressor') and self.context_compressor:
|
||
self.context_compressor.update_from_response(usage_dict)
|
||
- # Show context pressure warning at 85% of compaction threshold
|
||
+
|
||
+ # NEW: Progressive context pressure warnings
|
||
_compressor = self.context_compressor
|
||
if _compressor.threshold_tokens > 0:
|
||
_compaction_progress = _real_tokens / _compressor.threshold_tokens
|
||
- if _compaction_progress >= 0.85 and not self._context_pressure_warned:
|
||
- self._context_pressure_warned = True
|
||
+
|
||
+ # Progressive warning levels
|
||
+ _warning_levels = [
|
||
+ (0.60, "info", "ℹ️ Context usage at 60%"),
|
||
+ (0.75, "notice", "📊 Context usage at 75% — consider wrapping up"),
|
||
+ (0.85, "warning", "⚠️ Context usage at 85% — compression imminent"),
|
||
+ (0.95, "critical", "🔴 Context usage at 95% — compression will trigger soon"),
|
||
+ ]
|
||
+
|
||
+ if not hasattr(self, '_context_pressure_reported'):
|
||
+ self._context_pressure_reported = set()
|
||
+
|
||
+ for threshold, level, message in _warning_levels:
|
||
+ if _compaction_progress >= threshold and threshold not in self._context_pressure_reported:
|
||
+ self._context_pressure_reported.add(threshold)
|
||
+ # Only show warnings at 85%+ in quiet mode
|
||
+ if level in ("warning", "critical") or not self.quiet_mode:
|
||
+ if self.status_callback:
|
||
+ self.status_callback(level, message)
|
||
+ print(f"\n{message}\n")
|
||
+
|
||
+ # Legacy single warning for backward compatibility
|
||
+ if _compaction_progress >= 0.85 and not getattr(self, '_context_pressure_warned', False):
|
||
+ self._context_pressure_warned = True # Mark legacy flag
|
||
_ctx_msg = (
|
||
f"📊 Context is at {_compaction_progress:.0%} of compression threshold "
|
||
f"({_real_tokens:,} / {_compressor.threshold_tokens:,} tokens). "
|
||
--
|
||
2.40.0
|