diff --git a/configs/trajectory_compression.yaml b/configs/trajectory_compression.yaml new file mode 100644 index 000000000..ffaf13981 --- /dev/null +++ b/configs/trajectory_compression.yaml @@ -0,0 +1,97 @@ +# Trajectory Compression Configuration +# +# Post-processes completed agent trajectories to fit within a target token budget. +# Compression preserves head/tail turns and summarizes middle content only as needed. + +# Tokenizer settings for accurate token counting +tokenizer: + # HuggingFace tokenizer name + name: "moonshotai/Kimi-K2-Thinking" + + # Trust remote code (required for some tokenizers) + trust_remote_code: true + +# Compression targets and behavior +compression: + # Target maximum tokens for compressed trajectory + target_max_tokens: 29000 + + # Target size for summary (in tokens) + # This is factored into calculations when determining what to compress + summary_target_tokens: 750 + +# Protected turns that should NEVER be compressed +protected_turns: + # Always protect the first system message (tool definitions) + first_system: true + + # Always protect the first human message (original request) + first_human: true + + # Always protect the first gpt message (initial response/tool_call) + first_gpt: true + + # Always protect the first tool response (result of first action) + first_tool: true + + # Always protect the last 2 complete turn pairs (gpt+tool or gpt only) + # This ensures the model's final actions and conclusions are preserved + last_n_turns: 4 + +# LLM settings for generating summaries (OpenRouter only) +summarization: + # Model to use for summarization (should be fast and cheap) + # Using OpenRouter model path format + model: "google/gemini-3-flash-preview" + + # OpenRouter API settings + base_url: "https://openrouter.ai/api/v1" + + # Environment variable containing OpenRouter API key + api_key_env: "OPENROUTER_API_KEY" + + # Temperature for summarization (lower = more deterministic) + temperature: 0.3 + + # Max retries for API failures + max_retries: 3 + + # Delay between retries (seconds) + retry_delay: 2 + +# Output settings +output: + # Add notice to system message about potential summarization + add_summary_notice: true + + # Text to append to system message + summary_notice_text: "\n\nSome of the conversation may be summarized to preserve context." + + # Output directory suffix (appended to input directory name) + output_suffix: "_compressed" + +# Processing settings +processing: + # Number of parallel workers for batch processing + num_workers: 4 + + # Maximum concurrent API calls for summarization (async parallelism) + max_concurrent_requests: 50 + + # Skip trajectories that are already under target length + skip_under_target: true + + # If true, save trajectories even if compression can't get under target + # (will compress as much as possible) + save_over_limit: true + +# Metrics to track +metrics: + # Log detailed compression statistics + enabled: true + + # Save per-trajectory metrics in output + per_trajectory: false + + # Metrics file name (saved in output directory) + output_file: "compression_metrics.json" diff --git a/mini_swe_runner.py b/mini_swe_runner.py new file mode 100644 index 000000000..c5d943918 --- /dev/null +++ b/mini_swe_runner.py @@ -0,0 +1,704 @@ +#!/usr/bin/env python3 +""" +Mini-SWE-Agent Runner with Hermes Trajectory Format + +This module provides a runner that uses mini-swe-agent's execution environments +(local, docker, modal) but outputs trajectories in the Hermes-Agent format +compatible with batch_runner.py and trajectory_compressor.py. + +Features: +- Uses mini-swe-agent's Docker, Modal, or Local environments for command execution +- Outputs trajectories in Hermes format (from/value pairs with / XML) +- Compatible with the trajectory compression pipeline +- Supports batch processing from JSONL prompt files + +Usage: + # Run a single task with local environment + python mini_swe_runner.py --task "Create a hello world Python script" --env local + + # Run with Docker + python mini_swe_runner.py --task "List files in /tmp" --env docker --image python:3.11-slim + + # Run with Modal (cloud) + python mini_swe_runner.py --task "Install numpy and test it" --env modal --image python:3.11-slim + + # Batch mode from JSONL file + python mini_swe_runner.py --prompts_file prompts.jsonl --output_file trajectories.jsonl --env docker +""" + +import json +import logging +import os +import sys +import time +import uuid +from datetime import datetime +from pathlib import Path +from typing import List, Dict, Any, Optional, Literal + +import fire +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +# Add mini-swe-agent to path if not installed +mini_swe_path = Path(__file__).parent / "mini-swe-agent" / "src" +if mini_swe_path.exists(): + sys.path.insert(0, str(mini_swe_path)) + + +# ============================================================================ +# Terminal Tool Definition (matches Hermes-Agent format) +# ============================================================================ + +TERMINAL_TOOL_DEFINITION = { + "type": "function", + "function": { + "name": "terminal", + "description": """Execute bash commands in a sandboxed environment. + +**Environment:** +- Isolated execution environment (local, Docker, or Modal cloud) +- Filesystem persists between tool calls within the same task +- Internet access available + +**Command Execution:** +- Provide the command to execute via the 'command' parameter +- Optional 'timeout' parameter in seconds (default: 60) + +**Examples:** +- Run command: `{"command": "ls -la"}` +- With timeout: `{"command": "long_task.sh", "timeout": 300}` + +**Best Practices:** +- Use non-interactive commands (avoid vim, nano, interactive python) +- Pipe to cat if output might be large +- Install tools with apt-get or pip as needed + +**Completion:** +- When task is complete, output: echo "MINI_SWE_AGENT_FINAL_OUTPUT" followed by your result +""", + "parameters": { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The bash command to execute" + }, + "timeout": { + "type": "integer", + "description": "Command timeout in seconds (default: 60)" + } + }, + "required": ["command"] + } + } +} + + +# ============================================================================ +# Environment Factory +# ============================================================================ + +def create_environment( + env_type: str = "local", + image: str = "python:3.11-slim", + cwd: str = "/tmp", + timeout: int = 60, + **kwargs +): + """ + Create an execution environment from mini-swe-agent. + + Args: + env_type: One of "local", "docker", "modal" + image: Docker/Modal image name (ignored for local) + cwd: Working directory + timeout: Default command timeout + **kwargs: Additional environment-specific options + + Returns: + Environment instance with execute() method + """ + if env_type == "local": + from minisweagent.environments.local import LocalEnvironment + return LocalEnvironment(cwd=cwd, timeout=timeout) + + elif env_type == "docker": + from minisweagent.environments.docker import DockerEnvironment + return DockerEnvironment(image=image, cwd=cwd, timeout=timeout, **kwargs) + + elif env_type == "modal": + from minisweagent.environments.extra.swerex_modal import SwerexModalEnvironment + return SwerexModalEnvironment(image=image, cwd=cwd, timeout=timeout, **kwargs) + + else: + raise ValueError(f"Unknown environment type: {env_type}. Use 'local', 'docker', or 'modal'") + + +# ============================================================================ +# Mini-SWE Runner with Hermes Trajectory Format +# ============================================================================ + +class MiniSWERunner: + """ + Agent runner that uses mini-swe-agent environments but outputs + trajectories in Hermes-Agent format. + """ + + def __init__( + self, + model: str = "claude-sonnet-4-20250514", + base_url: str = None, + api_key: str = None, + env_type: str = "local", + image: str = "python:3.11-slim", + cwd: str = "/tmp", + max_iterations: int = 15, + command_timeout: int = 60, + verbose: bool = False, + ): + """ + Initialize the Mini-SWE Runner. + + Args: + model: Model name for OpenAI-compatible API + base_url: API base URL (optional, uses env vars if not provided) + api_key: API key (optional, uses env vars if not provided) + env_type: Environment type - "local", "docker", or "modal" + image: Docker/Modal image (ignored for local) + cwd: Working directory for commands + max_iterations: Maximum tool-calling iterations + command_timeout: Default timeout for commands + verbose: Enable verbose logging + """ + self.model = model + self.max_iterations = max_iterations + self.command_timeout = command_timeout + self.verbose = verbose + self.env_type = env_type + self.image = image + self.cwd = cwd + + # Setup logging + logging.basicConfig( + level=logging.DEBUG if verbose else logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%H:%M:%S' + ) + self.logger = logging.getLogger(__name__) + + # Initialize OpenAI client + from openai import OpenAI + + client_kwargs = {} + if base_url: + client_kwargs["base_url"] = base_url + + # Handle API key with fallbacks + if api_key: + client_kwargs["api_key"] = api_key + else: + client_kwargs["api_key"] = os.getenv( + "OPENROUTER_API_KEY", + os.getenv("ANTHROPIC_API_KEY", os.getenv("OPENAI_API_KEY", "")) + ) + + self.client = OpenAI(**client_kwargs) + + # Environment will be created per-task + self.env = None + + # Tool definition + self.tools = [TERMINAL_TOOL_DEFINITION] + + print(f"šŸ¤– Mini-SWE Runner initialized") + print(f" Model: {self.model}") + print(f" Environment: {self.env_type}") + if self.env_type != "local": + print(f" Image: {self.image}") + print(f" Max iterations: {self.max_iterations}") + + def _create_env(self): + """Create the execution environment.""" + print(f"šŸ”§ Creating {self.env_type} environment...") + self.env = create_environment( + env_type=self.env_type, + image=self.image, + cwd=self.cwd, + timeout=self.command_timeout + ) + print(f"āœ… Environment ready") + + def _cleanup_env(self): + """Cleanup the execution environment.""" + if self.env is not None: + if hasattr(self.env, 'cleanup'): + self.env.cleanup() + elif hasattr(self.env, 'stop'): + self.env.stop() + self.env = None + + def _execute_command(self, command: str, timeout: int = None) -> Dict[str, Any]: + """ + Execute a command in the environment. + + Args: + command: Bash command to execute + timeout: Optional timeout override + + Returns: + Dict with 'output' and 'returncode' + """ + if self.env is None: + self._create_env() + + try: + result = self.env.execute(command, timeout=timeout or self.command_timeout) + return { + "output": result.get("output", ""), + "exit_code": result.get("returncode", 0), + "error": None + } + except Exception as e: + return { + "output": "", + "exit_code": -1, + "error": str(e) + } + + def _format_tools_for_system_message(self) -> str: + """Format tool definitions for the system message.""" + formatted_tools = [] + for tool in self.tools: + func = tool["function"] + formatted_tools.append({ + "name": func["name"], + "description": func.get("description", ""), + "parameters": func.get("parameters", {}), + "required": None + }) + return json.dumps(formatted_tools, ensure_ascii=False) + + def _convert_to_hermes_format( + self, + messages: List[Dict[str, Any]], + user_query: str, + completed: bool + ) -> List[Dict[str, Any]]: + """ + Convert internal message format to Hermes trajectory format. + + This produces the exact format used by batch_runner.py. + """ + trajectory = [] + + # System message with tool definitions + system_msg = ( + "You are a function calling AI model. You are provided with function signatures within XML tags. " + "You may call one or more functions to assist with the user query. If available tools are not relevant in assisting " + "with user query, just respond in natural conversational language. Don't make assumptions about what values to plug " + "into functions. After calling & executing the functions, you will be provided with function results within " + " XML tags. Here are the available tools:\n" + f"\n{self._format_tools_for_system_message()}\n\n" + "For each function call return a JSON object, with the following pydantic model json schema for each:\n" + "{'title': 'FunctionCall', 'type': 'object', 'properties': {'name': {'title': 'Name', 'type': 'string'}, " + "'arguments': {'title': 'Arguments', 'type': 'object'}}, 'required': ['name', 'arguments']}\n" + "Each function call should be enclosed within XML tags.\n" + "Example:\n\n{'name': ,'arguments': }\n" + ) + + trajectory.append({"from": "system", "value": system_msg}) + trajectory.append({"from": "human", "value": user_query}) + + # Process messages (skip first user message as we already added it) + i = 1 + while i < len(messages): + msg = messages[i] + + if msg["role"] == "assistant": + if "tool_calls" in msg and msg["tool_calls"]: + # Assistant message with tool calls + content = "" + + # Add reasoning if present + if msg.get("reasoning"): + content = f"{msg['reasoning']}" + + if msg.get("content"): + content += msg["content"] + "\n" + + # Add tool calls in XML format + for tool_call in msg["tool_calls"]: + try: + arguments = json.loads(tool_call["function"]["arguments"]) \ + if isinstance(tool_call["function"]["arguments"], str) \ + else tool_call["function"]["arguments"] + except json.JSONDecodeError: + arguments = {} + + tool_call_json = { + "name": tool_call["function"]["name"], + "arguments": arguments + } + content += f"\n{json.dumps(tool_call_json, ensure_ascii=False)}\n\n" + + trajectory.append({"from": "gpt", "value": content.rstrip()}) + + # Collect subsequent tool responses + tool_responses = [] + j = i + 1 + while j < len(messages) and messages[j]["role"] == "tool": + tool_msg = messages[j] + tool_content = tool_msg["content"] + + # Try to parse as JSON + try: + if tool_content.strip().startswith(("{", "[")): + tool_content = json.loads(tool_content) + except (json.JSONDecodeError, AttributeError): + pass + + tool_response = f"\n" + tool_response += json.dumps({ + "tool_call_id": tool_msg.get("tool_call_id", ""), + "name": msg["tool_calls"][len(tool_responses)]["function"]["name"] \ + if len(tool_responses) < len(msg["tool_calls"]) else "unknown", + "content": tool_content + }, ensure_ascii=False) + tool_response += "\n" + tool_responses.append(tool_response) + j += 1 + + if tool_responses: + trajectory.append({"from": "tool", "value": "\n".join(tool_responses)}) + i = j - 1 + + else: + # Regular assistant message (no tool calls) + content = "" + if msg.get("reasoning"): + content = f"{msg['reasoning']}" + content += msg.get("content") or "" + trajectory.append({"from": "gpt", "value": content}) + + elif msg["role"] == "user": + trajectory.append({"from": "human", "value": msg["content"]}) + + i += 1 + + return trajectory + + def run_task(self, task: str) -> Dict[str, Any]: + """ + Run a single task and return the result with trajectory. + + Args: + task: The task/prompt to execute + + Returns: + Dict with trajectory, completion status, and metadata + """ + print(f"\n{'='*60}") + print(f"šŸ“ Task: {task[:80]}{'...' if len(task) > 80 else ''}") + print(f"{'='*60}") + + # Initialize environment + self._create_env() + + # Message history + messages = [{"role": "user", "content": task}] + + # System prompt for the LLM (ephemeral - not saved to trajectory) + system_prompt = """You are an AI agent that can execute bash commands to complete tasks. + +When you need to run commands, use the 'terminal' tool with your bash command. + +**Important:** +- When you have completed the task successfully, run: echo "MINI_SWE_AGENT_FINAL_OUTPUT" followed by a summary +- Be concise and efficient in your approach +- Install any needed tools with apt-get or pip +- Avoid interactive commands (no vim, nano, less, etc.) + +Complete the user's task step by step.""" + + api_call_count = 0 + completed = False + final_response = None + + try: + while api_call_count < self.max_iterations: + api_call_count += 1 + print(f"\nšŸ”„ API call #{api_call_count}/{self.max_iterations}") + + # Prepare API messages + api_messages = [{"role": "system", "content": system_prompt}] + messages + + # Make API call + try: + response = self.client.chat.completions.create( + model=self.model, + messages=api_messages, + tools=self.tools, + timeout=300.0 + ) + except Exception as e: + self.logger.error(f"API call failed: {e}") + break + + assistant_message = response.choices[0].message + + # Log assistant response + if assistant_message.content: + print(f"šŸ¤– Assistant: {assistant_message.content[:100]}...") + + # Check for tool calls + if assistant_message.tool_calls: + print(f"šŸ”§ Tool calls: {len(assistant_message.tool_calls)}") + + # Add assistant message with tool calls + messages.append({ + "role": "assistant", + "content": assistant_message.content, + "tool_calls": [ + { + "id": tc.id, + "type": tc.type, + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments + } + } + for tc in assistant_message.tool_calls + ] + }) + + # Execute each tool call + for tc in assistant_message.tool_calls: + try: + args = json.loads(tc.function.arguments) + except json.JSONDecodeError: + args = {} + + command = args.get("command", "echo 'No command provided'") + timeout = args.get("timeout", self.command_timeout) + + print(f" šŸ“ž terminal: {command[:60]}...") + + # Execute command + result = self._execute_command(command, timeout) + + # Format result + result_json = json.dumps({ + "content": { + "output": result["output"], + "exit_code": result["exit_code"], + "error": result["error"] + } + }, ensure_ascii=False) + + # Check for task completion signal + if "MINI_SWE_AGENT_FINAL_OUTPUT" in result["output"]: + print(f" āœ… Task completion signal detected!") + completed = True + + # Add tool response + messages.append({ + "role": "tool", + "content": result_json, + "tool_call_id": tc.id + }) + + print(f" āœ… exit_code={result['exit_code']}, output={len(result['output'])} chars") + + # If task completed, we can stop + if completed: + final_response = assistant_message.content + break + + else: + # No tool calls - final response + final_response = assistant_message.content or "" + messages.append({ + "role": "assistant", + "content": final_response + }) + completed = True + print(f"šŸŽ‰ Agent finished (no more tool calls)") + break + + if api_call_count >= self.max_iterations: + print(f"āš ļø Reached max iterations ({self.max_iterations})") + + finally: + # Cleanup environment + self._cleanup_env() + + # Convert to Hermes trajectory format + trajectory = self._convert_to_hermes_format(messages, task, completed) + + return { + "conversations": trajectory, + "completed": completed, + "api_calls": api_call_count, + "metadata": { + "model": self.model, + "env_type": self.env_type, + "timestamp": datetime.now().isoformat() + } + } + + def run_batch( + self, + prompts: List[str], + output_file: str + ) -> List[Dict[str, Any]]: + """ + Run multiple tasks and save trajectories to a JSONL file. + + Args: + prompts: List of task prompts + output_file: Output JSONL file path + + Returns: + List of results + """ + results = [] + + print(f"\nšŸ“¦ Running batch of {len(prompts)} tasks") + print(f"šŸ“ Output: {output_file}") + + with open(output_file, 'w', encoding='utf-8') as f: + for i, prompt in enumerate(prompts, 1): + print(f"\n{'='*60}") + print(f"šŸ“‹ Task {i}/{len(prompts)}") + print(f"{'='*60}") + + try: + result = self.run_task(prompt) + results.append(result) + + # Write to file immediately + f.write(json.dumps(result, ensure_ascii=False) + "\n") + f.flush() + + print(f"āœ… Task {i} completed (api_calls={result['api_calls']})") + + except Exception as e: + self.logger.error(f"Error on task {i}: {e}") + error_result = { + "conversations": [], + "completed": False, + "api_calls": 0, + "error": str(e), + "metadata": {"timestamp": datetime.now().isoformat()} + } + results.append(error_result) + f.write(json.dumps(error_result, ensure_ascii=False) + "\n") + f.flush() + + print(f"\nāœ… Batch complete! {len(results)} trajectories saved to {output_file}") + return results + + +# ============================================================================ +# CLI Interface +# ============================================================================ + +def main( + task: str = None, + prompts_file: str = None, + output_file: str = "mini-swe-agent-test1.jsonl", + model: str = "claude-sonnet-4-20250514", + base_url: str = None, + api_key: str = None, + env: str = "local", + image: str = "python:3.11-slim", + cwd: str = "/tmp", + max_iterations: int = 15, + timeout: int = 60, + verbose: bool = False, +): + """ + Run mini-swe-agent tasks with Hermes trajectory format output. + + Args: + task: Single task to run (use this OR prompts_file) + prompts_file: JSONL file with prompts (each line: {"prompt": "..."}) + output_file: Output JSONL file for trajectories + model: Model name (default: claude-sonnet-4-20250514) + base_url: API base URL (optional) + api_key: API key (optional, uses env vars) + env: Environment type - "local", "docker", or "modal" + image: Docker/Modal image (default: python:3.11-slim) + cwd: Working directory (default: /tmp) + max_iterations: Maximum tool-calling iterations (default: 15) + timeout: Command timeout in seconds (default: 60) + verbose: Enable verbose logging + + Examples: + # Single task with local environment + python mini_swe_runner.py --task "Create hello.py that prints Hello World" + + # Single task with Docker + python mini_swe_runner.py --task "List files" --env docker + + # Batch from file + python mini_swe_runner.py --prompts_file tasks.jsonl --output_file results.jsonl + """ + print("šŸš€ Mini-SWE Runner with Hermes Trajectory Format") + print("=" * 60) + + # Initialize runner + runner = MiniSWERunner( + model=model, + base_url=base_url, + api_key=api_key, + env_type=env, + image=image, + cwd=cwd, + max_iterations=max_iterations, + command_timeout=timeout, + verbose=verbose, + ) + + if task: + # Single task mode + result = runner.run_task(task) + + # Save to file + with open(output_file, 'w', encoding='utf-8') as f: + f.write(json.dumps(result, ensure_ascii=False) + "\n") + + print(f"\nšŸ“ Trajectory saved to: {output_file}") + print(f"āœ… Completed: {result['completed']}") + print(f"šŸ“ž API calls: {result['api_calls']}") + print(f"šŸ’¬ Turns: {len(result['conversations'])}") + + elif prompts_file: + # Batch mode + prompts = [] + with open(prompts_file, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + try: + entry = json.loads(line) + prompts.append(entry.get("prompt", entry.get("task", ""))) + except json.JSONDecodeError: + prompts.append(line) + + if not prompts: + print(f"āŒ No prompts found in {prompts_file}") + return + + runner.run_batch(prompts, output_file) + + else: + print("āŒ Please provide either --task or --prompts_file") + print(" Example: python mini_swe_runner.py --task 'Create a hello world script'") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/scripts/sample_and_compress.py b/scripts/sample_and_compress.py new file mode 100644 index 000000000..c31496f76 --- /dev/null +++ b/scripts/sample_and_compress.py @@ -0,0 +1,411 @@ +#!/usr/bin/env python3 +""" +Sample and Compress HuggingFace Datasets + +Downloads trajectories from multiple HuggingFace datasets, randomly samples them, +and runs trajectory compression to fit within a target token budget. + +Usage: + python scripts/sample_and_compress.py + + # Custom sample size + python scripts/sample_and_compress.py --total_samples=5000 + + # Custom output name + python scripts/sample_and_compress.py --output_name=compressed_16k +""" + +import json +import random +import os +from pathlib import Path +from typing import List, Dict, Any, Tuple +import fire + +# Load environment variables +from dotenv import load_dotenv +load_dotenv() + + +# Default datasets to sample from +DEFAULT_DATASETS = [ + "NousResearch/swe-terminus-agent-glm-kimi-minimax", + "NousResearch/hermes-agent-megascience-sft1", + "NousResearch/Hermes-Agent-Thinking-GLM-4.7-SFT2", + "NousResearch/Hermes-Agent-Thinking-GLM-4.7-SFT1", + "NousResearch/terminal-tasks-glm-hermes-agent" +] + + +def load_dataset_from_hf(dataset_name: str) -> List[Dict[str, Any]]: + """ + Load a dataset from HuggingFace. + + Args: + dataset_name: HuggingFace dataset name (e.g., "NousResearch/dataset-name") + + Returns: + List of trajectory entries + """ + from datasets import load_dataset + + print(f" Loading {dataset_name}...") + + try: + # Try loading with default config + ds = load_dataset(dataset_name, split="train") + except Exception as e: + print(f" āš ļø Error loading {dataset_name}: {e}") + return [] + + # Convert to list of dicts + entries = [] + for item in ds: + # Handle different possible formats + if "conversations" in item: + entries.append({"conversations": item["conversations"]}) + elif "messages" in item: + # Convert messages format to conversations format if needed + entries.append({"conversations": item["messages"]}) + else: + # Assume the whole item is the entry + entries.append(dict(item)) + + print(f" āœ… Loaded {len(entries):,} entries from {dataset_name}") + return entries + + +# Global tokenizer for multiprocessing (set in worker init) +_TOKENIZER = None + + +def _init_tokenizer_worker(tokenizer_name: str): + """Initialize tokenizer in worker process.""" + global _TOKENIZER + from transformers import AutoTokenizer + _TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) + + +def _count_tokens_for_entry(entry: Dict) -> Tuple[Dict, int]: + """ + Count tokens for a single entry (used in parallel processing). + + Args: + entry: Trajectory entry with 'conversations' field + + Returns: + Tuple of (entry, token_count) + """ + global _TOKENIZER + + conversations = entry.get("conversations", []) + if not conversations: + return entry, 0 + + total = 0 + for turn in conversations: + value = turn.get("value", "") + if value: + try: + total += len(_TOKENIZER.encode(value)) + except: + # Fallback to character estimate + total += len(value) // 4 + + return entry, total + + +def sample_from_datasets( + datasets: List[str], + total_samples: int, + min_tokens: int = 16000, + tokenizer_name: str = "moonshotai/Kimi-K2-Thinking", + seed: int = 42, + num_proc: int = 8 +) -> List[Dict[str, Any]]: + """ + Load all datasets, filter by token count, then randomly sample from combined pool. + + Args: + datasets: List of HuggingFace dataset names + total_samples: Total number of samples to collect + min_tokens: Minimum token count to include (only sample trajectories >= this) + tokenizer_name: HuggingFace tokenizer for counting tokens + seed: Random seed for reproducibility + num_proc: Number of parallel processes for tokenization + + Returns: + List of sampled trajectory entries + """ + from multiprocessing import Pool + from functools import partial + + random.seed(seed) + + print(f"\nšŸ“„ Loading {len(datasets)} datasets...") + print(f" Minimum tokens: {min_tokens:,} (filtering smaller trajectories)") + print(f" Parallel workers: {num_proc}") + print() + + # Load ALL entries from all datasets into one pool + all_entries = [] + + for dataset_name in datasets: + entries = load_dataset_from_hf(dataset_name) + + if not entries: + print(f" āš ļø Skipping {dataset_name} (no entries loaded)") + continue + + # Add source metadata to each entry + for entry in entries: + entry["_source_dataset"] = dataset_name + + all_entries.extend(entries) + + print(f"\nšŸ“Š Total entries loaded: {len(all_entries):,}") + + # Filter by token count using parallel processing + print(f"\nšŸ” Filtering trajectories with >= {min_tokens:,} tokens (using {num_proc} workers)...") + + filtered_entries = [] + token_counts = [] + + # Use multiprocessing for token counting + with Pool( + processes=num_proc, + initializer=_init_tokenizer_worker, + initargs=(tokenizer_name,) + ) as pool: + # Process in chunks and show progress + chunk_size = 1000 + processed = 0 + + for result in pool.imap_unordered(_count_tokens_for_entry, all_entries, chunksize=100): + entry, token_count = result + processed += 1 + + if processed % chunk_size == 0: + print(f" Processed {processed:,}/{len(all_entries):,}...", end="\r") + + if token_count >= min_tokens: + entry["_original_tokens"] = token_count + filtered_entries.append(entry) + token_counts.append(token_count) + + print(f"\n āœ… Found {len(filtered_entries):,} trajectories >= {min_tokens:,} tokens") + + if token_counts: + avg_tokens = sum(token_counts) / len(token_counts) + print(f" šŸ“ˆ Token stats: min={min(token_counts):,}, max={max(token_counts):,}, avg={avg_tokens:,.0f}") + + # Random sample from the filtered pool + if len(filtered_entries) <= total_samples: + print(f"\nāš ļø Only {len(filtered_entries):,} trajectories available, using all of them") + sampled = filtered_entries + else: + sampled = random.sample(filtered_entries, total_samples) + print(f"\nāœ… Randomly sampled {len(sampled):,} trajectories from pool of {len(filtered_entries):,}") + + # Show source distribution + source_counts = {} + for entry in sampled: + source = entry.get("_source_dataset", "unknown").split("/")[-1] + source_counts[source] = source_counts.get(source, 0) + 1 + + print(f"\nšŸ“Œ Sample distribution by source:") + for source, count in sorted(source_counts.items()): + print(f" {source}: {count:,}") + + # Shuffle + random.shuffle(sampled) + + return sampled + + +def save_samples_for_compression( + samples: List[Dict[str, Any]], + output_dir: Path, + batch_size: int = 100 +): + """ + Save samples to JSONL files for trajectory compression. + + Args: + samples: List of trajectory entries + output_dir: Directory to save JSONL files + batch_size: Number of entries per file + """ + output_dir.mkdir(parents=True, exist_ok=True) + + # Split into batches + num_batches = (len(samples) + batch_size - 1) // batch_size + + print(f"\nšŸ’¾ Saving {len(samples)} samples to {output_dir}") + print(f" Batch size: {batch_size}, Total batches: {num_batches}") + + for i in range(num_batches): + start_idx = i * batch_size + end_idx = min((i + 1) * batch_size, len(samples)) + batch = samples[start_idx:end_idx] + + output_file = output_dir / f"batch_{i}.jsonl" + with open(output_file, 'w', encoding='utf-8') as f: + for entry in batch: + f.write(json.dumps(entry, ensure_ascii=False) + '\n') + + print(f" āœ… Saved {num_batches} batch files") + + +def run_compression(input_dir: Path, output_dir: Path, config_path: str): + """ + Run trajectory compression on the sampled data. + + Args: + input_dir: Directory containing JSONL files to compress + output_dir: Directory for compressed output + config_path: Path to compression config YAML + """ + # Import the compressor + import sys + sys.path.insert(0, str(Path(__file__).parent.parent)) + from trajectory_compressor import TrajectoryCompressor, CompressionConfig + + print(f"\nšŸ—œļø Running trajectory compression...") + print(f" Input: {input_dir}") + print(f" Output: {output_dir}") + print(f" Config: {config_path}") + + # Load config + config = CompressionConfig.from_yaml(config_path) + + # Initialize compressor + compressor = TrajectoryCompressor(config) + + # Run compression + compressor.process_directory(input_dir, output_dir) + + +def merge_output_to_single_jsonl(input_dir: Path, output_file: Path): + """ + Merge all JSONL files in a directory into a single JSONL file. + + Args: + input_dir: Directory containing JSONL files + output_file: Output JSONL file path + """ + print(f"\nšŸ“¦ Merging output files into {output_file.name}...") + + all_entries = [] + for jsonl_file in sorted(input_dir.glob("*.jsonl")): + if jsonl_file.name == output_file.name: + continue + with open(jsonl_file, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + all_entries.append(json.loads(line)) + + # Write merged file + with open(output_file, 'w', encoding='utf-8') as f: + for entry in all_entries: + f.write(json.dumps(entry, ensure_ascii=False) + '\n') + + print(f" āœ… Merged {len(all_entries):,} entries into {output_file.name}") + return output_file + + +def main( + total_samples: int = 2500, + output_name: str = "compressed_agentic", + datasets: str = None, + config: str = "configs/trajectory_compression.yaml", + seed: int = 42, + batch_size: int = 100, + min_tokens: int = 16000, + num_proc: int = 8, + skip_download: bool = False, +): + """ + Sample trajectories from HuggingFace datasets and run compression. + + Args: + total_samples: Total number of samples to collect (default: 2500) + output_name: Name for output directory/file (default: "compressed_agentic") + datasets: Comma-separated list of dataset names (uses defaults if not provided) + config: Path to compression config YAML + seed: Random seed for reproducibility + batch_size: Number of entries per JSONL file during processing + min_tokens: Minimum token count to filter trajectories (default: 16000) + num_proc: Number of parallel workers for tokenization (default: 8) + skip_download: Skip download and use existing sampled data + """ + print("=" * 70) + print("šŸ“Š TRAJECTORY SAMPLING AND COMPRESSION") + print("=" * 70) + + # Parse datasets + if datasets: + dataset_list = [d.strip() for d in datasets.split(",")] + else: + dataset_list = DEFAULT_DATASETS + + print(f"\nšŸ“‹ Configuration:") + print(f" Total samples: {total_samples:,}") + print(f" Min tokens filter: {min_tokens:,}") + print(f" Parallel workers: {num_proc}") + print(f" Datasets: {len(dataset_list)}") + for ds in dataset_list: + print(f" - {ds}") + print(f" Output name: {output_name}") + print(f" Config: {config}") + print(f" Seed: {seed}") + + # Setup paths + base_dir = Path(__file__).parent.parent + sampled_dir = base_dir / "data" / f"{output_name}_raw" + compressed_dir = base_dir / "data" / f"{output_name}_batches" + final_output = base_dir / "data" / f"{output_name}.jsonl" + + if not skip_download: + # Step 1: Download, filter by token count, and sample from combined pool + samples = sample_from_datasets( + dataset_list, + total_samples, + min_tokens=min_tokens, + seed=seed, + num_proc=num_proc + ) + + if not samples: + print("āŒ No samples collected. Exiting.") + return + + # Step 2: Save to JSONL files + save_samples_for_compression(samples, sampled_dir, batch_size) + else: + print(f"\nā­ļø Skipping download, using existing data in {sampled_dir}") + + # Step 3: Run compression + config_path = base_dir / config + if not config_path.exists(): + print(f"āŒ Config not found: {config_path}") + return + + run_compression(sampled_dir, compressed_dir, str(config_path)) + + # Step 4: Merge into single JSONL file + merge_output_to_single_jsonl(compressed_dir, final_output) + + print("\n" + "=" * 70) + print("āœ… COMPLETE!") + print("=" * 70) + print(f"\nšŸ“ Raw samples: {sampled_dir}") + print(f"šŸ“ Compressed batches: {compressed_dir}") + print(f"šŸ“ Final output: {final_output}") + print(f"\nTo upload to HuggingFace:") + print(f" huggingface-cli upload NousResearch/{output_name} {final_output}") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/trajectory_compressor.py b/trajectory_compressor.py new file mode 100644 index 000000000..511b9b848 --- /dev/null +++ b/trajectory_compressor.py @@ -0,0 +1,1243 @@ +#!/usr/bin/env python3 +""" +Trajectory Compressor + +Post-processes completed agent trajectories to compress them within a target +token budget while preserving training signal quality. + +Compression Strategy: +1. Protect first turns (system, human, first gpt, first tool) +2. Protect last N turns (final actions and conclusions) +3. Compress MIDDLE turns only, starting from 2nd tool response +4. Compress only as much as needed to fit under target +5. Replace compressed region with a single human summary message +6. Keep remaining tool calls intact (model continues working after summary) + +Usage: + python trajectory_compressor.py --input_dir=data/my_run + python trajectory_compressor.py --input_dir=data/my_run --config=configs/trajectory_compression.yaml + python trajectory_compressor.py --input_dir=data/my_run --target_max_tokens=16000 +""" + +import json +import os +import re +import time +import yaml +import logging +import asyncio +from pathlib import Path +from typing import List, Dict, Any, Optional, Tuple, Callable +from dataclasses import dataclass, field +from datetime import datetime +import fire +from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeElapsedColumn, TimeRemainingColumn +from rich.console import Console + +# Load environment variables +from dotenv import load_dotenv +load_dotenv() + + +@dataclass +class CompressionConfig: + """Configuration for trajectory compression.""" + # Tokenizer + tokenizer_name: str = "moonshotai/Kimi-K2-Thinking" + trust_remote_code: bool = True + + # Compression targets + target_max_tokens: int = 15250 + summary_target_tokens: int = 750 + + # Protected turns + protect_first_system: bool = True + protect_first_human: bool = True + protect_first_gpt: bool = True + protect_first_tool: bool = True + protect_last_n_turns: int = 4 + + # Summarization (OpenRouter) + summarization_model: str = "google/gemini-3-flash-preview" + base_url: str = "https://openrouter.ai/api/v1" + api_key_env: str = "OPENROUTER_API_KEY" + temperature: float = 0.3 + max_retries: int = 3 + retry_delay: int = 2 + + # Output + add_summary_notice: bool = True + summary_notice_text: str = "\n\nSome of your previous tool responses may be summarized to preserve context." + output_suffix: str = "_compressed" + + # Processing + num_workers: int = 4 + max_concurrent_requests: int = 50 # Max concurrent API calls for summarization + skip_under_target: bool = True + save_over_limit: bool = True + + # Metrics + metrics_enabled: bool = True + metrics_per_trajectory: bool = True + metrics_output_file: str = "compression_metrics.json" + + @classmethod + def from_yaml(cls, yaml_path: str) -> "CompressionConfig": + """Load configuration from YAML file.""" + with open(yaml_path, 'r') as f: + data = yaml.safe_load(f) + + config = cls() + + # Tokenizer + if 'tokenizer' in data: + config.tokenizer_name = data['tokenizer'].get('name', config.tokenizer_name) + config.trust_remote_code = data['tokenizer'].get('trust_remote_code', config.trust_remote_code) + + # Compression + if 'compression' in data: + config.target_max_tokens = data['compression'].get('target_max_tokens', config.target_max_tokens) + config.summary_target_tokens = data['compression'].get('summary_target_tokens', config.summary_target_tokens) + + # Protected turns + if 'protected_turns' in data: + config.protect_first_system = data['protected_turns'].get('first_system', config.protect_first_system) + config.protect_first_human = data['protected_turns'].get('first_human', config.protect_first_human) + config.protect_first_gpt = data['protected_turns'].get('first_gpt', config.protect_first_gpt) + config.protect_first_tool = data['protected_turns'].get('first_tool', config.protect_first_tool) + config.protect_last_n_turns = data['protected_turns'].get('last_n_turns', config.protect_last_n_turns) + + # Summarization + if 'summarization' in data: + config.summarization_model = data['summarization'].get('model', config.summarization_model) + config.base_url = data['summarization'].get('base_url', config.base_url) + config.api_key_env = data['summarization'].get('api_key_env', config.api_key_env) + config.temperature = data['summarization'].get('temperature', config.temperature) + config.max_retries = data['summarization'].get('max_retries', config.max_retries) + config.retry_delay = data['summarization'].get('retry_delay', config.retry_delay) + + # Output + if 'output' in data: + config.add_summary_notice = data['output'].get('add_summary_notice', config.add_summary_notice) + config.summary_notice_text = data['output'].get('summary_notice_text', config.summary_notice_text) + config.output_suffix = data['output'].get('output_suffix', config.output_suffix) + + # Processing + if 'processing' in data: + config.num_workers = data['processing'].get('num_workers', config.num_workers) + config.max_concurrent_requests = data['processing'].get('max_concurrent_requests', config.max_concurrent_requests) + config.skip_under_target = data['processing'].get('skip_under_target', config.skip_under_target) + config.save_over_limit = data['processing'].get('save_over_limit', config.save_over_limit) + + # Metrics + if 'metrics' in data: + config.metrics_enabled = data['metrics'].get('enabled', config.metrics_enabled) + config.metrics_per_trajectory = data['metrics'].get('per_trajectory', config.metrics_per_trajectory) + config.metrics_output_file = data['metrics'].get('output_file', config.metrics_output_file) + + return config + + +@dataclass +class TrajectoryMetrics: + """Metrics for a single trajectory compression.""" + original_tokens: int = 0 + compressed_tokens: int = 0 + tokens_saved: int = 0 + compression_ratio: float = 1.0 + + original_turns: int = 0 + compressed_turns: int = 0 + turns_removed: int = 0 + + turns_compressed_start_idx: int = -1 + turns_compressed_end_idx: int = -1 + turns_in_compressed_region: int = 0 + + was_compressed: bool = False + still_over_limit: bool = False + skipped_under_target: bool = False + + summarization_api_calls: int = 0 + summarization_errors: int = 0 + + def to_dict(self) -> Dict[str, Any]: + return { + "original_tokens": self.original_tokens, + "compressed_tokens": self.compressed_tokens, + "tokens_saved": self.tokens_saved, + "compression_ratio": round(self.compression_ratio, 4), + "original_turns": self.original_turns, + "compressed_turns": self.compressed_turns, + "turns_removed": self.turns_removed, + "compression_region": { + "start_idx": self.turns_compressed_start_idx, + "end_idx": self.turns_compressed_end_idx, + "turns_count": self.turns_in_compressed_region, + }, + "was_compressed": self.was_compressed, + "still_over_limit": self.still_over_limit, + "skipped_under_target": self.skipped_under_target, + "summarization_api_calls": self.summarization_api_calls, + "summarization_errors": self.summarization_errors, + } + + +@dataclass +class AggregateMetrics: + """Aggregate metrics across all trajectories.""" + total_trajectories: int = 0 + trajectories_compressed: int = 0 + trajectories_skipped_under_target: int = 0 + trajectories_still_over_limit: int = 0 + trajectories_failed: int = 0 + + total_tokens_before: int = 0 + total_tokens_after: int = 0 + total_tokens_saved: int = 0 + + total_turns_before: int = 0 + total_turns_after: int = 0 + total_turns_removed: int = 0 + + total_summarization_calls: int = 0 + total_summarization_errors: int = 0 + + # Distribution stats + compression_ratios: List[float] = field(default_factory=list) + tokens_saved_list: List[int] = field(default_factory=list) + turns_removed_list: List[int] = field(default_factory=list) + + processing_start_time: str = "" + processing_end_time: str = "" + processing_duration_seconds: float = 0.0 + + def add_trajectory_metrics(self, metrics: TrajectoryMetrics): + """Add a trajectory's metrics to the aggregate.""" + self.total_trajectories += 1 + self.total_tokens_before += metrics.original_tokens + self.total_tokens_after += metrics.compressed_tokens + self.total_tokens_saved += metrics.tokens_saved + self.total_turns_before += metrics.original_turns + self.total_turns_after += metrics.compressed_turns + self.total_turns_removed += metrics.turns_removed + self.total_summarization_calls += metrics.summarization_api_calls + self.total_summarization_errors += metrics.summarization_errors + + if metrics.was_compressed: + self.trajectories_compressed += 1 + self.compression_ratios.append(metrics.compression_ratio) + self.tokens_saved_list.append(metrics.tokens_saved) + self.turns_removed_list.append(metrics.turns_removed) + + if metrics.skipped_under_target: + self.trajectories_skipped_under_target += 1 + + if metrics.still_over_limit: + self.trajectories_still_over_limit += 1 + + def to_dict(self) -> Dict[str, Any]: + avg_compression_ratio = ( + sum(self.compression_ratios) / len(self.compression_ratios) + if self.compression_ratios else 1.0 + ) + avg_tokens_saved = ( + sum(self.tokens_saved_list) / len(self.tokens_saved_list) + if self.tokens_saved_list else 0 + ) + avg_turns_removed = ( + sum(self.turns_removed_list) / len(self.turns_removed_list) + if self.turns_removed_list else 0 + ) + + return { + "summary": { + "total_trajectories": self.total_trajectories, + "trajectories_compressed": self.trajectories_compressed, + "trajectories_skipped_under_target": self.trajectories_skipped_under_target, + "trajectories_still_over_limit": self.trajectories_still_over_limit, + "trajectories_failed": self.trajectories_failed, + "compression_rate": round(self.trajectories_compressed / max(self.total_trajectories, 1), 4), + }, + "tokens": { + "total_before": self.total_tokens_before, + "total_after": self.total_tokens_after, + "total_saved": self.total_tokens_saved, + "overall_compression_ratio": round(self.total_tokens_after / max(self.total_tokens_before, 1), 4), + }, + "turns": { + "total_before": self.total_turns_before, + "total_after": self.total_turns_after, + "total_removed": self.total_turns_removed, + }, + "averages": { + "avg_compression_ratio": round(avg_compression_ratio, 4), + "avg_tokens_saved_per_compressed": round(avg_tokens_saved, 1), + "avg_turns_removed_per_compressed": round(avg_turns_removed, 2), + }, + "summarization": { + "total_api_calls": self.total_summarization_calls, + "total_errors": self.total_summarization_errors, + "success_rate": round(1 - (self.total_summarization_errors / max(self.total_summarization_calls, 1)), 4), + }, + "processing": { + "start_time": self.processing_start_time, + "end_time": self.processing_end_time, + "duration_seconds": round(self.processing_duration_seconds, 2), + }, + } + + +class TrajectoryCompressor: + """ + Compresses agent trajectories to fit within a target token budget. + + Compression strategy: + 1. Keep protected head turns (system, human, first gpt+tool) + 2. Keep protected tail turns (last N turns) + 3. From the compressible middle region, compress only as much as needed + 4. Replace compressed turns with a single human summary message + 5. Keep remaining middle turns intact (model continues with tools) + """ + + def __init__(self, config: CompressionConfig): + """Initialize the compressor.""" + self.config = config + self.aggregate_metrics = AggregateMetrics() + + # Initialize tokenizer + self._init_tokenizer() + + # Initialize OpenRouter client + self._init_summarizer() + + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%H:%M:%S' + ) + self.logger = logging.getLogger(__name__) + + def _init_tokenizer(self): + """Initialize HuggingFace tokenizer for token counting.""" + try: + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + self.config.tokenizer_name, + trust_remote_code=self.config.trust_remote_code + ) + print(f"āœ… Loaded tokenizer: {self.config.tokenizer_name}") + except Exception as e: + raise RuntimeError(f"Failed to load tokenizer '{self.config.tokenizer_name}': {e}") + + def _init_summarizer(self): + """Initialize OpenRouter client for summarization (sync and async).""" + api_key = os.getenv(self.config.api_key_env) + if not api_key: + raise RuntimeError(f"Missing API key. Set {self.config.api_key_env} environment variable.") + + from openai import OpenAI, AsyncOpenAI + + # Sync client (for backwards compatibility) + self.client = OpenAI( + api_key=api_key, + base_url=self.config.base_url + ) + + # Async client for parallel processing + self.async_client = AsyncOpenAI( + api_key=api_key, + base_url=self.config.base_url + ) + + print(f"āœ… Initialized OpenRouter client: {self.config.summarization_model}") + print(f" Max concurrent requests: {self.config.max_concurrent_requests}") + + def count_tokens(self, text: str) -> int: + """Count tokens in text using the configured tokenizer.""" + if not text: + return 0 + try: + return len(self.tokenizer.encode(text)) + except Exception: + # Fallback to character estimate + return len(text) // 4 + + def count_trajectory_tokens(self, trajectory: List[Dict[str, str]]) -> int: + """Count total tokens in a trajectory.""" + return sum(self.count_tokens(turn.get("value", "")) for turn in trajectory) + + def count_turn_tokens(self, trajectory: List[Dict[str, str]]) -> List[int]: + """Count tokens for each turn in a trajectory.""" + return [self.count_tokens(turn.get("value", "")) for turn in trajectory] + + def _find_protected_indices(self, trajectory: List[Dict[str, str]]) -> Tuple[set, int, int]: + """ + Find indices of protected turns. + + Returns: + Tuple of (protected_set, compressible_start, compressible_end) + """ + n = len(trajectory) + protected = set() + + # Track first occurrences + first_system = first_human = first_gpt = first_tool = None + + for i, turn in enumerate(trajectory): + role = turn.get("from", "") + if role == "system" and first_system is None: + first_system = i + elif role == "human" and first_human is None: + first_human = i + elif role == "gpt" and first_gpt is None: + first_gpt = i + elif role == "tool" and first_tool is None: + first_tool = i + + # Protect first turns + if self.config.protect_first_system and first_system is not None: + protected.add(first_system) + if self.config.protect_first_human and first_human is not None: + protected.add(first_human) + if self.config.protect_first_gpt and first_gpt is not None: + protected.add(first_gpt) + if self.config.protect_first_tool and first_tool is not None: + protected.add(first_tool) + + # Protect last N turns + for i in range(max(0, n - self.config.protect_last_n_turns), n): + protected.add(i) + + # Determine compressible region + # Start after the last protected head turn + head_protected = [i for i in protected if i < n // 2] + tail_protected = [i for i in protected if i >= n // 2] + + compressible_start = max(head_protected) + 1 if head_protected else 0 + compressible_end = min(tail_protected) if tail_protected else n + + return protected, compressible_start, compressible_end + + def _extract_turn_content_for_summary(self, trajectory: List[Dict[str, str]], start: int, end: int) -> str: + """ + Extract content from turns to be summarized. + + Args: + trajectory: Full trajectory + start: Start index (inclusive) + end: End index (exclusive) + + Returns: + Formatted string of turn contents for summarization + """ + parts = [] + for i in range(start, end): + turn = trajectory[i] + role = turn.get("from", "unknown") + value = turn.get("value", "") + + # Truncate very long values for the summary prompt + if len(value) > 3000: + value = value[:1500] + "\n...[truncated]...\n" + value[-500:] + + parts.append(f"[Turn {i} - {role.upper()}]:\n{value}") + + return "\n\n".join(parts) + + def _generate_summary(self, content: str, metrics: TrajectoryMetrics) -> str: + """ + Generate a summary of the compressed turns using OpenRouter. + + Args: + content: The content to summarize + metrics: Metrics object to update + + Returns: + Summary string + """ + prompt = f"""Summarize the following agent conversation turns concisely. This summary will replace these turns in the conversation history. + +Write the summary from a neutral perspective describing what the assistant did and learned. Include: +1. What actions the assistant took (tool calls, searches, file operations) +2. Key information or results obtained +3. Any important decisions or findings +4. Relevant data, file names, values, or outputs + +Keep the summary factual and informative. Target approximately {self.config.summary_target_tokens} tokens. + +--- +TURNS TO SUMMARIZE: +{content} +--- + +Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" + + for attempt in range(self.config.max_retries): + try: + metrics.summarization_api_calls += 1 + + response = self.client.chat.completions.create( + model=self.config.summarization_model, + messages=[{"role": "user", "content": prompt}], + temperature=self.config.temperature, + max_tokens=self.config.summary_target_tokens * 2, + ) + + summary = response.choices[0].message.content.strip() + + # Ensure it starts with the prefix + if not summary.startswith("[CONTEXT SUMMARY]:"): + summary = "[CONTEXT SUMMARY]: " + summary + + return summary + + except Exception as e: + metrics.summarization_errors += 1 + self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}") + + if attempt < self.config.max_retries - 1: + time.sleep(self.config.retry_delay * (attempt + 1)) + else: + # Fallback: create a basic summary + return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]" + + async def _generate_summary_async(self, content: str, metrics: TrajectoryMetrics) -> str: + """ + Generate a summary of the compressed turns using OpenRouter (async version). + + Args: + content: The content to summarize + metrics: Metrics object to update + + Returns: + Summary string + """ + prompt = f"""Summarize the following agent conversation turns concisely. This summary will replace these turns in the conversation history. + +Write the summary from a neutral perspective describing what the assistant did and learned. Include: +1. What actions the assistant took (tool calls, searches, file operations) +2. Key information or results obtained +3. Any important decisions or findings +4. Relevant data, file names, values, or outputs + +Keep the summary factual and informative. Target approximately {self.config.summary_target_tokens} tokens. + +--- +TURNS TO SUMMARIZE: +{content} +--- + +Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" + + for attempt in range(self.config.max_retries): + try: + metrics.summarization_api_calls += 1 + + response = await self.async_client.chat.completions.create( + model=self.config.summarization_model, + messages=[{"role": "user", "content": prompt}], + temperature=self.config.temperature, + max_tokens=self.config.summary_target_tokens * 2, + ) + + summary = response.choices[0].message.content.strip() + + # Ensure it starts with the prefix + if not summary.startswith("[CONTEXT SUMMARY]:"): + summary = "[CONTEXT SUMMARY]: " + summary + + return summary + + except Exception as e: + metrics.summarization_errors += 1 + self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}") + + if attempt < self.config.max_retries - 1: + await asyncio.sleep(self.config.retry_delay * (attempt + 1)) + else: + # Fallback: create a basic summary + return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]" + + def compress_trajectory( + self, + trajectory: List[Dict[str, str]] + ) -> Tuple[List[Dict[str, str]], TrajectoryMetrics]: + """ + Compress a single trajectory to fit within target token budget. + + Algorithm: + 1. Count total tokens + 2. If under target, skip + 3. Find compressible region (between protected head and tail) + 4. Calculate how many tokens need to be saved + 5. Accumulate turns from start of compressible region until savings met + 6. Replace accumulated turns with single human summary + 7. Keep remaining turns intact + + Args: + trajectory: List of conversation turns + + Returns: + Tuple of (compressed_trajectory, metrics) + """ + metrics = TrajectoryMetrics() + metrics.original_turns = len(trajectory) + + # Count tokens per turn + turn_tokens = self.count_turn_tokens(trajectory) + total_tokens = sum(turn_tokens) + metrics.original_tokens = total_tokens + + # Check if compression needed + if total_tokens <= self.config.target_max_tokens: + metrics.skipped_under_target = True + metrics.compressed_tokens = total_tokens + metrics.compressed_turns = len(trajectory) + metrics.compression_ratio = 1.0 + return trajectory, metrics + + # Find protected regions + protected, compress_start, compress_end = self._find_protected_indices(trajectory) + + # Check if there's anything to compress + if compress_start >= compress_end: + # Nothing to compress, return as-is + metrics.compressed_tokens = total_tokens + metrics.compressed_turns = len(trajectory) + metrics.still_over_limit = total_tokens > self.config.target_max_tokens + return trajectory, metrics + + # Calculate how much we need to save + tokens_to_save = total_tokens - self.config.target_max_tokens + + # We'll replace N turns with 1 summary turn + # Net savings = (sum of N turns' tokens) - summary_target_tokens + # We need: net_savings >= tokens_to_save + # So: sum of turns >= tokens_to_save + summary_target_tokens + target_tokens_to_compress = tokens_to_save + self.config.summary_target_tokens + + # Accumulate turns from compress_start until we have enough savings + accumulated_tokens = 0 + compress_until = compress_start + + for i in range(compress_start, compress_end): + accumulated_tokens += turn_tokens[i] + compress_until = i + 1 # Exclusive end + + # Check if we have enough savings + if accumulated_tokens >= target_tokens_to_compress: + break + + # If we still don't have enough savings, compress the entire compressible region + if accumulated_tokens < target_tokens_to_compress and compress_until < compress_end: + compress_until = compress_end + accumulated_tokens = sum(turn_tokens[compress_start:compress_end]) + + # Record compression region + metrics.turns_compressed_start_idx = compress_start + metrics.turns_compressed_end_idx = compress_until + metrics.turns_in_compressed_region = compress_until - compress_start + + # Extract content for summary + content_to_summarize = self._extract_turn_content_for_summary( + trajectory, compress_start, compress_until + ) + + # Generate summary + summary = self._generate_summary(content_to_summarize, metrics) + + # Build compressed trajectory + compressed = [] + + # Add head (turns before compression region) + for i in range(compress_start): + turn = trajectory[i].copy() + # Add notice to system message + if turn.get("from") == "system" and self.config.add_summary_notice: + turn["value"] = turn["value"] + self.config.summary_notice_text + compressed.append(turn) + + # Add summary as human message + compressed.append({ + "from": "human", + "value": summary + }) + + # Add tail (turns after compression region) + for i in range(compress_until, len(trajectory)): + compressed.append(trajectory[i].copy()) + + # Calculate final metrics + metrics.compressed_turns = len(compressed) + metrics.compressed_tokens = self.count_trajectory_tokens(compressed) + metrics.turns_removed = metrics.original_turns - metrics.compressed_turns + metrics.tokens_saved = metrics.original_tokens - metrics.compressed_tokens + metrics.compression_ratio = metrics.compressed_tokens / max(metrics.original_tokens, 1) + metrics.was_compressed = True + metrics.still_over_limit = metrics.compressed_tokens > self.config.target_max_tokens + + return compressed, metrics + + async def compress_trajectory_async( + self, + trajectory: List[Dict[str, str]] + ) -> Tuple[List[Dict[str, str]], TrajectoryMetrics]: + """ + Compress a single trajectory to fit within target token budget (async version). + + Same algorithm as compress_trajectory but uses async API calls for summarization. + """ + metrics = TrajectoryMetrics() + metrics.original_turns = len(trajectory) + + # Count tokens per turn + turn_tokens = self.count_turn_tokens(trajectory) + total_tokens = sum(turn_tokens) + metrics.original_tokens = total_tokens + + # Check if compression needed + if total_tokens <= self.config.target_max_tokens: + metrics.skipped_under_target = True + metrics.compressed_tokens = total_tokens + metrics.compressed_turns = len(trajectory) + metrics.compression_ratio = 1.0 + return trajectory, metrics + + # Find protected regions + protected, compress_start, compress_end = self._find_protected_indices(trajectory) + + # Check if there's anything to compress + if compress_start >= compress_end: + metrics.compressed_tokens = total_tokens + metrics.compressed_turns = len(trajectory) + metrics.still_over_limit = total_tokens > self.config.target_max_tokens + return trajectory, metrics + + # Calculate how much we need to save + tokens_to_save = total_tokens - self.config.target_max_tokens + target_tokens_to_compress = tokens_to_save + self.config.summary_target_tokens + + # Accumulate turns from compress_start until we have enough savings + accumulated_tokens = 0 + compress_until = compress_start + + for i in range(compress_start, compress_end): + accumulated_tokens += turn_tokens[i] + compress_until = i + 1 + if accumulated_tokens >= target_tokens_to_compress: + break + + # If we still don't have enough savings, compress the entire compressible region + if accumulated_tokens < target_tokens_to_compress and compress_until < compress_end: + compress_until = compress_end + accumulated_tokens = sum(turn_tokens[compress_start:compress_end]) + + # Record compression region + metrics.turns_compressed_start_idx = compress_start + metrics.turns_compressed_end_idx = compress_until + metrics.turns_in_compressed_region = compress_until - compress_start + + # Extract content for summary + content_to_summarize = self._extract_turn_content_for_summary( + trajectory, compress_start, compress_until + ) + + # Generate summary (ASYNC) + summary = await self._generate_summary_async(content_to_summarize, metrics) + + # Build compressed trajectory + compressed = [] + + # Add head (turns before compression region) + for i in range(compress_start): + turn = trajectory[i].copy() + if turn.get("from") == "system" and self.config.add_summary_notice: + turn["value"] = turn["value"] + self.config.summary_notice_text + compressed.append(turn) + + # Add summary as human message + compressed.append({ + "from": "human", + "value": summary + }) + + # Add tail (turns after compression region) + for i in range(compress_until, len(trajectory)): + compressed.append(trajectory[i].copy()) + + # Calculate final metrics + metrics.compressed_turns = len(compressed) + metrics.compressed_tokens = self.count_trajectory_tokens(compressed) + metrics.turns_removed = metrics.original_turns - metrics.compressed_turns + metrics.tokens_saved = metrics.original_tokens - metrics.compressed_tokens + metrics.compression_ratio = metrics.compressed_tokens / max(metrics.original_tokens, 1) + metrics.was_compressed = True + metrics.still_over_limit = metrics.compressed_tokens > self.config.target_max_tokens + + return compressed, metrics + + async def process_entry_async(self, entry: Dict[str, Any]) -> Tuple[Dict[str, Any], TrajectoryMetrics]: + """ + Process a single JSONL entry (async version). + """ + if "conversations" not in entry: + metrics = TrajectoryMetrics() + return entry, metrics + + trajectory = entry["conversations"] + compressed_trajectory, metrics = await self.compress_trajectory_async(trajectory) + + # Create new entry with compressed trajectory + result = entry.copy() + result["conversations"] = compressed_trajectory + + # Add compression metadata if enabled + if self.config.metrics_per_trajectory and metrics.was_compressed: + result["compression_metrics"] = metrics.to_dict() + + return result, metrics + + def process_entry(self, entry: Dict[str, Any]) -> Tuple[Dict[str, Any], TrajectoryMetrics]: + """ + Process a single JSONL entry. + + Args: + entry: JSONL entry containing 'conversations' field + + Returns: + Tuple of (processed_entry, metrics) + """ + if "conversations" not in entry: + metrics = TrajectoryMetrics() + return entry, metrics + + trajectory = entry["conversations"] + compressed_trajectory, metrics = self.compress_trajectory(trajectory) + + # Create new entry with compressed trajectory + result = entry.copy() + result["conversations"] = compressed_trajectory + + # Add compression metadata if enabled + if self.config.metrics_per_trajectory and metrics.was_compressed: + result["compression_metrics"] = metrics.to_dict() + + return result, metrics + + def process_file( + self, + input_path: Path, + output_path: Path, + progress_callback: Optional[Callable[[TrajectoryMetrics], None]] = None + ) -> List[TrajectoryMetrics]: + """ + Process a single JSONL file. + + Args: + input_path: Path to input JSONL file + output_path: Path to output JSONL file + progress_callback: Optional callback called after each entry with its metrics + + Returns: + List of metrics for each trajectory + """ + file_metrics = [] + + # Read all entries + entries = [] + with open(input_path, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if line: + try: + entries.append(json.loads(line)) + except json.JSONDecodeError as e: + self.logger.warning(f"Skipping invalid JSON at {input_path}:{line_num}: {e}") + + # Process entries + processed_entries = [] + for entry in entries: + try: + processed_entry, metrics = self.process_entry(entry) + processed_entries.append(processed_entry) + file_metrics.append(metrics) + self.aggregate_metrics.add_trajectory_metrics(metrics) + + # Call progress callback if provided + if progress_callback: + progress_callback(metrics) + + except Exception as e: + self.logger.error(f"Error processing entry: {e}") + self.aggregate_metrics.trajectories_failed += 1 + # Keep original entry on error + processed_entries.append(entry) + empty_metrics = TrajectoryMetrics() + file_metrics.append(empty_metrics) + + if progress_callback: + progress_callback(empty_metrics) + + # Write output + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + for entry in processed_entries: + f.write(json.dumps(entry, ensure_ascii=False) + '\n') + + return file_metrics + + def process_directory(self, input_dir: Path, output_dir: Path): + """ + Process all JSONL files in a directory using async parallel processing. + + Args: + input_dir: Input directory containing JSONL files + output_dir: Output directory for compressed files + """ + # Run the async version + asyncio.run(self._process_directory_async(input_dir, output_dir)) + + async def _process_directory_async(self, input_dir: Path, output_dir: Path): + """ + Async implementation of directory processing with parallel API calls. + """ + console = Console() + + # Record start time + self.aggregate_metrics.processing_start_time = datetime.now().isoformat() + start_time = time.time() + + # Find all JSONL files + jsonl_files = sorted(input_dir.glob("*.jsonl")) + + if not jsonl_files: + self.logger.warning(f"No JSONL files found in {input_dir}") + return + + # Load ALL entries from all files + console.print("\n[dim]Loading all entries...[/dim]") + all_entries = [] # List of (file_path, entry_idx, entry) + + for file_path in jsonl_files: + with open(file_path, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f): + line = line.strip() + if line: + try: + entry = json.loads(line) + all_entries.append((file_path, line_num, entry)) + except json.JSONDecodeError as e: + self.logger.warning(f"Skipping invalid JSON at {file_path}:{line_num}: {e}") + + total_entries = len(all_entries) + + console.print(f"\n{'='*60}") + console.print(f"šŸ“‚ Input: {input_dir}") + console.print(f"šŸ“‚ Output: {output_dir}") + console.print(f"šŸ“„ Files to process: {len(jsonl_files)}") + console.print(f"šŸ“Š Total trajectories: {total_entries:,}") + console.print(f"šŸŽÆ Target max tokens: {self.config.target_max_tokens:,}") + console.print(f"šŸ“ Summary target tokens: {self.config.summary_target_tokens}") + console.print(f"⚔ Max concurrent API calls: {self.config.max_concurrent_requests}") + console.print(f"{'='*60}\n") + + # Create semaphore for rate limiting + semaphore = asyncio.Semaphore(self.config.max_concurrent_requests) + + # Tracking for progress display (thread-safe with lock) + progress_lock = asyncio.Lock() + compressed_count = 0 + skipped_count = 0 + api_calls = 0 + in_flight = 0 + + # Results storage: {file_path: {entry_idx: (processed_entry, metrics)}} + results = {f: {} for f in jsonl_files} + + async def process_single(file_path: Path, entry_idx: int, entry: Dict, + progress, main_task, status_task): + """Process a single entry with semaphore rate limiting.""" + nonlocal compressed_count, skipped_count, api_calls, in_flight + + async with semaphore: + # Track in-flight + async with progress_lock: + in_flight += 1 + + try: + processed_entry, metrics = await self.process_entry_async(entry) + results[file_path][entry_idx] = (processed_entry, metrics) + + # Update aggregate metrics (with lock for thread safety) + async with progress_lock: + self.aggregate_metrics.add_trajectory_metrics(metrics) + + # Update counters + if metrics.was_compressed: + compressed_count += 1 + api_calls += metrics.summarization_api_calls + if metrics.skipped_under_target: + skipped_count += 1 + + in_flight -= 1 + + # Update progress + progress.advance(main_task) + progress.update( + status_task, + description=f"[dim]āœ… {compressed_count} compressed | ā­ļø {skipped_count} skipped | šŸ”„ {api_calls} API calls | ⚔ {in_flight} in-flight[/dim]" + ) + + except Exception as e: + self.logger.error(f"Error processing entry from {file_path}:{entry_idx}: {e}") + + async with progress_lock: + self.aggregate_metrics.trajectories_failed += 1 + in_flight -= 1 + progress.advance(main_task) + + # Keep original entry on error + results[file_path][entry_idx] = (entry, TrajectoryMetrics()) + + # Create progress bar + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TextColumn("•"), + TimeElapsedColumn(), + TextColumn("•"), + TimeRemainingColumn(), + console=console, + refresh_per_second=10 # Higher refresh for async + ) as progress: + # Main task for overall progress + main_task = progress.add_task( + f"[cyan]Compressing {total_entries:,} trajectories", + total=total_entries + ) + + # Status line task + status_task = progress.add_task( + "[dim]Starting...[/dim]", + total=None + ) + + # Create all tasks + tasks = [ + process_single(file_path, entry_idx, entry, progress, main_task, status_task) + for file_path, entry_idx, entry in all_entries + ] + + # Run all tasks concurrently (semaphore limits actual concurrency) + await asyncio.gather(*tasks) + + # Remove status task + progress.remove_task(status_task) + + # Write results to output files (preserving original order) + console.print("\n[dim]Writing output files...[/dim]") + output_dir.mkdir(parents=True, exist_ok=True) + + for file_path in jsonl_files: + output_path = output_dir / file_path.name + file_results = results[file_path] + + # Sort by original entry index to preserve order + sorted_entries = [file_results[idx][0] for idx in sorted(file_results.keys())] + + with open(output_path, 'w', encoding='utf-8') as f: + for entry in sorted_entries: + f.write(json.dumps(entry, ensure_ascii=False) + '\n') + + # Record end time + self.aggregate_metrics.processing_end_time = datetime.now().isoformat() + self.aggregate_metrics.processing_duration_seconds = time.time() - start_time + + # Print summary + self._print_summary() + + # Save metrics + if self.config.metrics_enabled: + metrics_path = output_dir / self.config.metrics_output_file + with open(metrics_path, 'w') as f: + json.dump(self.aggregate_metrics.to_dict(), f, indent=2) + console.print(f"\nšŸ’¾ Metrics saved to {metrics_path}") + + def _print_summary(self): + """Print comprehensive compression summary statistics.""" + m = self.aggregate_metrics.to_dict() + + # Calculate some additional stats + total = m['summary']['total_trajectories'] + compressed = m['summary']['trajectories_compressed'] + skipped = m['summary']['trajectories_skipped_under_target'] + over_limit = m['summary']['trajectories_still_over_limit'] + failed = m['summary']['trajectories_failed'] + + # Token stats + tokens_before = m['tokens']['total_before'] + tokens_after = m['tokens']['total_after'] + tokens_saved = m['tokens']['total_saved'] + + # Calculate percentages + compressed_pct = (compressed / max(total, 1)) * 100 + skipped_pct = (skipped / max(total, 1)) * 100 + over_limit_pct = (over_limit / max(total, 1)) * 100 + + print(f"\n") + print(f"ā•”{'═'*70}ā•—") + print(f"ā•‘{'TRAJECTORY COMPRESSION REPORT':^70}ā•‘") + print(f"ā• {'═'*70}ā•£") + + # Trajectories section + print(f"ā•‘{'':2}šŸ“ TRAJECTORIES{' '*54}ā•‘") + print(f"ā•‘{'─'*70}ā•‘") + print(f"ā•‘{'':4}Total Processed: {total:>10,}{' '*32}ā•‘") + print(f"ā•‘{'':4}ā”œā”€ Compressed: {compressed:>10,} ({compressed_pct:>5.1f}%){' '*18}ā•‘") + print(f"ā•‘{'':4}ā”œā”€ Skipped (under limit):{skipped:>9,} ({skipped_pct:>5.1f}%){' '*18}ā•‘") + print(f"ā•‘{'':4}ā”œā”€ Still over limit: {over_limit:>10,} ({over_limit_pct:>5.1f}%){' '*18}ā•‘") + print(f"ā•‘{'':4}└─ Failed: {failed:>10,}{' '*32}ā•‘") + + print(f"ā• {'═'*70}ā•£") + + # Tokens section + print(f"ā•‘{'':2}šŸ”¢ TOKENS{' '*60}ā•‘") + print(f"ā•‘{'─'*70}ā•‘") + print(f"ā•‘{'':4}Before Compression: {tokens_before:>15,} tokens{' '*21}ā•‘") + print(f"ā•‘{'':4}After Compression: {tokens_after:>15,} tokens{' '*21}ā•‘") + print(f"ā•‘{'':4}Total Saved: {tokens_saved:>15,} tokens{' '*21}ā•‘") + print(f"ā•‘{'':4}Overall Compression: {m['tokens']['overall_compression_ratio']:>14.1%}{' '*28}ā•‘") + + if tokens_before > 0: + savings_pct = (tokens_saved / tokens_before) * 100 + print(f"ā•‘{'':4}Space Savings: {savings_pct:>14.1f}%{' '*28}ā•‘") + + print(f"ā• {'═'*70}ā•£") + + # Turns section + print(f"ā•‘{'':2}šŸ’¬ CONVERSATION TURNS{' '*48}ā•‘") + print(f"ā•‘{'─'*70}ā•‘") + print(f"ā•‘{'':4}Before Compression: {m['turns']['total_before']:>15,} turns{' '*22}ā•‘") + print(f"ā•‘{'':4}After Compression: {m['turns']['total_after']:>15,} turns{' '*22}ā•‘") + print(f"ā•‘{'':4}Total Removed: {m['turns']['total_removed']:>15,} turns{' '*22}ā•‘") + + print(f"ā• {'═'*70}ā•£") + + # Averages section (for compressed trajectories only) + print(f"ā•‘{'':2}šŸ“ˆ AVERAGES (Compressed Trajectories Only){' '*27}ā•‘") + print(f"ā•‘{'─'*70}ā•‘") + if compressed > 0: + print(f"ā•‘{'':4}Avg Compression Ratio: {m['averages']['avg_compression_ratio']:>14.1%}{' '*28}ā•‘") + print(f"ā•‘{'':4}Avg Tokens Saved: {m['averages']['avg_tokens_saved_per_compressed']:>14,.0f}{' '*28}ā•‘") + print(f"ā•‘{'':4}Avg Turns Removed: {m['averages']['avg_turns_removed_per_compressed']:>14.1f}{' '*28}ā•‘") + else: + print(f"ā•‘{'':4}No trajectories were compressed{' '*38}ā•‘") + + print(f"ā• {'═'*70}ā•£") + + # Summarization API section + print(f"ā•‘{'':2}šŸ¤– SUMMARIZATION API{' '*49}ā•‘") + print(f"ā•‘{'─'*70}ā•‘") + print(f"ā•‘{'':4}API Calls Made: {m['summarization']['total_api_calls']:>15,}{' '*27}ā•‘") + print(f"ā•‘{'':4}Errors: {m['summarization']['total_errors']:>15,}{' '*27}ā•‘") + print(f"ā•‘{'':4}Success Rate: {m['summarization']['success_rate']:>14.1%}{' '*28}ā•‘") + + print(f"ā• {'═'*70}ā•£") + + # Processing time section + duration = m['processing']['duration_seconds'] + if duration > 60: + time_str = f"{duration/60:.1f} minutes" + else: + time_str = f"{duration:.1f} seconds" + + throughput = total / max(duration, 0.001) + + print(f"ā•‘{'':2}ā±ļø PROCESSING TIME{' '*51}ā•‘") + print(f"ā•‘{'─'*70}ā•‘") + print(f"ā•‘{'':4}Duration: {time_str:>20}{' '*22}ā•‘") + print(f"ā•‘{'':4}Throughput: {throughput:>15.1f} traj/sec{' '*18}ā•‘") + print(f"ā•‘{'':4}Started: {m['processing']['start_time'][:19]:>20}{' '*22}ā•‘") + print(f"ā•‘{'':4}Finished: {m['processing']['end_time'][:19]:>20}{' '*22}ā•‘") + + print(f"ā•š{'═'*70}ā•") + + # Distribution summary if we have data + if self.aggregate_metrics.compression_ratios: + ratios = self.aggregate_metrics.compression_ratios + tokens_saved_list = self.aggregate_metrics.tokens_saved_list + + print(f"\nšŸ“Š Distribution Summary:") + print(f" Compression ratios: min={min(ratios):.2%}, max={max(ratios):.2%}, median={sorted(ratios)[len(ratios)//2]:.2%}") + print(f" Tokens saved: min={min(tokens_saved_list):,}, max={max(tokens_saved_list):,}, median={sorted(tokens_saved_list)[len(tokens_saved_list)//2]:,}") + + +def main( + input_dir: str, + output_dir: str = None, + config: str = "configs/trajectory_compression.yaml", + target_max_tokens: int = None, + tokenizer: str = None, + dry_run: bool = False, +): + """ + Compress agent trajectories to fit within a target token budget. + + Args: + input_dir: Directory containing JSONL trajectory files + output_dir: Output directory (default: input_dir + "_compressed") + config: Path to YAML configuration file + target_max_tokens: Override target token count from config + tokenizer: Override tokenizer name from config + dry_run: Analyze without compressing (just show what would happen) + """ + print("šŸ—œļø Trajectory Compressor") + print("=" * 60) + + # Load configuration + config_path = Path(config) + if config_path.exists(): + print(f"šŸ“‹ Loading config from {config}") + compression_config = CompressionConfig.from_yaml(config) + else: + print(f"āš ļø Config not found at {config}, using defaults") + compression_config = CompressionConfig() + + # Apply CLI overrides + if target_max_tokens: + compression_config.target_max_tokens = target_max_tokens + if tokenizer: + compression_config.tokenizer_name = tokenizer + + # Setup paths + input_path = Path(input_dir) + if not input_path.exists(): + print(f"āŒ Input directory not found: {input_dir}") + return + + if output_dir: + output_path = Path(output_dir) + else: + output_path = input_path.parent / (input_path.name + compression_config.output_suffix) + + if dry_run: + print(f"\nšŸ” DRY RUN MODE - analyzing without writing") + print(f"šŸ“ Would process: {input_path}") + print(f"šŸ“ Would output to: {output_path}") + # TODO: Implement dry run analysis + return + + # Initialize compressor + compressor = TrajectoryCompressor(compression_config) + + # Process directory + compressor.process_directory(input_path, output_path) + + print("\nāœ… Compression complete!") + + +if __name__ == "__main__": + fire.Fire(main)