#!/usr/bin/env python3 """ Trajectory Compressor Post-processes completed agent trajectories to compress them within a target token budget while preserving training signal quality. Compression Strategy: 1. Protect first turns (system, human, first gpt, first tool) 2. Protect last N turns (final actions and conclusions) 3. Compress MIDDLE turns only, starting from 2nd tool response 4. Compress only as much as needed to fit under target 5. Replace compressed region with a single human summary message 6. Keep remaining tool calls intact (model continues working after summary) Usage: # Compress a directory of JSONL files python trajectory_compressor.py --input=data/my_run # Compress a single JSONL file python trajectory_compressor.py --input=data/trajectories.jsonl # Compress 15% sample of a file python trajectory_compressor.py --input=data/trajectories.jsonl --sample_percent=15 # Compress with custom output and token target python trajectory_compressor.py --input=data/trajectories.jsonl --output=compressed.jsonl --target_max_tokens=16000 # Compress 10% sample from a directory python trajectory_compressor.py --input=data/my_run --sample_percent=10 """ import json import os import re import time import yaml import logging import asyncio from pathlib import Path from typing import List, Dict, Any, Optional, Tuple, Callable from dataclasses import dataclass, field from datetime import datetime import fire from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeElapsedColumn, TimeRemainingColumn from rich.console import Console from hermes_constants import OPENROUTER_BASE_URL # Load environment variables from dotenv import load_dotenv load_dotenv() @dataclass class CompressionConfig: """Configuration for trajectory compression.""" # Tokenizer tokenizer_name: str = "moonshotai/Kimi-K2-Thinking" trust_remote_code: bool = True # Compression targets target_max_tokens: int = 15250 summary_target_tokens: int = 750 # Protected turns protect_first_system: bool = True protect_first_human: bool = True protect_first_gpt: bool = True protect_first_tool: bool = True protect_last_n_turns: int = 4 # Summarization (OpenRouter) summarization_model: str = "google/gemini-3-flash-preview" base_url: str = OPENROUTER_BASE_URL api_key_env: str = "OPENROUTER_API_KEY" temperature: float = 0.3 max_retries: int = 3 retry_delay: int = 2 # Output add_summary_notice: bool = True summary_notice_text: str = "\n\nSome of your previous tool responses may be summarized to preserve context." output_suffix: str = "_compressed" # Processing num_workers: int = 4 max_concurrent_requests: int = 50 # Max concurrent API calls for summarization skip_under_target: bool = True save_over_limit: bool = True per_trajectory_timeout: int = 300 # Timeout per trajectory in seconds (default: 5 min) # Metrics metrics_enabled: bool = True metrics_per_trajectory: bool = True metrics_output_file: str = "compression_metrics.json" @classmethod def from_yaml(cls, yaml_path: str) -> "CompressionConfig": """Load configuration from YAML file.""" with open(yaml_path, 'r') as f: data = yaml.safe_load(f) config = cls() # Tokenizer if 'tokenizer' in data: config.tokenizer_name = data['tokenizer'].get('name', config.tokenizer_name) config.trust_remote_code = data['tokenizer'].get('trust_remote_code', config.trust_remote_code) # Compression if 'compression' in data: config.target_max_tokens = data['compression'].get('target_max_tokens', config.target_max_tokens) config.summary_target_tokens = data['compression'].get('summary_target_tokens', config.summary_target_tokens) # Protected turns if 'protected_turns' in data: config.protect_first_system = data['protected_turns'].get('first_system', config.protect_first_system) config.protect_first_human = data['protected_turns'].get('first_human', config.protect_first_human) config.protect_first_gpt = data['protected_turns'].get('first_gpt', config.protect_first_gpt) config.protect_first_tool = data['protected_turns'].get('first_tool', config.protect_first_tool) config.protect_last_n_turns = data['protected_turns'].get('last_n_turns', config.protect_last_n_turns) # Summarization if 'summarization' in data: config.summarization_model = data['summarization'].get('model', config.summarization_model) config.base_url = data['summarization'].get('base_url', config.base_url) config.api_key_env = data['summarization'].get('api_key_env', config.api_key_env) config.temperature = data['summarization'].get('temperature', config.temperature) config.max_retries = data['summarization'].get('max_retries', config.max_retries) config.retry_delay = data['summarization'].get('retry_delay', config.retry_delay) # Output if 'output' in data: config.add_summary_notice = data['output'].get('add_summary_notice', config.add_summary_notice) config.summary_notice_text = data['output'].get('summary_notice_text', config.summary_notice_text) config.output_suffix = data['output'].get('output_suffix', config.output_suffix) # Processing if 'processing' in data: config.num_workers = data['processing'].get('num_workers', config.num_workers) config.max_concurrent_requests = data['processing'].get('max_concurrent_requests', config.max_concurrent_requests) config.skip_under_target = data['processing'].get('skip_under_target', config.skip_under_target) config.save_over_limit = data['processing'].get('save_over_limit', config.save_over_limit) # Metrics if 'metrics' in data: config.metrics_enabled = data['metrics'].get('enabled', config.metrics_enabled) config.metrics_per_trajectory = data['metrics'].get('per_trajectory', config.metrics_per_trajectory) config.metrics_output_file = data['metrics'].get('output_file', config.metrics_output_file) return config @dataclass class TrajectoryMetrics: """Metrics for a single trajectory compression.""" original_tokens: int = 0 compressed_tokens: int = 0 tokens_saved: int = 0 compression_ratio: float = 1.0 original_turns: int = 0 compressed_turns: int = 0 turns_removed: int = 0 turns_compressed_start_idx: int = -1 turns_compressed_end_idx: int = -1 turns_in_compressed_region: int = 0 was_compressed: bool = False still_over_limit: bool = False skipped_under_target: bool = False summarization_api_calls: int = 0 summarization_errors: int = 0 def to_dict(self) -> Dict[str, Any]: return { "original_tokens": self.original_tokens, "compressed_tokens": self.compressed_tokens, "tokens_saved": self.tokens_saved, "compression_ratio": round(self.compression_ratio, 4), "original_turns": self.original_turns, "compressed_turns": self.compressed_turns, "turns_removed": self.turns_removed, "compression_region": { "start_idx": self.turns_compressed_start_idx, "end_idx": self.turns_compressed_end_idx, "turns_count": self.turns_in_compressed_region, }, "was_compressed": self.was_compressed, "still_over_limit": self.still_over_limit, "skipped_under_target": self.skipped_under_target, "summarization_api_calls": self.summarization_api_calls, "summarization_errors": self.summarization_errors, } @dataclass class AggregateMetrics: """Aggregate metrics across all trajectories.""" total_trajectories: int = 0 trajectories_compressed: int = 0 trajectories_skipped_under_target: int = 0 trajectories_still_over_limit: int = 0 trajectories_failed: int = 0 total_tokens_before: int = 0 total_tokens_after: int = 0 total_tokens_saved: int = 0 total_turns_before: int = 0 total_turns_after: int = 0 total_turns_removed: int = 0 total_summarization_calls: int = 0 total_summarization_errors: int = 0 # Distribution stats compression_ratios: List[float] = field(default_factory=list) tokens_saved_list: List[int] = field(default_factory=list) turns_removed_list: List[int] = field(default_factory=list) processing_start_time: str = "" processing_end_time: str = "" processing_duration_seconds: float = 0.0 def add_trajectory_metrics(self, metrics: TrajectoryMetrics): """Add a trajectory's metrics to the aggregate.""" self.total_trajectories += 1 self.total_tokens_before += metrics.original_tokens self.total_tokens_after += metrics.compressed_tokens self.total_tokens_saved += metrics.tokens_saved self.total_turns_before += metrics.original_turns self.total_turns_after += metrics.compressed_turns self.total_turns_removed += metrics.turns_removed self.total_summarization_calls += metrics.summarization_api_calls self.total_summarization_errors += metrics.summarization_errors if metrics.was_compressed: self.trajectories_compressed += 1 self.compression_ratios.append(metrics.compression_ratio) self.tokens_saved_list.append(metrics.tokens_saved) self.turns_removed_list.append(metrics.turns_removed) if metrics.skipped_under_target: self.trajectories_skipped_under_target += 1 if metrics.still_over_limit: self.trajectories_still_over_limit += 1 def to_dict(self) -> Dict[str, Any]: avg_compression_ratio = ( sum(self.compression_ratios) / len(self.compression_ratios) if self.compression_ratios else 1.0 ) avg_tokens_saved = ( sum(self.tokens_saved_list) / len(self.tokens_saved_list) if self.tokens_saved_list else 0 ) avg_turns_removed = ( sum(self.turns_removed_list) / len(self.turns_removed_list) if self.turns_removed_list else 0 ) return { "summary": { "total_trajectories": self.total_trajectories, "trajectories_compressed": self.trajectories_compressed, "trajectories_skipped_under_target": self.trajectories_skipped_under_target, "trajectories_still_over_limit": self.trajectories_still_over_limit, "trajectories_failed": self.trajectories_failed, "compression_rate": round(self.trajectories_compressed / max(self.total_trajectories, 1), 4), }, "tokens": { "total_before": self.total_tokens_before, "total_after": self.total_tokens_after, "total_saved": self.total_tokens_saved, "overall_compression_ratio": round(self.total_tokens_after / max(self.total_tokens_before, 1), 4), }, "turns": { "total_before": self.total_turns_before, "total_after": self.total_turns_after, "total_removed": self.total_turns_removed, }, "averages": { "avg_compression_ratio": round(avg_compression_ratio, 4), "avg_tokens_saved_per_compressed": round(avg_tokens_saved, 1), "avg_turns_removed_per_compressed": round(avg_turns_removed, 2), }, "summarization": { "total_api_calls": self.total_summarization_calls, "total_errors": self.total_summarization_errors, "success_rate": round(1 - (self.total_summarization_errors / max(self.total_summarization_calls, 1)), 4), }, "processing": { "start_time": self.processing_start_time, "end_time": self.processing_end_time, "duration_seconds": round(self.processing_duration_seconds, 2), }, } class TrajectoryCompressor: """ Compresses agent trajectories to fit within a target token budget. Compression strategy: 1. Keep protected head turns (system, human, first gpt+tool) 2. Keep protected tail turns (last N turns) 3. From the compressible middle region, compress only as much as needed 4. Replace compressed turns with a single human summary message 5. Keep remaining middle turns intact (model continues with tools) """ def __init__(self, config: CompressionConfig): """Initialize the compressor.""" self.config = config self.aggregate_metrics = AggregateMetrics() # Initialize tokenizer self._init_tokenizer() # Initialize OpenRouter client self._init_summarizer() logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%H:%M:%S' ) self.logger = logging.getLogger(__name__) def _init_tokenizer(self): """Initialize HuggingFace tokenizer for token counting.""" try: from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained( self.config.tokenizer_name, trust_remote_code=self.config.trust_remote_code ) print(f"āœ… Loaded tokenizer: {self.config.tokenizer_name}") except Exception as e: raise RuntimeError(f"Failed to load tokenizer '{self.config.tokenizer_name}': {e}") def _init_summarizer(self): """Initialize LLM routing for summarization (sync and async). Uses call_llm/async_call_llm from the centralized provider router which handles auth, headers, and provider detection internally. For custom endpoints, falls back to raw client construction. """ from agent.auxiliary_client import call_llm, async_call_llm provider = self._detect_provider() if provider: # Store provider for use in _generate_summary calls self._llm_provider = provider self._use_call_llm = True # Verify the provider is available from agent.auxiliary_client import resolve_provider_client client, _ = resolve_provider_client( provider, model=self.config.summarization_model) if client is None: raise RuntimeError( f"Provider '{provider}' is not configured. " f"Check your API key or run: hermes setup") self.client = None # Not used directly self.async_client = None # Not used directly else: # Custom endpoint — use config's raw base_url + api_key_env self._use_call_llm = False api_key = os.getenv(self.config.api_key_env) if not api_key: raise RuntimeError( f"Missing API key. Set {self.config.api_key_env} " f"environment variable.") from openai import OpenAI, AsyncOpenAI self.client = OpenAI( api_key=api_key, base_url=self.config.base_url) self.async_client = AsyncOpenAI( api_key=api_key, base_url=self.config.base_url) print(f"āœ… Initialized summarizer client: {self.config.summarization_model}") print(f" Max concurrent requests: {self.config.max_concurrent_requests}") def _detect_provider(self) -> str: """Detect the provider name from the configured base_url.""" url = self.config.base_url.lower() if "openrouter" in url: return "openrouter" if "nousresearch.com" in url: return "nous" if "chatgpt.com/backend-api/codex" in url: return "codex" if "api.z.ai" in url: return "zai" if "moonshot.ai" in url or "api.kimi.com" in url: return "kimi-coding" if "minimaxi.com" in url: return "minimax-cn" if "minimax.io" in url: return "minimax" # Unknown base_url — not a known provider return "" def count_tokens(self, text: str) -> int: """Count tokens in text using the configured tokenizer.""" if not text: return 0 try: return len(self.tokenizer.encode(text)) except Exception: # Fallback to character estimate return len(text) // 4 def count_trajectory_tokens(self, trajectory: List[Dict[str, str]]) -> int: """Count total tokens in a trajectory.""" return sum(self.count_tokens(turn.get("value", "")) for turn in trajectory) def count_turn_tokens(self, trajectory: List[Dict[str, str]]) -> List[int]: """Count tokens for each turn in a trajectory.""" return [self.count_tokens(turn.get("value", "")) for turn in trajectory] def _find_protected_indices(self, trajectory: List[Dict[str, str]]) -> Tuple[set, int, int]: """ Find indices of protected turns. Returns: Tuple of (protected_set, compressible_start, compressible_end) """ n = len(trajectory) protected = set() # Track first occurrences first_system = first_human = first_gpt = first_tool = None for i, turn in enumerate(trajectory): role = turn.get("from", "") if role == "system" and first_system is None: first_system = i elif role == "human" and first_human is None: first_human = i elif role == "gpt" and first_gpt is None: first_gpt = i elif role == "tool" and first_tool is None: first_tool = i # Protect first turns if self.config.protect_first_system and first_system is not None: protected.add(first_system) if self.config.protect_first_human and first_human is not None: protected.add(first_human) if self.config.protect_first_gpt and first_gpt is not None: protected.add(first_gpt) if self.config.protect_first_tool and first_tool is not None: protected.add(first_tool) # Protect last N turns for i in range(max(0, n - self.config.protect_last_n_turns), n): protected.add(i) # Determine compressible region # Start after the last protected head turn head_protected = [i for i in protected if i < n // 2] tail_protected = [i for i in protected if i >= n // 2] compressible_start = max(head_protected) + 1 if head_protected else 0 compressible_end = min(tail_protected) if tail_protected else n return protected, compressible_start, compressible_end def _extract_turn_content_for_summary(self, trajectory: List[Dict[str, str]], start: int, end: int) -> str: """ Extract content from turns to be summarized. Args: trajectory: Full trajectory start: Start index (inclusive) end: End index (exclusive) Returns: Formatted string of turn contents for summarization """ parts = [] for i in range(start, end): turn = trajectory[i] role = turn.get("from", "unknown") value = turn.get("value", "") # Truncate very long values for the summary prompt if len(value) > 3000: value = value[:1500] + "\n...[truncated]...\n" + value[-500:] parts.append(f"[Turn {i} - {role.upper()}]:\n{value}") return "\n\n".join(parts) def _generate_summary(self, content: str, metrics: TrajectoryMetrics) -> str: """ Generate a summary of the compressed turns using OpenRouter. Args: content: The content to summarize metrics: Metrics object to update Returns: Summary string """ prompt = f"""Summarize the following agent conversation turns concisely. This summary will replace these turns in the conversation history. Write the summary from a neutral perspective describing what the assistant did and learned. Include: 1. What actions the assistant took (tool calls, searches, file operations) 2. Key information or results obtained 3. Any important decisions or findings 4. Relevant data, file names, values, or outputs Keep the summary factual and informative. Target approximately {self.config.summary_target_tokens} tokens. --- TURNS TO SUMMARIZE: {content} --- Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" for attempt in range(self.config.max_retries): try: metrics.summarization_api_calls += 1 if getattr(self, '_use_call_llm', False): from agent.auxiliary_client import call_llm response = call_llm( provider=self._llm_provider, model=self.config.summarization_model, messages=[{"role": "user", "content": prompt}], temperature=self.config.temperature, max_tokens=self.config.summary_target_tokens * 2, ) else: response = self.client.chat.completions.create( model=self.config.summarization_model, messages=[{"role": "user", "content": prompt}], temperature=self.config.temperature, max_tokens=self.config.summary_target_tokens * 2, ) summary = response.choices[0].message.content.strip() # Ensure it starts with the prefix if not summary.startswith("[CONTEXT SUMMARY]:"): summary = "[CONTEXT SUMMARY]: " + summary return summary except Exception as e: metrics.summarization_errors += 1 self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}") if attempt < self.config.max_retries - 1: time.sleep(self.config.retry_delay * (attempt + 1)) else: # Fallback: create a basic summary return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]" async def _generate_summary_async(self, content: str, metrics: TrajectoryMetrics) -> str: """ Generate a summary of the compressed turns using OpenRouter (async version). Args: content: The content to summarize metrics: Metrics object to update Returns: Summary string """ prompt = f"""Summarize the following agent conversation turns concisely. This summary will replace these turns in the conversation history. Write the summary from a neutral perspective describing what the assistant did and learned. Include: 1. What actions the assistant took (tool calls, searches, file operations) 2. Key information or results obtained 3. Any important decisions or findings 4. Relevant data, file names, values, or outputs Keep the summary factual and informative. Target approximately {self.config.summary_target_tokens} tokens. --- TURNS TO SUMMARIZE: {content} --- Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" for attempt in range(self.config.max_retries): try: metrics.summarization_api_calls += 1 if getattr(self, '_use_call_llm', False): from agent.auxiliary_client import async_call_llm response = await async_call_llm( provider=self._llm_provider, model=self.config.summarization_model, messages=[{"role": "user", "content": prompt}], temperature=self.config.temperature, max_tokens=self.config.summary_target_tokens * 2, ) else: response = await self.async_client.chat.completions.create( model=self.config.summarization_model, messages=[{"role": "user", "content": prompt}], temperature=self.config.temperature, max_tokens=self.config.summary_target_tokens * 2, ) summary = response.choices[0].message.content.strip() # Ensure it starts with the prefix if not summary.startswith("[CONTEXT SUMMARY]:"): summary = "[CONTEXT SUMMARY]: " + summary return summary except Exception as e: metrics.summarization_errors += 1 self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}") if attempt < self.config.max_retries - 1: await asyncio.sleep(self.config.retry_delay * (attempt + 1)) else: # Fallback: create a basic summary return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]" def compress_trajectory( self, trajectory: List[Dict[str, str]] ) -> Tuple[List[Dict[str, str]], TrajectoryMetrics]: """ Compress a single trajectory to fit within target token budget. Algorithm: 1. Count total tokens 2. If under target, skip 3. Find compressible region (between protected head and tail) 4. Calculate how many tokens need to be saved 5. Accumulate turns from start of compressible region until savings met 6. Replace accumulated turns with single human summary 7. Keep remaining turns intact Args: trajectory: List of conversation turns Returns: Tuple of (compressed_trajectory, metrics) """ metrics = TrajectoryMetrics() metrics.original_turns = len(trajectory) # Count tokens per turn turn_tokens = self.count_turn_tokens(trajectory) total_tokens = sum(turn_tokens) metrics.original_tokens = total_tokens # Check if compression needed if total_tokens <= self.config.target_max_tokens: metrics.skipped_under_target = True metrics.compressed_tokens = total_tokens metrics.compressed_turns = len(trajectory) metrics.compression_ratio = 1.0 return trajectory, metrics # Find protected regions protected, compress_start, compress_end = self._find_protected_indices(trajectory) # Check if there's anything to compress if compress_start >= compress_end: # Nothing to compress, return as-is metrics.compressed_tokens = total_tokens metrics.compressed_turns = len(trajectory) metrics.still_over_limit = total_tokens > self.config.target_max_tokens return trajectory, metrics # Calculate how much we need to save tokens_to_save = total_tokens - self.config.target_max_tokens # We'll replace N turns with 1 summary turn # Net savings = (sum of N turns' tokens) - summary_target_tokens # We need: net_savings >= tokens_to_save # So: sum of turns >= tokens_to_save + summary_target_tokens target_tokens_to_compress = tokens_to_save + self.config.summary_target_tokens # Accumulate turns from compress_start until we have enough savings accumulated_tokens = 0 compress_until = compress_start for i in range(compress_start, compress_end): accumulated_tokens += turn_tokens[i] compress_until = i + 1 # Exclusive end # Check if we have enough savings if accumulated_tokens >= target_tokens_to_compress: break # If we still don't have enough savings, compress the entire compressible region if accumulated_tokens < target_tokens_to_compress and compress_until < compress_end: compress_until = compress_end accumulated_tokens = sum(turn_tokens[compress_start:compress_end]) # Record compression region metrics.turns_compressed_start_idx = compress_start metrics.turns_compressed_end_idx = compress_until metrics.turns_in_compressed_region = compress_until - compress_start # Extract content for summary content_to_summarize = self._extract_turn_content_for_summary( trajectory, compress_start, compress_until ) # Generate summary summary = self._generate_summary(content_to_summarize, metrics) # Build compressed trajectory compressed = [] # Add head (turns before compression region) for i in range(compress_start): turn = trajectory[i].copy() # Add notice to system message if turn.get("from") == "system" and self.config.add_summary_notice: turn["value"] = turn["value"] + self.config.summary_notice_text compressed.append(turn) # Add summary as human message compressed.append({ "from": "human", "value": summary }) # Add tail (turns after compression region) for i in range(compress_until, len(trajectory)): compressed.append(trajectory[i].copy()) # Calculate final metrics metrics.compressed_turns = len(compressed) metrics.compressed_tokens = self.count_trajectory_tokens(compressed) metrics.turns_removed = metrics.original_turns - metrics.compressed_turns metrics.tokens_saved = metrics.original_tokens - metrics.compressed_tokens metrics.compression_ratio = metrics.compressed_tokens / max(metrics.original_tokens, 1) metrics.was_compressed = True metrics.still_over_limit = metrics.compressed_tokens > self.config.target_max_tokens return compressed, metrics async def compress_trajectory_async( self, trajectory: List[Dict[str, str]] ) -> Tuple[List[Dict[str, str]], TrajectoryMetrics]: """ Compress a single trajectory to fit within target token budget (async version). Same algorithm as compress_trajectory but uses async API calls for summarization. """ metrics = TrajectoryMetrics() metrics.original_turns = len(trajectory) # Count tokens per turn turn_tokens = self.count_turn_tokens(trajectory) total_tokens = sum(turn_tokens) metrics.original_tokens = total_tokens # Check if compression needed if total_tokens <= self.config.target_max_tokens: metrics.skipped_under_target = True metrics.compressed_tokens = total_tokens metrics.compressed_turns = len(trajectory) metrics.compression_ratio = 1.0 return trajectory, metrics # Find protected regions protected, compress_start, compress_end = self._find_protected_indices(trajectory) # Check if there's anything to compress if compress_start >= compress_end: metrics.compressed_tokens = total_tokens metrics.compressed_turns = len(trajectory) metrics.still_over_limit = total_tokens > self.config.target_max_tokens return trajectory, metrics # Calculate how much we need to save tokens_to_save = total_tokens - self.config.target_max_tokens target_tokens_to_compress = tokens_to_save + self.config.summary_target_tokens # Accumulate turns from compress_start until we have enough savings accumulated_tokens = 0 compress_until = compress_start for i in range(compress_start, compress_end): accumulated_tokens += turn_tokens[i] compress_until = i + 1 if accumulated_tokens >= target_tokens_to_compress: break # If we still don't have enough savings, compress the entire compressible region if accumulated_tokens < target_tokens_to_compress and compress_until < compress_end: compress_until = compress_end accumulated_tokens = sum(turn_tokens[compress_start:compress_end]) # Record compression region metrics.turns_compressed_start_idx = compress_start metrics.turns_compressed_end_idx = compress_until metrics.turns_in_compressed_region = compress_until - compress_start # Extract content for summary content_to_summarize = self._extract_turn_content_for_summary( trajectory, compress_start, compress_until ) # Generate summary (ASYNC) summary = await self._generate_summary_async(content_to_summarize, metrics) # Build compressed trajectory compressed = [] # Add head (turns before compression region) for i in range(compress_start): turn = trajectory[i].copy() if turn.get("from") == "system" and self.config.add_summary_notice: turn["value"] = turn["value"] + self.config.summary_notice_text compressed.append(turn) # Add summary as human message compressed.append({ "from": "human", "value": summary }) # Add tail (turns after compression region) for i in range(compress_until, len(trajectory)): compressed.append(trajectory[i].copy()) # Calculate final metrics metrics.compressed_turns = len(compressed) metrics.compressed_tokens = self.count_trajectory_tokens(compressed) metrics.turns_removed = metrics.original_turns - metrics.compressed_turns metrics.tokens_saved = metrics.original_tokens - metrics.compressed_tokens metrics.compression_ratio = metrics.compressed_tokens / max(metrics.original_tokens, 1) metrics.was_compressed = True metrics.still_over_limit = metrics.compressed_tokens > self.config.target_max_tokens return compressed, metrics async def process_entry_async(self, entry: Dict[str, Any]) -> Tuple[Dict[str, Any], TrajectoryMetrics]: """ Process a single JSONL entry (async version). """ if "conversations" not in entry: metrics = TrajectoryMetrics() return entry, metrics trajectory = entry["conversations"] compressed_trajectory, metrics = await self.compress_trajectory_async(trajectory) # Create new entry with compressed trajectory result = entry.copy() result["conversations"] = compressed_trajectory # Add compression metadata if enabled if self.config.metrics_per_trajectory and metrics.was_compressed: result["compression_metrics"] = metrics.to_dict() return result, metrics def process_entry(self, entry: Dict[str, Any]) -> Tuple[Dict[str, Any], TrajectoryMetrics]: """ Process a single JSONL entry. Args: entry: JSONL entry containing 'conversations' field Returns: Tuple of (processed_entry, metrics) """ if "conversations" not in entry: metrics = TrajectoryMetrics() return entry, metrics trajectory = entry["conversations"] compressed_trajectory, metrics = self.compress_trajectory(trajectory) # Create new entry with compressed trajectory result = entry.copy() result["conversations"] = compressed_trajectory # Add compression metadata if enabled if self.config.metrics_per_trajectory and metrics.was_compressed: result["compression_metrics"] = metrics.to_dict() return result, metrics def process_file( self, input_path: Path, output_path: Path, progress_callback: Optional[Callable[[TrajectoryMetrics], None]] = None ) -> List[TrajectoryMetrics]: """ Process a single JSONL file. Args: input_path: Path to input JSONL file output_path: Path to output JSONL file progress_callback: Optional callback called after each entry with its metrics Returns: List of metrics for each trajectory """ file_metrics = [] # Read all entries entries = [] with open(input_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): line = line.strip() if line: try: entries.append(json.loads(line)) except json.JSONDecodeError as e: self.logger.warning(f"Skipping invalid JSON at {input_path}:{line_num}: {e}") # Process entries processed_entries = [] for entry in entries: try: processed_entry, metrics = self.process_entry(entry) processed_entries.append(processed_entry) file_metrics.append(metrics) self.aggregate_metrics.add_trajectory_metrics(metrics) # Call progress callback if provided if progress_callback: progress_callback(metrics) except Exception as e: self.logger.error(f"Error processing entry: {e}") self.aggregate_metrics.trajectories_failed += 1 # Keep original entry on error processed_entries.append(entry) empty_metrics = TrajectoryMetrics() file_metrics.append(empty_metrics) if progress_callback: progress_callback(empty_metrics) # Write output output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w', encoding='utf-8') as f: for entry in processed_entries: f.write(json.dumps(entry, ensure_ascii=False) + '\n') return file_metrics def process_directory(self, input_dir: Path, output_dir: Path): """ Process all JSONL files in a directory using async parallel processing. Args: input_dir: Input directory containing JSONL files output_dir: Output directory for compressed files """ # Run the async version asyncio.run(self._process_directory_async(input_dir, output_dir)) async def _process_directory_async(self, input_dir: Path, output_dir: Path): """ Async implementation of directory processing with parallel API calls. """ console = Console() # Record start time self.aggregate_metrics.processing_start_time = datetime.now().isoformat() start_time = time.time() # Find all JSONL files jsonl_files = sorted(input_dir.glob("*.jsonl")) if not jsonl_files: self.logger.warning(f"No JSONL files found in {input_dir}") return # Load ALL entries from all files console.print("\n[dim]Loading all entries...[/dim]") all_entries = [] # List of (file_path, entry_idx, entry) for file_path in jsonl_files: with open(file_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f): line = line.strip() if line: try: entry = json.loads(line) all_entries.append((file_path, line_num, entry)) except json.JSONDecodeError as e: self.logger.warning(f"Skipping invalid JSON at {file_path}:{line_num}: {e}") total_entries = len(all_entries) console.print(f"\n{'='*60}") console.print(f"šŸ“‚ Input: {input_dir}") console.print(f"šŸ“‚ Output: {output_dir}") console.print(f"šŸ“„ Files to process: {len(jsonl_files)}") console.print(f"šŸ“Š Total trajectories: {total_entries:,}") console.print(f"šŸŽÆ Target max tokens: {self.config.target_max_tokens:,}") console.print(f"šŸ“ Summary target tokens: {self.config.summary_target_tokens}") console.print(f"⚔ Max concurrent API calls: {self.config.max_concurrent_requests}") console.print(f"{'='*60}\n") # Create semaphore for rate limiting semaphore = asyncio.Semaphore(self.config.max_concurrent_requests) # Tracking for progress display (thread-safe with lock) progress_lock = asyncio.Lock() compressed_count = 0 skipped_count = 0 api_calls = 0 in_flight = 0 # Results storage: {file_path: {entry_idx: (processed_entry, metrics)}} results = {f: {} for f in jsonl_files} # Track timeouts separately timeout_count = 0 async def process_single(file_path: Path, entry_idx: int, entry: Dict, progress, main_task, status_task): """Process a single entry with semaphore rate limiting and timeout.""" nonlocal compressed_count, skipped_count, api_calls, in_flight, timeout_count async with semaphore: # Track in-flight async with progress_lock: in_flight += 1 try: # Apply per-trajectory timeout processed_entry, metrics = await asyncio.wait_for( self.process_entry_async(entry), timeout=self.config.per_trajectory_timeout ) results[file_path][entry_idx] = (processed_entry, metrics) # Update aggregate metrics (with lock for thread safety) async with progress_lock: self.aggregate_metrics.add_trajectory_metrics(metrics) # Update counters if metrics.was_compressed: compressed_count += 1 api_calls += metrics.summarization_api_calls if metrics.skipped_under_target: skipped_count += 1 in_flight -= 1 # Update progress progress.advance(main_task) progress.update( status_task, description=f"[dim]āœ… {compressed_count} compressed | ā­ļø {skipped_count} skipped | ā±ļø {timeout_count} timeout | šŸ”„ {api_calls} API calls | ⚔ {in_flight} in-flight[/dim]" ) except asyncio.TimeoutError: self.logger.warning(f"Timeout processing entry from {file_path}:{entry_idx} (>{self.config.per_trajectory_timeout}s)") async with progress_lock: self.aggregate_metrics.trajectories_failed += 1 timeout_count += 1 in_flight -= 1 progress.advance(main_task) progress.update( status_task, description=f"[dim]āœ… {compressed_count} compressed | ā­ļø {skipped_count} skipped | ā±ļø {timeout_count} timeout | šŸ”„ {api_calls} API calls | ⚔ {in_flight} in-flight[/dim]" ) # Skip this entry entirely (don't include in output) results[file_path][entry_idx] = None except Exception as e: self.logger.error(f"Error processing entry from {file_path}:{entry_idx}: {e}") async with progress_lock: self.aggregate_metrics.trajectories_failed += 1 in_flight -= 1 progress.advance(main_task) # Keep original entry on error results[file_path][entry_idx] = (entry, TrajectoryMetrics()) # Create progress bar with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), TextColumn("•"), TimeElapsedColumn(), TextColumn("•"), TimeRemainingColumn(), console=console, refresh_per_second=10 # Higher refresh for async ) as progress: # Main task for overall progress main_task = progress.add_task( f"[cyan]Compressing {total_entries:,} trajectories", total=total_entries ) # Status line task status_task = progress.add_task( "[dim]Starting...[/dim]", total=None ) # Create all tasks tasks = [ process_single(file_path, entry_idx, entry, progress, main_task, status_task) for file_path, entry_idx, entry in all_entries ] # Run all tasks concurrently (semaphore limits actual concurrency) await asyncio.gather(*tasks) # Remove status task progress.remove_task(status_task) # Write results to output files (preserving original order) console.print("\n[dim]Writing output files...[/dim]") output_dir.mkdir(parents=True, exist_ok=True) for file_path in jsonl_files: output_path = output_dir / file_path.name file_results = results[file_path] # Sort by original entry index to preserve order, skip None (timed out) entries sorted_entries = [ file_results[idx][0] for idx in sorted(file_results.keys()) if file_results[idx] is not None ] with open(output_path, 'w', encoding='utf-8') as f: for entry in sorted_entries: f.write(json.dumps(entry, ensure_ascii=False) + '\n') # Record end time self.aggregate_metrics.processing_end_time = datetime.now().isoformat() self.aggregate_metrics.processing_duration_seconds = time.time() - start_time # Print summary self._print_summary() # Save metrics if self.config.metrics_enabled: metrics_path = output_dir / self.config.metrics_output_file with open(metrics_path, 'w') as f: json.dump(self.aggregate_metrics.to_dict(), f, indent=2) console.print(f"\nšŸ’¾ Metrics saved to {metrics_path}") def _print_summary(self): """Print comprehensive compression summary statistics.""" m = self.aggregate_metrics.to_dict() # Calculate some additional stats total = m['summary']['total_trajectories'] compressed = m['summary']['trajectories_compressed'] skipped = m['summary']['trajectories_skipped_under_target'] over_limit = m['summary']['trajectories_still_over_limit'] failed = m['summary']['trajectories_failed'] # Token stats tokens_before = m['tokens']['total_before'] tokens_after = m['tokens']['total_after'] tokens_saved = m['tokens']['total_saved'] # Calculate percentages compressed_pct = (compressed / max(total, 1)) * 100 skipped_pct = (skipped / max(total, 1)) * 100 over_limit_pct = (over_limit / max(total, 1)) * 100 print(f"\n") print(f"ā•”{'═'*70}ā•—") print(f"ā•‘{'TRAJECTORY COMPRESSION REPORT':^70}ā•‘") print(f"ā• {'═'*70}ā•£") # Trajectories section print(f"ā•‘{'':2}šŸ“ TRAJECTORIES{' '*54}ā•‘") print(f"ā•‘{'─'*70}ā•‘") print(f"ā•‘{'':4}Total Processed: {total:>10,}{' '*32}ā•‘") print(f"ā•‘{'':4}ā”œā”€ Compressed: {compressed:>10,} ({compressed_pct:>5.1f}%){' '*18}ā•‘") print(f"ā•‘{'':4}ā”œā”€ Skipped (under limit):{skipped:>9,} ({skipped_pct:>5.1f}%){' '*18}ā•‘") print(f"ā•‘{'':4}ā”œā”€ Still over limit: {over_limit:>10,} ({over_limit_pct:>5.1f}%){' '*18}ā•‘") print(f"ā•‘{'':4}└─ Failed: {failed:>10,}{' '*32}ā•‘") print(f"ā• {'═'*70}ā•£") # Tokens section print(f"ā•‘{'':2}šŸ”¢ TOKENS{' '*60}ā•‘") print(f"ā•‘{'─'*70}ā•‘") print(f"ā•‘{'':4}Before Compression: {tokens_before:>15,} tokens{' '*21}ā•‘") print(f"ā•‘{'':4}After Compression: {tokens_after:>15,} tokens{' '*21}ā•‘") print(f"ā•‘{'':4}Total Saved: {tokens_saved:>15,} tokens{' '*21}ā•‘") print(f"ā•‘{'':4}Overall Compression: {m['tokens']['overall_compression_ratio']:>14.1%}{' '*28}ā•‘") if tokens_before > 0: savings_pct = (tokens_saved / tokens_before) * 100 print(f"ā•‘{'':4}Space Savings: {savings_pct:>14.1f}%{' '*28}ā•‘") print(f"ā• {'═'*70}ā•£") # Turns section print(f"ā•‘{'':2}šŸ’¬ CONVERSATION TURNS{' '*48}ā•‘") print(f"ā•‘{'─'*70}ā•‘") print(f"ā•‘{'':4}Before Compression: {m['turns']['total_before']:>15,} turns{' '*22}ā•‘") print(f"ā•‘{'':4}After Compression: {m['turns']['total_after']:>15,} turns{' '*22}ā•‘") print(f"ā•‘{'':4}Total Removed: {m['turns']['total_removed']:>15,} turns{' '*22}ā•‘") print(f"ā• {'═'*70}ā•£") # Averages section (for compressed trajectories only) print(f"ā•‘{'':2}šŸ“ˆ AVERAGES (Compressed Trajectories Only){' '*27}ā•‘") print(f"ā•‘{'─'*70}ā•‘") if compressed > 0: print(f"ā•‘{'':4}Avg Compression Ratio: {m['averages']['avg_compression_ratio']:>14.1%}{' '*28}ā•‘") print(f"ā•‘{'':4}Avg Tokens Saved: {m['averages']['avg_tokens_saved_per_compressed']:>14,.0f}{' '*28}ā•‘") print(f"ā•‘{'':4}Avg Turns Removed: {m['averages']['avg_turns_removed_per_compressed']:>14.1f}{' '*28}ā•‘") else: print(f"ā•‘{'':4}No trajectories were compressed{' '*38}ā•‘") print(f"ā• {'═'*70}ā•£") # Summarization API section print(f"ā•‘{'':2}šŸ¤– SUMMARIZATION API{' '*49}ā•‘") print(f"ā•‘{'─'*70}ā•‘") print(f"ā•‘{'':4}API Calls Made: {m['summarization']['total_api_calls']:>15,}{' '*27}ā•‘") print(f"ā•‘{'':4}Errors: {m['summarization']['total_errors']:>15,}{' '*27}ā•‘") print(f"ā•‘{'':4}Success Rate: {m['summarization']['success_rate']:>14.1%}{' '*28}ā•‘") print(f"ā• {'═'*70}ā•£") # Processing time section duration = m['processing']['duration_seconds'] if duration > 60: time_str = f"{duration/60:.1f} minutes" else: time_str = f"{duration:.1f} seconds" throughput = total / max(duration, 0.001) print(f"ā•‘{'':2}ā±ļø PROCESSING TIME{' '*51}ā•‘") print(f"ā•‘{'─'*70}ā•‘") print(f"ā•‘{'':4}Duration: {time_str:>20}{' '*22}ā•‘") print(f"ā•‘{'':4}Throughput: {throughput:>15.1f} traj/sec{' '*18}ā•‘") print(f"ā•‘{'':4}Started: {m['processing']['start_time'][:19]:>20}{' '*22}ā•‘") print(f"ā•‘{'':4}Finished: {m['processing']['end_time'][:19]:>20}{' '*22}ā•‘") print(f"ā•š{'═'*70}ā•") # Distribution summary if we have data if self.aggregate_metrics.compression_ratios: ratios = self.aggregate_metrics.compression_ratios tokens_saved_list = self.aggregate_metrics.tokens_saved_list print(f"\nšŸ“Š Distribution Summary:") print(f" Compression ratios: min={min(ratios):.2%}, max={max(ratios):.2%}, median={sorted(ratios)[len(ratios)//2]:.2%}") print(f" Tokens saved: min={min(tokens_saved_list):,}, max={max(tokens_saved_list):,}, median={sorted(tokens_saved_list)[len(tokens_saved_list)//2]:,}") def main( input: str, output: str = None, config: str = "configs/trajectory_compression.yaml", target_max_tokens: int = None, tokenizer: str = None, sample_percent: float = None, seed: int = 42, dry_run: bool = False, ): """ Compress agent trajectories to fit within a target token budget. Supports both single JSONL files and directories containing multiple JSONL files. Optionally sample a percentage of trajectories before compression. Args: input: Path to JSONL file or directory containing JSONL files output: Output path (file for file input, directory for dir input) Default: adds "_compressed" suffix to input name config: Path to YAML configuration file target_max_tokens: Override target token count from config tokenizer: Override tokenizer name from config sample_percent: Sample this percentage of trajectories (1-100) before compression seed: Random seed for sampling reproducibility (default: 42) dry_run: Analyze without compressing (just show what would happen) Examples: # Compress a directory (original behavior) python trajectory_compressor.py --input=data/my_run # Compress a single file python trajectory_compressor.py --input=data/trajectories.jsonl # Compress 15% sample of a file python trajectory_compressor.py --input=data/trajectories.jsonl --sample_percent=15 # Compress 10% sample with custom output python trajectory_compressor.py --input=data/trajectories.jsonl --sample_percent=10 --output=data/sampled_compressed.jsonl """ import random import tempfile import shutil print("šŸ—œļø Trajectory Compressor") print("=" * 60) # Load configuration config_path = Path(config) if config_path.exists(): print(f"šŸ“‹ Loading config from {config}") compression_config = CompressionConfig.from_yaml(config) else: print(f"āš ļø Config not found at {config}, using defaults") compression_config = CompressionConfig() # Apply CLI overrides if target_max_tokens: compression_config.target_max_tokens = target_max_tokens if tokenizer: compression_config.tokenizer_name = tokenizer # Validate sample_percent if sample_percent is not None: if sample_percent <= 0 or sample_percent > 100: print(f"āŒ sample_percent must be between 1 and 100, got {sample_percent}") return print(f"šŸŽ² Will sample {sample_percent}% of trajectories (seed={seed})") # Setup paths and determine input type input_path = Path(input) if not input_path.exists(): print(f"āŒ Input not found: {input}") return is_file_input = input_path.is_file() if is_file_input: print(f"šŸ“„ Input mode: Single JSONL file") # For file input, default output is file with _compressed suffix if output: output_path = Path(output) else: output_path = input_path.parent / (input_path.stem + compression_config.output_suffix + ".jsonl") # Load entries from the single file entries = [] with open(input_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): line = line.strip() if line: try: entries.append(json.loads(line)) except json.JSONDecodeError as e: print(f"āš ļø Skipping invalid JSON at line {line_num}: {e}") total_entries = len(entries) print(f" Loaded {total_entries:,} trajectories from {input_path.name}") # Sample if requested if sample_percent is not None: random.seed(seed) sample_size = max(1, int(total_entries * sample_percent / 100)) entries = random.sample(entries, sample_size) print(f" Sampled {len(entries):,} trajectories ({sample_percent}% of {total_entries:,})") if dry_run: print(f"\nšŸ” DRY RUN MODE - analyzing without writing") print(f"šŸ“„ Would process: {len(entries):,} trajectories") print(f"šŸ“„ Would output to: {output_path}") return # Create a temporary directory for processing with tempfile.TemporaryDirectory() as temp_dir: temp_input_dir = Path(temp_dir) / "input" temp_output_dir = Path(temp_dir) / "output" temp_input_dir.mkdir() # Write entries to temp file temp_input_file = temp_input_dir / "trajectories.jsonl" with open(temp_input_file, 'w', encoding='utf-8') as f: for entry in entries: f.write(json.dumps(entry, ensure_ascii=False) + '\n') # Initialize compressor and process compressor = TrajectoryCompressor(compression_config) compressor.process_directory(temp_input_dir, temp_output_dir) # Copy result to output path (merge all files in temp_output_dir) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w', encoding='utf-8') as out_f: for jsonl_file in sorted(temp_output_dir.glob("*.jsonl")): with open(jsonl_file, 'r', encoding='utf-8') as in_f: for line in in_f: out_f.write(line) # Copy metrics file if it exists metrics_file = temp_output_dir / compression_config.metrics_output_file if metrics_file.exists(): metrics_output = output_path.parent / (output_path.stem + "_metrics.json") shutil.copy(metrics_file, metrics_output) print(f"šŸ’¾ Metrics saved to {metrics_output}") print(f"\nāœ… Compression complete!") print(f"šŸ“„ Output: {output_path}") else: # Directory input - original behavior print(f"šŸ“ Input mode: Directory of JSONL files") if output: output_path = Path(output) else: output_path = input_path.parent / (input_path.name + compression_config.output_suffix) # If sampling is requested for directory mode, we need to handle it differently if sample_percent is not None: print(f"\nāš ļø Sampling from directory: will sample {sample_percent}% from each file") # Create a temp directory with sampled files with tempfile.TemporaryDirectory() as temp_dir: temp_input_dir = Path(temp_dir) / "input" temp_input_dir.mkdir() random.seed(seed) total_original = 0 total_sampled = 0 # Sample from each JSONL file for jsonl_file in sorted(input_path.glob("*.jsonl")): entries = [] with open(jsonl_file, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if line: try: entries.append(json.loads(line)) except json.JSONDecodeError: pass total_original += len(entries) sample_size = max(1, int(len(entries) * sample_percent / 100)) sampled_entries = random.sample(entries, min(sample_size, len(entries))) total_sampled += len(sampled_entries) # Write sampled entries temp_file = temp_input_dir / jsonl_file.name with open(temp_file, 'w', encoding='utf-8') as f: for entry in sampled_entries: f.write(json.dumps(entry, ensure_ascii=False) + '\n') print(f" Sampled {total_sampled:,} from {total_original:,} total trajectories") if dry_run: print(f"\nšŸ” DRY RUN MODE - analyzing without writing") print(f"šŸ“ Would process: {temp_input_dir}") print(f"šŸ“ Would output to: {output_path}") return # Initialize compressor and process the sampled data compressor = TrajectoryCompressor(compression_config) compressor.process_directory(temp_input_dir, output_path) else: if dry_run: print(f"\nšŸ” DRY RUN MODE - analyzing without writing") print(f"šŸ“ Would process: {input_path}") print(f"šŸ“ Would output to: {output_path}") return # Initialize compressor and process directly compressor = TrajectoryCompressor(compression_config) compressor.process_directory(input_path, output_path) print("\nāœ… Compression complete!") if __name__ == "__main__": fire.Fire(main)