From 0e2e69a71ddacdc99ba7282f3b4ab93ded881834 Mon Sep 17 00:00:00 2001 From: teknium Date: Mon, 6 Oct 2025 03:17:58 +0000 Subject: [PATCH] Add batch processing capabilities with checkpointing and statistics tracking, along with toolset distribution management. Update README and add test scripts for validation. --- .cursorrules | 23 ++ README.md | 37 ++ batch_runner.py | 709 +++++++++++++++++++++++++++++++++++++ test_run.sh | 10 +- tests/test_batch_runner.py | 129 +++++++ toolset_distributions.py | 269 ++++++++++++++ 6 files changed, 1168 insertions(+), 9 deletions(-) create mode 100644 .cursorrules create mode 100644 batch_runner.py create mode 100644 tests/test_batch_runner.py create mode 100644 toolset_distributions.py diff --git a/.cursorrules b/.cursorrules new file mode 100644 index 000000000..defa74ef3 --- /dev/null +++ b/.cursorrules @@ -0,0 +1,23 @@ +Hermes-Agent is an agent harness for LLMs. + +When building, the tool functionality is in the tools/ directory, where each specific tool (or in some cases, tools that are built for the same execution category or api) are placed in a script each their own. + +Each tool is then consolidated in the model_tools.py file in the repo root. + +There is also a way to consolidate sets of tools in toolsets.py for the agent to use. + +The primary agent runner code is in run_agent, but other runners could be developed using the tools and framework. + +Always ensure consistency between tools, the model_tools.py and toolsets.py when changing any of them, otherwise they could become desynced in a way that is detrimental to functionality. + +The expected pathway for using API keys is to setup and place them in a .env file in the repo root. + +Test scripts will be placed in tests/ + +The run_agent loop is setup to: +- Process the enabled toolsets to provide to the model, +- Pipe in a prompt or problem from the input to the agent, +- Loop the LLM each time it calls a tool, until the model decides no more tools are needed and provides a natural language response, +- Return that response. + +There are additional caveats for logging, where we restructure the "tools" as a system prompt for storage later into a format that can be used and handled properly later. \ No newline at end of file diff --git a/README.md b/README.md index 0a73bc389..4816a9550 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ An AI agent with advanced tool-calling capabilities, featuring a flexible toolse - **Reasoning Tools**: Advanced multi-model reasoning (Mixture of Agents) - **Creative Tools**: Generate images from text prompts - **Toolsets System**: Organize tools into logical groups for different scenarios +- **Batch Processing**: Process datasets in parallel with checkpointing and statistics tracking ## Setup @@ -133,6 +134,36 @@ create_custom_toolset( agent = AIAgent(enabled_toolsets=["my_tools"]) ``` +## Batch Processing + +Process multiple prompts from a dataset in parallel with automatic checkpointing and statistics tracking: + +```bash +# Basic batch processing +python batch_runner.py \ + --dataset_file=prompts.jsonl \ + --batch_size=20 \ + --run_name=my_run + +# With specific distribution +python batch_runner.py \ + --dataset_file=prompts.jsonl \ + --batch_size=20 \ + --run_name=image_run \ + --distribution=image_gen \ + --num_workers=4 +``` + +**Key Features:** +- Parallel processing with configurable workers +- Toolset distributions for varied data generation +- Automatic checkpointing and resume capability +- Combined output in `data//trajectories.jsonl` +- Tool usage statistics and success rates + +**Quick Start:** See [QUICKSTART_BATCH.md](QUICKSTART_BATCH.md) for a 5-minute getting started guide. +**Full Documentation:** See [BATCH_PROCESSING.md](BATCH_PROCESSING.md) for comprehensive documentation. + ## Command Line Arguments - `--query`: The question or task for the agent @@ -164,10 +195,16 @@ All environment variables can be configured in the `.env` file (copy from `.env. ## Documentation +**Single Agent Usage:** - `TOOLSETS_README.md`: Comprehensive guide to the toolsets system - `toolsets.py`: View and modify available toolsets - `model_tools.py`: Core tool definitions and handlers +**Batch Processing:** +- `QUICKSTART_BATCH.md`: 5-minute quick start guide +- `BATCH_PROCESSING.md`: Complete batch processing documentation +- `toolset_distributions.py`: Toolset distributions for data generation + ## Examples See `TOOLSETS_README.md` for extensive examples of using different toolsets for various scenarios. diff --git a/batch_runner.py b/batch_runner.py new file mode 100644 index 000000000..ab86f2621 --- /dev/null +++ b/batch_runner.py @@ -0,0 +1,709 @@ +#!/usr/bin/env python3 +""" +Batch Agent Runner + +This module provides parallel batch processing capabilities for running the agent +across multiple prompts from a dataset. It includes: +- Dataset loading and batching +- Parallel batch processing with multiprocessing +- Checkpointing for fault tolerance and resumption +- Trajectory saving in the proper format (from/value pairs) +- Tool usage statistics aggregation across all batches + +Usage: + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run + + # Resume an interrupted run + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume + + # Use a specific toolset distribution + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen +""" + +import json +import logging +import os +import time +from pathlib import Path +from typing import List, Dict, Any, Optional, Tuple +from datetime import datetime +from multiprocessing import Pool, Manager, Lock +import traceback + +import fire + +from run_agent import AIAgent +from toolset_distributions import ( + get_distribution, + list_distributions, + sample_toolsets_from_distribution, + validate_distribution +) + + +# Global configuration for worker processes +_WORKER_CONFIG = {} + + +def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]: + """ + Extract tool usage statistics from message history. + + Args: + messages (List[Dict]): Message history + + Returns: + Dict: Tool statistics with counts and success/failure rates + """ + tool_stats = {} + + # Track tool calls and their results + tool_calls_map = {} # Map tool_call_id to tool name + + for msg in messages: + # Track tool calls from assistant messages + if msg["role"] == "assistant" and "tool_calls" in msg and msg["tool_calls"]: + for tool_call in msg["tool_calls"]: + tool_name = tool_call["function"]["name"] + tool_call_id = tool_call["id"] + + # Initialize stats for this tool if not exists + if tool_name not in tool_stats: + tool_stats[tool_name] = { + "count": 0, + "success": 0, + "failure": 0 + } + + tool_stats[tool_name]["count"] += 1 + tool_calls_map[tool_call_id] = tool_name + + # Track tool responses + elif msg["role"] == "tool": + tool_call_id = msg.get("tool_call_id", "") + content = msg.get("content", "") + + # Determine if tool call was successful + is_success = True + try: + # Try to parse as JSON and check for error field + content_json = json.loads(content) if isinstance(content, str) else content + if isinstance(content_json, dict) and "error" in content_json: + is_success = False + except: + # If not JSON, check if content contains error indicators + if not content or "error" in content.lower(): + is_success = False + + # Update success/failure count + if tool_call_id in tool_calls_map: + tool_name = tool_calls_map[tool_call_id] + if is_success: + tool_stats[tool_name]["success"] += 1 + else: + tool_stats[tool_name]["failure"] += 1 + + return tool_stats + + +def _process_single_prompt( + prompt_index: int, + prompt_data: Dict[str, Any], + batch_num: int, + config: Dict[str, Any] +) -> Dict[str, Any]: + """ + Process a single prompt with the agent. + + Args: + prompt_index (int): Index of prompt in dataset + prompt_data (Dict): Prompt data containing 'prompt' field + batch_num (int): Batch number + config (Dict): Configuration dict with agent parameters + + Returns: + Dict: Result containing trajectory, stats, and metadata + """ + prompt = prompt_data["prompt"] + + try: + # Sample toolsets from distribution for this prompt + selected_toolsets = sample_toolsets_from_distribution(config["distribution"]) + + if config.get("verbose"): + print(f" Prompt {prompt_index}: Using toolsets {selected_toolsets}") + + # Initialize agent with sampled toolsets + agent = AIAgent( + base_url=config.get("base_url"), + api_key=config.get("api_key"), + model=config["model"], + max_iterations=config["max_iterations"], + enabled_toolsets=selected_toolsets, + save_trajectories=False, # We handle saving ourselves + verbose_logging=config.get("verbose", False) + ) + + # Run the agent + result = agent.run_conversation(prompt) + + # Extract tool usage statistics + tool_stats = _extract_tool_stats(result["messages"]) + + # Convert to trajectory format (using existing method) + trajectory = agent._convert_to_trajectory_format( + result["messages"], + prompt, + result["completed"] + ) + + return { + "success": True, + "prompt_index": prompt_index, + "trajectory": trajectory, + "tool_stats": tool_stats, + "completed": result["completed"], + "api_calls": result["api_calls"], + "toolsets_used": selected_toolsets, + "metadata": { + "batch_num": batch_num, + "timestamp": datetime.now().isoformat(), + "model": config["model"] + } + } + + except Exception as e: + print(f"โŒ Error processing prompt {prompt_index}: {e}") + if config.get("verbose"): + traceback.print_exc() + + return { + "success": False, + "prompt_index": prompt_index, + "error": str(e), + "trajectory": None, + "tool_stats": {}, + "toolsets_used": [], + "metadata": { + "batch_num": batch_num, + "timestamp": datetime.now().isoformat() + } + } + + +def _process_batch_worker(args: Tuple) -> Dict[str, Any]: + """ + Worker function to process a single batch of prompts. + + Args: + args (Tuple): (batch_num, batch_data, output_dir, completed_prompts, config) + + Returns: + Dict: Batch results with statistics + """ + batch_num, batch_data, output_dir, completed_prompts_set, config = args + + output_dir = Path(output_dir) + print(f"\n๐Ÿ”„ Batch {batch_num}: Starting ({len(batch_data)} prompts)") + + # Output file for this batch + batch_output_file = output_dir / f"batch_{batch_num}.jsonl" + + # Filter out already completed prompts + prompts_to_process = [ + (idx, data) for idx, data in batch_data + if idx not in completed_prompts_set + ] + + if not prompts_to_process: + print(f"โœ… Batch {batch_num}: Already completed (skipping)") + return { + "batch_num": batch_num, + "processed": 0, + "skipped": len(batch_data), + "tool_stats": {}, + "completed_prompts": [] + } + + print(f" Processing {len(prompts_to_process)} prompts (skipping {len(batch_data) - len(prompts_to_process)} already completed)") + + # Initialize aggregated stats for this batch + batch_tool_stats = {} + completed_in_batch = [] + + # Process each prompt sequentially in this batch + for prompt_index, prompt_data in prompts_to_process: + # Process the prompt + result = _process_single_prompt( + prompt_index, + prompt_data, + batch_num, + config + ) + + # Save trajectory if successful + if result["success"] and result["trajectory"]: + trajectory_entry = { + "prompt_index": prompt_index, + "conversations": result["trajectory"], + "metadata": result["metadata"], + "completed": result["completed"], + "api_calls": result["api_calls"], + "toolsets_used": result["toolsets_used"] + } + + # Append to batch output file + with open(batch_output_file, 'a', encoding='utf-8') as f: + f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n") + + # Aggregate tool statistics + for tool_name, stats in result.get("tool_stats", {}).items(): + if tool_name not in batch_tool_stats: + batch_tool_stats[tool_name] = { + "count": 0, + "success": 0, + "failure": 0 + } + + batch_tool_stats[tool_name]["count"] += stats["count"] + batch_tool_stats[tool_name]["success"] += stats["success"] + batch_tool_stats[tool_name]["failure"] += stats["failure"] + + completed_in_batch.append(prompt_index) + print(f" โœ… Prompt {prompt_index} completed") + + print(f"โœ… Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)") + + return { + "batch_num": batch_num, + "processed": len(prompts_to_process), + "skipped": len(batch_data) - len(prompts_to_process), + "tool_stats": batch_tool_stats, + "completed_prompts": completed_in_batch + } + + +class BatchRunner: + """ + Manages batch processing of agent prompts with checkpointing and statistics. + """ + + def __init__( + self, + dataset_file: str, + batch_size: int, + run_name: str, + distribution: str = "default", + max_iterations: int = 10, + base_url: str = None, + api_key: str = None, + model: str = "claude-opus-4-20250514", + num_workers: int = 4, + verbose: bool = False + ): + """ + Initialize the batch runner. + + Args: + dataset_file (str): Path to the dataset JSONL file with 'prompt' field + batch_size (int): Number of prompts per batch + run_name (str): Name for this run (used for checkpointing and output) + distribution (str): Toolset distribution to use (default: "default") + max_iterations (int): Max iterations per agent run + base_url (str): Base URL for model API + api_key (str): API key for model + model (str): Model name to use + num_workers (int): Number of parallel workers + verbose (bool): Enable verbose logging + """ + self.dataset_file = Path(dataset_file) + self.batch_size = batch_size + self.run_name = run_name + self.distribution = distribution + self.max_iterations = max_iterations + self.base_url = base_url + self.api_key = api_key + self.model = model + self.num_workers = num_workers + self.verbose = verbose + + # Validate distribution + if not validate_distribution(distribution): + raise ValueError(f"Unknown distribution: {distribution}. Available: {list(list_distributions().keys())}") + + # Setup output directory + self.output_dir = Path("data") / run_name + self.output_dir.mkdir(parents=True, exist_ok=True) + + # Checkpoint file + self.checkpoint_file = self.output_dir / "checkpoint.json" + + # Statistics file + self.stats_file = self.output_dir / "statistics.json" + + # Load dataset + self.dataset = self._load_dataset() + + # Create batches + self.batches = self._create_batches() + + print(f"๐Ÿ“Š Batch Runner Initialized") + print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)") + print(f" Batch size: {self.batch_size}") + print(f" Total batches: {len(self.batches)}") + print(f" Run name: {self.run_name}") + print(f" Distribution: {self.distribution}") + print(f" Output directory: {self.output_dir}") + print(f" Workers: {self.num_workers}") + + def _load_dataset(self) -> List[Dict[str, Any]]: + """ + Load dataset from JSONL file. + + Returns: + List[Dict]: List of dataset entries + """ + if not self.dataset_file.exists(): + raise FileNotFoundError(f"Dataset file not found: {self.dataset_file}") + + dataset = [] + with open(self.dataset_file, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + + try: + entry = json.loads(line) + if 'prompt' not in entry: + print(f"โš ๏ธ Warning: Line {line_num} missing 'prompt' field, skipping") + continue + dataset.append(entry) + except json.JSONDecodeError as e: + print(f"โš ๏ธ Warning: Invalid JSON on line {line_num}: {e}") + continue + + if not dataset: + raise ValueError(f"No valid entries found in dataset file: {self.dataset_file}") + + return dataset + + def _create_batches(self) -> List[List[Tuple[int, Dict[str, Any]]]]: + """ + Split dataset into batches with indices. + + Returns: + List of batches, where each batch is a list of (index, entry) tuples + """ + batches = [] + for i in range(0, len(self.dataset), self.batch_size): + batch = [(idx, entry) for idx, entry in enumerate(self.dataset[i:i + self.batch_size], start=i)] + batches.append(batch) + + return batches + + def _load_checkpoint(self) -> Dict[str, Any]: + """ + Load checkpoint data if it exists. + + Returns: + Dict: Checkpoint data with completed prompt indices + """ + if not self.checkpoint_file.exists(): + return { + "run_name": self.run_name, + "completed_prompts": [], + "batch_stats": {}, + "last_updated": None + } + + try: + with open(self.checkpoint_file, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + print(f"โš ๏ธ Warning: Failed to load checkpoint: {e}") + return { + "run_name": self.run_name, + "completed_prompts": [], + "batch_stats": {}, + "last_updated": None + } + + def _save_checkpoint(self, checkpoint_data: Dict[str, Any], lock: Optional[Lock] = None): + """ + Save checkpoint data. + + Args: + checkpoint_data (Dict): Checkpoint data to save + lock (Lock): Optional lock for thread-safe access + """ + checkpoint_data["last_updated"] = datetime.now().isoformat() + + if lock: + with lock: + with open(self.checkpoint_file, 'w', encoding='utf-8') as f: + json.dump(checkpoint_data, f, indent=2) + else: + with open(self.checkpoint_file, 'w', encoding='utf-8') as f: + json.dump(checkpoint_data, f, indent=2) + + + def run(self, resume: bool = False): + """ + Run the batch processing pipeline. + + Args: + resume (bool): Whether to resume from checkpoint + """ + print("\n" + "=" * 70) + print("๐Ÿš€ Starting Batch Processing") + print("=" * 70) + + # Load checkpoint + checkpoint_data = self._load_checkpoint() if resume else { + "run_name": self.run_name, + "completed_prompts": [], + "batch_stats": {}, + "last_updated": None + } + + if resume and checkpoint_data.get("completed_prompts"): + print(f"๐Ÿ“‚ Resuming from checkpoint ({len(checkpoint_data['completed_prompts'])} prompts already completed)") + + # Prepare configuration for workers + config = { + "distribution": self.distribution, + "model": self.model, + "max_iterations": self.max_iterations, + "base_url": self.base_url, + "api_key": self.api_key, + "verbose": self.verbose + } + + # Get completed prompts set + completed_prompts_set = set(checkpoint_data.get("completed_prompts", [])) + + # Aggregate statistics across all batches + total_tool_stats = {} + + start_time = time.time() + + # Process batches in parallel + with Pool(processes=self.num_workers) as pool: + # Create tasks for each batch + tasks = [ + ( + batch_num, + batch_data, + str(self.output_dir), # Convert Path to string for pickling + completed_prompts_set, + config + ) + for batch_num, batch_data in enumerate(self.batches) + ] + + # Use map to process batches in parallel + results = pool.map(_process_batch_worker, tasks) + + # Aggregate all batch statistics and update checkpoint + all_completed_prompts = list(completed_prompts_set) + for batch_result in results: + # Add newly completed prompts + all_completed_prompts.extend(batch_result.get("completed_prompts", [])) + + # Aggregate tool stats + for tool_name, stats in batch_result.get("tool_stats", {}).items(): + if tool_name not in total_tool_stats: + total_tool_stats[tool_name] = { + "count": 0, + "success": 0, + "failure": 0 + } + + total_tool_stats[tool_name]["count"] += stats["count"] + total_tool_stats[tool_name]["success"] += stats["success"] + total_tool_stats[tool_name]["failure"] += stats["failure"] + + # Save final checkpoint + checkpoint_data["completed_prompts"] = all_completed_prompts + self._save_checkpoint(checkpoint_data) + + # Calculate success rates + for tool_name in total_tool_stats: + stats = total_tool_stats[tool_name] + total_calls = stats["success"] + stats["failure"] + if total_calls > 0: + stats["success_rate"] = round(stats["success"] / total_calls * 100, 2) + stats["failure_rate"] = round(stats["failure"] / total_calls * 100, 2) + else: + stats["success_rate"] = 0.0 + stats["failure_rate"] = 0.0 + + # Combine all batch files into a single trajectories.jsonl file + combined_file = self.output_dir / "trajectories.jsonl" + print(f"\n๐Ÿ“ฆ Combining batch files into {combined_file.name}...") + + with open(combined_file, 'w', encoding='utf-8') as outfile: + for batch_num in range(len(self.batches)): + batch_file = self.output_dir / f"batch_{batch_num}.jsonl" + if batch_file.exists(): + with open(batch_file, 'r', encoding='utf-8') as infile: + for line in infile: + outfile.write(line) + + print(f"โœ… Combined {len(self.batches)} batch files into trajectories.jsonl") + + # Save final statistics + final_stats = { + "run_name": self.run_name, + "distribution": self.distribution, + "total_prompts": len(self.dataset), + "total_batches": len(self.batches), + "batch_size": self.batch_size, + "model": self.model, + "completed_at": datetime.now().isoformat(), + "duration_seconds": round(time.time() - start_time, 2), + "tool_statistics": total_tool_stats + } + + with open(self.stats_file, 'w', encoding='utf-8') as f: + json.dump(final_stats, f, indent=2) + + # Print summary + print("\n" + "=" * 70) + print("๐Ÿ“Š BATCH PROCESSING COMPLETE") + print("=" * 70) + print(f"โœ… Total prompts processed: {len(self.dataset)}") + print(f"โœ… Total batches: {len(self.batches)}") + print(f"โฑ๏ธ Total duration: {round(time.time() - start_time, 2)}s") + print(f"\n๐Ÿ“ˆ Tool Usage Statistics:") + print("-" * 70) + + if total_tool_stats: + # Sort by count descending + sorted_tools = sorted( + total_tool_stats.items(), + key=lambda x: x[1]["count"], + reverse=True + ) + + print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}") + print("-" * 70) + for tool_name, stats in sorted_tools: + print( + f"{tool_name:<25} " + f"{stats['count']:<10} " + f"{stats['success']:<10} " + f"{stats['failure']:<10} " + f"{stats['success_rate']:.1f}%" + ) + else: + print("No tool calls were made during this run.") + + print(f"\n๐Ÿ’พ Results saved to: {self.output_dir}") + print(f" - Trajectories: trajectories.jsonl (combined)") + print(f" - Individual batches: batch_*.jsonl (for debugging)") + print(f" - Statistics: {self.stats_file.name}") + print(f" - Checkpoint: {self.checkpoint_file.name}") + + +def main( + dataset_file: str = None, + batch_size: int = None, + run_name: str = None, + distribution: str = "default", + model: str = "claude-opus-4-20250514", + api_key: str = None, + base_url: str = "https://api.anthropic.com/v1/", + max_turns: int = 10, + num_workers: int = 4, + resume: bool = False, + verbose: bool = False, + list_distributions: bool = False +): + """ + Run batch processing of agent prompts from a dataset. + + Args: + dataset_file (str): Path to JSONL file with 'prompt' field in each entry + batch_size (int): Number of prompts per batch + run_name (str): Name for this run (used for output and checkpointing) + distribution (str): Toolset distribution to use (default: "default") + model (str): Model name to use (default: "claude-opus-4-20250514") + api_key (str): API key for model authentication + base_url (str): Base URL for model API + max_turns (int): Maximum number of tool calling iterations per prompt (default: 10) + num_workers (int): Number of parallel worker processes (default: 4) + resume (bool): Resume from checkpoint if run was interrupted (default: False) + verbose (bool): Enable verbose logging (default: False) + list_distributions (bool): List available toolset distributions and exit + + Examples: + # Basic usage + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run + + # Resume interrupted run + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume + + # Use specific distribution + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=image_test --distribution=image_gen + + # List available distributions + python batch_runner.py --list_distributions + """ + # Handle list distributions + if list_distributions: + from toolset_distributions import list_distributions as get_all_dists, print_distribution_info + + print("๐Ÿ“Š Available Toolset Distributions") + print("=" * 70) + + all_dists = get_all_dists() + for dist_name in sorted(all_dists.keys()): + print_distribution_info(dist_name) + + print("\n๐Ÿ’ก Usage:") + print(" python batch_runner.py --dataset_file=data.jsonl --batch_size=10 \\") + print(" --run_name=my_run --distribution=") + return + + # Validate required arguments + if not dataset_file: + print("โŒ Error: --dataset_file is required") + return + + if not batch_size or batch_size < 1: + print("โŒ Error: --batch_size must be a positive integer") + return + + if not run_name: + print("โŒ Error: --run_name is required") + return + + # Initialize and run batch runner + try: + runner = BatchRunner( + dataset_file=dataset_file, + batch_size=batch_size, + run_name=run_name, + distribution=distribution, + max_iterations=max_turns, + base_url=base_url, + api_key=api_key, + model=model, + num_workers=num_workers, + verbose=verbose + ) + + runner.run(resume=resume) + + except Exception as e: + print(f"\nโŒ Fatal error: {e}") + if verbose: + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + fire.Fire(main) + diff --git a/test_run.sh b/test_run.sh index 182cf284e..66be76d53 100755 --- a/test_run.sh +++ b/test_run.sh @@ -20,12 +20,4 @@ python run_agent.py \ --model claude-sonnet-4-5-20250929 \ --base_url https://api.anthropic.com/v1/ \ --api_key $ANTHROPIC_API_KEY \ - --save_trajectories \ - --enabled_toolsets=web - -# --model claude-sonnet-4-20250514 \ -# -#Possible Toolsets: -#web_tools -#vision_tools -#terminal_tools \ No newline at end of file + --save_trajectories \ No newline at end of file diff --git a/tests/test_batch_runner.py b/tests/test_batch_runner.py new file mode 100644 index 000000000..b6888d291 --- /dev/null +++ b/tests/test_batch_runner.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +""" +Test script for batch runner + +This script tests the batch runner with a small sample dataset +to verify functionality before running large batches. +""" + +import json +import shutil +from pathlib import Path + + +def create_test_dataset(): + """Create a small test dataset.""" + test_file = Path("tests/test_dataset.jsonl") + test_file.parent.mkdir(exist_ok=True) + + prompts = [ + {"prompt": "What is 2 + 2?"}, + {"prompt": "What is the capital of France?"}, + {"prompt": "Explain what Python is in one sentence."}, + ] + + with open(test_file, 'w') as f: + for prompt in prompts: + f.write(json.dumps(prompt) + "\n") + + print(f"โœ… Created test dataset: {test_file}") + return test_file + + +def cleanup_test_run(run_name): + """Clean up test run output.""" + output_dir = Path("data") / run_name + if output_dir.exists(): + shutil.rmtree(output_dir) + print(f"๐Ÿ—‘๏ธ Cleaned up test output: {output_dir}") + + +def verify_output(run_name): + """Verify that output files were created correctly.""" + output_dir = Path("data") / run_name + + # Check directory exists + if not output_dir.exists(): + print(f"โŒ Output directory not found: {output_dir}") + return False + + # Check for checkpoint + checkpoint_file = output_dir / "checkpoint.json" + if not checkpoint_file.exists(): + print(f"โŒ Checkpoint file not found: {checkpoint_file}") + return False + + # Check for statistics + stats_file = output_dir / "statistics.json" + if not stats_file.exists(): + print(f"โŒ Statistics file not found: {stats_file}") + return False + + # Check for batch files + batch_files = list(output_dir.glob("batch_*.jsonl")) + if not batch_files: + print(f"โŒ No batch files found in: {output_dir}") + return False + + print(f"โœ… Output verification passed:") + print(f" - Checkpoint: {checkpoint_file}") + print(f" - Statistics: {stats_file}") + print(f" - Batch files: {len(batch_files)}") + + # Load and display statistics + with open(stats_file) as f: + stats = json.load(f) + + print(f"\n๐Ÿ“Š Statistics Summary:") + print(f" - Total prompts: {stats['total_prompts']}") + print(f" - Total batches: {stats['total_batches']}") + print(f" - Duration: {stats['duration_seconds']}s") + + if stats.get('tool_statistics'): + print(f" - Tool calls:") + for tool, tool_stats in stats['tool_statistics'].items(): + print(f" โ€ข {tool}: {tool_stats['count']} calls, {tool_stats['success_rate']:.1f}% success") + + return True + + +def main(): + """Run the test.""" + print("๐Ÿงช Batch Runner Test") + print("=" * 60) + + run_name = "test_run" + + # Clean up any previous test run + cleanup_test_run(run_name) + + # Create test dataset + test_file = create_test_dataset() + + print(f"\n๐Ÿ“ To run the test manually:") + print(f" python batch_runner.py \\") + print(f" --dataset_file={test_file} \\") + print(f" --batch_size=2 \\") + print(f" --run_name={run_name} \\") + print(f" --distribution=minimal \\") + print(f" --num_workers=2") + + print(f"\n๐Ÿ’ก Or test with different distributions:") + print(f" python batch_runner.py --list_distributions") + + print(f"\n๐Ÿ” After running, you can verify output with:") + print(f" python tests/test_batch_runner.py --verify") + + # Note: We don't actually run the batch runner here to avoid API calls during testing + # Users should run it manually with their API keys configured + + +if __name__ == "__main__": + import sys + + if "--verify" in sys.argv: + run_name = "test_run" + verify_output(run_name) + else: + main() + diff --git a/toolset_distributions.py b/toolset_distributions.py new file mode 100644 index 000000000..41d093498 --- /dev/null +++ b/toolset_distributions.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python3 +""" +Toolset Distributions Module + +This module defines distributions of toolsets for data generation runs. +Each distribution specifies which toolsets should be used and their probability +of being selected for any given prompt during the batch processing. + +A distribution is a dictionary mapping toolset names to their selection probability (%). +Probabilities should sum to 100, but the system will normalize if they don't. + +Usage: + from toolset_distributions import get_distribution, list_distributions + + # Get a specific distribution + dist = get_distribution("image_gen") + + # List all available distributions + all_dists = list_distributions() +""" + +from typing import Dict, List, Optional +import random +from toolsets import validate_toolset + + +# Distribution definitions +# Each key is a distribution name, and the value is a dict of toolset_name: probability_percentage +DISTRIBUTIONS = { + # Default: All tools available 100% of the time + "default": { + "description": "All available tools, all the time", + "toolsets": { + "web": 100, + "vision": 100, + "image_gen": 100, + "terminal": 100, + "moa": 100 + } + }, + + # Image generation focused distribution + "image_gen": { + "description": "Heavy focus on image generation with vision and web support", + "toolsets": { + "image_gen": 80, # 80% chance of image generation tools + "vision": 60, # 60% chance of vision tools + "web": 40, # 40% chance of web tools + "moa": 20 # 20% chance of reasoning tools + } + }, + + # Research-focused distribution + "research": { + "description": "Web research with vision analysis and reasoning", + "toolsets": { + "web": 90, # 90% chance of web tools + "vision": 50, # 50% chance of vision tools + "moa": 40, # 40% chance of reasoning tools + "terminal": 10 # 10% chance of terminal tools + } + }, + + # Development-focused distribution + "development": { + "description": "Terminal and reasoning with occasional web lookup", + "toolsets": { + "terminal": 80, # 80% chance of terminal tools + "moa": 60, # 60% chance of reasoning tools + "web": 30, # 30% chance of web tools + "vision": 10 # 10% chance of vision tools + } + }, + + # Safe mode (no terminal) + "safe": { + "description": "All tools except terminal for safety", + "toolsets": { + "web": 80, + "vision": 60, + "image_gen": 60, + "moa": 50 + } + }, + + # Balanced distribution + "balanced": { + "description": "Equal probability of all toolsets", + "toolsets": { + "web": 50, + "vision": 50, + "image_gen": 50, + "terminal": 50, + "moa": 50 + } + }, + + # Minimal (web only) + "minimal": { + "description": "Only web tools for basic research", + "toolsets": { + "web": 100 + } + }, + + # Creative (vision + image generation) + "creative": { + "description": "Image generation and vision analysis focus", + "toolsets": { + "image_gen": 90, + "vision": 90, + "web": 30 + } + }, + + # Reasoning heavy + "reasoning": { + "description": "Heavy mixture of agents usage with minimal other tools", + "toolsets": { + "moa": 90, + "web": 30, + "terminal": 20 + } + } +} + + +def get_distribution(name: str) -> Optional[Dict[str, any]]: + """ + Get a toolset distribution by name. + + Args: + name (str): Name of the distribution + + Returns: + Dict: Distribution definition with description and toolsets + None: If distribution not found + """ + return DISTRIBUTIONS.get(name) + + +def list_distributions() -> Dict[str, Dict]: + """ + List all available distributions. + + Returns: + Dict: All distribution definitions + """ + return DISTRIBUTIONS.copy() + + +def sample_toolsets_from_distribution(distribution_name: str) -> List[str]: + """ + Sample toolsets based on a distribution's probabilities. + + Each toolset in the distribution has a % chance of being included. + This allows multiple toolsets to be active simultaneously. + + Args: + distribution_name (str): Name of the distribution to sample from + + Returns: + List[str]: List of sampled toolset names + + Raises: + ValueError: If distribution name is not found + """ + dist = get_distribution(distribution_name) + if not dist: + raise ValueError(f"Unknown distribution: {distribution_name}") + + # Sample each toolset independently based on its probability + selected_toolsets = [] + + for toolset_name, probability in dist["toolsets"].items(): + # Validate toolset exists + if not validate_toolset(toolset_name): + print(f"โš ๏ธ Warning: Toolset '{toolset_name}' in distribution '{distribution_name}' is not valid") + continue + + # Roll the dice - if random value is less than probability, include this toolset + if random.random() * 100 < probability: + selected_toolsets.append(toolset_name) + + # If no toolsets were selected (can happen with low probabilities), + # ensure at least one toolset is selected by picking the highest probability one + if not selected_toolsets and dist["toolsets"]: + # Find toolset with highest probability + highest_prob_toolset = max(dist["toolsets"].items(), key=lambda x: x[1])[0] + if validate_toolset(highest_prob_toolset): + selected_toolsets.append(highest_prob_toolset) + + return selected_toolsets + + +def validate_distribution(distribution_name: str) -> bool: + """ + Check if a distribution name is valid. + + Args: + distribution_name (str): Distribution name to validate + + Returns: + bool: True if valid, False otherwise + """ + return distribution_name in DISTRIBUTIONS + + +def print_distribution_info(distribution_name: str) -> None: + """ + Print detailed information about a distribution. + + Args: + distribution_name (str): Distribution name + """ + dist = get_distribution(distribution_name) + if not dist: + print(f"โŒ Unknown distribution: {distribution_name}") + return + + print(f"\n๐Ÿ“Š Distribution: {distribution_name}") + print(f" Description: {dist['description']}") + print(f" Toolsets:") + for toolset, prob in sorted(dist["toolsets"].items(), key=lambda x: x[1], reverse=True): + print(f" โ€ข {toolset:15} : {prob:3}% chance") + + +if __name__ == "__main__": + """ + Demo and testing of the distributions system + """ + print("๐Ÿ“Š Toolset Distributions Demo") + print("=" * 60) + + # List all distributions + print("\n๐Ÿ“‹ Available Distributions:") + print("-" * 40) + for name, dist in list_distributions().items(): + print(f"\n {name}:") + print(f" {dist['description']}") + toolset_list = ", ".join([f"{ts}({p}%)" for ts, p in dist["toolsets"].items()]) + print(f" Toolsets: {toolset_list}") + + # Demo sampling + print("\n\n๐ŸŽฒ Sampling Examples:") + print("-" * 40) + + test_distributions = ["image_gen", "research", "balanced", "default"] + + for dist_name in test_distributions: + print(f"\n{dist_name}:") + # Sample 5 times to show variability + samples = [] + for _ in range(5): + sampled = sample_toolsets_from_distribution(dist_name) + samples.append(sorted(sampled)) + + print(f" Sample 1: {samples[0]}") + print(f" Sample 2: {samples[1]}") + print(f" Sample 3: {samples[2]}") + print(f" Sample 4: {samples[3]}") + print(f" Sample 5: {samples[4]}") + + # Show detailed info + print("\n\n๐Ÿ“Š Detailed Distribution Info:") + print("-" * 40) + print_distribution_info("image_gen") + print_distribution_info("research") +