diff --git a/.cursorrules b/.cursorrules new file mode 100644 index 000000000..defa74ef3 --- /dev/null +++ b/.cursorrules @@ -0,0 +1,23 @@ +Hermes-Agent is an agent harness for LLMs. + +When building, the tool functionality is in the tools/ directory, where each specific tool (or in some cases, tools that are built for the same execution category or api) are placed in a script each their own. + +Each tool is then consolidated in the model_tools.py file in the repo root. + +There is also a way to consolidate sets of tools in toolsets.py for the agent to use. + +The primary agent runner code is in run_agent, but other runners could be developed using the tools and framework. + +Always ensure consistency between tools, the model_tools.py and toolsets.py when changing any of them, otherwise they could become desynced in a way that is detrimental to functionality. + +The expected pathway for using API keys is to setup and place them in a .env file in the repo root. + +Test scripts will be placed in tests/ + +The run_agent loop is setup to: +- Process the enabled toolsets to provide to the model, +- Pipe in a prompt or problem from the input to the agent, +- Loop the LLM each time it calls a tool, until the model decides no more tools are needed and provides a natural language response, +- Return that response. + +There are additional caveats for logging, where we restructure the "tools" as a system prompt for storage later into a format that can be used and handled properly later. \ No newline at end of file diff --git a/.env.example b/.env.example new file mode 100644 index 000000000..bae4ec14f --- /dev/null +++ b/.env.example @@ -0,0 +1,49 @@ +# Hermes Agent Environment Configuration +# Copy this file to .env and fill in your API keys +# Get API keys from the URLs listed below + +# ============================================================================= +# REQUIRED API KEYS +# ============================================================================= + +# Anthropic API Key - Main agent model +# Get at: https://console.anthropic.com/ +ANTHROPIC_API_KEY= + +# Firecrawl API Key - Web search, extract, and crawl +# Get at: https://firecrawl.dev/ +FIRECRAWL_API_KEY= + +# Nous Research API Key - Vision analysis and multi-model reasoning +# Get at: https://inference-api.nousresearch.com/ +NOUS_API_KEY= + +# Morph API Key - Terminal/command execution tools +# Get at: https://morph.so/ +MORPH_API_KEY= + +# FAL.ai API Key - Image generation +# Get at: https://fal.ai/ +FAL_KEY= + +# ============================================================================= +# OPTIONAL API KEYS +# ============================================================================= + +# OpenAI API Key - Optional, for enhanced Hecate features +# Get at: https://platform.openai.com/ +OPENAI_API_KEY= + +# ============================================================================= +# OPTIONAL CONFIGURATION +# ============================================================================= + +# Terminal Tool Settings +HECATE_VM_LIFETIME_SECONDS=300 +HECATE_DEFAULT_SNAPSHOT_ID=snapshot_p5294qxt + +# Debug Logging (set to "true" to enable, logs saved to ./logs/) +WEB_TOOLS_DEBUG=false +VISION_TOOLS_DEBUG=false +MOA_TOOLS_DEBUG=false +IMAGE_TOOLS_DEBUG=false diff --git a/.gitignore b/.gitignore index 734c47142..ad61aed98 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,17 @@ __pycache__/ export* __pycache__/model_tools.cpython-310.pyc __pycache__/web_tools.cpython-310.pyc -logs/ \ No newline at end of file +logs/ +data/ +.pytest_cache/ +tmp/ +temp_vision_images/ +hermes-*/* +examples/ +tests/quick_test_dataset.jsonl +tests/sample_dataset.jsonl +run_datagen_kimik2-thinking.sh +run_datagen_megascience_glm4-6.sh +run_datagen_sonnet.sh +source-data/* +run_datagen_megascience_glm4-6.sh diff --git a/README.md b/README.md index 541a01a24..65f135667 100644 --- a/README.md +++ b/README.md @@ -10,15 +10,46 @@ An AI agent with advanced tool-calling capabilities, featuring a flexible toolse - **Reasoning Tools**: Advanced multi-model reasoning (Mixture of Agents) - **Creative Tools**: Generate images from text prompts - **Toolsets System**: Organize tools into logical groups for different scenarios +- **Batch Processing**: Process datasets in parallel with checkpointing and statistics tracking +- **Ephemeral System Prompts**: Guide model behavior without polluting training datasets ## Setup + +### 1. Install Dependencies ```bash +# Create and activate virtual environment (recommended) +python3 -m venv venv +source venv/bin/activate # On Windows: venv\Scripts\activate + +# Install required packages pip install -r requirements.txt + +# Install Hecate for terminal tools git clone git@github.com:NousResearch/hecate.git cd hecate pip install -e . +cd .. ``` +### 2. Configure Environment Variables +```bash +# Copy the example environment file +cp .env.example .env + +# Edit .env and add your API keys +nano .env # or use your preferred editor +``` + +**Required API Keys:** +- `ANTHROPIC_API_KEY` - Main agent model (get at: https://console.anthropic.com/) +- `FIRECRAWL_API_KEY` - Web tools (get at: https://firecrawl.dev/) +- `NOUS_API_KEY` - Vision & reasoning tools (get at: https://inference-api.nousresearch.com/) +- `MORPH_API_KEY` - Terminal tools (get at: https://morph.so/) +- `FAL_KEY` - Image generation (get at: https://fal.ai/) +- `OPENAI_API_KEY` - Optional, for some Hecate features + +See `.env.example` for all available configuration options including debug settings and terminal tool configuration. + ## Toolsets System The agent uses a toolsets system for organizing and managing tools. All tools must be part of a toolset to be accessible - individual tool selection is not supported. This ensures consistent and logical grouping of capabilities. @@ -47,6 +78,9 @@ python run_agent.py --enabled_toolsets=research --query "Find latest AI papers" # Combine multiple toolsets python run_agent.py --enabled_toolsets=web,vision --query "Analyze this website" +# Enable all toolsets explicitly (same as omitting the flag) +python run_agent.py --enabled_toolsets=all --query "Do web research and run commands if helpful" + # Safe mode (no terminal access) python run_agent.py --enabled_toolsets=safe --query "Help without running commands" @@ -101,34 +135,109 @@ create_custom_toolset( agent = AIAgent(enabled_toolsets=["my_tools"]) ``` +## Batch Processing + +Process multiple prompts from a dataset in parallel with automatic checkpointing and statistics tracking: + +```bash +# Basic batch processing +python batch_runner.py \ + --dataset_file=prompts.jsonl \ + --batch_size=20 \ + --run_name=my_run + +# With specific distribution +python batch_runner.py \ + --dataset_file=prompts.jsonl \ + --batch_size=20 \ + --run_name=image_run \ + --distribution=image_gen \ + --num_workers=4 +``` + +**Key Features:** +- Parallel processing with configurable workers +- Toolset distributions for varied data generation +- Automatic checkpointing and resume capability +- Combined output in `data//trajectories.jsonl` +- Tool usage statistics and success rates + +**Quick Start:** See [QUICKSTART_BATCH.md](QUICKSTART_BATCH.md) for a 5-minute getting started guide. +**Full Documentation:** See [BATCH_PROCESSING.md](BATCH_PROCESSING.md) for comprehensive documentation. + +### Ephemeral System Prompts + +The ephemeral system prompt feature allows you to guide the model's behavior during batch processing **without** saving that prompt to the training dataset trajectories. This is useful for: + +- Guiding model behavior during data collection +- Adding task-specific instructions +- Keeping saved trajectories clean and focused on tool-calling format + +**Example:** +```bash +python batch_runner.py \ + --dataset_file=prompts.jsonl \ + --batch_size=10 \ + --run_name=my_run \ + --ephemeral_system_prompt="You are a helpful assistant focused on image generation." +``` + +The ephemeral prompt will influence the model's behavior during execution, but **only the standard tool-calling system prompt** will be saved in the trajectory files. + +**Documentation:** See [docs/ephemeral_system_prompt.md](docs/ephemeral_system_prompt.md) for complete details. + ## Command Line Arguments +**Single Agent (`run_agent.py`):** - `--query`: The question or task for the agent - `--model`: Model to use (default: claude-opus-4-20250514) - `--api_key`: API key for authentication - `--base_url`: API endpoint URL - `--max_turns`: Maximum number of tool-calling iterations -- `--enabled_toolsets`: Comma-separated list of toolsets to enable +- `--enabled_toolsets`: Comma-separated list of toolsets to enable. Use `all` (or `*`) to enable everything. If omitted, all toolsets are enabled by default. - `--disabled_toolsets`: Comma-separated list of toolsets to disable - `--list_tools`: List all available toolsets and tools - `--save_trajectories`: Save conversation trajectories to JSONL files +**Batch Processing (`batch_runner.py`):** +- `--dataset_file`: Path to JSONL file with prompts +- `--batch_size`: Number of prompts per batch +- `--run_name`: Name for this run (for output/checkpointing) +- `--distribution`: Toolset distribution to use (default: "default") +- `--num_workers`: Number of parallel workers (default: 4) +- `--resume`: Resume from checkpoint if interrupted +- `--ephemeral_system_prompt`: System prompt used during execution but NOT saved to trajectories +- `--list_distributions`: List available toolset distributions + ## Environment Variables -Set these environment variables to enable different tools: +All environment variables can be configured in the `.env` file (copy from `.env.example`). -- `FIRECRAWL_API_KEY`: For web tools (search, extract, crawl) -- `MORPH_API_KEY`: For terminal tools -- `NOUS_API_KEY`: For vision and reasoning tools -- `FAL_KEY`: For image generation tools -- `ANTHROPIC_API_KEY`: For the main agent model +**Core API Keys:** +- `ANTHROPIC_API_KEY`: Main agent model +- `FIRECRAWL_API_KEY`: Web tools (search, extract, crawl) +- `NOUS_API_KEY`: Vision and reasoning tools +- `MORPH_API_KEY`: Terminal tools +- `FAL_KEY`: Image generation tools +- `OPENAI_API_KEY`: Optional, for some Hecate features + +**Configuration Options:** +- `HECATE_VM_LIFETIME_SECONDS`: VM lifetime (default: 300) +- `HECATE_DEFAULT_SNAPSHOT_ID`: Default snapshot (default: snapshot_p5294qxt) +- `WEB_TOOLS_DEBUG`, `VISION_TOOLS_DEBUG`, `MOA_TOOLS_DEBUG`, `IMAGE_TOOLS_DEBUG`: Enable debug logging ## Documentation +**Single Agent Usage:** - `TOOLSETS_README.md`: Comprehensive guide to the toolsets system - `toolsets.py`: View and modify available toolsets - `model_tools.py`: Core tool definitions and handlers +**Batch Processing:** +- `QUICKSTART_BATCH.md`: 5-minute quick start guide +- `BATCH_PROCESSING.md`: Complete batch processing documentation +- `toolset_distributions.py`: Toolset distributions for data generation + ## Examples See `TOOLSETS_README.md` for extensive examples of using different toolsets for various scenarios. diff --git a/batch_runner.py b/batch_runner.py new file mode 100644 index 000000000..2487d9fb1 --- /dev/null +++ b/batch_runner.py @@ -0,0 +1,753 @@ +#!/usr/bin/env python3 +""" +Batch Agent Runner + +This module provides parallel batch processing capabilities for running the agent +across multiple prompts from a dataset. It includes: +- Dataset loading and batching +- Parallel batch processing with multiprocessing +- Checkpointing for fault tolerance and resumption +- Trajectory saving in the proper format (from/value pairs) +- Tool usage statistics aggregation across all batches + +Usage: + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run + + # Resume an interrupted run + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume + + # Use a specific toolset distribution + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen +""" + +import json +import logging +import os +import time +from pathlib import Path +from typing import List, Dict, Any, Optional, Tuple +from datetime import datetime +from multiprocessing import Pool, Manager, Lock +import traceback + +import fire + +from run_agent import AIAgent +from toolset_distributions import ( + get_distribution, + list_distributions, + sample_toolsets_from_distribution, + validate_distribution +) + + +# Global configuration for worker processes +_WORKER_CONFIG = {} + + +def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]: + """ + Extract tool usage statistics from message history. + + Args: + messages (List[Dict]): Message history + + Returns: + Dict: Tool statistics with counts and success/failure rates + """ + tool_stats = {} + + # Track tool calls and their results + tool_calls_map = {} # Map tool_call_id to tool name + + for msg in messages: + # Track tool calls from assistant messages + if msg["role"] == "assistant" and "tool_calls" in msg and msg["tool_calls"]: + for tool_call in msg["tool_calls"]: + tool_name = tool_call["function"]["name"] + tool_call_id = tool_call["id"] + + # Initialize stats for this tool if not exists + if tool_name not in tool_stats: + tool_stats[tool_name] = { + "count": 0, + "success": 0, + "failure": 0 + } + + tool_stats[tool_name]["count"] += 1 + tool_calls_map[tool_call_id] = tool_name + + # Track tool responses + elif msg["role"] == "tool": + tool_call_id = msg.get("tool_call_id", "") + content = msg.get("content", "") + + # Determine if tool call was successful + is_success = True + try: + # Try to parse as JSON and check for actual error values + content_json = json.loads(content) if isinstance(content, str) else content + + if isinstance(content_json, dict): + # Check if error field exists AND has a non-null value + if "error" in content_json and content_json["error"] is not None: + is_success = False + + # Special handling for terminal tool responses + # Terminal wraps its response in a "content" field + if "content" in content_json and isinstance(content_json["content"], dict): + inner_content = content_json["content"] + # Check for actual error (non-null error field) + # Note: non-zero exit codes are not failures - the model can self-correct + if inner_content.get("error") is not None: + is_success = False + + # Check for "success": false pattern used by some tools + if content_json.get("success") is False: + is_success = False + + except: + # If not JSON, check if content is empty or explicitly states an error + # Note: We avoid simple substring matching to prevent false positives + if not content: + is_success = False + # Only mark as failure if it explicitly starts with "Error:" or "ERROR:" + elif content.strip().lower().startswith("error:"): + is_success = False + + # Update success/failure count + if tool_call_id in tool_calls_map: + tool_name = tool_calls_map[tool_call_id] + if is_success: + tool_stats[tool_name]["success"] += 1 + else: + tool_stats[tool_name]["failure"] += 1 + + return tool_stats + + +def _process_single_prompt( + prompt_index: int, + prompt_data: Dict[str, Any], + batch_num: int, + config: Dict[str, Any] +) -> Dict[str, Any]: + """ + Process a single prompt with the agent. + + Args: + prompt_index (int): Index of prompt in dataset + prompt_data (Dict): Prompt data containing 'prompt' field + batch_num (int): Batch number + config (Dict): Configuration dict with agent parameters + + Returns: + Dict: Result containing trajectory, stats, and metadata + """ + prompt = prompt_data["prompt"] + + try: + # Sample toolsets from distribution for this prompt + selected_toolsets = sample_toolsets_from_distribution(config["distribution"]) + + if config.get("verbose"): + print(f" Prompt {prompt_index}: Using toolsets {selected_toolsets}") + + # Initialize agent with sampled toolsets + agent = AIAgent( + base_url=config.get("base_url"), + api_key=config.get("api_key"), + model=config["model"], + max_iterations=config["max_iterations"], + enabled_toolsets=selected_toolsets, + save_trajectories=False, # We handle saving ourselves + verbose_logging=config.get("verbose", False), + ephemeral_system_prompt=config.get("ephemeral_system_prompt"), + log_prefix_chars=config.get("log_prefix_chars", 100) + ) + + # Run the agent with task_id to ensure each task gets its own isolated VM + result = agent.run_conversation(prompt, task_id=f"task_{prompt_index}") + + # Extract tool usage statistics + tool_stats = _extract_tool_stats(result["messages"]) + + # Convert to trajectory format (using existing method) + trajectory = agent._convert_to_trajectory_format( + result["messages"], + prompt, + result["completed"] + ) + + return { + "success": True, + "prompt_index": prompt_index, + "trajectory": trajectory, + "tool_stats": tool_stats, + "completed": result["completed"], + "api_calls": result["api_calls"], + "toolsets_used": selected_toolsets, + "metadata": { + "batch_num": batch_num, + "timestamp": datetime.now().isoformat(), + "model": config["model"] + } + } + + except Exception as e: + print(f"āŒ Error processing prompt {prompt_index}: {e}") + if config.get("verbose"): + traceback.print_exc() + + return { + "success": False, + "prompt_index": prompt_index, + "error": str(e), + "trajectory": None, + "tool_stats": {}, + "toolsets_used": [], + "metadata": { + "batch_num": batch_num, + "timestamp": datetime.now().isoformat() + } + } + + +def _process_batch_worker(args: Tuple) -> Dict[str, Any]: + """ + Worker function to process a single batch of prompts. + + Args: + args (Tuple): (batch_num, batch_data, output_dir, completed_prompts, config) + + Returns: + Dict: Batch results with statistics + """ + batch_num, batch_data, output_dir, completed_prompts_set, config = args + + output_dir = Path(output_dir) + print(f"\nšŸ”„ Batch {batch_num}: Starting ({len(batch_data)} prompts)") + + # Output file for this batch + batch_output_file = output_dir / f"batch_{batch_num}.jsonl" + + # Filter out already completed prompts + prompts_to_process = [ + (idx, data) for idx, data in batch_data + if idx not in completed_prompts_set + ] + + if not prompts_to_process: + print(f"āœ… Batch {batch_num}: Already completed (skipping)") + return { + "batch_num": batch_num, + "processed": 0, + "skipped": len(batch_data), + "tool_stats": {}, + "completed_prompts": [] + } + + print(f" Processing {len(prompts_to_process)} prompts (skipping {len(batch_data) - len(prompts_to_process)} already completed)") + + # Initialize aggregated stats for this batch + batch_tool_stats = {} + completed_in_batch = [] + + # Process each prompt sequentially in this batch + for prompt_index, prompt_data in prompts_to_process: + # Process the prompt + result = _process_single_prompt( + prompt_index, + prompt_data, + batch_num, + config + ) + + # Save trajectory if successful + if result["success"] and result["trajectory"]: + trajectory_entry = { + "prompt_index": prompt_index, + "conversations": result["trajectory"], + "metadata": result["metadata"], + "completed": result["completed"], + "api_calls": result["api_calls"], + "toolsets_used": result["toolsets_used"] + } + + # Append to batch output file + with open(batch_output_file, 'a', encoding='utf-8') as f: + f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n") + + # Aggregate tool statistics + for tool_name, stats in result.get("tool_stats", {}).items(): + if tool_name not in batch_tool_stats: + batch_tool_stats[tool_name] = { + "count": 0, + "success": 0, + "failure": 0 + } + + batch_tool_stats[tool_name]["count"] += stats["count"] + batch_tool_stats[tool_name]["success"] += stats["success"] + batch_tool_stats[tool_name]["failure"] += stats["failure"] + + completed_in_batch.append(prompt_index) + print(f" āœ… Prompt {prompt_index} completed") + + print(f"āœ… Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)") + + return { + "batch_num": batch_num, + "processed": len(prompts_to_process), + "skipped": len(batch_data) - len(prompts_to_process), + "tool_stats": batch_tool_stats, + "completed_prompts": completed_in_batch + } + + +class BatchRunner: + """ + Manages batch processing of agent prompts with checkpointing and statistics. + """ + + def __init__( + self, + dataset_file: str, + batch_size: int, + run_name: str, + distribution: str = "default", + max_iterations: int = 10, + base_url: str = None, + api_key: str = None, + model: str = "claude-opus-4-20250514", + num_workers: int = 4, + verbose: bool = False, + ephemeral_system_prompt: str = None, + log_prefix_chars: int = 100, + ): + """ + Initialize the batch runner. + + Args: + dataset_file (str): Path to the dataset JSONL file with 'prompt' field + batch_size (int): Number of prompts per batch + run_name (str): Name for this run (used for checkpointing and output) + distribution (str): Toolset distribution to use (default: "default") + max_iterations (int): Max iterations per agent run + base_url (str): Base URL for model API + api_key (str): API key for model + model (str): Model name to use + num_workers (int): Number of parallel workers + verbose (bool): Enable verbose logging + ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional) + log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20) + """ + self.dataset_file = Path(dataset_file) + self.batch_size = batch_size + self.run_name = run_name + self.distribution = distribution + self.max_iterations = max_iterations + self.base_url = base_url + self.api_key = api_key + self.model = model + self.num_workers = num_workers + self.verbose = verbose + self.ephemeral_system_prompt = ephemeral_system_prompt + self.log_prefix_chars = log_prefix_chars + + # Validate distribution + if not validate_distribution(distribution): + raise ValueError(f"Unknown distribution: {distribution}. Available: {list(list_distributions().keys())}") + + # Setup output directory + self.output_dir = Path("data") / run_name + self.output_dir.mkdir(parents=True, exist_ok=True) + + # Checkpoint file + self.checkpoint_file = self.output_dir / "checkpoint.json" + + # Statistics file + self.stats_file = self.output_dir / "statistics.json" + + # Load dataset + self.dataset = self._load_dataset() + + # Create batches + self.batches = self._create_batches() + + print(f"šŸ“Š Batch Runner Initialized") + print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)") + print(f" Batch size: {self.batch_size}") + print(f" Total batches: {len(self.batches)}") + print(f" Run name: {self.run_name}") + print(f" Distribution: {self.distribution}") + print(f" Output directory: {self.output_dir}") + print(f" Workers: {self.num_workers}") + if self.ephemeral_system_prompt: + prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt + print(f" šŸ”’ Ephemeral system prompt: '{prompt_preview}'") + + def _load_dataset(self) -> List[Dict[str, Any]]: + """ + Load dataset from JSONL file. + + Returns: + List[Dict]: List of dataset entries + """ + if not self.dataset_file.exists(): + raise FileNotFoundError(f"Dataset file not found: {self.dataset_file}") + + dataset = [] + with open(self.dataset_file, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + + try: + entry = json.loads(line) + if 'prompt' not in entry: + print(f"āš ļø Warning: Line {line_num} missing 'prompt' field, skipping") + continue + dataset.append(entry) + except json.JSONDecodeError as e: + print(f"āš ļø Warning: Invalid JSON on line {line_num}: {e}") + continue + + if not dataset: + raise ValueError(f"No valid entries found in dataset file: {self.dataset_file}") + + return dataset + + def _create_batches(self) -> List[List[Tuple[int, Dict[str, Any]]]]: + """ + Split dataset into batches with indices. + + Returns: + List of batches, where each batch is a list of (index, entry) tuples + """ + batches = [] + for i in range(0, len(self.dataset), self.batch_size): + batch = [(idx, entry) for idx, entry in enumerate(self.dataset[i:i + self.batch_size], start=i)] + batches.append(batch) + + return batches + + def _load_checkpoint(self) -> Dict[str, Any]: + """ + Load checkpoint data if it exists. + + Returns: + Dict: Checkpoint data with completed prompt indices + """ + if not self.checkpoint_file.exists(): + return { + "run_name": self.run_name, + "completed_prompts": [], + "batch_stats": {}, + "last_updated": None + } + + try: + with open(self.checkpoint_file, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + print(f"āš ļø Warning: Failed to load checkpoint: {e}") + return { + "run_name": self.run_name, + "completed_prompts": [], + "batch_stats": {}, + "last_updated": None + } + + def _save_checkpoint(self, checkpoint_data: Dict[str, Any], lock: Optional[Lock] = None): + """ + Save checkpoint data. + + Args: + checkpoint_data (Dict): Checkpoint data to save + lock (Lock): Optional lock for thread-safe access + """ + checkpoint_data["last_updated"] = datetime.now().isoformat() + + if lock: + with lock: + with open(self.checkpoint_file, 'w', encoding='utf-8') as f: + json.dump(checkpoint_data, f, indent=2, ensure_ascii=False) + else: + with open(self.checkpoint_file, 'w', encoding='utf-8') as f: + json.dump(checkpoint_data, f, indent=2, ensure_ascii=False) + + + def run(self, resume: bool = False): + """ + Run the batch processing pipeline. + + Args: + resume (bool): Whether to resume from checkpoint + """ + print("\n" + "=" * 70) + print("šŸš€ Starting Batch Processing") + print("=" * 70) + + # Load checkpoint + checkpoint_data = self._load_checkpoint() if resume else { + "run_name": self.run_name, + "completed_prompts": [], + "batch_stats": {}, + "last_updated": None + } + + if resume and checkpoint_data.get("completed_prompts"): + print(f"šŸ“‚ Resuming from checkpoint ({len(checkpoint_data['completed_prompts'])} prompts already completed)") + + # Prepare configuration for workers + config = { + "distribution": self.distribution, + "model": self.model, + "max_iterations": self.max_iterations, + "base_url": self.base_url, + "api_key": self.api_key, + "verbose": self.verbose, + "ephemeral_system_prompt": self.ephemeral_system_prompt, + "log_prefix_chars": self.log_prefix_chars + } + + # Get completed prompts set + completed_prompts_set = set(checkpoint_data.get("completed_prompts", [])) + + # Aggregate statistics across all batches + total_tool_stats = {} + + start_time = time.time() + + # Process batches in parallel + with Pool(processes=self.num_workers) as pool: + # Create tasks for each batch + tasks = [ + ( + batch_num, + batch_data, + str(self.output_dir), # Convert Path to string for pickling + completed_prompts_set, + config + ) + for batch_num, batch_data in enumerate(self.batches) + ] + + # Use map to process batches in parallel + results = pool.map(_process_batch_worker, tasks) + + # Aggregate all batch statistics and update checkpoint + all_completed_prompts = list(completed_prompts_set) + for batch_result in results: + # Add newly completed prompts + all_completed_prompts.extend(batch_result.get("completed_prompts", [])) + + # Aggregate tool stats + for tool_name, stats in batch_result.get("tool_stats", {}).items(): + if tool_name not in total_tool_stats: + total_tool_stats[tool_name] = { + "count": 0, + "success": 0, + "failure": 0 + } + + total_tool_stats[tool_name]["count"] += stats["count"] + total_tool_stats[tool_name]["success"] += stats["success"] + total_tool_stats[tool_name]["failure"] += stats["failure"] + + # Save final checkpoint + checkpoint_data["completed_prompts"] = all_completed_prompts + self._save_checkpoint(checkpoint_data) + + # Calculate success rates + for tool_name in total_tool_stats: + stats = total_tool_stats[tool_name] + total_calls = stats["success"] + stats["failure"] + if total_calls > 0: + stats["success_rate"] = round(stats["success"] / total_calls * 100, 2) + stats["failure_rate"] = round(stats["failure"] / total_calls * 100, 2) + else: + stats["success_rate"] = 0.0 + stats["failure_rate"] = 0.0 + + # Combine all batch files into a single trajectories.jsonl file + combined_file = self.output_dir / "trajectories.jsonl" + print(f"\nšŸ“¦ Combining batch files into {combined_file.name}...") + + with open(combined_file, 'w', encoding='utf-8') as outfile: + for batch_num in range(len(self.batches)): + batch_file = self.output_dir / f"batch_{batch_num}.jsonl" + if batch_file.exists(): + with open(batch_file, 'r', encoding='utf-8') as infile: + for line in infile: + outfile.write(line) + + print(f"āœ… Combined {len(self.batches)} batch files into trajectories.jsonl") + + # Save final statistics + final_stats = { + "run_name": self.run_name, + "distribution": self.distribution, + "total_prompts": len(self.dataset), + "total_batches": len(self.batches), + "batch_size": self.batch_size, + "model": self.model, + "completed_at": datetime.now().isoformat(), + "duration_seconds": round(time.time() - start_time, 2), + "tool_statistics": total_tool_stats + } + + with open(self.stats_file, 'w', encoding='utf-8') as f: + json.dump(final_stats, f, indent=2, ensure_ascii=False) + + # Print summary + print("\n" + "=" * 70) + print("šŸ“Š BATCH PROCESSING COMPLETE") + print("=" * 70) + print(f"āœ… Total prompts processed: {len(self.dataset)}") + print(f"āœ… Total batches: {len(self.batches)}") + print(f"ā±ļø Total duration: {round(time.time() - start_time, 2)}s") + print(f"\nšŸ“ˆ Tool Usage Statistics:") + print("-" * 70) + + if total_tool_stats: + # Sort by count descending + sorted_tools = sorted( + total_tool_stats.items(), + key=lambda x: x[1]["count"], + reverse=True + ) + + print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}") + print("-" * 70) + for tool_name, stats in sorted_tools: + print( + f"{tool_name:<25} " + f"{stats['count']:<10} " + f"{stats['success']:<10} " + f"{stats['failure']:<10} " + f"{stats['success_rate']:.1f}%" + ) + else: + print("No tool calls were made during this run.") + + print(f"\nšŸ’¾ Results saved to: {self.output_dir}") + print(f" - Trajectories: trajectories.jsonl (combined)") + print(f" - Individual batches: batch_*.jsonl (for debugging)") + print(f" - Statistics: {self.stats_file.name}") + print(f" - Checkpoint: {self.checkpoint_file.name}") + + +def main( + dataset_file: str = None, + batch_size: int = None, + run_name: str = None, + distribution: str = "default", + model: str = "claude-opus-4-20250514", + api_key: str = None, + base_url: str = "https://api.anthropic.com/v1/", + max_turns: int = 10, + num_workers: int = 4, + resume: bool = False, + verbose: bool = False, + list_distributions: bool = False, + ephemeral_system_prompt: str = None, + log_prefix_chars: int = 100, +): + """ + Run batch processing of agent prompts from a dataset. + + Args: + dataset_file (str): Path to JSONL file with 'prompt' field in each entry + batch_size (int): Number of prompts per batch + run_name (str): Name for this run (used for output and checkpointing) + distribution (str): Toolset distribution to use (default: "default") + model (str): Model name to use (default: "claude-opus-4-20250514") + api_key (str): API key for model authentication + base_url (str): Base URL for model API + max_turns (int): Maximum number of tool calling iterations per prompt (default: 10) + num_workers (int): Number of parallel worker processes (default: 4) + resume (bool): Resume from checkpoint if run was interrupted (default: False) + verbose (bool): Enable verbose logging (default: False) + list_distributions (bool): List available toolset distributions and exit + ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional) + log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20) + + Examples: + # Basic usage + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run + + # Resume interrupted run + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume + + # Use specific distribution + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=image_test --distribution=image_gen + + # With ephemeral system prompt (not saved to dataset) + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ + --ephemeral_system_prompt="You are a helpful assistant focused on image generation." + + # List available distributions + python batch_runner.py --list_distributions + """ + # Handle list distributions + if list_distributions: + from toolset_distributions import list_distributions as get_all_dists, print_distribution_info + + print("šŸ“Š Available Toolset Distributions") + print("=" * 70) + + all_dists = get_all_dists() + for dist_name in sorted(all_dists.keys()): + print_distribution_info(dist_name) + + print("\nšŸ’” Usage:") + print(" python batch_runner.py --dataset_file=data.jsonl --batch_size=10 \\") + print(" --run_name=my_run --distribution=") + return + + # Validate required arguments + if not dataset_file: + print("āŒ Error: --dataset_file is required") + return + + if not batch_size or batch_size < 1: + print("āŒ Error: --batch_size must be a positive integer") + return + + if not run_name: + print("āŒ Error: --run_name is required") + return + + # Initialize and run batch runner + try: + runner = BatchRunner( + dataset_file=dataset_file, + batch_size=batch_size, + run_name=run_name, + distribution=distribution, + max_iterations=max_turns, + base_url=base_url, + api_key=api_key, + model=model, + num_workers=num_workers, + verbose=verbose, + ephemeral_system_prompt=ephemeral_system_prompt, + log_prefix_chars=log_prefix_chars + ) + + runner.run(resume=resume) + + except Exception as e: + print(f"\nāŒ Fatal error: {e}") + if verbose: + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + fire.Fire(main) + diff --git a/model_tools.py b/model_tools.py index 42f068604..eb27b7535 100644 --- a/model_tools.py +++ b/model_tools.py @@ -28,13 +28,15 @@ Usage: import json import asyncio -from typing import Dict, Any, List +from typing import Dict, Any, List, Optional -from web_tools import web_search_tool, web_extract_tool, web_crawl_tool, check_firecrawl_api_key -from terminal_tool import terminal_tool, check_hecate_requirements, TERMINAL_TOOL_DESCRIPTION -from vision_tools import vision_analyze_tool, check_vision_requirements -from mixture_of_agents_tool import mixture_of_agents_tool, check_moa_requirements -from image_generation_tool import image_generate_tool, check_image_generation_requirements +from tools.web_tools import web_search_tool, web_extract_tool, web_crawl_tool, check_firecrawl_api_key +from tools.simple_terminal_tool import simple_terminal_tool, check_requirements as check_simple_terminal_requirements, SIMPLE_TERMINAL_TOOL_DESCRIPTION +# Keep old terminal tool for backwards compatibility if needed +# from tools.terminal_tool import terminal_tool, check_hecate_requirements, TERMINAL_TOOL_DESCRIPTION +from tools.vision_tools import vision_analyze_tool, check_vision_requirements +from tools.mixture_of_agents_tool import mixture_of_agents_tool, check_moa_requirements +from tools.image_generation_tool import image_generate_tool, check_image_generation_requirements from toolsets import ( get_toolset, resolve_toolset, resolve_multiple_toolsets, get_all_toolsets, get_toolset_names, validate_toolset, @@ -111,7 +113,7 @@ def get_web_tool_definitions() -> List[Dict[str, Any]]: def get_terminal_tool_definitions() -> List[Dict[str, Any]]: """ Get tool definitions for terminal tools in OpenAI's expected format. - + Returns: List[Dict]: List of terminal tool definitions compatible with OpenAI API """ @@ -120,7 +122,7 @@ def get_terminal_tool_definitions() -> List[Dict[str, Any]]: "type": "function", "function": { "name": "terminal", - "description": TERMINAL_TOOL_DESCRIPTION, + "description": SIMPLE_TERMINAL_TOOL_DESCRIPTION, "parameters": { "type": "object", "properties": { @@ -128,28 +130,18 @@ def get_terminal_tool_definitions() -> List[Dict[str, Any]]: "type": "string", "description": "The command to execute on the VM" }, - "input_keys": { - "type": "string", - "description": "Keystrokes to send to the most recent interactive session (e.g., 'hello\\n' for typing hello + Enter). If no active session exists, this will be ignored." - }, "background": { "type": "boolean", "description": "Whether to run the command in the background (default: false)", "default": False }, - "idle_threshold": { - "type": "number", - "description": "Seconds to wait for output before considering session idle (default: 5.0)", - "default": 5.0, - "minimum": 0.1 - }, "timeout": { "type": "integer", "description": "Command timeout in seconds (optional)", "minimum": 1 } }, - "required": [] + "required": ["command"] } } } @@ -262,11 +254,11 @@ def get_all_tool_names() -> List[str]: # Web tools if check_firecrawl_api_key(): tool_names.extend(["web_search", "web_extract", "web_crawl"]) - - # Terminal tools - if check_hecate_requirements(): + + # Terminal tools + if check_simple_terminal_requirements(): tool_names.extend(["terminal"]) - + # Vision tools if check_vision_requirements(): tool_names.extend(["vision_analyze"]) @@ -346,11 +338,11 @@ def get_tool_definitions( if check_firecrawl_api_key(): for tool in get_web_tool_definitions(): all_available_tools_map[tool["function"]["name"]] = tool - - if check_hecate_requirements(): + + if check_simple_terminal_requirements(): for tool in get_terminal_tool_definitions(): all_available_tools_map[tool["function"]["name"]] = tool - + if check_vision_requirements(): for tool in get_vision_tool_definitions(): all_available_tools_map[tool["function"]["name"]] = tool @@ -478,30 +470,29 @@ def handle_web_function_call(function_name: str, function_args: Dict[str, Any]) return asyncio.run(web_crawl_tool(url, instructions, "basic")) else: - return json.dumps({"error": f"Unknown web function: {function_name}"}) + return json.dumps({"error": f"Unknown web function: {function_name}"}, ensure_ascii=False) -def handle_terminal_function_call(function_name: str, function_args: Dict[str, Any]) -> str: +def handle_terminal_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str: """ Handle function calls for terminal tools. - + Args: function_name (str): Name of the terminal function to call function_args (Dict): Arguments for the function - + task_id (str): Unique identifier for this task to isolate VMs between concurrent tasks (optional) + Returns: str: Function result as JSON string """ if function_name == "terminal": command = function_args.get("command") - input_keys = function_args.get("input_keys") background = function_args.get("background", False) - idle_threshold = function_args.get("idle_threshold", 5.0) timeout = function_args.get("timeout") - return terminal_tool(command, input_keys, None, background, idle_threshold, timeout) - + return simple_terminal_tool(command=command, background=background, timeout=timeout, task_id=task_id) + else: - return json.dumps({"error": f"Unknown terminal function: {function_name}"}) + return json.dumps({"error": f"Unknown terminal function: {function_name}"}, ensure_ascii=False) def handle_vision_function_call(function_name: str, function_args: Dict[str, Any]) -> str: @@ -525,7 +516,7 @@ def handle_vision_function_call(function_name: str, function_args: Dict[str, Any return asyncio.run(vision_analyze_tool(image_url, full_prompt, "gemini-2.5-flash")) else: - return json.dumps({"error": f"Unknown vision function: {function_name}"}) + return json.dumps({"error": f"Unknown vision function: {function_name}"}, ensure_ascii=False) def handle_moa_function_call(function_name: str, function_args: Dict[str, Any]) -> str: @@ -543,13 +534,13 @@ def handle_moa_function_call(function_name: str, function_args: Dict[str, Any]) user_prompt = function_args.get("user_prompt", "") if not user_prompt: - return json.dumps({"error": "user_prompt is required for MoA processing"}) + return json.dumps({"error": "user_prompt is required for MoA processing"}, ensure_ascii=False) # Run async function in event loop return asyncio.run(mixture_of_agents_tool(user_prompt=user_prompt)) else: - return json.dumps({"error": f"Unknown MoA function: {function_name}"}) + return json.dumps({"error": f"Unknown MoA function: {function_name}"}, ensure_ascii=False) def handle_image_function_call(function_name: str, function_args: Dict[str, Any]) -> str: @@ -567,7 +558,7 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any] prompt = function_args.get("prompt", "") if not prompt: - return json.dumps({"success": False, "image": None}) + return json.dumps({"success": False, "image": None}, ensure_ascii=False) image_size = function_args.get("image_size", "landscape_16_9") @@ -581,8 +572,21 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any] allow_nsfw_images = True seed = None - # Run async function in event loop - return asyncio.run(image_generate_tool( + # Run async function in event loop with proper handling for multiprocessing + try: + # Try to get existing event loop + loop = asyncio.get_event_loop() + if loop.is_closed(): + # If closed, create a new one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + # No event loop in current thread, create one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Run the coroutine in the event loop + result = loop.run_until_complete(image_generate_tool( prompt=prompt, image_size=image_size, num_inference_steps=num_inference_steps, @@ -594,26 +598,29 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any] allow_nsfw_images=allow_nsfw_images, seed=seed )) + + return result else: - return json.dumps({"error": f"Unknown image generation function: {function_name}"}) + return json.dumps({"error": f"Unknown image generation function: {function_name}"}, ensure_ascii=False) -def handle_function_call(function_name: str, function_args: Dict[str, Any]) -> str: +def handle_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str: """ Main function call dispatcher that routes calls to appropriate toolsets. - + This function determines which toolset a function belongs to and dispatches the call to the appropriate handler. This makes it easy to add new toolsets without changing the main calling interface. - + Args: function_name (str): Name of the function to call function_args (Dict): Arguments for the function - + task_id (str): Unique identifier for this task to isolate VMs between concurrent tasks (optional) + Returns: str: Function result as JSON string - + Raises: None: Returns error as JSON string instead of raising exceptions """ @@ -621,32 +628,33 @@ def handle_function_call(function_name: str, function_args: Dict[str, Any]) -> s # Route web tools if function_name in ["web_search", "web_extract", "web_crawl"]: return handle_web_function_call(function_name, function_args) - + # Route terminal tools elif function_name in ["terminal"]: - return handle_terminal_function_call(function_name, function_args) - + return handle_terminal_function_call(function_name, function_args, task_id) + # Route vision tools elif function_name in ["vision_analyze"]: return handle_vision_function_call(function_name, function_args) - + # Route MoA tools elif function_name in ["mixture_of_agents"]: return handle_moa_function_call(function_name, function_args) - + # Route image generation tools elif function_name in ["image_generate"]: return handle_image_function_call(function_name, function_args) - + else: error_msg = f"Unknown function: {function_name}" print(f"āŒ {error_msg}") - return json.dumps({"error": error_msg}) + + return json.dumps({"error": error_msg}, ensure_ascii=False) except Exception as e: error_msg = f"Error executing {function_name}: {str(e)}" print(f"āŒ {error_msg}") - return json.dumps({"error": error_msg}) + return json.dumps({"error": error_msg}, ensure_ascii=False) def get_available_toolsets() -> Dict[str, Dict[str, Any]]: """ @@ -663,10 +671,10 @@ def get_available_toolsets() -> Dict[str, Dict[str, Any]]: "requirements": ["FIRECRAWL_API_KEY environment variable"] }, "terminal_tools": { - "available": check_hecate_requirements(), - "tools": ["terminal_tool"], - "description": "Execute commands with optional interactive session support on Linux VMs", - "requirements": ["MORPH_API_KEY environment variable", "hecate package"] + "available": check_simple_terminal_requirements(), + "tools": ["simple_terminal_tool"], + "description": "Execute commands on secure Linux VMs without session persistence", + "requirements": ["MORPH_API_KEY environment variable"] }, "vision_tools": { "available": check_vision_requirements(), @@ -693,13 +701,13 @@ def get_available_toolsets() -> Dict[str, Dict[str, Any]]: def check_toolset_requirements() -> Dict[str, bool]: """ Check if all requirements for available toolsets are met. - + Returns: Dict: Status of each toolset's requirements """ return { "web_tools": check_firecrawl_api_key(), - "terminal_tools": check_hecate_requirements(), + "terminal_tools": check_simple_terminal_requirements(), "vision_tools": check_vision_requirements(), "moa_tools": check_moa_requirements(), "image_tools": check_image_generation_requirements() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..17ad4e69e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,28 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "hermes-agent" +version = "0.1.0" +description = "AI agent with advanced tool-calling and toolsets" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Hermes Agent" }] +license = { text = "MIT" } +dependencies = [ + "firecrawl-py", + "openai", + "fal-client", + "python-dotenv", + "fire" +] + +[project.scripts] +hermes-agent = "run_agent:main" + +[tool.setuptools] +py-modules = ["run_agent", "model_tools", "toolsets"] + +[tool.setuptools.packages.find] +include = ["tools"] diff --git a/requirements.txt b/requirements.txt index a8c9eda41..f9b9514dd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,7 @@ openai fal-client fire git@github.com:NousResearch/hecate.git -tenacity \ No newline at end of file +tenacity +python-dotenv +fire +httpx diff --git a/run_agent.py b/run_agent.py index 5fabfe078..97fb37087 100644 --- a/run_agent.py +++ b/run_agent.py @@ -28,9 +28,22 @@ from typing import List, Dict, Any, Optional from openai import OpenAI import fire from datetime import datetime +from pathlib import Path + +# Load environment variables from .env file +from dotenv import load_dotenv + +# Load .env file if it exists +env_path = Path(__file__).parent / '.env' +if env_path.exists(): + load_dotenv(dotenv_path=env_path) + print(f"āœ… Loaded environment variables from {env_path}") +else: + print(f"ā„¹ļø No .env file found at {env_path}. Using system environment variables.") # Import our tool system from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements +from tools.terminal_tool import cleanup_vm class AIAgent: @@ -42,20 +55,22 @@ class AIAgent: """ def __init__( - self, - base_url: str = None, - api_key: str = None, + self, + base_url: str = None, + api_key: str = None, model: str = "gpt-4", max_iterations: int = 10, tool_delay: float = 1.0, enabled_toolsets: List[str] = None, disabled_toolsets: List[str] = None, save_trajectories: bool = False, - verbose_logging: bool = False + verbose_logging: bool = False, + ephemeral_system_prompt: str = None, + log_prefix_chars: int = 100, ): """ Initialize the AI Agent. - + Args: base_url (str): Base URL for the model API (optional) api_key (str): API key for authentication (optional, uses env var if not provided) @@ -66,13 +81,17 @@ class AIAgent: disabled_toolsets (List[str]): Disable tools from these toolsets (optional) save_trajectories (bool): Whether to save conversation trajectories to JSONL files (default: False) verbose_logging (bool): Enable verbose logging for debugging (default: False) + ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional) + log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20) """ self.model = model self.max_iterations = max_iterations self.tool_delay = tool_delay self.save_trajectories = save_trajectories self.verbose_logging = verbose_logging - + self.ephemeral_system_prompt = ephemeral_system_prompt + self.log_prefix_chars = log_prefix_chars + # Store toolset filtering options self.enabled_toolsets = enabled_toolsets self.disabled_toolsets = disabled_toolsets @@ -84,10 +103,11 @@ class AIAgent: format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%H:%M:%S' ) - # Also set OpenAI client logging to debug - logging.getLogger('openai').setLevel(logging.DEBUG) - logging.getLogger('httpx').setLevel(logging.DEBUG) - print("šŸ” Verbose logging enabled") + # Keep OpenAI and httpx at INFO level to avoid massive base64 logs + # Even in verbose mode, we don't want to see full request/response bodies + logging.getLogger('openai').setLevel(logging.INFO) + logging.getLogger('httpx').setLevel(logging.WARNING) + print("šŸ” Verbose logging enabled (OpenAI/httpx request bodies suppressed)") else: # Set logging to INFO level for important messages only logging.basicConfig( @@ -145,6 +165,11 @@ class AIAgent: # Show trajectory saving status if self.save_trajectories: print("šŸ“ Trajectory saving enabled") + + # Show ephemeral system prompt status + if self.ephemeral_system_prompt: + prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt + print(f"šŸ”’ Ephemeral system prompt: '{prompt_preview}' (not saved to trajectories)") def _format_tools_for_system_message(self) -> str: """ @@ -168,7 +193,7 @@ class AIAgent: } formatted_tools.append(formatted_tool) - return json.dumps(formatted_tools) + return json.dumps(formatted_tools, ensure_ascii=False) def _convert_to_trajectory_format(self, messages: List[Dict[str, Any]], user_query: str, completed: bool) -> List[Dict[str, Any]]: """ @@ -229,7 +254,7 @@ class AIAgent: "name": tool_call["function"]["name"], "arguments": json.loads(tool_call["function"]["arguments"]) if isinstance(tool_call["function"]["arguments"], str) else tool_call["function"]["arguments"] } - content += f"\n{json.dumps(tool_call_json)}\n\n" + content += f"\n{json.dumps(tool_call_json, ensure_ascii=False)}\n\n" trajectory.append({ "from": "gpt", @@ -256,7 +281,7 @@ class AIAgent: "tool_call_id": tool_msg.get("tool_call_id", ""), "name": msg["tool_calls"][len(tool_responses)]["function"]["name"] if len(tool_responses) < len(msg["tool_calls"]) else "unknown", "content": tool_content - }) + }, ensure_ascii=False) tool_response += "\n" tool_responses.append(tool_response) j += 1 @@ -321,22 +346,27 @@ class AIAgent: print(f"āš ļø Failed to save trajectory: {e}") def run_conversation( - self, - user_message: str, - system_message: str = None, - conversation_history: List[Dict[str, Any]] = None + self, + user_message: str, + system_message: str = None, + conversation_history: List[Dict[str, Any]] = None, + task_id: str = None ) -> Dict[str, Any]: """ Run a complete conversation with tool calling until completion. - + Args: user_message (str): The user's message/question - system_message (str): Custom system message (optional) + system_message (str): Custom system message (optional, overrides ephemeral_system_prompt if provided) conversation_history (List[Dict]): Previous conversation messages (optional) - + task_id (str): Unique identifier for this task to isolate VMs between concurrent tasks (optional, auto-generated if not provided) + Returns: Dict: Complete conversation result with final response and message history """ + # Generate unique task_id if not provided to isolate VMs between concurrent tasks + import uuid + effective_task_id = task_id or str(uuid.uuid4()) # Initialize conversation messages = conversation_history or [] @@ -348,13 +378,17 @@ class AIAgent: print(f"šŸ’¬ Starting conversation: '{user_message[:60]}{'...' if len(user_message) > 60 else ''}'") + # Determine which system prompt to use for API calls (ephemeral) + # Priority: explicit system_message > ephemeral_system_prompt > None + active_system_prompt = system_message if system_message is not None else self.ephemeral_system_prompt + # Main conversation loop api_call_count = 0 final_response = None while api_call_count < self.max_iterations: api_call_count += 1 - print(f"\nšŸ”„ Making API call #{api_call_count}...") + print(f"\nšŸ”„ Making OpenAI-compatible API call #{api_call_count}...") # Log request details if verbose if self.verbose_logging: @@ -363,33 +397,40 @@ class AIAgent: api_start_time = time.time() retry_count = 0 - max_retries = 3 - + max_retries = 6 # Increased to allow longer backoff periods + while retry_count <= max_retries: try: + # Prepare messages for API call + # If we have an ephemeral system prompt, prepend it to the messages + api_messages = messages.copy() + if active_system_prompt: + # Insert system message at the beginning + api_messages = [{"role": "system", "content": active_system_prompt}] + api_messages + # Make API call with tools response = self.client.chat.completions.create( model=self.model, - messages=messages, + messages=api_messages, tools=self.tools if self.tools else None, - timeout=60.0 # Add explicit timeout + timeout=300.0 # 5 minute timeout for long-running agent tasks ) - + api_duration = time.time() - api_start_time - print(f"ā±ļø API call completed in {api_duration:.2f}s") - + print(f"ā±ļø OpenAI-compatible API call completed in {api_duration:.2f}s") + if self.verbose_logging: logging.debug(f"API Response received - Usage: {response.usage if hasattr(response, 'usage') else 'N/A'}") - + break # Success, exit retry loop - + except Exception as api_error: retry_count += 1 if retry_count > max_retries: raise api_error - - wait_time = min(2 ** retry_count, 10) # Exponential backoff, max 10s - print(f"āš ļø API call failed (attempt {retry_count}/{max_retries}): {str(api_error)[:100]}") + + wait_time = min(2 ** retry_count, 60) # Exponential backoff: 2s, 4s, 8s, 16s, 32s, 60s, 60s + print(f"āš ļø OpenAI-compatible API call failed (attempt {retry_count}/{max_retries}): {str(api_error)[:100]}") print(f"ā³ Retrying in {wait_time}s...") logging.warning(f"API retry {retry_count}/{max_retries} after error: {api_error}") time.sleep(wait_time) @@ -436,28 +477,33 @@ class AIAgent: print(f"āŒ Invalid JSON in tool call arguments: {e}") function_args = {} - print(f" šŸ“ž Tool {i}: {function_name}({list(function_args.keys())})") - + # Preview tool call arguments + args_str = json.dumps(function_args, ensure_ascii=False) + args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str + print(f" šŸ“ž Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}") + tool_start_time = time.time() - - # Execute the tool - function_result = handle_function_call(function_name, function_args) - + + # Execute the tool with task_id to isolate VMs between concurrent tasks + function_result = handle_function_call(function_name, function_args, effective_task_id) + tool_duration = time.time() - tool_start_time result_preview = function_result[:200] if len(function_result) > 200 else function_result - + if self.verbose_logging: logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s") logging.debug(f"Tool result preview: {result_preview}...") - + # Add tool result to conversation messages.append({ "role": "tool", "content": function_result, "tool_call_id": tool_call.id }) - - print(f" āœ… Tool {i} completed in {tool_duration:.2f}s") + + # Preview tool response + response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result + print(f" āœ… Tool {i} completed in {tool_duration:.2f}s - {response_preview}") # Delay between tool calls if self.tool_delay > 0 and i < len(assistant_message.tool_calls): @@ -476,11 +522,11 @@ class AIAgent: "content": final_response }) - print(f"šŸŽ‰ Conversation completed after {api_call_count} API call(s)") + print(f"šŸŽ‰ Conversation completed after {api_call_count} OpenAI-compatible API call(s)") break except Exception as e: - error_msg = f"Error during API call #{api_call_count}: {str(e)}" + error_msg = f"Error during OpenAI-compatible API call #{api_call_count}: {str(e)}" print(f"āŒ {error_msg}") if self.verbose_logging: @@ -505,10 +551,17 @@ class AIAgent: # Determine if conversation completed successfully completed = final_response is not None and api_call_count < self.max_iterations - + # Save trajectory if enabled self._save_trajectory(messages, user_message, completed) - + + # Clean up VM for this task after conversation completes + try: + cleanup_vm(effective_task_id) + except Exception as e: + if self.verbose_logging: + logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}") + return { "final_response": final_response, "messages": messages, @@ -532,7 +585,7 @@ class AIAgent: def main( query: str = None, - model: str = "claude-opus-4-20250514", + model: str = "claude-opus-4-20250514", api_key: str = None, base_url: str = "https://api.anthropic.com/v1/", max_turns: int = 10, @@ -540,25 +593,27 @@ def main( disabled_toolsets: str = None, list_tools: bool = False, save_trajectories: bool = False, - verbose: bool = False + verbose: bool = False, + log_prefix_chars: int = 20 ): """ Main function for running the agent directly. - + Args: query (str): Natural language query for the agent. Defaults to Python 3.13 example. model (str): Model name to use. Defaults to claude-opus-4-20250514. api_key (str): API key for authentication. Uses ANTHROPIC_API_KEY env var if not provided. base_url (str): Base URL for the model API. Defaults to https://api.anthropic.com/v1/ max_turns (int): Maximum number of API call iterations. Defaults to 10. - enabled_toolsets (str): Comma-separated list of toolsets to enable. Supports predefined - toolsets (e.g., "research", "development", "safe"). + enabled_toolsets (str): Comma-separated list of toolsets to enable. Supports predefined + toolsets (e.g., "research", "development", "safe"). Multiple toolsets can be combined: "web,vision" disabled_toolsets (str): Comma-separated list of toolsets to disable (e.g., "terminal") list_tools (bool): Just list available tools and exit save_trajectories (bool): Save conversation trajectories to JSONL files. Defaults to False. verbose (bool): Enable verbose logging for debugging. Defaults to False. - + log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses. Defaults to 20. + Toolset Examples: - "research": Web search, extract, crawl + vision tools """ @@ -675,7 +730,8 @@ def main( enabled_toolsets=enabled_toolsets_list, disabled_toolsets=disabled_toolsets_list, save_trajectories=save_trajectories, - verbose_logging=verbose + verbose_logging=verbose, + log_prefix_chars=log_prefix_chars ) except RuntimeError as e: print(f"āŒ Failed to initialize agent: {e}") diff --git a/run_datagen_images.sh b/run_datagen_images.sh new file mode 100644 index 000000000..79e448ec6 --- /dev/null +++ b/run_datagen_images.sh @@ -0,0 +1,12 @@ +python batch_runner.py \ + --dataset_file="hermes-agent-imagen-data/hermes_agent_imagen_eval.jsonl" \ + --batch_size=10 \ + --run_name="imagen_eval_gpt5" \ + --distribution="image_gen" \ + --model="gpt-5" \ + --base_url="https://api.openai.com/v1" \ + --api_key="${OPENAI_API_KEY}" \ + --num_workers=4 \ + --max_turns=5 \ + --verbose \ + --ephemeral_system_prompt="When generating an image for the user view the image by using the vision_analyze tool to ensure it is what the user wanted. If it isn't feel free to retry a few times. If none are perfect, choose the best option that is the closest match, and explain its imperfections. If the image generation tool fails, try again a few times. If the vision analyze tool fails, provide the image to the user and explain it is your best effort attempt." \ No newline at end of file diff --git a/run_datagen_megascience.sh b/run_datagen_megascience.sh new file mode 100755 index 000000000..da1e8e1f8 --- /dev/null +++ b/run_datagen_megascience.sh @@ -0,0 +1,12 @@ +python batch_runner.py \ + --dataset_file="hermes-agent-megascience-data/hermes_agent_megascience_eval.jsonl" \ + --batch_size=10 \ + --run_name="megascience_eval_gpt5_2" \ + --distribution="science" \ + --model="gpt-5" \ + --base_url="https://api.openai.com/v1" \ + --api_key="${OPENAI_API_KEY}" \ + --num_workers=5 \ + --max_turns=30 \ + --verbose \ + --ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use a tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should not be confident in your own reasoning, knowledge, or calculations without using a tool to verify or validate your work." \ No newline at end of file diff --git a/run_datagen_megascience_glm4-6.sh b/run_datagen_megascience_glm4-6.sh new file mode 100755 index 000000000..2e8c49d14 --- /dev/null +++ b/run_datagen_megascience_glm4-6.sh @@ -0,0 +1,12 @@ +python batch_runner.py \ + --dataset_file="hermes-agent-megascience-data/hermes_agent_megascience_eval.jsonl" \ + --batch_size=10 \ + --run_name="megascience_eval_glm4-6-fixedterminal-2" \ + --distribution="science" \ + --model="z-ai/glm-4.6" \ + --base_url="https://openrouter.ai/api/v1" \ + --api_key="${OPENROUTER_API_KEY}" \ + --num_workers=5 \ + --max_turns=30 \ + --verbose \ + --ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use a tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run." \ No newline at end of file diff --git a/terminal_tool.py b/terminal_tool.py deleted file mode 100644 index e01d7a617..000000000 --- a/terminal_tool.py +++ /dev/null @@ -1,234 +0,0 @@ -#!/usr/bin/env python3 -""" -Terminal Tool Module - -This module provides a single terminal tool using Hecate's VM infrastructure. -It wraps Hecate's functionality to provide a simple interface for executing commands -on Morph VMs with automatic lifecycle management. - -Available tool: -- terminal_tool: Execute commands with optional interactive session support - -Usage: - from terminal_tool import terminal_tool - - # Execute a single command - result = terminal_tool("ls -la") - - # Execute in an interactive session - result = terminal_tool("python", input_keys="print('hello')\\nexit()\\n") -""" - -import json -import os -from typing import Optional, Dict, Any -from hecate import run_tool_with_lifecycle_management -from morphcloud._llm import ToolCall - -# Detailed description for the terminal tool based on Hermes Terminal system prompt -TERMINAL_TOOL_DESCRIPTION = """Execute commands on a secure, persistent Linux VM environment with full interactive application support. - -**Environment:** -- Minimal Debian-based OS with internet access -- Automatic VM lifecycle management (creates on-demand, reuses, cleans up) -- **Full state persistence across tool calls**: current directory (pwd), environment variables, activated virtual environments (conda/venv), running processes, and command history all persist between consecutive tool calls -- Session state managed automatically via tmux - -**Command Execution:** -- Simple commands: Just provide the 'command' parameter -- Background processes: Set 'background': True for servers/long-running tasks -- Interactive applications automatically detected and handled - -**Interactive Applications (TUIs/Pagers/Prompts):** -When commands enter interactive mode (vim, nano, less, git prompts, package managers, etc.), you'll receive screen content with "frozen" status. This is NORMAL - the session is still active and waiting for input. - -**To interact with frozen sessions:** -1. Use 'input_keys' parameter with keystrokes to send -2. System auto-detects and uses the active session -3. Session stays active until application exits - -**Special Key Syntax for input_keys:** -- ``: Escape key -- ``: Enter/Return -- ``, ``, ``: Control combinations -- ``, ``, ``, ``: Arrow keys -- ``, ``: Tab and Backspace -- `` through ``: Function keys -- ``: Shift+Tab -- Uppercase letters for Shift+letter (e.g., 'V' for Shift+V) -- Symbols for Shift+number (e.g., '!' for Shift+1, ':' for Shift+;) - -**Examples:** -- Start vim: `{"command": "vim file.txt"}` -- Type in vim: `{"input_keys": "iHello World"}` -- Save and quit: `{"input_keys": ":wq"}` -- Navigate in less: `{"input_keys": "j"}` -- Quit less: `{"input_keys": "q"}` - -**Best Practices:** -- Run servers/long processes in background with separate tool calls -- Chain multiple foreground commands in single call if needed -- Monitor disk usage for large tasks, clean up to free space -- Test components incrementally with mock inputs -- Install whatever tools needed - full system access provided""" - -def terminal_tool( - command: Optional[str] = None, - input_keys: Optional[str] = None, - session_id: Optional[str] = None, - background: bool = False, - idle_threshold: float = 5.0, - timeout: Optional[int] = None -) -> str: - """ - Execute a command on a Morph VM with optional interactive session support. - - This tool uses Hecate's VM lifecycle management to automatically create - and manage VMs. VMs are reused within the configured lifetime window - and automatically cleaned up after inactivity. - - Args: - command: The command to execute (optional if continuing existing session) - input_keys: Keystrokes to send to interactive session (e.g., "hello\\n") - session_id: ID of existing session to continue (optional) - background: Whether to run the command in the background (default: False) - idle_threshold: Seconds to wait for output before considering session idle (default: 5.0) - timeout: Command timeout in seconds (optional) - - Returns: - str: JSON string containing command output, session info, exit code, and any errors - - Examples: - # Execute a simple command - >>> result = terminal_tool(command="ls -la /tmp") - - # Start an interactive Python session - >>> result = terminal_tool(command="python3") - >>> session_data = json.loads(result) - >>> session_id = session_data["session_id"] - - # Send input to the session - >>> result = terminal_tool(input_keys="print('Hello')\\n", session_id=session_id) - - # Run a background task - >>> result = terminal_tool(command="sleep 60", background=True) - """ - try: - # Build tool input based on provided parameters - tool_input = {} - - if command: - tool_input["command"] = command - if input_keys: - tool_input["input_keys"] = input_keys - if session_id: - tool_input["session_id"] = session_id - if background: - tool_input["background"] = background - if idle_threshold != 5.0: - tool_input["idle_threshold"] = idle_threshold - if timeout is not None: - tool_input["timeout"] = timeout - - tool_call = ToolCall( - name="run_command", - input=tool_input - ) - - # Execute with lifecycle management - result = run_tool_with_lifecycle_management(tool_call) - - # Format the result with all possible fields - # Map hecate's "stdout" to "output" for compatibility - formatted_result = { - "output": result.get("stdout", result.get("output", "")), - "screen": result.get("screen", ""), - "session_id": result.get("session_id"), - "exit_code": result.get("returncode", result.get("exit_code", -1)), - "error": result.get("error"), - "status": "active" if result.get("session_id") else "ended" - } - - return json.dumps(formatted_result) - - except Exception as e: - return json.dumps({ - "output": "", - "screen": "", - "session_id": None, - "exit_code": -1, - "error": f"Failed to execute terminal command: {str(e)}", - "status": "error" - }) - -def check_hecate_requirements() -> bool: - """ - Check if all requirements for terminal tools are met. - - Returns: - bool: True if all requirements are met, False otherwise - """ - # Check for required environment variables - required_vars = ["MORPH_API_KEY"] - optional_vars = ["OPENAI_API_KEY"] # Needed for Hecate's LLM features - - missing_required = [var for var in required_vars if not os.getenv(var)] - missing_optional = [var for var in optional_vars if not os.getenv(var)] - - if missing_required: - print(f"Missing required environment variables: {', '.join(missing_required)}") - return False - - if missing_optional: - print(f"Warning: Missing optional environment variables: {', '.join(missing_optional)}") - print(" (Some Hecate features may be limited)") - - # Check if Hecate is importable - try: - import hecate - return True - except ImportError: - print("Hecate is not installed. Please install it with: pip install hecate") - return False - -# Module-level initialization check -_requirements_met = check_hecate_requirements() - -if __name__ == "__main__": - """ - Simple test/demo when run directly - """ - print("Terminal Tool Module") - print("=" * 40) - - if not _requirements_met: - print("Requirements not met. Please check the messages above.") - exit(1) - - print("All requirements met!") - print("\nAvailable Tool:") - print(" - terminal_tool: Execute commands with optional interactive session support") - - print("\nUsage Examples:") - print(" # Execute a command") - print(" result = terminal_tool(command='ls -la')") - print(" ") - print(" # Start an interactive session") - print(" result = terminal_tool(command='python3')") - print(" session_data = json.loads(result)") - print(" session_id = session_data['session_id']") - print(" ") - print(" # Send input to the session") - print(" result = terminal_tool(") - print(" input_keys='print(\"Hello\")\\\\n',") - print(" session_id=session_id") - print(" )") - print(" ") - print(" # Run a background task") - print(" result = terminal_tool(command='sleep 60', background=True)") - - print("\nEnvironment Variables:") - print(f" MORPH_API_KEY: {'Set' if os.getenv('MORPH_API_KEY') else 'Not set'}") - print(f" OPENAI_API_KEY: {'Set' if os.getenv('OPENAI_API_KEY') else 'Not set (optional)'}") - print(f" HECATE_VM_LIFETIME_SECONDS: {os.getenv('HECATE_VM_LIFETIME_SECONDS', '300')} (default: 300)") - print(f" HECATE_DEFAULT_SNAPSHOT_ID: {os.getenv('HECATE_DEFAULT_SNAPSHOT_ID', 'snapshot_p5294qxt')} (default: snapshot_p5294qxt)") \ No newline at end of file diff --git a/test_run.sh b/test_run.sh old mode 100644 new mode 100755 index ff4ffc3c2..66be76d53 --- a/test_run.sh +++ b/test_run.sh @@ -17,15 +17,7 @@ export WEB_TOOLS_DEBUG=true python run_agent.py \ --query "$PROMPT" \ --max_turns 30 \ - --model claude-sonnet-4-20250514 \ + --model claude-sonnet-4-5-20250929 \ --base_url https://api.anthropic.com/v1/ \ --api_key $ANTHROPIC_API_KEY \ - --save_trajectories \ - --enabled_toolsets=web - -# --model claude-sonnet-4-20250514 \ -# -#Possible Toolsets: -#web_tools -#vision_tools -#terminal_tools \ No newline at end of file + --save_trajectories \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_batch_runner.py b/tests/test_batch_runner.py new file mode 100644 index 000000000..41b0b72b1 --- /dev/null +++ b/tests/test_batch_runner.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +""" +Test script for batch runner + +This script tests the batch runner with a small sample dataset +to verify functionality before running large batches. +""" + +import json +import shutil +from pathlib import Path + + +def create_test_dataset(): + """Create a small test dataset.""" + test_file = Path("tests/test_dataset.jsonl") + test_file.parent.mkdir(exist_ok=True) + + prompts = [ + {"prompt": "What is 2 + 2?"}, + {"prompt": "What is the capital of France?"}, + {"prompt": "Explain what Python is in one sentence."}, + ] + + with open(test_file, 'w') as f: + for prompt in prompts: + f.write(json.dumps(prompt, ensure_ascii=False) + "\n") + + print(f"āœ… Created test dataset: {test_file}") + return test_file + + +def cleanup_test_run(run_name): + """Clean up test run output.""" + output_dir = Path("data") / run_name + if output_dir.exists(): + shutil.rmtree(output_dir) + print(f"šŸ—‘ļø Cleaned up test output: {output_dir}") + + +def verify_output(run_name): + """Verify that output files were created correctly.""" + output_dir = Path("data") / run_name + + # Check directory exists + if not output_dir.exists(): + print(f"āŒ Output directory not found: {output_dir}") + return False + + # Check for checkpoint + checkpoint_file = output_dir / "checkpoint.json" + if not checkpoint_file.exists(): + print(f"āŒ Checkpoint file not found: {checkpoint_file}") + return False + + # Check for statistics + stats_file = output_dir / "statistics.json" + if not stats_file.exists(): + print(f"āŒ Statistics file not found: {stats_file}") + return False + + # Check for batch files + batch_files = list(output_dir.glob("batch_*.jsonl")) + if not batch_files: + print(f"āŒ No batch files found in: {output_dir}") + return False + + print(f"āœ… Output verification passed:") + print(f" - Checkpoint: {checkpoint_file}") + print(f" - Statistics: {stats_file}") + print(f" - Batch files: {len(batch_files)}") + + # Load and display statistics + with open(stats_file) as f: + stats = json.load(f) + + print(f"\nšŸ“Š Statistics Summary:") + print(f" - Total prompts: {stats['total_prompts']}") + print(f" - Total batches: {stats['total_batches']}") + print(f" - Duration: {stats['duration_seconds']}s") + + if stats.get('tool_statistics'): + print(f" - Tool calls:") + for tool, tool_stats in stats['tool_statistics'].items(): + print(f" • {tool}: {tool_stats['count']} calls, {tool_stats['success_rate']:.1f}% success") + + return True + + +def main(): + """Run the test.""" + print("🧪 Batch Runner Test") + print("=" * 60) + + run_name = "test_run" + + # Clean up any previous test run + cleanup_test_run(run_name) + + # Create test dataset + test_file = create_test_dataset() + + print(f"\nšŸ“ To run the test manually:") + print(f" python batch_runner.py \\") + print(f" --dataset_file={test_file} \\") + print(f" --batch_size=2 \\") + print(f" --run_name={run_name} \\") + print(f" --distribution=minimal \\") + print(f" --num_workers=2") + + print(f"\nšŸ’” Or test with different distributions:") + print(f" python batch_runner.py --list_distributions") + + print(f"\nšŸ” After running, you can verify output with:") + print(f" python tests/test_batch_runner.py --verify") + + # Note: We don't actually run the batch runner here to avoid API calls during testing + # Users should run it manually with their API keys configured + + +if __name__ == "__main__": + import sys + + if "--verify" in sys.argv: + run_name = "test_run" + verify_output(run_name) + else: + main() + diff --git a/tests/test_checkpoint_resumption.py b/tests/test_checkpoint_resumption.py new file mode 100644 index 000000000..d7c88910f --- /dev/null +++ b/tests/test_checkpoint_resumption.py @@ -0,0 +1,424 @@ +#!/usr/bin/env python3 +""" +Test script to verify checkpoint behavior in batch_runner.py + +This script simulates batch processing with intentional failures to test: +1. Whether checkpoints are saved incrementally during processing +2. Whether resume functionality works correctly after interruption +3. Whether data integrity is maintained across checkpoint cycles + +Usage: + # Test current implementation + python tests/test_checkpoint_resumption.py --test_current + + # Test after fix is applied + python tests/test_checkpoint_resumption.py --test_fixed + + # Run full comparison + python tests/test_checkpoint_resumption.py --compare +""" + +import json +import os +import shutil +import sys +import time +import signal +from pathlib import Path +from typing import List, Dict, Any +import traceback + +# Add parent directory to path to import batch_runner +sys.path.insert(0, str(Path(__file__).parent.parent)) + + +def create_test_dataset(num_prompts: int = 20) -> Path: + """Create a small test dataset for checkpoint testing.""" + test_data_dir = Path("tests/test_data") + test_data_dir.mkdir(parents=True, exist_ok=True) + + dataset_file = test_data_dir / "checkpoint_test_dataset.jsonl" + + with open(dataset_file, 'w', encoding='utf-8') as f: + for i in range(num_prompts): + entry = { + "prompt": f"Test prompt {i}: What is 2+2? Just answer briefly.", + "test_id": i + } + f.write(json.dumps(entry, ensure_ascii=False) + "\n") + + print(f"āœ… Created test dataset: {dataset_file} ({num_prompts} prompts)") + return dataset_file + + +def monitor_checkpoint_during_run(checkpoint_file: Path, duration: int = 30) -> List[Dict[str, Any]]: + """ + Monitor checkpoint file during a batch run to see when it gets updated. + + Args: + checkpoint_file: Path to checkpoint file to monitor + duration: How long to monitor (seconds) + + Returns: + List of checkpoint snapshots with timestamps + """ + snapshots = [] + start_time = time.time() + last_mtime = None + + print(f"\nšŸ” Monitoring checkpoint file: {checkpoint_file}") + print(f" Duration: {duration}s") + print("-" * 70) + + while time.time() - start_time < duration: + if checkpoint_file.exists(): + current_mtime = checkpoint_file.stat().st_mtime + + # Check if file was modified + if last_mtime is None or current_mtime != last_mtime: + elapsed = time.time() - start_time + + try: + with open(checkpoint_file, 'r') as f: + checkpoint_data = json.load(f) + + snapshot = { + "elapsed_seconds": round(elapsed, 2), + "completed_count": len(checkpoint_data.get("completed_prompts", [])), + "completed_prompts": checkpoint_data.get("completed_prompts", [])[:5], # First 5 for display + "timestamp": checkpoint_data.get("last_updated") + } + + snapshots.append(snapshot) + + print(f"[{elapsed:6.2f}s] Checkpoint updated: {snapshot['completed_count']} prompts completed") + + except Exception as e: + print(f"[{elapsed:6.2f}s] Error reading checkpoint: {e}") + + last_mtime = current_mtime + else: + if len(snapshots) == 0: + print(f"[{time.time() - start_time:6.2f}s] Checkpoint file not yet created...") + + time.sleep(0.5) # Check every 0.5 seconds + + return snapshots + + +def test_current_implementation(): + """Test the current checkpoint implementation.""" + print("\n" + "=" * 70) + print("TEST 1: Current Implementation - Checkpoint Timing") + print("=" * 70) + print("\nšŸ“ Testing whether checkpoints are saved incrementally during run...") + + # Setup + dataset_file = create_test_dataset(num_prompts=12) + run_name = "checkpoint_test_current" + output_dir = Path("data") / run_name + + # Clean up any existing test data + if output_dir.exists(): + shutil.rmtree(output_dir) + + # Import here to avoid issues if module changes + from batch_runner import BatchRunner + + checkpoint_file = output_dir / "checkpoint.json" + + # Start monitoring in a separate process would be ideal, but for simplicity + # we'll just check before and after + print(f"\nā–¶ļø Starting batch run...") + print(f" Dataset: {dataset_file}") + print(f" Batch size: 3 (4 batches total)") + print(f" Workers: 2") + print(f" Expected behavior: If incremental, checkpoint should update during run") + + start_time = time.time() + + try: + runner = BatchRunner( + dataset_file=str(dataset_file), + batch_size=3, + run_name=run_name, + distribution="default", + max_iterations=3, # Keep it short + model="claude-opus-4-20250514", + num_workers=2, + verbose=False + ) + + # Run with monitoring + import threading + snapshots = [] + + def monitor(): + nonlocal snapshots + snapshots = monitor_checkpoint_during_run(checkpoint_file, duration=60) + + monitor_thread = threading.Thread(target=monitor, daemon=True) + monitor_thread.start() + + runner.run(resume=False) + + monitor_thread.join(timeout=2) + + except Exception as e: + print(f"āŒ Error during run: {e}") + traceback.print_exc() + return False + + elapsed = time.time() - start_time + + # Analyze results + print("\n" + "=" * 70) + print("šŸ“Š TEST RESULTS") + print("=" * 70) + print(f"Total run time: {elapsed:.2f}s") + print(f"Checkpoint updates observed: {len(snapshots)}") + + if len(snapshots) == 0: + print("\nāŒ ISSUE: No checkpoint updates observed during run") + print(" This suggests checkpoints are only saved at the end") + return False + elif len(snapshots) == 1: + print("\nāš ļø WARNING: Only 1 checkpoint update (likely at the end)") + print(" This confirms the bug - no incremental checkpointing") + return False + else: + print(f"\nāœ… GOOD: Multiple checkpoint updates ({len(snapshots)}) observed") + print(" Checkpointing appears to be incremental") + + # Show timeline + print("\nšŸ“ˆ Checkpoint Timeline:") + for i, snapshot in enumerate(snapshots, 1): + print(f" {i}. [{snapshot['elapsed_seconds']:6.2f}s] " + f"{snapshot['completed_count']} prompts completed") + + return True + + +def test_interruption_and_resume(): + """Test that resume actually works after interruption.""" + print("\n" + "=" * 70) + print("TEST 2: Interruption and Resume") + print("=" * 70) + print("\nšŸ“ Testing whether resume works after manual interruption...") + + # Setup + dataset_file = create_test_dataset(num_prompts=15) + run_name = "checkpoint_test_resume" + output_dir = Path("data") / run_name + + # Clean up any existing test data + if output_dir.exists(): + shutil.rmtree(output_dir) + + from batch_runner import BatchRunner + + checkpoint_file = output_dir / "checkpoint.json" + + print(f"\nā–¶ļø Starting first run (will process 5 prompts, then simulate interruption)...") + + try: + # Create a modified dataset with only first 5 prompts for initial run + temp_dataset = Path("tests/test_data/checkpoint_test_resume_partial.jsonl") + with open(dataset_file, 'r') as f: + lines = f.readlines()[:5] + with open(temp_dataset, 'w') as f: + f.writelines(lines) + + runner = BatchRunner( + dataset_file=str(temp_dataset), + batch_size=2, + run_name=run_name, + distribution="default", + max_iterations=3, + model="claude-opus-4-20250514", + num_workers=1, + verbose=False + ) + + runner.run(resume=False) + + # Check checkpoint after first run + if not checkpoint_file.exists(): + print("āŒ ERROR: Checkpoint file not created after first run") + return False + + with open(checkpoint_file, 'r') as f: + checkpoint_data = json.load(f) + + initial_completed = len(checkpoint_data.get("completed_prompts", [])) + print(f"āœ… First run completed: {initial_completed} prompts saved to checkpoint") + + # Now try to resume with full dataset + print(f"\nā–¶ļø Starting resume run with full dataset (15 prompts)...") + + runner2 = BatchRunner( + dataset_file=str(dataset_file), + batch_size=2, + run_name=run_name, + distribution="default", + max_iterations=3, + model="claude-opus-4-20250514", + num_workers=1, + verbose=False + ) + + runner2.run(resume=True) + + # Check final checkpoint + with open(checkpoint_file, 'r') as f: + final_checkpoint = json.load(f) + + final_completed = len(final_checkpoint.get("completed_prompts", [])) + + print("\n" + "=" * 70) + print("šŸ“Š TEST RESULTS") + print("=" * 70) + print(f"Initial completed: {initial_completed}") + print(f"Final completed: {final_completed}") + print(f"Expected: 15") + + if final_completed == 15: + print("\nāœ… PASS: Resume successfully completed all prompts") + return True + else: + print(f"\nāŒ FAIL: Expected 15 completed, got {final_completed}") + return False + + except Exception as e: + print(f"āŒ Error during test: {e}") + traceback.print_exc() + return False + + +def test_simulated_crash(): + """Test behavior when process crashes mid-execution.""" + print("\n" + "=" * 70) + print("TEST 3: Simulated Crash During Execution") + print("=" * 70) + print("\nšŸ“ This test would require running in a subprocess and killing it...") + print(" Skipping for safety - manual testing recommended") + return None + + +def print_test_plan(): + """Print the detailed test and fix plan.""" + print("\n" + "=" * 70) + print("CHECKPOINT FIX - DETAILED PLAN") + print("=" * 70) + + print(""" +šŸ“‹ PROBLEM SUMMARY +------------------ +Current implementation uses pool.map() which blocks until ALL batches complete. +Checkpoint is only saved after all batches finish (line 558-559). + +If process crashes during batch processing: +- All progress is lost +- Resume does nothing (no incremental checkpoint was saved) + +šŸ“‹ PROPOSED SOLUTION +-------------------- +Replace pool.map() with pool.imap_unordered() to get results as they complete. +Save checkpoint after EACH batch completes using a multiprocessing Lock. + +Key changes: +1. Use Manager().Lock() for thread-safe checkpoint writes +2. Replace pool.map() with pool.imap_unordered() +3. Update checkpoint after each batch result +4. Maintain backward compatibility with existing checkpoints + +šŸ“‹ IMPLEMENTATION STEPS +----------------------- +1. Add Manager and Lock initialization before Pool creation +2. Pass shared checkpoint data and lock to workers (via Manager) +3. Replace pool.map() with pool.imap_unordered() +4. In result loop: save checkpoint after each batch +5. Add error handling for checkpoint write failures + +šŸ“‹ RISKS & MITIGATIONS +---------------------- +Risk: Checkpoint file corruption if two processes write simultaneously +→ Mitigation: Use multiprocessing.Lock() for exclusive access + +Risk: Performance impact from frequent checkpoint writes +→ Mitigation: Checkpoint writes are fast (small JSON), negligible impact + +Risk: Breaking existing runs that are already checkpointed +→ Mitigation: Maintain checkpoint format, only change timing + +Risk: Bugs in multiprocessing lock/manager code +→ Mitigation: Thorough testing with this test script + +šŸ“‹ TESTING STRATEGY +------------------- +1. Run test_current_implementation() - Confirm bug exists +2. Apply fix to batch_runner.py +3. Run test_current_implementation() again - Should see incremental updates +4. Run test_interruption_and_resume() - Verify resume works +5. Manual test: Start run, kill process mid-batch, resume + +šŸ“‹ ROLLBACK PLAN +---------------- +If issues arise: +1. Git revert the changes +2. Original code is working (just missing incremental checkpoint) +3. No data corruption risk - checkpoints are write-only +""") + + +def main( + test_current: bool = False, + test_resume: bool = False, + test_crash: bool = False, + compare: bool = False, + show_plan: bool = False +): + """ + Run checkpoint behavior tests. + + Args: + test_current: Test current implementation checkpoint timing + test_resume: Test interruption and resume functionality + test_crash: Test simulated crash scenario (manual) + compare: Run all tests and compare + show_plan: Show detailed fix plan + """ + if show_plan or (not any([test_current, test_resume, test_crash, compare])): + print_test_plan() + return + + results = {} + + if test_current or compare: + results['current'] = test_current_implementation() + + if test_resume or compare: + results['resume'] = test_interruption_and_resume() + + if test_crash or compare: + results['crash'] = test_simulated_crash() + + # Summary + if results: + print("\n" + "=" * 70) + print("OVERALL TEST SUMMARY") + print("=" * 70) + for test_name, result in results.items(): + if result is None: + status = "ā­ļø SKIPPED" + elif result: + status = "āœ… PASS" + else: + status = "āŒ FAIL" + print(f"{status} - {test_name}") + + +if __name__ == "__main__": + import fire + fire.Fire(main) + diff --git a/tests/test_nous_api_limits.py b/tests/test_nous_api_limits.py new file mode 100755 index 000000000..25265a0cc --- /dev/null +++ b/tests/test_nous_api_limits.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +""" +Test script to diagnose Nous API 400 errors with gemini-2.5-flash model. +This tests various content lengths and parameters to identify what causes failures. +""" + +import asyncio +import os +from openai import AsyncOpenAI +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +# Initialize the Nous API client +nous_client = AsyncOpenAI( + api_key=os.getenv("NOUS_API_KEY"), + base_url="https://inference-api.nousresearch.com/v1" +) + +MODEL = "gemini-2.5-flash" + +async def test_api_call(test_name: str, content_length: int, **kwargs): + """Test an API call with specific parameters.""" + print(f"\n{'='*60}") + print(f"Test: {test_name}") + print(f"Content length: {content_length:,} characters") + print(f"Additional params: {kwargs}") + print(f"{'='*60}") + + # Generate test content + content = "A" * content_length + + system_prompt = """You are an expert content analyst. Your job is to process web content and create a comprehensive yet concise summary that preserves all important information while dramatically reducing bulk. + +Create a well-structured markdown summary that includes: +1. Key excerpts (quotes, code snippets, important facts) in their original format +2. Comprehensive summary of all other important information +3. Proper markdown formatting with headers, bullets, and emphasis + +Your goal is to preserve ALL important information while reducing length. Never lose key facts, figures, insights, or actionable information. Make it scannable and well-organized.""" + + user_prompt = f"""Please process this web content and create a comprehensive markdown summary: + +CONTENT TO PROCESS: +{content} + +Create a markdown summary that captures all key information in a well-organized, scannable format. Include important quotes and code snippets in their original formatting. Focus on actionable information, specific details, and unique insights.""" + + try: + response = await nous_client.chat.completions.create( + model=MODEL, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt} + ], + **kwargs + ) + + result = response.choices[0].message.content + print(f"āœ… SUCCESS") + print(f" Response length: {len(result)} characters") + print(f" Model used: {response.model}") + print(f" Usage: {response.usage}") + return True + + except Exception as e: + print(f"āŒ FAILED: {str(e)}") + return False + +async def main(): + """Run all tests.""" + print("Testing Nous API with gemini-2.5-flash model") + print(f"API Key present: {'Yes' if os.getenv('NOUS_API_KEY') else 'No'}") + + results = {} + + # Test 1: Small content (should always work) + results['small'] = await test_api_call( + "Small content (5,000 chars)", + 5000, + temperature=0.1, + max_tokens=4000 + ) + await asyncio.sleep(1) + + # Test 2: Medium content (around what was failing) + results['medium'] = await test_api_call( + "Medium content (20,000 chars)", + 20000, + temperature=0.1, + max_tokens=4000 + ) + await asyncio.sleep(1) + + # Test 3: Large content (79,625 chars like the error) + results['large'] = await test_api_call( + "Large content (79,625 chars)", + 79625, + temperature=0.1, + max_tokens=4000 + ) + await asyncio.sleep(1) + + # Test 4: Very large content (100k chars) + results['very_large'] = await test_api_call( + "Very large content (100,000 chars)", + 100000, + temperature=0.1, + max_tokens=4000 + ) + await asyncio.sleep(1) + + # Test 5: Same as working case but different max_tokens + results['diff_max_tokens'] = await test_api_call( + "Medium content with higher max_tokens", + 20000, + temperature=0.1, + max_tokens=8000 + ) + await asyncio.sleep(1) + + # Test 6: No max_tokens specified + results['no_max_tokens'] = await test_api_call( + "Medium content without max_tokens", + 20000, + temperature=0.1 + ) + await asyncio.sleep(1) + + # Test 7: With actual web content (mixed characters) + mixed_content = """ + This is a test of web content with various characters: + - Unicode: ä½ å„½äø–ē•Œ šŸŒ + - Special chars: <>&"' + - Numbers: 123456789 + - Markdown: **bold** _italic_ `code` + - URLs: https://example.com + """ * 1000 # Repeat to make it ~79k chars + + print(f"\n{'='*60}") + print(f"Test: Mixed content (real-world scenario)") + print(f"Content length: {len(mixed_content):,} characters") + print(f"{'='*60}") + + try: + response = await nous_client.chat.completions.create( + model=MODEL, + messages=[ + {"role": "system", "content": "Summarize this content."}, + {"role": "user", "content": mixed_content} + ], + temperature=0.1, + max_tokens=4000 + ) + print(f"āœ… SUCCESS") + results['mixed_content'] = True + except Exception as e: + print(f"āŒ FAILED: {str(e)}") + results['mixed_content'] = False + + # Summary + print(f"\n{'='*60}") + print("SUMMARY OF RESULTS:") + print(f"{'='*60}") + for test, passed in results.items(): + status = "āœ… PASS" if passed else "āŒ FAIL" + print(f"{test:20s}: {status}") + + passed = sum(results.values()) + total = len(results) + print(f"\nTotal: {passed}/{total} tests passed") + +if __name__ == "__main__": + asyncio.run(main()) + diff --git a/tests/test_nous_api_pattern.py b/tests/test_nous_api_pattern.py new file mode 100644 index 000000000..d450a6dc9 --- /dev/null +++ b/tests/test_nous_api_pattern.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +""" +Test to understand the pattern of failures - it's not about content length! +""" + +import asyncio +import os +from openai import AsyncOpenAI +from dotenv import load_dotenv + +load_dotenv() + +nous_client = AsyncOpenAI( + api_key=os.getenv("NOUS_API_KEY"), + base_url="https://inference-api.nousresearch.com/v1" +) + +MODEL = "gemini-2.5-flash" + +async def quick_test(description: str, content: str, **kwargs): + """Quick API test.""" + print(f"\n{description} ({len(content):,} chars)...", end=" ") + + try: + response = await nous_client.chat.completions.create( + model=MODEL, + messages=[ + {"role": "system", "content": "Summarize this."}, + {"role": "user", "content": content} + ], + **kwargs + ) + print(f"āœ… SUCCESS") + return True + except Exception as e: + print(f"āŒ FAILED: {str(e)[:80]}") + return False + +async def main(): + print("Testing different content types and parameters...") + + # Theory 1: Repeated characters trigger validation + print("\n" + "="*60) + print("THEORY 1: Repeated characters") + print("="*60) + await quick_test("Repeated 'A's (5k)", "A" * 5000, temperature=0.1, max_tokens=4000) + await asyncio.sleep(0.5) + await quick_test("Repeated 'A's (79k)", "A" * 79625, temperature=0.1, max_tokens=4000) + await asyncio.sleep(0.5) + await quick_test("Varied text (5k)", "Test content. " * 400, temperature=0.1, max_tokens=4000) + await asyncio.sleep(0.5) + await quick_test("Varied text (79k)", "Test content with variety. " * 3000, temperature=0.1, max_tokens=4000) + + # Theory 2: max_tokens parameter + print("\n" + "="*60) + print("THEORY 2: max_tokens parameter") + print("="*60) + content = "Test " * 4000 # 20k chars + await quick_test("max_tokens=4000", content, temperature=0.1, max_tokens=4000) + await asyncio.sleep(0.5) + await quick_test("max_tokens=8000", content, temperature=0.1, max_tokens=8000) + await asyncio.sleep(0.5) + await quick_test("max_tokens=2000", content, temperature=0.1, max_tokens=2000) + await asyncio.sleep(0.5) + await quick_test("No max_tokens", content, temperature=0.1) + + # Theory 3: Temperature parameter + print("\n" + "="*60) + print("THEORY 3: Temperature parameter") + print("="*60) + content = "Test " * 4000 + await quick_test("temperature=0.1", content, temperature=0.1, max_tokens=4000) + await asyncio.sleep(0.5) + await quick_test("temperature=0.0", content, temperature=0.0, max_tokens=4000) + await asyncio.sleep(0.5) + await quick_test("temperature=0.5", content, temperature=0.5, max_tokens=4000) + await asyncio.sleep(0.5) + await quick_test("No temperature", content, max_tokens=4000) + + # Theory 4: System prompt impact + print("\n" + "="*60) + print("THEORY 4: System prompt length") + print("="*60) + + short_system = "Summarize this." + long_system = """You are an expert content analyst. Your job is to process web content and create a comprehensive yet concise summary that preserves all important information while dramatically reducing bulk. + +Create a well-structured markdown summary that includes: +1. Key excerpts (quotes, code snippets, important facts) in their original format +2. Comprehensive summary of all other important information +3. Proper markdown formatting with headers, bullets, and emphasis + +Your goal is to preserve ALL important information while reducing length.""" + + content = "A" * 5000 + + print(f"\nShort system prompt...", end=" ") + try: + response = await nous_client.chat.completions.create( + model=MODEL, + messages=[ + {"role": "system", "content": short_system}, + {"role": "user", "content": content} + ], + temperature=0.1, + max_tokens=4000 + ) + print(f"āœ… SUCCESS") + except Exception as e: + print(f"āŒ FAILED") + + await asyncio.sleep(0.5) + + print(f"Long system prompt...", end=" ") + try: + response = await nous_client.chat.completions.create( + model=MODEL, + messages=[ + {"role": "system", "content": long_system}, + {"role": "user", "content": content} + ], + temperature=0.1, + max_tokens=4000 + ) + print(f"āœ… SUCCESS") + except Exception as e: + print(f"āŒ FAILED") + +if __name__ == "__main__": + asyncio.run(main()) + diff --git a/tests/test_temperature_fix.py b/tests/test_temperature_fix.py new file mode 100644 index 000000000..bab2ed282 --- /dev/null +++ b/tests/test_temperature_fix.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +""" +Test to confirm: temperature < 0.3 causes failures on Nous API +""" + +import asyncio +import os +from openai import AsyncOpenAI +from dotenv import load_dotenv + +load_dotenv() + +nous_client = AsyncOpenAI( + api_key=os.getenv("NOUS_API_KEY"), + base_url="https://inference-api.nousresearch.com/v1" +) + +MODEL = "gemini-2.5-flash" + +async def test_temp(temp_value): + """Test a specific temperature value.""" + content = "Test content. " * 1000 # 14k chars + + print(f"Testing temperature={temp_value}...", end=" ") + + try: + response = await nous_client.chat.completions.create( + model=MODEL, + messages=[ + {"role": "system", "content": "Summarize this content."}, + {"role": "user", "content": content} + ], + temperature=temp_value, + max_tokens=4000 + ) + print(f"āœ… SUCCESS") + return True + except Exception as e: + print(f"āŒ FAILED") + return False + +async def main(): + print("Testing temperature threshold for Nous API...") + print("="*60) + + temps = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 1.0] + + for temp in temps: + await test_temp(temp) + await asyncio.sleep(0.5) + + print("="*60) + print("\nNow testing with ACTUAL web_tools.py content and parameters:") + print("="*60) + + # Simulate the actual web_tools.py call + system_prompt = """You are an expert content analyst. Your job is to process web content and create a comprehensive yet concise summary that preserves all important information while dramatically reducing bulk. + +Create a well-structured markdown summary that includes: +1. Key excerpts (quotes, code snippets, important facts) in their original format +2. Comprehensive summary of all other important information +3. Proper markdown formatting with headers, bullets, and emphasis + +Your goal is to preserve ALL important information while reducing length. Never lose key facts, figures, insights, or actionable information. Make it scannable and well-organized.""" + + content = "Sample web page content. " * 3000 # ~75k chars like the real failures + + user_prompt = f"""Please process this web content and create a comprehensive markdown summary: + +CONTENT TO PROCESS: +{content} + +Create a markdown summary that captures all key information in a well-organized, scannable format. Include important quotes and code snippets in their original formatting. Focus on actionable information, specific details, and unique insights.""" + + print(f"\nActual web_tools call (temp=0.1, {len(content):,} chars)...", end=" ") + try: + response = await nous_client.chat.completions.create( + model=MODEL, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt} + ], + temperature=0.1, + max_tokens=4000 + ) + print(f"āœ… SUCCESS") + except: + print(f"āŒ FAILED") + + await asyncio.sleep(0.5) + + print(f"Same call but with temp=0.3...", end=" ") + try: + response = await nous_client.chat.completions.create( + model=MODEL, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt} + ], + temperature=0.3, + max_tokens=4000 + ) + print(f"āœ… SUCCESS") + except: + print(f"āŒ FAILED") + +if __name__ == "__main__": + asyncio.run(main()) + diff --git a/test_web_tools.py b/tests/test_web_tools.py similarity index 97% rename from test_web_tools.py rename to tests/test_web_tools.py index 7c86becb6..3214ee283 100644 --- a/test_web_tools.py +++ b/tests/test_web_tools.py @@ -1,620 +1,620 @@ -#!/usr/bin/env python3 -""" -Comprehensive Test Suite for Web Tools Module - -This script tests all web tools functionality to ensure they work correctly. -Run this after any updates to the web_tools.py module or Firecrawl library. - -Usage: - python test_web_tools.py # Run all tests - python test_web_tools.py --no-llm # Skip LLM processing tests - python test_web_tools.py --verbose # Show detailed output - -Requirements: - - FIRECRAWL_API_KEY environment variable must be set - - NOUS_API_KEY environment vitinariable (optional, for LLM tests) -""" - -import json -import asyncio -import sys -import os -import argparse -from datetime import datetime -from typing import List, Dict, Any - -# Import the web tools to test -from web_tools import ( - web_search_tool, - web_extract_tool, - web_crawl_tool, - check_firecrawl_api_key, - check_nous_api_key, - get_debug_session_info -) - - -class Colors: - """ANSI color codes for terminal output""" - HEADER = '\033[95m' - BLUE = '\033[94m' - CYAN = '\033[96m' - GREEN = '\033[92m' - WARNING = '\033[93m' - FAIL = '\033[91m' - ENDC = '\033[0m' - BOLD = '\033[1m' - UNDERLINE = '\033[4m' - - -def print_header(text: str): - """Print a formatted header""" - print(f"\n{Colors.HEADER}{Colors.BOLD}{'='*60}{Colors.ENDC}") - print(f"{Colors.HEADER}{Colors.BOLD}{text}{Colors.ENDC}") - print(f"{Colors.HEADER}{Colors.BOLD}{'='*60}{Colors.ENDC}") - - -def print_section(text: str): - """Print a formatted section header""" - print(f"\n{Colors.CYAN}{Colors.BOLD}šŸ“Œ {text}{Colors.ENDC}") - print(f"{Colors.CYAN}{'-'*50}{Colors.ENDC}") - - -def print_success(text: str): - """Print success message""" - print(f"{Colors.GREEN}āœ… {text}{Colors.ENDC}") - - -def print_error(text: str): - """Print error message""" - print(f"{Colors.FAIL}āŒ {text}{Colors.ENDC}") - - -def print_warning(text: str): - """Print warning message""" - print(f"{Colors.WARNING}āš ļø {text}{Colors.ENDC}") - - -def print_info(text: str, indent: int = 0): - """Print info message""" - indent_str = " " * indent - print(f"{indent_str}{Colors.BLUE}ā„¹ļø {text}{Colors.ENDC}") - - -class WebToolsTester: - """Test suite for web tools""" - - def __init__(self, verbose: bool = False, test_llm: bool = True): - self.verbose = verbose - self.test_llm = test_llm - self.test_results = { - "passed": [], - "failed": [], - "skipped": [] - } - self.start_time = None - self.end_time = None - - def log_result(self, test_name: str, status: str, details: str = ""): - """Log test result""" - result = { - "test": test_name, - "status": status, - "details": details, - "timestamp": datetime.now().isoformat() - } - - if status == "passed": - self.test_results["passed"].append(result) - print_success(f"{test_name}: {details}" if details else test_name) - elif status == "failed": - self.test_results["failed"].append(result) - print_error(f"{test_name}: {details}" if details else test_name) - elif status == "skipped": - self.test_results["skipped"].append(result) - print_warning(f"{test_name} skipped: {details}" if details else f"{test_name} skipped") - - def test_environment(self) -> bool: - """Test environment setup and API keys""" - print_section("Environment Check") - - # Check Firecrawl API key - if not check_firecrawl_api_key(): - self.log_result("Firecrawl API Key", "failed", "FIRECRAWL_API_KEY not set") - return False - else: - self.log_result("Firecrawl API Key", "passed", "Found") - - # Check Nous API key (optional) - if not check_nous_api_key(): - self.log_result("Nous API Key", "skipped", "NOUS_API_KEY not set (LLM tests will be skipped)") - self.test_llm = False - else: - self.log_result("Nous API Key", "passed", "Found") - - # Check debug mode - debug_info = get_debug_session_info() - if debug_info["enabled"]: - print_info(f"Debug mode enabled - Session: {debug_info['session_id']}") - print_info(f"Debug log: {debug_info['log_path']}") - - return True - - def test_web_search(self) -> List[str]: - """Test web search functionality""" - print_section("Test 1: Web Search") - - test_queries = [ - ("Python web scraping tutorial", 5), - ("Firecrawl API documentation", 3), - ("inflammatory arthritis symptoms treatment", 8) # Test medical query from your example - ] - - extracted_urls = [] - - for query, limit in test_queries: - try: - print(f"\n Testing search: '{query}' (limit={limit})") - - if self.verbose: - print(f" Calling web_search_tool(query='{query}', limit={limit})") - - # Perform search - result = web_search_tool(query, limit) - - # Parse result - try: - data = json.loads(result) - except json.JSONDecodeError as e: - self.log_result(f"Search: {query[:30]}...", "failed", f"Invalid JSON: {e}") - if self.verbose: - print(f" Raw response (first 500 chars): {result[:500]}...") - continue - - if "error" in data: - self.log_result(f"Search: {query[:30]}...", "failed", f"API error: {data['error']}") - continue - - # Check structure - if "success" not in data or "data" not in data: - self.log_result(f"Search: {query[:30]}...", "failed", "Missing success or data fields") - if self.verbose: - print(f" Response keys: {list(data.keys())}") - continue - - web_results = data.get("data", {}).get("web", []) - - if not web_results: - self.log_result(f"Search: {query[:30]}...", "failed", "Empty web results array") - if self.verbose: - print(f" data.web content: {data.get('data', {}).get('web')}") - continue - - # Validate each result - valid_results = 0 - missing_fields = [] - - for i, result in enumerate(web_results): - required_fields = ["url", "title", "description"] - has_all_fields = all(key in result for key in required_fields) - - if has_all_fields: - valid_results += 1 - # Collect URLs for extraction test - if len(extracted_urls) < 3: - extracted_urls.append(result["url"]) - - if self.verbose: - print(f" Result {i+1}: āœ“ {result['title'][:50]}...") - print(f" URL: {result['url'][:60]}...") - else: - missing = [f for f in required_fields if f not in result] - missing_fields.append(f"Result {i+1} missing: {missing}") - if self.verbose: - print(f" Result {i+1}: āœ— Missing fields: {missing}") - - # Log results - if valid_results == len(web_results): - self.log_result( - f"Search: {query[:30]}...", - "passed", - f"All {valid_results} results valid" - ) - else: - self.log_result( - f"Search: {query[:30]}...", - "failed", - f"Only {valid_results}/{len(web_results)} valid. Issues: {'; '.join(missing_fields[:3])}" - ) - - except Exception as e: - self.log_result(f"Search: {query[:30]}...", "failed", f"Exception: {type(e).__name__}: {str(e)}") - if self.verbose: - import traceback - print(f" Traceback: {traceback.format_exc()}") - - if self.verbose and extracted_urls: - print(f"\n URLs collected for extraction test: {len(extracted_urls)}") - for url in extracted_urls: - print(f" - {url}") - - return extracted_urls - - async def test_web_extract(self, urls: List[str] = None): - """Test web content extraction""" - print_section("Test 2: Web Extract (without LLM)") - - # Use provided URLs or defaults - if not urls: - urls = [ - "https://docs.firecrawl.dev/introduction", - "https://www.python.org/about/" - ] - print(f" Using default URLs for testing") - else: - print(f" Using {len(urls)} URLs from search results") - - # Test extraction - if urls: - try: - test_urls = urls[:2] # Test with max 2 URLs - print(f"\n Extracting content from {len(test_urls)} URL(s)...") - for url in test_urls: - print(f" - {url}") - - if self.verbose: - print(f" Calling web_extract_tool(urls={test_urls}, format='markdown', use_llm_processing=False)") - - result = await web_extract_tool( - test_urls, - format="markdown", - use_llm_processing=False - ) - - # Parse result - try: - data = json.loads(result) - except json.JSONDecodeError as e: - self.log_result("Extract (no LLM)", "failed", f"Invalid JSON: {e}") - if self.verbose: - print(f" Raw response (first 500 chars): {result[:500]}...") - return - - if "error" in data: - self.log_result("Extract (no LLM)", "failed", f"API error: {data['error']}") - return - - results = data.get("results", []) - - if not results: - self.log_result("Extract (no LLM)", "failed", "No results in response") - if self.verbose: - print(f" Response keys: {list(data.keys())}") - return - - # Validate each result - valid_results = 0 - failed_results = 0 - total_content_length = 0 - extraction_details = [] - - for i, result in enumerate(results): - title = result.get("title", "No title") - content = result.get("content", "") - error = result.get("error") - - if error: - failed_results += 1 - extraction_details.append(f"Page {i+1}: ERROR - {error}") - if self.verbose: - print(f" Page {i+1}: āœ— Error - {error}") - elif content: - content_len = len(content) - total_content_length += content_len - valid_results += 1 - extraction_details.append(f"Page {i+1}: {title[:40]}... ({content_len} chars)") - if self.verbose: - print(f" Page {i+1}: āœ“ {title[:50]}... - {content_len} characters") - print(f" First 100 chars: {content[:100]}...") - else: - extraction_details.append(f"Page {i+1}: {title[:40]}... (EMPTY)") - if self.verbose: - print(f" Page {i+1}: ⚠ {title[:50]}... - Empty content") - - # Log results - if valid_results > 0: - self.log_result( - "Extract (no LLM)", - "passed", - f"{valid_results}/{len(results)} pages extracted, {total_content_length} total chars" - ) - else: - self.log_result( - "Extract (no LLM)", - "failed", - f"No valid content. {failed_results} errors, {len(results) - failed_results} empty" - ) - if self.verbose: - print(f"\n Extraction details:") - for detail in extraction_details: - print(f" {detail}") - - except Exception as e: - self.log_result("Extract (no LLM)", "failed", f"Exception: {type(e).__name__}: {str(e)}") - if self.verbose: - import traceback - print(f" Traceback: {traceback.format_exc()}") - - async def test_web_extract_with_llm(self, urls: List[str] = None): - """Test web extraction with LLM processing""" - print_section("Test 3: Web Extract (with Gemini LLM)") - - if not self.test_llm: - self.log_result("Extract (with LLM)", "skipped", "LLM testing disabled") - return - - # Use a URL likely to have substantial content - test_url = urls[0] if urls else "https://docs.firecrawl.dev/features/scrape" - - try: - print(f"\n Extracting and processing: {test_url}") - - result = await web_extract_tool( - [test_url], - format="markdown", - use_llm_processing=True, - min_length=1000 # Lower threshold for testing - ) - - data = json.loads(result) - - if "error" in data: - self.log_result("Extract (with LLM)", "failed", data["error"]) - return - - results = data.get("results", []) - - if not results: - self.log_result("Extract (with LLM)", "failed", "No results returned") - return - - result = results[0] - content = result.get("content", "") - - if content: - content_len = len(content) - - # Check if content was actually processed (should be shorter than typical raw content) - if content_len > 0: - self.log_result( - "Extract (with LLM)", - "passed", - f"Content processed: {content_len} chars" - ) - - if self.verbose: - print(f"\n First 300 chars of processed content:") - print(f" {content[:300]}...") - else: - self.log_result("Extract (with LLM)", "failed", "No content after processing") - else: - self.log_result("Extract (with LLM)", "failed", "No content field in result") - - except json.JSONDecodeError as e: - self.log_result("Extract (with LLM)", "failed", f"Invalid JSON: {e}") - except Exception as e: - self.log_result("Extract (with LLM)", "failed", str(e)) - - async def test_web_crawl(self): - """Test web crawling functionality""" - print_section("Test 4: Web Crawl") - - test_sites = [ - ("https://docs.firecrawl.dev", None, 2), # Test docs site - ("https://firecrawl.dev", None, 3), # Test main site - ] - - for url, instructions, expected_min_pages in test_sites: - try: - print(f"\n Testing crawl of: {url}") - if instructions: - print(f" Instructions: {instructions}") - else: - print(f" No instructions (general crawl)") - print(f" Expected minimum pages: {expected_min_pages}") - - # Show what's being called - if self.verbose: - print(f" Calling web_crawl_tool(url='{url}', instructions={instructions}, use_llm_processing=False)") - - result = await web_crawl_tool( - url, - instructions=instructions, - use_llm_processing=False # Disable LLM for faster testing - ) - - # Check if result is valid JSON - try: - data = json.loads(result) - except json.JSONDecodeError as e: - self.log_result(f"Crawl: {url}", "failed", f"Invalid JSON response: {e}") - if self.verbose: - print(f" Raw response (first 500 chars): {result[:500]}...") - continue - - # Check for errors - if "error" in data: - self.log_result(f"Crawl: {url}", "failed", f"API error: {data['error']}") - continue - - # Get results - results = data.get("results", []) - - if not results: - self.log_result(f"Crawl: {url}", "failed", "No pages in results array") - if self.verbose: - print(f" Full response: {json.dumps(data, indent=2)[:1000]}...") - continue - - # Analyze pages - valid_pages = 0 - empty_pages = 0 - total_content = 0 - page_details = [] - - for i, page in enumerate(results): - content = page.get("content", "") - title = page.get("title", "Untitled") - error = page.get("error") - - if error: - page_details.append(f"Page {i+1}: ERROR - {error}") - elif content: - valid_pages += 1 - content_len = len(content) - total_content += content_len - page_details.append(f"Page {i+1}: {title[:40]}... ({content_len} chars)") - else: - empty_pages += 1 - page_details.append(f"Page {i+1}: {title[:40]}... (EMPTY)") - - # Show detailed results if verbose - if self.verbose: - print(f"\n Crawl Results:") - print(f" Total pages returned: {len(results)}") - print(f" Valid pages (with content): {valid_pages}") - print(f" Empty pages: {empty_pages}") - print(f" Total content size: {total_content} characters") - print(f"\n Page Details:") - for detail in page_details[:10]: # Show first 10 pages - print(f" - {detail}") - if len(page_details) > 10: - print(f" ... and {len(page_details) - 10} more pages") - - # Determine pass/fail - if valid_pages >= expected_min_pages: - self.log_result( - f"Crawl: {url}", - "passed", - f"{valid_pages}/{len(results)} valid pages, {total_content} chars total" - ) - else: - self.log_result( - f"Crawl: {url}", - "failed", - f"Only {valid_pages} valid pages (expected >= {expected_min_pages}), {empty_pages} empty, {len(results)} total" - ) - - except Exception as e: - self.log_result(f"Crawl: {url}", "failed", f"Exception: {type(e).__name__}: {str(e)}") - if self.verbose: - import traceback - print(f" Traceback:") - print(" " + "\n ".join(traceback.format_exc().split("\n"))) - - async def run_all_tests(self): - """Run all tests""" - self.start_time = datetime.now() - - print_header("WEB TOOLS TEST SUITE") - print(f"Started at: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}") - - # Test environment - if not self.test_environment(): - print_error("\nCannot proceed without required API keys!") - return False - - # Test search and collect URLs - urls = self.test_web_search() - - # Test extraction - await self.test_web_extract(urls if urls else None) - - # Test extraction with LLM - if self.test_llm: - await self.test_web_extract_with_llm(urls if urls else None) - - # Test crawling - await self.test_web_crawl() - - # Print summary - self.end_time = datetime.now() - duration = (self.end_time - self.start_time).total_seconds() - - print_header("TEST SUMMARY") - print(f"Duration: {duration:.2f} seconds") - print(f"\n{Colors.GREEN}Passed: {len(self.test_results['passed'])}{Colors.ENDC}") - print(f"{Colors.FAIL}Failed: {len(self.test_results['failed'])}{Colors.ENDC}") - print(f"{Colors.WARNING}Skipped: {len(self.test_results['skipped'])}{Colors.ENDC}") - - # List failed tests - if self.test_results["failed"]: - print(f"\n{Colors.FAIL}{Colors.BOLD}Failed Tests:{Colors.ENDC}") - for test in self.test_results["failed"]: - print(f" - {test['test']}: {test['details']}") - - # Save results to file - self.save_results() - - return len(self.test_results["failed"]) == 0 - - def save_results(self): - """Save test results to a JSON file""" - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - filename = f"test_results_web_tools_{timestamp}.json" - - results = { - "test_suite": "Web Tools", - "start_time": self.start_time.isoformat() if self.start_time else None, - "end_time": self.end_time.isoformat() if self.end_time else None, - "duration_seconds": (self.end_time - self.start_time).total_seconds() if self.start_time and self.end_time else None, - "summary": { - "passed": len(self.test_results["passed"]), - "failed": len(self.test_results["failed"]), - "skipped": len(self.test_results["skipped"]) - }, - "results": self.test_results, - "environment": { - "firecrawl_api_key": check_firecrawl_api_key(), - "nous_api_key": check_nous_api_key(), - "debug_mode": get_debug_session_info()["enabled"] - } - } - - try: - with open(filename, 'w') as f: - json.dump(results, f, indent=2) - print_info(f"Test results saved to: {filename}") - except Exception as e: - print_warning(f"Failed to save results: {e}") - - -async def main(): - """Main entry point""" - parser = argparse.ArgumentParser(description="Test Web Tools Module") - parser.add_argument("--no-llm", action="store_true", help="Skip LLM processing tests") - parser.add_argument("--verbose", "-v", action="store_true", help="Show detailed output") - parser.add_argument("--debug", action="store_true", help="Enable debug mode for web tools") - - args = parser.parse_args() - - # Set debug mode if requested - if args.debug: - os.environ["WEB_TOOLS_DEBUG"] = "true" - print_info("Debug mode enabled for web tools") - - # Create tester - tester = WebToolsTester( - verbose=args.verbose, - test_llm=not args.no_llm - ) - - # Run tests - success = await tester.run_all_tests() - - # Exit with appropriate code - sys.exit(0 if success else 1) - - -if __name__ == "__main__": - asyncio.run(main()) +#!/usr/bin/env python3 +""" +Comprehensive Test Suite for Web Tools Module + +This script tests all web tools functionality to ensure they work correctly. +Run this after any updates to the web_tools.py module or Firecrawl library. + +Usage: + python test_web_tools.py # Run all tests + python test_web_tools.py --no-llm # Skip LLM processing tests + python test_web_tools.py --verbose # Show detailed output + +Requirements: + - FIRECRAWL_API_KEY environment variable must be set + - NOUS_API_KEY environment vitinariable (optional, for LLM tests) +""" + +import json +import asyncio +import sys +import os +import argparse +from datetime import datetime +from typing import List, Dict, Any + +# Import the web tools to test (updated path after moving tools/) +from tools.web_tools import ( + web_search_tool, + web_extract_tool, + web_crawl_tool, + check_firecrawl_api_key, + check_nous_api_key, + get_debug_session_info +) + + +class Colors: + """ANSI color codes for terminal output""" + HEADER = '\033[95m' + BLUE = '\033[94m' + CYAN = '\033[96m' + GREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + + +def print_header(text: str): + """Print a formatted header""" + print(f"\n{Colors.HEADER}{Colors.BOLD}{'='*60}{Colors.ENDC}") + print(f"{Colors.HEADER}{Colors.BOLD}{text}{Colors.ENDC}") + print(f"{Colors.HEADER}{Colors.BOLD}{'='*60}{Colors.ENDC}") + + +def print_section(text: str): + """Print a formatted section header""" + print(f"\n{Colors.CYAN}{Colors.BOLD}šŸ“Œ {text}{Colors.ENDC}") + print(f"{Colors.CYAN}{'-'*50}{Colors.ENDC}") + + +def print_success(text: str): + """Print success message""" + print(f"{Colors.GREEN}āœ… {text}{Colors.ENDC}") + + +def print_error(text: str): + """Print error message""" + print(f"{Colors.FAIL}āŒ {text}{Colors.ENDC}") + + +def print_warning(text: str): + """Print warning message""" + print(f"{Colors.WARNING}āš ļø {text}{Colors.ENDC}") + + +def print_info(text: str, indent: int = 0): + """Print info message""" + indent_str = " " * indent + print(f"{indent_str}{Colors.BLUE}ā„¹ļø {text}{Colors.ENDC}") + + +class WebToolsTester: + """Test suite for web tools""" + + def __init__(self, verbose: bool = False, test_llm: bool = True): + self.verbose = verbose + self.test_llm = test_llm + self.test_results = { + "passed": [], + "failed": [], + "skipped": [] + } + self.start_time = None + self.end_time = None + + def log_result(self, test_name: str, status: str, details: str = ""): + """Log test result""" + result = { + "test": test_name, + "status": status, + "details": details, + "timestamp": datetime.now().isoformat() + } + + if status == "passed": + self.test_results["passed"].append(result) + print_success(f"{test_name}: {details}" if details else test_name) + elif status == "failed": + self.test_results["failed"].append(result) + print_error(f"{test_name}: {details}" if details else test_name) + elif status == "skipped": + self.test_results["skipped"].append(result) + print_warning(f"{test_name} skipped: {details}" if details else f"{test_name} skipped") + + def test_environment(self) -> bool: + """Test environment setup and API keys""" + print_section("Environment Check") + + # Check Firecrawl API key + if not check_firecrawl_api_key(): + self.log_result("Firecrawl API Key", "failed", "FIRECRAWL_API_KEY not set") + return False + else: + self.log_result("Firecrawl API Key", "passed", "Found") + + # Check Nous API key (optional) + if not check_nous_api_key(): + self.log_result("Nous API Key", "skipped", "NOUS_API_KEY not set (LLM tests will be skipped)") + self.test_llm = False + else: + self.log_result("Nous API Key", "passed", "Found") + + # Check debug mode + debug_info = get_debug_session_info() + if debug_info["enabled"]: + print_info(f"Debug mode enabled - Session: {debug_info['session_id']}") + print_info(f"Debug log: {debug_info['log_path']}") + + return True + + def test_web_search(self) -> List[str]: + """Test web search functionality""" + print_section("Test 1: Web Search") + + test_queries = [ + ("Python web scraping tutorial", 5), + ("Firecrawl API documentation", 3), + ("inflammatory arthritis symptoms treatment", 8) # Test medical query from your example + ] + + extracted_urls = [] + + for query, limit in test_queries: + try: + print(f"\n Testing search: '{query}' (limit={limit})") + + if self.verbose: + print(f" Calling web_search_tool(query='{query}', limit={limit})") + + # Perform search + result = web_search_tool(query, limit) + + # Parse result + try: + data = json.loads(result) + except json.JSONDecodeError as e: + self.log_result(f"Search: {query[:30]}...", "failed", f"Invalid JSON: {e}") + if self.verbose: + print(f" Raw response (first 500 chars): {result[:500]}...") + continue + + if "error" in data: + self.log_result(f"Search: {query[:30]}...", "failed", f"API error: {data['error']}") + continue + + # Check structure + if "success" not in data or "data" not in data: + self.log_result(f"Search: {query[:30]}...", "failed", "Missing success or data fields") + if self.verbose: + print(f" Response keys: {list(data.keys())}") + continue + + web_results = data.get("data", {}).get("web", []) + + if not web_results: + self.log_result(f"Search: {query[:30]}...", "failed", "Empty web results array") + if self.verbose: + print(f" data.web content: {data.get('data', {}).get('web')}") + continue + + # Validate each result + valid_results = 0 + missing_fields = [] + + for i, result in enumerate(web_results): + required_fields = ["url", "title", "description"] + has_all_fields = all(key in result for key in required_fields) + + if has_all_fields: + valid_results += 1 + # Collect URLs for extraction test + if len(extracted_urls) < 3: + extracted_urls.append(result["url"]) + + if self.verbose: + print(f" Result {i+1}: āœ“ {result['title'][:50]}...") + print(f" URL: {result['url'][:60]}...") + else: + missing = [f for f in required_fields if f not in result] + missing_fields.append(f"Result {i+1} missing: {missing}") + if self.verbose: + print(f" Result {i+1}: āœ— Missing fields: {missing}") + + # Log results + if valid_results == len(web_results): + self.log_result( + f"Search: {query[:30]}...", + "passed", + f"All {valid_results} results valid" + ) + else: + self.log_result( + f"Search: {query[:30]}...", + "failed", + f"Only {valid_results}/{len(web_results)} valid. Issues: {'; '.join(missing_fields[:3])}" + ) + + except Exception as e: + self.log_result(f"Search: {query[:30]}...", "failed", f"Exception: {type(e).__name__}: {str(e)}") + if self.verbose: + import traceback + print(f" Traceback: {traceback.format_exc()}") + + if self.verbose and extracted_urls: + print(f"\n URLs collected for extraction test: {len(extracted_urls)}") + for url in extracted_urls: + print(f" - {url}") + + return extracted_urls + + async def test_web_extract(self, urls: List[str] = None): + """Test web content extraction""" + print_section("Test 2: Web Extract (without LLM)") + + # Use provided URLs or defaults + if not urls: + urls = [ + "https://docs.firecrawl.dev/introduction", + "https://www.python.org/about/" + ] + print(f" Using default URLs for testing") + else: + print(f" Using {len(urls)} URLs from search results") + + # Test extraction + if urls: + try: + test_urls = urls[:2] # Test with max 2 URLs + print(f"\n Extracting content from {len(test_urls)} URL(s)...") + for url in test_urls: + print(f" - {url}") + + if self.verbose: + print(f" Calling web_extract_tool(urls={test_urls}, format='markdown', use_llm_processing=False)") + + result = await web_extract_tool( + test_urls, + format="markdown", + use_llm_processing=False + ) + + # Parse result + try: + data = json.loads(result) + except json.JSONDecodeError as e: + self.log_result("Extract (no LLM)", "failed", f"Invalid JSON: {e}") + if self.verbose: + print(f" Raw response (first 500 chars): {result[:500]}...") + return + + if "error" in data: + self.log_result("Extract (no LLM)", "failed", f"API error: {data['error']}") + return + + results = data.get("results", []) + + if not results: + self.log_result("Extract (no LLM)", "failed", "No results in response") + if self.verbose: + print(f" Response keys: {list(data.keys())}") + return + + # Validate each result + valid_results = 0 + failed_results = 0 + total_content_length = 0 + extraction_details = [] + + for i, result in enumerate(results): + title = result.get("title", "No title") + content = result.get("content", "") + error = result.get("error") + + if error: + failed_results += 1 + extraction_details.append(f"Page {i+1}: ERROR - {error}") + if self.verbose: + print(f" Page {i+1}: āœ— Error - {error}") + elif content: + content_len = len(content) + total_content_length += content_len + valid_results += 1 + extraction_details.append(f"Page {i+1}: {title[:40]}... ({content_len} chars)") + if self.verbose: + print(f" Page {i+1}: āœ“ {title[:50]}... - {content_len} characters") + print(f" First 100 chars: {content[:100]}...") + else: + extraction_details.append(f"Page {i+1}: {title[:40]}... (EMPTY)") + if self.verbose: + print(f" Page {i+1}: ⚠ {title[:50]}... - Empty content") + + # Log results + if valid_results > 0: + self.log_result( + "Extract (no LLM)", + "passed", + f"{valid_results}/{len(results)} pages extracted, {total_content_length} total chars" + ) + else: + self.log_result( + "Extract (no LLM)", + "failed", + f"No valid content. {failed_results} errors, {len(results) - failed_results} empty" + ) + if self.verbose: + print(f"\n Extraction details:") + for detail in extraction_details: + print(f" {detail}") + + except Exception as e: + self.log_result("Extract (no LLM)", "failed", f"Exception: {type(e).__name__}: {str(e)}") + if self.verbose: + import traceback + print(f" Traceback: {traceback.format_exc()}") + + async def test_web_extract_with_llm(self, urls: List[str] = None): + """Test web extraction with LLM processing""" + print_section("Test 3: Web Extract (with Gemini LLM)") + + if not self.test_llm: + self.log_result("Extract (with LLM)", "skipped", "LLM testing disabled") + return + + # Use a URL likely to have substantial content + test_url = urls[0] if urls else "https://docs.firecrawl.dev/features/scrape" + + try: + print(f"\n Extracting and processing: {test_url}") + + result = await web_extract_tool( + [test_url], + format="markdown", + use_llm_processing=True, + min_length=1000 # Lower threshold for testing + ) + + data = json.loads(result) + + if "error" in data: + self.log_result("Extract (with LLM)", "failed", data["error"]) + return + + results = data.get("results", []) + + if not results: + self.log_result("Extract (with LLM)", "failed", "No results returned") + return + + result = results[0] + content = result.get("content", "") + + if content: + content_len = len(content) + + # Check if content was actually processed (should be shorter than typical raw content) + if content_len > 0: + self.log_result( + "Extract (with LLM)", + "passed", + f"Content processed: {content_len} chars" + ) + + if self.verbose: + print(f"\n First 300 chars of processed content:") + print(f" {content[:300]}...") + else: + self.log_result("Extract (with LLM)", "failed", "No content after processing") + else: + self.log_result("Extract (with LLM)", "failed", "No content field in result") + + except json.JSONDecodeError as e: + self.log_result("Extract (with LLM)", "failed", f"Invalid JSON: {e}") + except Exception as e: + self.log_result("Extract (with LLM)", "failed", str(e)) + + async def test_web_crawl(self): + """Test web crawling functionality""" + print_section("Test 4: Web Crawl") + + test_sites = [ + ("https://docs.firecrawl.dev", None, 2), # Test docs site + ("https://firecrawl.dev", None, 3), # Test main site + ] + + for url, instructions, expected_min_pages in test_sites: + try: + print(f"\n Testing crawl of: {url}") + if instructions: + print(f" Instructions: {instructions}") + else: + print(f" No instructions (general crawl)") + print(f" Expected minimum pages: {expected_min_pages}") + + # Show what's being called + if self.verbose: + print(f" Calling web_crawl_tool(url='{url}', instructions={instructions}, use_llm_processing=False)") + + result = await web_crawl_tool( + url, + instructions=instructions, + use_llm_processing=False # Disable LLM for faster testing + ) + + # Check if result is valid JSON + try: + data = json.loads(result) + except json.JSONDecodeError as e: + self.log_result(f"Crawl: {url}", "failed", f"Invalid JSON response: {e}") + if self.verbose: + print(f" Raw response (first 500 chars): {result[:500]}...") + continue + + # Check for errors + if "error" in data: + self.log_result(f"Crawl: {url}", "failed", f"API error: {data['error']}") + continue + + # Get results + results = data.get("results", []) + + if not results: + self.log_result(f"Crawl: {url}", "failed", "No pages in results array") + if self.verbose: + print(f" Full response: {json.dumps(data, indent=2)[:1000]}...") + continue + + # Analyze pages + valid_pages = 0 + empty_pages = 0 + total_content = 0 + page_details = [] + + for i, page in enumerate(results): + content = page.get("content", "") + title = page.get("title", "Untitled") + error = page.get("error") + + if error: + page_details.append(f"Page {i+1}: ERROR - {error}") + elif content: + valid_pages += 1 + content_len = len(content) + total_content += content_len + page_details.append(f"Page {i+1}: {title[:40]}... ({content_len} chars)") + else: + empty_pages += 1 + page_details.append(f"Page {i+1}: {title[:40]}... (EMPTY)") + + # Show detailed results if verbose + if self.verbose: + print(f"\n Crawl Results:") + print(f" Total pages returned: {len(results)}") + print(f" Valid pages (with content): {valid_pages}") + print(f" Empty pages: {empty_pages}") + print(f" Total content size: {total_content} characters") + print(f"\n Page Details:") + for detail in page_details[:10]: # Show first 10 pages + print(f" - {detail}") + if len(page_details) > 10: + print(f" ... and {len(page_details) - 10} more pages") + + # Determine pass/fail + if valid_pages >= expected_min_pages: + self.log_result( + f"Crawl: {url}", + "passed", + f"{valid_pages}/{len(results)} valid pages, {total_content} chars total" + ) + else: + self.log_result( + f"Crawl: {url}", + "failed", + f"Only {valid_pages} valid pages (expected >= {expected_min_pages}), {empty_pages} empty, {len(results)} total" + ) + + except Exception as e: + self.log_result(f"Crawl: {url}", "failed", f"Exception: {type(e).__name__}: {str(e)}") + if self.verbose: + import traceback + print(f" Traceback:") + print(" " + "\n ".join(traceback.format_exc().split("\n"))) + + async def run_all_tests(self): + """Run all tests""" + self.start_time = datetime.now() + + print_header("WEB TOOLS TEST SUITE") + print(f"Started at: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}") + + # Test environment + if not self.test_environment(): + print_error("\nCannot proceed without required API keys!") + return False + + # Test search and collect URLs + urls = self.test_web_search() + + # Test extraction + await self.test_web_extract(urls if urls else None) + + # Test extraction with LLM + if self.test_llm: + await self.test_web_extract_with_llm(urls if urls else None) + + # Test crawling + await self.test_web_crawl() + + # Print summary + self.end_time = datetime.now() + duration = (self.end_time - self.start_time).total_seconds() + + print_header("TEST SUMMARY") + print(f"Duration: {duration:.2f} seconds") + print(f"\n{Colors.GREEN}Passed: {len(self.test_results['passed'])}{Colors.ENDC}") + print(f"{Colors.FAIL}Failed: {len(self.test_results['failed'])}{Colors.ENDC}") + print(f"{Colors.WARNING}Skipped: {len(self.test_results['skipped'])}{Colors.ENDC}") + + # List failed tests + if self.test_results["failed"]: + print(f"\n{Colors.FAIL}{Colors.BOLD}Failed Tests:{Colors.ENDC}") + for test in self.test_results["failed"]: + print(f" - {test['test']}: {test['details']}") + + # Save results to file + self.save_results() + + return len(self.test_results["failed"]) == 0 + + def save_results(self): + """Save test results to a JSON file""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"test_results_web_tools_{timestamp}.json" + + results = { + "test_suite": "Web Tools", + "start_time": self.start_time.isoformat() if self.start_time else None, + "end_time": self.end_time.isoformat() if self.end_time else None, + "duration_seconds": (self.end_time - self.start_time).total_seconds() if self.start_time and self.end_time else None, + "summary": { + "passed": len(self.test_results["passed"]), + "failed": len(self.test_results["failed"]), + "skipped": len(self.test_results["skipped"]) + }, + "results": self.test_results, + "environment": { + "firecrawl_api_key": check_firecrawl_api_key(), + "nous_api_key": check_nous_api_key(), + "debug_mode": get_debug_session_info()["enabled"] + } + } + + try: + with open(filename, 'w') as f: + json.dump(results, f, indent=2, ensure_ascii=False) + print_info(f"Test results saved to: {filename}") + except Exception as e: + print_warning(f"Failed to save results: {e}") + + +async def main(): + """Main entry point""" + parser = argparse.ArgumentParser(description="Test Web Tools Module") + parser.add_argument("--no-llm", action="store_true", help="Skip LLM processing tests") + parser.add_argument("--verbose", "-v", action="store_true", help="Show detailed output") + parser.add_argument("--debug", action="store_true", help="Enable debug mode for web tools") + + args = parser.parse_args() + + # Set debug mode if requested + if args.debug: + os.environ["WEB_TOOLS_DEBUG"] = "true" + print_info("Debug mode enabled for web tools") + + # Create tester + tester = WebToolsTester( + verbose=args.verbose, + test_llm=not args.no_llm + ) + + # Run tests + success = await tester.run_all_tests() + + # Exit with appropriate code + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 000000000..9843e757e --- /dev/null +++ b/tools/__init__.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +""" +Tools Package + +This package contains all the specific tool implementations for the Hermes Agent. +Each module provides specialized functionality for different capabilities: + +- web_tools: Web search, content extraction, and crawling +- terminal_tool: Command execution on virtual machines +- vision_tools: Image analysis and understanding +- mixture_of_agents_tool: Multi-model collaborative reasoning +- image_generation_tool: Text-to-image generation with upscaling + +The tools are imported into model_tools.py which provides a unified interface +for the AI agent to access all capabilities. +""" + +# Export all tools for easy importing +from .web_tools import ( + web_search_tool, + web_extract_tool, + web_crawl_tool, + check_firecrawl_api_key +) + +from .terminal_tool import ( + terminal_tool, + check_hecate_requirements, + TERMINAL_TOOL_DESCRIPTION +) + +from .vision_tools import ( + vision_analyze_tool, + check_vision_requirements +) + +from .mixture_of_agents_tool import ( + mixture_of_agents_tool, + check_moa_requirements +) + +from .image_generation_tool import ( + image_generate_tool, + check_image_generation_requirements +) + +__all__ = [ + # Web tools + 'web_search_tool', + 'web_extract_tool', + 'web_crawl_tool', + 'check_firecrawl_api_key', + # Terminal tools + 'terminal_tool', + 'check_hecate_requirements', + 'TERMINAL_TOOL_DESCRIPTION', + # Vision tools + 'vision_analyze_tool', + 'check_vision_requirements', + # MoA tools + 'mixture_of_agents_tool', + 'check_moa_requirements', + # Image generation tools + 'image_generate_tool', + 'check_image_generation_requirements', +] + diff --git a/image_generation_tool.py b/tools/image_generation_tool.py similarity index 98% rename from image_generation_tool.py rename to tools/image_generation_tool.py index 09d4e51d6..7ceae7dbc 100644 --- a/image_generation_tool.py +++ b/tools/image_generation_tool.py @@ -319,9 +319,6 @@ async def image_generate_tool( if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0: raise ValueError("Prompt is required and must be a non-empty string") - if len(prompt) > 1000: - raise ValueError("Prompt must be 1000 characters or less") - # Check API key availability if not os.getenv("FAL_KEY"): raise ValueError("FAL_KEY environment variable not set") @@ -417,7 +414,7 @@ async def image_generate_tool( _log_debug_call("image_generate_tool", debug_call_data) _save_debug_log() - return json.dumps(response_data, indent=2) + return json.dumps(response_data, indent=2, ensure_ascii=False) except Exception as e: generation_time = (datetime.datetime.now() - start_time).total_seconds() @@ -435,7 +432,7 @@ async def image_generate_tool( _log_debug_call("image_generate_tool", debug_call_data) _save_debug_log() - return json.dumps(response_data, indent=2) + return json.dumps(response_data, indent=2, ensure_ascii=False) def check_fal_api_key() -> bool: diff --git a/mixture_of_agents_tool.py b/tools/mixture_of_agents_tool.py similarity index 96% rename from mixture_of_agents_tool.py rename to tools/mixture_of_agents_tool.py index 206b14963..c94d9e1de 100644 --- a/mixture_of_agents_tool.py +++ b/tools/mixture_of_agents_tool.py @@ -1,586 +1,586 @@ -#!/usr/bin/env python3 -""" -Mixture-of-Agents Tool Module - -This module implements the Mixture-of-Agents (MoA) methodology that leverages -the collective strengths of multiple LLMs through a layered architecture to -achieve state-of-the-art performance on complex reasoning tasks. - -Based on the research paper: "Mixture-of-Agents Enhances Large Language Model Capabilities" -by Junlin Wang et al. (arXiv:2406.04692v1) - -Key Features: -- Multi-layer LLM collaboration for enhanced reasoning -- Parallel processing of reference models for efficiency -- Intelligent aggregation and synthesis of diverse responses -- Specialized for extremely difficult problems requiring intense reasoning -- Optimized for coding, mathematics, and complex analytical tasks - -Available Tool: -- mixture_of_agents_tool: Process complex queries using multiple frontier models - -Architecture: -1. Reference models generate diverse initial responses in parallel -2. Aggregator model synthesizes responses into a high-quality output -3. Multiple layers can be used for iterative refinement (future enhancement) - -Models Used: -- Reference Models: claude-opus-4-20250514, gemini-2.5-pro, o4-mini, deepseek-r1 -- Aggregator Model: claude-opus-4-20250514 (highest capability for synthesis) - -Configuration: - To customize the MoA setup, modify the configuration constants at the top of this file: - - REFERENCE_MODELS: List of models for generating diverse initial responses - - AGGREGATOR_MODEL: Model used to synthesize the final response - - REFERENCE_TEMPERATURE/AGGREGATOR_TEMPERATURE: Sampling temperatures - - MIN_SUCCESSFUL_REFERENCES: Minimum successful models needed to proceed - -Usage: - from mixture_of_agents_tool import mixture_of_agents_tool - import asyncio - - # Process a complex query - result = await mixture_of_agents_tool( - user_prompt="Solve this complex mathematical proof..." - ) -""" - -import json -import os -import asyncio -import uuid -import datetime -from pathlib import Path -from typing import Dict, Any, List, Optional -from openai import AsyncOpenAI - -# Initialize Nous Research API client for MoA processing -nous_client = AsyncOpenAI( - api_key=os.getenv("NOUS_API_KEY"), - base_url="https://inference-api.nousresearch.com/v1" -) - -# Configuration for MoA processing -# Reference models - these generate diverse initial responses in parallel -REFERENCE_MODELS = [ - "claude-opus-4-20250514", - "gemini-2.5-pro", - "gpt-5", - "deepseek-r1" -] - -# Aggregator model - synthesizes reference responses into final output -AGGREGATOR_MODEL = "claude-opus-4-20250514" # Use highest capability model for aggregation - -# Temperature settings optimized for MoA performance -REFERENCE_TEMPERATURE = 0.6 # Balanced creativity for diverse perspectives -AGGREGATOR_TEMPERATURE = 0.4 # Focused synthesis for consistency - -# Failure handling configuration -MIN_SUCCESSFUL_REFERENCES = 1 # Minimum successful reference models needed to proceed - -# System prompt for the aggregator model (from the research paper) -AGGREGATOR_SYSTEM_PROMPT = """You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. - -Responses from models:""" - -# Debug mode configuration -DEBUG_MODE = os.getenv("MOA_TOOLS_DEBUG", "false").lower() == "true" -DEBUG_SESSION_ID = str(uuid.uuid4()) -DEBUG_LOG_PATH = Path("./logs") -DEBUG_DATA = { - "session_id": DEBUG_SESSION_ID, - "start_time": datetime.datetime.now().isoformat(), - "debug_enabled": DEBUG_MODE, - "tool_calls": [] -} if DEBUG_MODE else None - -# Create logs directory if debug mode is enabled -if DEBUG_MODE: - DEBUG_LOG_PATH.mkdir(exist_ok=True) - print(f"šŸ› MoA debug mode enabled - Session ID: {DEBUG_SESSION_ID}") - - -def _log_debug_call(tool_name: str, call_data: Dict[str, Any]) -> None: - """ - Log a debug call entry to the global debug data structure. - - Args: - tool_name (str): Name of the tool being called - call_data (Dict[str, Any]): Data about the call including parameters and results - """ - if not DEBUG_MODE or not DEBUG_DATA: - return - - call_entry = { - "timestamp": datetime.datetime.now().isoformat(), - "tool_name": tool_name, - **call_data - } - - DEBUG_DATA["tool_calls"].append(call_entry) - - -def _save_debug_log() -> None: - """ - Save the current debug data to a JSON file in the logs directory. - """ - if not DEBUG_MODE or not DEBUG_DATA: - return - - try: - debug_filename = f"moa_tools_debug_{DEBUG_SESSION_ID}.json" - debug_filepath = DEBUG_LOG_PATH / debug_filename - - # Update end time - DEBUG_DATA["end_time"] = datetime.datetime.now().isoformat() - DEBUG_DATA["total_calls"] = len(DEBUG_DATA["tool_calls"]) - - with open(debug_filepath, 'w', encoding='utf-8') as f: - json.dump(DEBUG_DATA, f, indent=2, ensure_ascii=False) - - print(f"šŸ› MoA debug log saved: {debug_filepath}") - - except Exception as e: - print(f"āŒ Error saving MoA debug log: {str(e)}") - - -def _construct_aggregator_prompt(system_prompt: str, responses: List[str]) -> str: - """ - Construct the final system prompt for the aggregator including all model responses. - - Args: - system_prompt (str): Base system prompt for aggregation - responses (List[str]): List of responses from reference models - - Returns: - str: Complete system prompt with enumerated responses - """ - response_text = "\n".join([f"{i+1}. {response}" for i, response in enumerate(responses)]) - return f"{system_prompt}\n\n{response_text}" - - -async def _run_reference_model_safe( - model: str, - user_prompt: str, - temperature: float = REFERENCE_TEMPERATURE, - max_tokens: int = 32000, - max_retries: int = 3 -) -> tuple[str, str, bool]: - """ - Run a single reference model with retry logic and graceful failure handling. - - Args: - model (str): Model identifier to use - user_prompt (str): The user's query - temperature (float): Sampling temperature for response generation - max_tokens (int): Maximum tokens in response - max_retries (int): Maximum number of retry attempts - - Returns: - tuple[str, str, bool]: (model_name, response_content_or_error, success_flag) - """ - for attempt in range(max_retries): - try: - print(f"šŸ¤– Querying {model} (attempt {attempt + 1}/{max_retries})") - - # Build parameters for the API call - api_params = { - "model": model, - "messages": [{"role": "user", "content": user_prompt}] - } - - # GPT models (especially gpt-4o-mini) don't support custom temperature values - # Only include temperature for non-GPT models - if not model.lower().startswith('gpt-'): - api_params["temperature"] = temperature - - response = await nous_client.chat.completions.create(**api_params) - - content = response.choices[0].message.content.strip() - print(f"āœ… {model} responded ({len(content)} characters)") - return model, content, True - - except Exception as e: - error_str = str(e) - # Log more detailed error information for debugging - if "invalid" in error_str.lower(): - print(f"āš ļø {model} invalid request error (attempt {attempt + 1}): {error_str}") - elif "rate" in error_str.lower() or "limit" in error_str.lower(): - print(f"āš ļø {model} rate limit error (attempt {attempt + 1}): {error_str}") - else: - print(f"āš ļø {model} unknown error (attempt {attempt + 1}): {error_str}") - - if attempt < max_retries - 1: - # Exponential backoff for rate limiting - sleep_time = 2 ** attempt - print(f" Retrying in {sleep_time}s...") - await asyncio.sleep(sleep_time) - else: - error_msg = f"{model} failed after {max_retries} attempts: {error_str}" - print(f"āŒ {error_msg}") - return model, error_msg, False - - -async def _run_aggregator_model( - system_prompt: str, - user_prompt: str, - temperature: float = AGGREGATOR_TEMPERATURE, - max_tokens: int = None -) -> str: - """ - Run the aggregator model to synthesize the final response. - - Args: - system_prompt (str): System prompt with all reference responses - user_prompt (str): Original user query - temperature (float): Focused temperature for consistent aggregation - max_tokens (int): Maximum tokens in final response - - Returns: - str: Synthesized final response - """ - print(f"🧠 Running aggregator model: {AGGREGATOR_MODEL}") - - # Build parameters for the API call - api_params = { - "model": AGGREGATOR_MODEL, - "messages": [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} - ] - } - - # GPT models (especially gpt-4o-mini) don't support custom temperature values - # Only include temperature for non-GPT models - if not AGGREGATOR_MODEL.lower().startswith('gpt-'): - api_params["temperature"] = temperature - - response = await nous_client.chat.completions.create(**api_params) - - content = response.choices[0].message.content.strip() - print(f"āœ… Aggregation complete ({len(content)} characters)") - return content - - -async def mixture_of_agents_tool( - user_prompt: str, - reference_models: Optional[List[str]] = None, - aggregator_model: Optional[str] = None -) -> str: - """ - Process a complex query using the Mixture-of-Agents methodology. - - This tool leverages multiple frontier language models to collaboratively solve - extremely difficult problems requiring intense reasoning. It's particularly - effective for: - - Complex mathematical proofs and calculations - - Advanced coding problems and algorithm design - - Multi-step analytical reasoning tasks - - Problems requiring diverse domain expertise - - Tasks where single models show limitations - - The MoA approach uses a fixed 2-layer architecture: - 1. Layer 1: Multiple reference models generate diverse responses in parallel (temp=0.6) - 2. Layer 2: Aggregator model synthesizes the best elements into final response (temp=0.4) - - Args: - user_prompt (str): The complex query or problem to solve - reference_models (Optional[List[str]]): Custom reference models to use - aggregator_model (Optional[str]): Custom aggregator model to use - - Returns: - str: JSON string containing the MoA results with the following structure: - { - "success": bool, - "response": str, - "models_used": { - "reference_models": List[str], - "aggregator_model": str - }, - "processing_time": float - } - - Raises: - Exception: If MoA processing fails or API key is not set - """ - start_time = datetime.datetime.now() - - debug_call_data = { - "parameters": { - "user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt, - "reference_models": reference_models or REFERENCE_MODELS, - "aggregator_model": aggregator_model or AGGREGATOR_MODEL, - "reference_temperature": REFERENCE_TEMPERATURE, - "aggregator_temperature": AGGREGATOR_TEMPERATURE, - "min_successful_references": MIN_SUCCESSFUL_REFERENCES - }, - "error": None, - "success": False, - "reference_responses_count": 0, - "failed_models_count": 0, - "failed_models": [], - "final_response_length": 0, - "processing_time_seconds": 0, - "models_used": {} - } - - try: - print(f"šŸš€ Starting Mixture-of-Agents processing...") - print(f"šŸ“ Query: {user_prompt[:100]}{'...' if len(user_prompt) > 100 else ''}") - - # Validate API key availability - if not os.getenv("NOUS_API_KEY"): - raise ValueError("NOUS_API_KEY environment variable not set") - - # Use provided models or defaults - ref_models = reference_models or REFERENCE_MODELS - agg_model = aggregator_model or AGGREGATOR_MODEL - - print(f"šŸ”„ Using {len(ref_models)} reference models in 2-layer MoA architecture") - - # Layer 1: Generate diverse responses from reference models (with failure handling) - print("šŸ“” Layer 1: Generating reference responses...") - model_results = await asyncio.gather(*[ - _run_reference_model_safe(model, user_prompt, REFERENCE_TEMPERATURE) - for model in ref_models - ]) - - # Separate successful and failed responses - successful_responses = [] - failed_models = [] - - for model_name, content, success in model_results: - if success: - successful_responses.append(content) - else: - failed_models.append(model_name) - - successful_count = len(successful_responses) - failed_count = len(failed_models) - - print(f"šŸ“Š Reference model results: {successful_count} successful, {failed_count} failed") - - if failed_models: - print(f"āš ļø Failed models: {', '.join(failed_models)}") - - # Check if we have enough successful responses to proceed - if successful_count < MIN_SUCCESSFUL_REFERENCES: - raise ValueError(f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses.") - - debug_call_data["reference_responses_count"] = successful_count - debug_call_data["failed_models_count"] = failed_count - debug_call_data["failed_models"] = failed_models - - # Layer 2: Aggregate responses using the aggregator model - print("🧠 Layer 2: Synthesizing final response...") - aggregator_system_prompt = _construct_aggregator_prompt( - AGGREGATOR_SYSTEM_PROMPT, - successful_responses - ) - - final_response = await _run_aggregator_model( - aggregator_system_prompt, - user_prompt, - AGGREGATOR_TEMPERATURE - ) - - # Calculate processing time - end_time = datetime.datetime.now() - processing_time = (end_time - start_time).total_seconds() - - print(f"āœ… MoA processing completed in {processing_time:.2f} seconds") - - # Prepare successful response (only final aggregated result, minimal fields) - result = { - "success": True, - "response": final_response, - "models_used": { - "reference_models": ref_models, - "aggregator_model": agg_model - } - } - - debug_call_data["success"] = True - debug_call_data["final_response_length"] = len(final_response) - debug_call_data["processing_time_seconds"] = processing_time - debug_call_data["models_used"] = result["models_used"] - - # Log debug information - _log_debug_call("mixture_of_agents_tool", debug_call_data) - _save_debug_log() - - return json.dumps(result, indent=2) - - except Exception as e: - error_msg = f"Error in MoA processing: {str(e)}" - print(f"āŒ {error_msg}") - - # Calculate processing time even for errors - end_time = datetime.datetime.now() - processing_time = (end_time - start_time).total_seconds() - - # Prepare error response (minimal fields) - result = { - "success": False, - "response": "MoA processing failed. Please try again or use a single model for this query.", - "models_used": { - "reference_models": reference_models or REFERENCE_MODELS, - "aggregator_model": aggregator_model or AGGREGATOR_MODEL - }, - "error": error_msg - } - - debug_call_data["error"] = error_msg - debug_call_data["processing_time_seconds"] = processing_time - _log_debug_call("mixture_of_agents_tool", debug_call_data) - _save_debug_log() - - return json.dumps(result, indent=2) - - -def check_nous_api_key() -> bool: - """ - Check if the Nous Research API key is available in environment variables. - - Returns: - bool: True if API key is set, False otherwise - """ - return bool(os.getenv("NOUS_API_KEY")) - - -def check_moa_requirements() -> bool: - """ - Check if all requirements for MoA tools are met. - - Returns: - bool: True if requirements are met, False otherwise - """ - return check_nous_api_key() - - -def get_debug_session_info() -> Dict[str, Any]: - """ - Get information about the current debug session. - - Returns: - Dict[str, Any]: Dictionary containing debug session information - """ - if not DEBUG_MODE or not DEBUG_DATA: - return { - "enabled": False, - "session_id": None, - "log_path": None, - "total_calls": 0 - } - - return { - "enabled": True, - "session_id": DEBUG_SESSION_ID, - "log_path": str(DEBUG_LOG_PATH / f"moa_tools_debug_{DEBUG_SESSION_ID}.json"), - "total_calls": len(DEBUG_DATA["tool_calls"]) - } - - -def get_available_models() -> Dict[str, List[str]]: - """ - Get information about available models for MoA processing. - - Returns: - Dict[str, List[str]]: Dictionary with reference and aggregator models - """ - return { - "reference_models": REFERENCE_MODELS, - "aggregator_models": [AGGREGATOR_MODEL], - "supported_models": REFERENCE_MODELS + [AGGREGATOR_MODEL] - } - - -def get_moa_configuration() -> Dict[str, Any]: - """ - Get the current MoA configuration settings. - - Returns: - Dict[str, Any]: Dictionary containing all configuration parameters - """ - return { - "reference_models": REFERENCE_MODELS, - "aggregator_model": AGGREGATOR_MODEL, - "reference_temperature": REFERENCE_TEMPERATURE, - "aggregator_temperature": AGGREGATOR_TEMPERATURE, - "min_successful_references": MIN_SUCCESSFUL_REFERENCES, - "total_reference_models": len(REFERENCE_MODELS), - "failure_tolerance": f"{len(REFERENCE_MODELS) - MIN_SUCCESSFUL_REFERENCES}/{len(REFERENCE_MODELS)} models can fail" - } - - -if __name__ == "__main__": - """ - Simple test/demo when run directly - """ - print("šŸ¤– Mixture-of-Agents Tool Module") - print("=" * 50) - - # Check if API key is available - api_available = check_nous_api_key() - - if not api_available: - print("āŒ NOUS_API_KEY environment variable not set") - print("Please set your API key: export NOUS_API_KEY='your-key-here'") - print("Get API key at: https://inference-api.nousresearch.com/") - exit(1) - else: - print("āœ… Nous Research API key found") - - print("šŸ› ļø MoA tools ready for use!") - - # Show current configuration - config = get_moa_configuration() - print(f"\nāš™ļø Current Configuration:") - print(f" šŸ¤– Reference models ({len(config['reference_models'])}): {', '.join(config['reference_models'])}") - print(f" 🧠 Aggregator model: {config['aggregator_model']}") - print(f" šŸŒ”ļø Reference temperature: {config['reference_temperature']}") - print(f" šŸŒ”ļø Aggregator temperature: {config['aggregator_temperature']}") - print(f" šŸ›”ļø Failure tolerance: {config['failure_tolerance']}") - print(f" šŸ“Š Minimum successful models: {config['min_successful_references']}") - - # Show debug mode status - if DEBUG_MODE: - print(f"\nšŸ› Debug mode ENABLED - Session ID: {DEBUG_SESSION_ID}") - print(f" Debug logs will be saved to: ./logs/moa_tools_debug_{DEBUG_SESSION_ID}.json") - else: - print("\nšŸ› Debug mode disabled (set MOA_TOOLS_DEBUG=true to enable)") - - print("\nBasic usage:") - print(" from mixture_of_agents_tool import mixture_of_agents_tool") - print(" import asyncio") - print("") - print(" async def main():") - print(" result = await mixture_of_agents_tool(") - print(" user_prompt='Solve this complex mathematical proof...'") - print(" )") - print(" print(result)") - print(" asyncio.run(main())") - - print("\nBest use cases:") - print(" - Complex mathematical proofs and calculations") - print(" - Advanced coding problems and algorithm design") - print(" - Multi-step analytical reasoning tasks") - print(" - Problems requiring diverse domain expertise") - print(" - Tasks where single models show limitations") - - print("\nPerformance characteristics:") - print(" - Higher latency due to multiple model calls") - print(" - Significantly improved quality for complex tasks") - print(" - Parallel processing for efficiency") - print(f" - Optimized temperatures: {REFERENCE_TEMPERATURE} for reference models, {AGGREGATOR_TEMPERATURE} for aggregation") - print(" - Token-efficient: only returns final aggregated response") - print(" - Resilient: continues with partial model failures") - print(f" - Configurable: easy to modify models and settings at top of file") - print(" - State-of-the-art results on challenging benchmarks") - - print("\nDebug mode:") - print(" # Enable debug logging") - print(" export MOA_TOOLS_DEBUG=true") - print(" # Debug logs capture all MoA processing steps and metrics") - print(" # Logs saved to: ./logs/moa_tools_debug_UUID.json") +#!/usr/bin/env python3 +""" +Mixture-of-Agents Tool Module + +This module implements the Mixture-of-Agents (MoA) methodology that leverages +the collective strengths of multiple LLMs through a layered architecture to +achieve state-of-the-art performance on complex reasoning tasks. + +Based on the research paper: "Mixture-of-Agents Enhances Large Language Model Capabilities" +by Junlin Wang et al. (arXiv:2406.04692v1) + +Key Features: +- Multi-layer LLM collaboration for enhanced reasoning +- Parallel processing of reference models for efficiency +- Intelligent aggregation and synthesis of diverse responses +- Specialized for extremely difficult problems requiring intense reasoning +- Optimized for coding, mathematics, and complex analytical tasks + +Available Tool: +- mixture_of_agents_tool: Process complex queries using multiple frontier models + +Architecture: +1. Reference models generate diverse initial responses in parallel +2. Aggregator model synthesizes responses into a high-quality output +3. Multiple layers can be used for iterative refinement (future enhancement) + +Models Used: +- Reference Models: claude-opus-4-20250514, gemini-2.5-pro, o4-mini, deepseek-r1 +- Aggregator Model: claude-opus-4-20250514 (highest capability for synthesis) + +Configuration: + To customize the MoA setup, modify the configuration constants at the top of this file: + - REFERENCE_MODELS: List of models for generating diverse initial responses + - AGGREGATOR_MODEL: Model used to synthesize the final response + - REFERENCE_TEMPERATURE/AGGREGATOR_TEMPERATURE: Sampling temperatures + - MIN_SUCCESSFUL_REFERENCES: Minimum successful models needed to proceed + +Usage: + from mixture_of_agents_tool import mixture_of_agents_tool + import asyncio + + # Process a complex query + result = await mixture_of_agents_tool( + user_prompt="Solve this complex mathematical proof..." + ) +""" + +import json +import os +import asyncio +import uuid +import datetime +from pathlib import Path +from typing import Dict, Any, List, Optional +from openai import AsyncOpenAI + +# Initialize Nous Research API client for MoA processing +nous_client = AsyncOpenAI( + api_key=os.getenv("NOUS_API_KEY"), + base_url="https://inference-api.nousresearch.com/v1" +) + +# Configuration for MoA processing +# Reference models - these generate diverse initial responses in parallel +REFERENCE_MODELS = [ + "claude-opus-4-20250514", + "gemini-2.5-pro", + "gpt-5", + "deepseek-r1" +] + +# Aggregator model - synthesizes reference responses into final output +AGGREGATOR_MODEL = "claude-opus-4-20250514" # Use highest capability model for aggregation + +# Temperature settings optimized for MoA performance +REFERENCE_TEMPERATURE = 0.6 # Balanced creativity for diverse perspectives +AGGREGATOR_TEMPERATURE = 0.4 # Focused synthesis for consistency + +# Failure handling configuration +MIN_SUCCESSFUL_REFERENCES = 1 # Minimum successful reference models needed to proceed + +# System prompt for the aggregator model (from the research paper) +AGGREGATOR_SYSTEM_PROMPT = """You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. + +Responses from models:""" + +# Debug mode configuration +DEBUG_MODE = os.getenv("MOA_TOOLS_DEBUG", "false").lower() == "true" +DEBUG_SESSION_ID = str(uuid.uuid4()) +DEBUG_LOG_PATH = Path("./logs") +DEBUG_DATA = { + "session_id": DEBUG_SESSION_ID, + "start_time": datetime.datetime.now().isoformat(), + "debug_enabled": DEBUG_MODE, + "tool_calls": [] +} if DEBUG_MODE else None + +# Create logs directory if debug mode is enabled +if DEBUG_MODE: + DEBUG_LOG_PATH.mkdir(exist_ok=True) + print(f"šŸ› MoA debug mode enabled - Session ID: {DEBUG_SESSION_ID}") + + +def _log_debug_call(tool_name: str, call_data: Dict[str, Any]) -> None: + """ + Log a debug call entry to the global debug data structure. + + Args: + tool_name (str): Name of the tool being called + call_data (Dict[str, Any]): Data about the call including parameters and results + """ + if not DEBUG_MODE or not DEBUG_DATA: + return + + call_entry = { + "timestamp": datetime.datetime.now().isoformat(), + "tool_name": tool_name, + **call_data + } + + DEBUG_DATA["tool_calls"].append(call_entry) + + +def _save_debug_log() -> None: + """ + Save the current debug data to a JSON file in the logs directory. + """ + if not DEBUG_MODE or not DEBUG_DATA: + return + + try: + debug_filename = f"moa_tools_debug_{DEBUG_SESSION_ID}.json" + debug_filepath = DEBUG_LOG_PATH / debug_filename + + # Update end time + DEBUG_DATA["end_time"] = datetime.datetime.now().isoformat() + DEBUG_DATA["total_calls"] = len(DEBUG_DATA["tool_calls"]) + + with open(debug_filepath, 'w', encoding='utf-8') as f: + json.dump(DEBUG_DATA, f, indent=2, ensure_ascii=False) + + print(f"šŸ› MoA debug log saved: {debug_filepath}") + + except Exception as e: + print(f"āŒ Error saving MoA debug log: {str(e)}") + + +def _construct_aggregator_prompt(system_prompt: str, responses: List[str]) -> str: + """ + Construct the final system prompt for the aggregator including all model responses. + + Args: + system_prompt (str): Base system prompt for aggregation + responses (List[str]): List of responses from reference models + + Returns: + str: Complete system prompt with enumerated responses + """ + response_text = "\n".join([f"{i+1}. {response}" for i, response in enumerate(responses)]) + return f"{system_prompt}\n\n{response_text}" + + +async def _run_reference_model_safe( + model: str, + user_prompt: str, + temperature: float = REFERENCE_TEMPERATURE, + max_tokens: int = 32000, + max_retries: int = 6 +) -> tuple[str, str, bool]: + """ + Run a single reference model with retry logic and graceful failure handling. + + Args: + model (str): Model identifier to use + user_prompt (str): The user's query + temperature (float): Sampling temperature for response generation + max_tokens (int): Maximum tokens in response + max_retries (int): Maximum number of retry attempts + + Returns: + tuple[str, str, bool]: (model_name, response_content_or_error, success_flag) + """ + for attempt in range(max_retries): + try: + print(f"šŸ¤– Querying {model} (attempt {attempt + 1}/{max_retries})") + + # Build parameters for the API call + api_params = { + "model": model, + "messages": [{"role": "user", "content": user_prompt}] + } + + # GPT models (especially gpt-4o-mini) don't support custom temperature values + # Only include temperature for non-GPT models + if not model.lower().startswith('gpt-'): + api_params["temperature"] = temperature + + response = await nous_client.chat.completions.create(**api_params) + + content = response.choices[0].message.content.strip() + print(f"āœ… {model} responded ({len(content)} characters)") + return model, content, True + + except Exception as e: + error_str = str(e) + # Log more detailed error information for debugging + if "invalid" in error_str.lower(): + print(f"āš ļø {model} invalid request error (attempt {attempt + 1}): {error_str}") + elif "rate" in error_str.lower() or "limit" in error_str.lower(): + print(f"āš ļø {model} rate limit error (attempt {attempt + 1}): {error_str}") + else: + print(f"āš ļø {model} unknown error (attempt {attempt + 1}): {error_str}") + + if attempt < max_retries - 1: + # Exponential backoff for rate limiting: 2s, 4s, 8s, 16s, 32s, 60s + sleep_time = min(2 ** (attempt + 1), 60) + print(f" Retrying in {sleep_time}s...") + await asyncio.sleep(sleep_time) + else: + error_msg = f"{model} failed after {max_retries} attempts: {error_str}" + print(f"āŒ {error_msg}") + return model, error_msg, False + + +async def _run_aggregator_model( + system_prompt: str, + user_prompt: str, + temperature: float = AGGREGATOR_TEMPERATURE, + max_tokens: int = None +) -> str: + """ + Run the aggregator model to synthesize the final response. + + Args: + system_prompt (str): System prompt with all reference responses + user_prompt (str): Original user query + temperature (float): Focused temperature for consistent aggregation + max_tokens (int): Maximum tokens in final response + + Returns: + str: Synthesized final response + """ + print(f"🧠 Running aggregator model: {AGGREGATOR_MODEL}") + + # Build parameters for the API call + api_params = { + "model": AGGREGATOR_MODEL, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt} + ] + } + + # GPT models (especially gpt-4o-mini) don't support custom temperature values + # Only include temperature for non-GPT models + if not AGGREGATOR_MODEL.lower().startswith('gpt-'): + api_params["temperature"] = temperature + + response = await nous_client.chat.completions.create(**api_params) + + content = response.choices[0].message.content.strip() + print(f"āœ… Aggregation complete ({len(content)} characters)") + return content + + +async def mixture_of_agents_tool( + user_prompt: str, + reference_models: Optional[List[str]] = None, + aggregator_model: Optional[str] = None +) -> str: + """ + Process a complex query using the Mixture-of-Agents methodology. + + This tool leverages multiple frontier language models to collaboratively solve + extremely difficult problems requiring intense reasoning. It's particularly + effective for: + - Complex mathematical proofs and calculations + - Advanced coding problems and algorithm design + - Multi-step analytical reasoning tasks + - Problems requiring diverse domain expertise + - Tasks where single models show limitations + + The MoA approach uses a fixed 2-layer architecture: + 1. Layer 1: Multiple reference models generate diverse responses in parallel (temp=0.6) + 2. Layer 2: Aggregator model synthesizes the best elements into final response (temp=0.4) + + Args: + user_prompt (str): The complex query or problem to solve + reference_models (Optional[List[str]]): Custom reference models to use + aggregator_model (Optional[str]): Custom aggregator model to use + + Returns: + str: JSON string containing the MoA results with the following structure: + { + "success": bool, + "response": str, + "models_used": { + "reference_models": List[str], + "aggregator_model": str + }, + "processing_time": float + } + + Raises: + Exception: If MoA processing fails or API key is not set + """ + start_time = datetime.datetime.now() + + debug_call_data = { + "parameters": { + "user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt, + "reference_models": reference_models or REFERENCE_MODELS, + "aggregator_model": aggregator_model or AGGREGATOR_MODEL, + "reference_temperature": REFERENCE_TEMPERATURE, + "aggregator_temperature": AGGREGATOR_TEMPERATURE, + "min_successful_references": MIN_SUCCESSFUL_REFERENCES + }, + "error": None, + "success": False, + "reference_responses_count": 0, + "failed_models_count": 0, + "failed_models": [], + "final_response_length": 0, + "processing_time_seconds": 0, + "models_used": {} + } + + try: + print(f"šŸš€ Starting Mixture-of-Agents processing...") + print(f"šŸ“ Query: {user_prompt[:100]}{'...' if len(user_prompt) > 100 else ''}") + + # Validate API key availability + if not os.getenv("NOUS_API_KEY"): + raise ValueError("NOUS_API_KEY environment variable not set") + + # Use provided models or defaults + ref_models = reference_models or REFERENCE_MODELS + agg_model = aggregator_model or AGGREGATOR_MODEL + + print(f"šŸ”„ Using {len(ref_models)} reference models in 2-layer MoA architecture") + + # Layer 1: Generate diverse responses from reference models (with failure handling) + print("šŸ“” Layer 1: Generating reference responses...") + model_results = await asyncio.gather(*[ + _run_reference_model_safe(model, user_prompt, REFERENCE_TEMPERATURE) + for model in ref_models + ]) + + # Separate successful and failed responses + successful_responses = [] + failed_models = [] + + for model_name, content, success in model_results: + if success: + successful_responses.append(content) + else: + failed_models.append(model_name) + + successful_count = len(successful_responses) + failed_count = len(failed_models) + + print(f"šŸ“Š Reference model results: {successful_count} successful, {failed_count} failed") + + if failed_models: + print(f"āš ļø Failed models: {', '.join(failed_models)}") + + # Check if we have enough successful responses to proceed + if successful_count < MIN_SUCCESSFUL_REFERENCES: + raise ValueError(f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses.") + + debug_call_data["reference_responses_count"] = successful_count + debug_call_data["failed_models_count"] = failed_count + debug_call_data["failed_models"] = failed_models + + # Layer 2: Aggregate responses using the aggregator model + print("🧠 Layer 2: Synthesizing final response...") + aggregator_system_prompt = _construct_aggregator_prompt( + AGGREGATOR_SYSTEM_PROMPT, + successful_responses + ) + + final_response = await _run_aggregator_model( + aggregator_system_prompt, + user_prompt, + AGGREGATOR_TEMPERATURE + ) + + # Calculate processing time + end_time = datetime.datetime.now() + processing_time = (end_time - start_time).total_seconds() + + print(f"āœ… MoA processing completed in {processing_time:.2f} seconds") + + # Prepare successful response (only final aggregated result, minimal fields) + result = { + "success": True, + "response": final_response, + "models_used": { + "reference_models": ref_models, + "aggregator_model": agg_model + } + } + + debug_call_data["success"] = True + debug_call_data["final_response_length"] = len(final_response) + debug_call_data["processing_time_seconds"] = processing_time + debug_call_data["models_used"] = result["models_used"] + + # Log debug information + _log_debug_call("mixture_of_agents_tool", debug_call_data) + _save_debug_log() + + return json.dumps(result, indent=2, ensure_ascii=False) + + except Exception as e: + error_msg = f"Error in MoA processing: {str(e)}" + print(f"āŒ {error_msg}") + + # Calculate processing time even for errors + end_time = datetime.datetime.now() + processing_time = (end_time - start_time).total_seconds() + + # Prepare error response (minimal fields) + result = { + "success": False, + "response": "MoA processing failed. Please try again or use a single model for this query.", + "models_used": { + "reference_models": reference_models or REFERENCE_MODELS, + "aggregator_model": aggregator_model or AGGREGATOR_MODEL + }, + "error": error_msg + } + + debug_call_data["error"] = error_msg + debug_call_data["processing_time_seconds"] = processing_time + _log_debug_call("mixture_of_agents_tool", debug_call_data) + _save_debug_log() + + return json.dumps(result, indent=2, ensure_ascii=False) + + +def check_nous_api_key() -> bool: + """ + Check if the Nous Research API key is available in environment variables. + + Returns: + bool: True if API key is set, False otherwise + """ + return bool(os.getenv("NOUS_API_KEY")) + + +def check_moa_requirements() -> bool: + """ + Check if all requirements for MoA tools are met. + + Returns: + bool: True if requirements are met, False otherwise + """ + return check_nous_api_key() + + +def get_debug_session_info() -> Dict[str, Any]: + """ + Get information about the current debug session. + + Returns: + Dict[str, Any]: Dictionary containing debug session information + """ + if not DEBUG_MODE or not DEBUG_DATA: + return { + "enabled": False, + "session_id": None, + "log_path": None, + "total_calls": 0 + } + + return { + "enabled": True, + "session_id": DEBUG_SESSION_ID, + "log_path": str(DEBUG_LOG_PATH / f"moa_tools_debug_{DEBUG_SESSION_ID}.json"), + "total_calls": len(DEBUG_DATA["tool_calls"]) + } + + +def get_available_models() -> Dict[str, List[str]]: + """ + Get information about available models for MoA processing. + + Returns: + Dict[str, List[str]]: Dictionary with reference and aggregator models + """ + return { + "reference_models": REFERENCE_MODELS, + "aggregator_models": [AGGREGATOR_MODEL], + "supported_models": REFERENCE_MODELS + [AGGREGATOR_MODEL] + } + + +def get_moa_configuration() -> Dict[str, Any]: + """ + Get the current MoA configuration settings. + + Returns: + Dict[str, Any]: Dictionary containing all configuration parameters + """ + return { + "reference_models": REFERENCE_MODELS, + "aggregator_model": AGGREGATOR_MODEL, + "reference_temperature": REFERENCE_TEMPERATURE, + "aggregator_temperature": AGGREGATOR_TEMPERATURE, + "min_successful_references": MIN_SUCCESSFUL_REFERENCES, + "total_reference_models": len(REFERENCE_MODELS), + "failure_tolerance": f"{len(REFERENCE_MODELS) - MIN_SUCCESSFUL_REFERENCES}/{len(REFERENCE_MODELS)} models can fail" + } + + +if __name__ == "__main__": + """ + Simple test/demo when run directly + """ + print("šŸ¤– Mixture-of-Agents Tool Module") + print("=" * 50) + + # Check if API key is available + api_available = check_nous_api_key() + + if not api_available: + print("āŒ NOUS_API_KEY environment variable not set") + print("Please set your API key: export NOUS_API_KEY='your-key-here'") + print("Get API key at: https://inference-api.nousresearch.com/") + exit(1) + else: + print("āœ… Nous Research API key found") + + print("šŸ› ļø MoA tools ready for use!") + + # Show current configuration + config = get_moa_configuration() + print(f"\nāš™ļø Current Configuration:") + print(f" šŸ¤– Reference models ({len(config['reference_models'])}): {', '.join(config['reference_models'])}") + print(f" 🧠 Aggregator model: {config['aggregator_model']}") + print(f" šŸŒ”ļø Reference temperature: {config['reference_temperature']}") + print(f" šŸŒ”ļø Aggregator temperature: {config['aggregator_temperature']}") + print(f" šŸ›”ļø Failure tolerance: {config['failure_tolerance']}") + print(f" šŸ“Š Minimum successful models: {config['min_successful_references']}") + + # Show debug mode status + if DEBUG_MODE: + print(f"\nšŸ› Debug mode ENABLED - Session ID: {DEBUG_SESSION_ID}") + print(f" Debug logs will be saved to: ./logs/moa_tools_debug_{DEBUG_SESSION_ID}.json") + else: + print("\nšŸ› Debug mode disabled (set MOA_TOOLS_DEBUG=true to enable)") + + print("\nBasic usage:") + print(" from mixture_of_agents_tool import mixture_of_agents_tool") + print(" import asyncio") + print("") + print(" async def main():") + print(" result = await mixture_of_agents_tool(") + print(" user_prompt='Solve this complex mathematical proof...'") + print(" )") + print(" print(result)") + print(" asyncio.run(main())") + + print("\nBest use cases:") + print(" - Complex mathematical proofs and calculations") + print(" - Advanced coding problems and algorithm design") + print(" - Multi-step analytical reasoning tasks") + print(" - Problems requiring diverse domain expertise") + print(" - Tasks where single models show limitations") + + print("\nPerformance characteristics:") + print(" - Higher latency due to multiple model calls") + print(" - Significantly improved quality for complex tasks") + print(" - Parallel processing for efficiency") + print(f" - Optimized temperatures: {REFERENCE_TEMPERATURE} for reference models, {AGGREGATOR_TEMPERATURE} for aggregation") + print(" - Token-efficient: only returns final aggregated response") + print(" - Resilient: continues with partial model failures") + print(f" - Configurable: easy to modify models and settings at top of file") + print(" - State-of-the-art results on challenging benchmarks") + + print("\nDebug mode:") + print(" # Enable debug logging") + print(" export MOA_TOOLS_DEBUG=true") + print(" # Debug logs capture all MoA processing steps and metrics") + print(" # Logs saved to: ./logs/moa_tools_debug_UUID.json") diff --git a/tools/simple_terminal_tool.py b/tools/simple_terminal_tool.py new file mode 100644 index 000000000..6ebfeeda7 --- /dev/null +++ b/tools/simple_terminal_tool.py @@ -0,0 +1,395 @@ +#!/usr/bin/env python3 +""" +Simple Terminal Tool Module + +A simplified terminal tool that executes commands on MorphCloud VMs without tmux. +No session persistence, no interactive app support - just simple command execution. + +Features: +- Direct SSH command execution +- Background task support +- VM lifecycle management with TTL +- Automatic cleanup after inactivity + +Usage: + from simple_terminal_tool import simple_terminal_tool + + # Execute a simple command + result = simple_terminal_tool("ls -la") + + # Execute in background + result = simple_terminal_tool("python server.py", background=True) +""" + +import json +import os +import time +import threading +import atexit +from typing import Optional, Dict, Any + +# Tool description for LLM +SIMPLE_TERMINAL_TOOL_DESCRIPTION = """Execute commands on a secure Linux VM environment. + +**Environment:** +- Minimal Debian-based OS with internet access +- Automatic VM lifecycle management (creates on-demand, reuses, cleans up) +- Filesystem is persisted between tool calls but environment variables, venvs, etc are reset. + +**Command Execution:** +- Simple commands: Just provide the 'command' parameter +- Background processes: Set 'background': True for servers/long-running tasks +- Command timeout: Optional 'timeout' parameter in seconds + +**Examples:** +- Run command: `{"command": "ls -la"}` +- Background task: `{"command": "source path/to/my/venv/bin/activate && python server.py", "background": True}` +- With timeout: `{"command": "long_task.sh", "timeout": 300}` + +**Best Practices:** +- Run servers/long processes in background +- Monitor disk usage for large tasks +- Install whatever tools you need with sudo apt-get +- Do not be afraid to run pip with --break-system-packages + +**Things to avoid** +- Do NOT use interactive tools such as tmux, vim, nano, python repl - you will get stuck. Even git sometimes becomes interactive if the output is large. If you're not sure pipe to cat. +""" + +# Global state for VM lifecycle management +_active_instances: Dict[str, Any] = {} +_last_activity: Dict[str, float] = {} +_instance_lock = threading.Lock() +_cleanup_thread = None +_cleanup_running = False + + +def _cleanup_inactive_vms(vm_lifetime_seconds: int = 300): + """Clean up VMs that have been inactive for longer than vm_lifetime_seconds.""" + global _active_instances, _last_activity + + current_time = time.time() + tasks_to_cleanup = [] + + with _instance_lock: + for task_id, last_time in list(_last_activity.items()): + if current_time - last_time > vm_lifetime_seconds: + tasks_to_cleanup.append(task_id) + + for task_id in tasks_to_cleanup: + try: + if task_id in _active_instances: + instance = _active_instances[task_id] + if hasattr(instance, 'terminate'): + instance.terminate() + elif hasattr(instance, 'stop'): + instance.stop() + elif hasattr(instance, 'delete'): + instance.delete() + + del _active_instances[task_id] + print(f"[VM Cleanup] Terminated inactive VM for task: {task_id}") + + if task_id in _last_activity: + del _last_activity[task_id] + + except Exception as e: + # 404 errors are benign - VM already cleaned up by TTL + error_str = str(e) + if "404" in error_str or "InstanceNotFoundError" in error_str or "not found" in error_str.lower(): + print(f"[VM Cleanup] VM for task {task_id} already cleaned up (likely TTL expiration)") + else: + print(f"[VM Cleanup] Error cleaning up VM for task {task_id}: {e}") + + +def _cleanup_thread_worker(): + """Background thread worker that periodically cleans up inactive VMs.""" + global _cleanup_running + + while _cleanup_running: + try: + vm_lifetime = int(os.getenv("HECATE_VM_LIFETIME_SECONDS", "300")) + _cleanup_inactive_vms(vm_lifetime) + except Exception as e: + print(f"[VM Cleanup] Error in cleanup thread: {e}") + + for _ in range(60): + if not _cleanup_running: + break + time.sleep(1) + + +def _start_cleanup_thread(): + """Start the background cleanup thread if not already running.""" + global _cleanup_thread, _cleanup_running + + with _instance_lock: + if _cleanup_thread is None or not _cleanup_thread.is_alive(): + _cleanup_running = True + _cleanup_thread = threading.Thread(target=_cleanup_thread_worker, daemon=True) + _cleanup_thread.start() + + +def _stop_cleanup_thread(): + """Stop the background cleanup thread.""" + global _cleanup_running + _cleanup_running = False + if _cleanup_thread is not None: + _cleanup_thread.join(timeout=5) + + +def cleanup_vm(task_id: str): + """Manually clean up a specific VM by task_id.""" + global _active_instances, _last_activity + + with _instance_lock: + try: + if task_id in _active_instances: + instance = _active_instances[task_id] + if hasattr(instance, 'terminate'): + instance.terminate() + elif hasattr(instance, 'stop'): + instance.stop() + elif hasattr(instance, 'delete'): + instance.delete() + + del _active_instances[task_id] + print(f"[VM Cleanup] Manually terminated VM for task: {task_id}") + + if task_id in _last_activity: + del _last_activity[task_id] + + except Exception as e: + # 404 errors are benign - VM already cleaned up by TTL + error_str = str(e) + if "404" in error_str or "InstanceNotFoundError" in error_str or "not found" in error_str.lower(): + print(f"[VM Cleanup] VM for task {task_id} already cleaned up (likely TTL expiration)") + else: + print(f"[VM Cleanup] Error manually cleaning up VM for task {task_id}: {e}") + + +atexit.register(_stop_cleanup_thread) + + +def _execute_ssh_command(instance, command: str, timeout: Optional[int] = None) -> Dict[str, Any]: + """ + Execute a command via SSH on the VM instance. + + Args: + instance: MorphVM instance + command: Command to execute + timeout: Optional timeout in seconds + + Returns: + dict with stdout, stderr, returncode + """ + ssh_context_manager = None + try: + # Use the instance's SSH context manager + ssh_context_manager = instance.ssh() + ssh_context = ssh_context_manager.__enter__() + + # Execute the command + result = ssh_context.run(command, get_pty=False, timeout=timeout or 120) + + # Close the SSH connection + if ssh_context_manager: + try: + ssh_context_manager.__exit__(None, None, None) + except: + pass + + return { + "stdout": result.stdout or "", + "stderr": result.stderr or "", + "returncode": result.returncode + } + + except Exception as e: + # Close connection on error + if ssh_context_manager: + try: + ssh_context_manager.__exit__(None, None, None) + except: + pass + + # Check if it's a timeout + error_str = str(e).lower() + if "timeout" in error_str: + return { + "stdout": "", + "stderr": f"Command timed out after {timeout or 120} seconds", + "returncode": 124 + } + + return { + "stdout": "", + "stderr": f"SSH execution failed: {str(e)}", + "returncode": -1 + } + + +def simple_terminal_tool( + command: str, + background: bool = False, + timeout: Optional[int] = None, + task_id: Optional[str] = None +) -> str: + """ + Execute a command on a MorphCloud VM without session persistence. + + Args: + command: The command to execute + background: Whether to run in background (default: False) + timeout: Command timeout in seconds (default: 120) + task_id: Unique identifier for VM isolation (optional) + + Returns: + str: JSON string with output, exit_code, and error fields + + Examples: + # Execute a simple command + >>> result = simple_terminal_tool(command="ls -la /tmp") + + # Run a background task + >>> result = simple_terminal_tool(command="python server.py", background=True) + + # With custom timeout + >>> result = simple_terminal_tool(command="long_task.sh", timeout=300) + """ + global _active_instances, _last_activity + + try: + # Import required modules + try: + from morphcloud.api import MorphCloudClient + except ImportError as import_error: + return json.dumps({ + "output": "", + "exit_code": -1, + "error": f"Terminal tool disabled: {import_error}", + "status": "disabled" + }, ensure_ascii=False) + + # Get configuration + vm_ttl_seconds = int(os.getenv("HECATE_VM_TTL_SECONDS", "1200")) + snapshot_id = os.getenv("HECATE_DEFAULT_SNAPSHOT_ID", "snapshot_defv9tjg") + + # Check API key + morph_api_key = os.getenv("MORPH_API_KEY") + if not morph_api_key: + return json.dumps({ + "output": "", + "exit_code": -1, + "error": "MORPH_API_KEY environment variable not set", + "status": "disabled" + }, ensure_ascii=False) + + # Use task_id for VM isolation + effective_task_id = task_id or "default" + + # Start cleanup thread + _start_cleanup_thread() + + # Get or create VM instance + with _instance_lock: + if effective_task_id not in _active_instances: + morph_client = MorphCloudClient(api_key=morph_api_key) + _active_instances[effective_task_id] = morph_client.instances.start( + snapshot_id=snapshot_id, + ttl_seconds=vm_ttl_seconds, + ttl_action="stop" + ) + + # Update last activity time + _last_activity[effective_task_id] = time.time() + instance = _active_instances[effective_task_id] + + # Wait for instance to be ready + instance.wait_until_ready() + + # Prepare command for execution + if background: + # Run in background with nohup and redirect output + exec_command = f"nohup {command} > /tmp/bg_output.log 2>&1 &" + result = _execute_ssh_command(instance, exec_command, timeout=10) + + # For background tasks, return immediately with info + if result["returncode"] == 0: + return json.dumps({ + "output": "Background task started successfully", + "exit_code": 0, + "error": None + }, ensure_ascii=False) + else: + return json.dumps({ + "output": result["stdout"], + "exit_code": result["returncode"], + "error": result["stderr"] + }, ensure_ascii=False) + else: + # Run foreground command + result = _execute_ssh_command(instance, command, timeout=timeout) + + # Combine stdout and stderr for output + output = result["stdout"] + if result["stderr"] and result["returncode"] != 0: + output = f"{output}\n{result['stderr']}" if output else result["stderr"] + + return json.dumps({ + "output": output.strip(), + "exit_code": result["returncode"], + "error": result["stderr"] if result["returncode"] != 0 else None + }, ensure_ascii=False) + + except Exception as e: + return json.dumps({ + "output": "", + "exit_code": -1, + "error": f"Failed to execute command: {str(e)}", + "status": "error" + }, ensure_ascii=False) + + +def check_requirements() -> bool: + """Check if all requirements for the simple terminal tool are met.""" + required_vars = ["MORPH_API_KEY"] + missing_required = [var for var in required_vars if not os.getenv(var)] + + if missing_required: + print(f"Missing required environment variables: {', '.join(missing_required)}") + return False + + try: + from morphcloud.api import MorphCloudClient + return True + except Exception as e: + print(f"MorphCloud not available: {e}") + return False + + +if __name__ == "__main__": + """Simple test when run directly.""" + print("Simple Terminal Tool Module") + print("=" * 40) + + if not check_requirements(): + print("Requirements not met. Please check the messages above.") + exit(1) + + print("All requirements met!") + print("\nAvailable Tool:") + print(" - simple_terminal_tool: Execute commands without session persistence") + + print("\nUsage Examples:") + print(" # Execute a command") + print(" result = simple_terminal_tool(command='ls -la')") + print(" ") + print(" # Run a background task") + print(" result = simple_terminal_tool(command='python server.py', background=True)") + + print("\nEnvironment Variables:") + print(f" MORPH_API_KEY: {'Set' if os.getenv('MORPH_API_KEY') else 'Not set'}") + print(f" HECATE_VM_TTL_SECONDS: {os.getenv('HECATE_VM_TTL_SECONDS', '1200')} (default: 1200 / 20 minutes)") + print(f" HECATE_VM_LIFETIME_SECONDS: {os.getenv('HECATE_VM_LIFETIME_SECONDS', '300')} (default: 300 / 5 minutes)") + print(f" HECATE_DEFAULT_SNAPSHOT_ID: {os.getenv('HECATE_DEFAULT_SNAPSHOT_ID', 'snapshot_defv9tjg')}") diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py new file mode 100644 index 000000000..e4b436426 --- /dev/null +++ b/tools/terminal_tool.py @@ -0,0 +1,456 @@ +#!/usr/bin/env python3 +""" +Terminal Tool Module + +This module provides a single terminal tool using Hecate's VM infrastructure. +It wraps Hecate's functionality to provide a simple interface for executing commands +on Morph VMs with automatic lifecycle management. + +VM Lifecycle: +- VMs have a TTL (time to live) set at creation (default: 20 minutes) +- VMs are also cleaned up locally after 5 minutes of inactivity +- Timer resets with each use + +Available tool: +- terminal_tool: Execute commands with optional interactive session support + +Usage: + from terminal_tool import terminal_tool + + # Execute a single command + result = terminal_tool("ls -la") + + # Execute in an interactive session + result = terminal_tool("python", input_keys="print('hello')\\nexit()\\n") +""" + +import json +import os +import uuid +import threading +import time +import atexit +from typing import Optional, Dict, Any + +# Detailed description for the terminal tool based on Hermes Terminal system prompt +TERMINAL_TOOL_DESCRIPTION = """Execute commands on a secure, persistent Linux VM environment with full interactive application support. + +**Environment:** +- Minimal Debian-based OS with internet access +- Automatic VM lifecycle management (creates on-demand, reuses, cleans up) +- **Full state persistence across tool calls**: current directory (pwd), environment variables, activated virtual environments (conda/venv), running processes, and command history all persist between consecutive tool calls +- Session state managed automatically via tmux + +**Command Execution:** +- Simple commands: Just provide the 'command' parameter +- Background processes: Set 'background': True for servers/long-running tasks +- Interactive applications automatically detected and handled + +**Interactive Applications (TUIs/Pagers/Prompts):** +When commands enter interactive mode (vim, nano, less, git prompts, package managers, etc.), you'll receive screen content with "frozen" status. This is NORMAL - the session is still active and waiting for input. + +**To interact with frozen sessions:** +1. Use 'input_keys' parameter with keystrokes to send +2. System auto-detects and uses the active session +3. Session stays active until application exits + +**Special Key Syntax for input_keys:** +- ``: Escape key +- ``: Enter/Return +- ``, ``, ``: Control combinations +- ``, ``, ``, ``: Arrow keys +- ``, ``: Tab and Backspace +- `` through ``: Function keys +- ``: Shift+Tab +- Uppercase letters for Shift+letter (e.g., 'V' for Shift+V) +- Symbols for Shift+number (e.g., '!' for Shift+1, ':' for Shift+;) + +**Examples:** +- Start vim: `{"command": "vim file.txt"}` +- Type in vim: `{"input_keys": "iHello World"}` +- Save and quit: `{"input_keys": ":wq"}` +- Navigate in less: `{"input_keys": "j"}` +- Quit less: `{"input_keys": "q"}` + +**Best Practices:** +- Run servers/long processes in background with separate tool calls +- Chain multiple foreground commands in single call if needed +- Monitor disk usage for large tasks, clean up to free space +- Test components incrementally with mock inputs +- Install whatever tools needed - full system access provided""" + +# Global state for VM lifecycle management +# These persist across tool calls to enable session continuity +# Changed to dictionaries keyed by task_id to prevent leakage between concurrent tasks +_active_instances: Dict[str, Any] = {} +_active_contexts: Dict[str, Any] = {} +_last_activity: Dict[str, float] = {} # Track last activity time for each VM +_instance_lock = threading.Lock() +_cleanup_thread = None +_cleanup_running = False + +def _cleanup_inactive_vms(vm_lifetime_seconds: int = 300): + """ + Clean up VMs that have been inactive for longer than vm_lifetime_seconds. + This function should be called periodically by a background thread. + + Args: + vm_lifetime_seconds: Maximum lifetime in seconds for inactive VMs (default: 300) + """ + global _active_instances, _active_contexts, _last_activity + + current_time = time.time() + tasks_to_cleanup = [] + + with _instance_lock: + # Find all VMs that have been inactive for too long + for task_id, last_time in list(_last_activity.items()): + if current_time - last_time > vm_lifetime_seconds: + tasks_to_cleanup.append(task_id) + + # Clean up the inactive VMs + for task_id in tasks_to_cleanup: + try: + if task_id in _active_instances: + instance = _active_instances[task_id] + # Terminate the VM instance + if hasattr(instance, 'terminate'): + instance.terminate() + elif hasattr(instance, 'stop'): + instance.stop() + elif hasattr(instance, 'delete'): + instance.delete() + + # Remove from tracking dictionaries + del _active_instances[task_id] + print(f"[VM Cleanup] Terminated inactive VM for task: {task_id}") + + if task_id in _active_contexts: + del _active_contexts[task_id] + + if task_id in _last_activity: + del _last_activity[task_id] + + except Exception as e: + print(f"[VM Cleanup] Error cleaning up VM for task {task_id}: {e}") + +def _cleanup_thread_worker(): + """ + Background thread worker that periodically cleans up inactive VMs. + Runs every 60 seconds. + """ + global _cleanup_running + + while _cleanup_running: + try: + vm_lifetime = int(os.getenv("HECATE_VM_LIFETIME_SECONDS", "300")) + _cleanup_inactive_vms(vm_lifetime) + except Exception as e: + print(f"[VM Cleanup] Error in cleanup thread: {e}") + + # Sleep for 60 seconds, but check every second if we should stop + for _ in range(60): + if not _cleanup_running: + break + time.sleep(1) + +def _start_cleanup_thread(): + """ + Start the background cleanup thread if it's not already running. + """ + global _cleanup_thread, _cleanup_running + + with _instance_lock: + if _cleanup_thread is None or not _cleanup_thread.is_alive(): + _cleanup_running = True + _cleanup_thread = threading.Thread(target=_cleanup_thread_worker, daemon=True) + _cleanup_thread.start() + +def _stop_cleanup_thread(): + """ + Stop the background cleanup thread. + """ + global _cleanup_running + _cleanup_running = False + if _cleanup_thread is not None: + _cleanup_thread.join(timeout=5) + +def cleanup_vm(task_id: str): + """ + Manually clean up a specific VM by task_id. + This should be called when a task is completed. + + Args: + task_id: The task ID of the VM to clean up + """ + global _active_instances, _active_contexts, _last_activity + + with _instance_lock: + try: + if task_id in _active_instances: + instance = _active_instances[task_id] + # Terminate the VM instance + if hasattr(instance, 'terminate'): + instance.terminate() + elif hasattr(instance, 'stop'): + instance.stop() + elif hasattr(instance, 'delete'): + instance.delete() + + # Remove from tracking dictionaries + del _active_instances[task_id] + print(f"[VM Cleanup] Manually terminated VM for task: {task_id}") + + if task_id in _active_contexts: + del _active_contexts[task_id] + + if task_id in _last_activity: + del _last_activity[task_id] + + except Exception as e: + print(f"[VM Cleanup] Error manually cleaning up VM for task {task_id}: {e}") + +# Register cleanup on program exit +atexit.register(_stop_cleanup_thread) + +def terminal_tool( + command: Optional[str] = None, + input_keys: Optional[str] = None, + session_id: Optional[str] = None, + background: bool = False, + idle_threshold: float = 5.0, + timeout: Optional[int] = None, + task_id: Optional[str] = None +) -> str: + """ + Execute a command on a Morph VM with optional interactive session support. + + This tool uses Hecate's VM lifecycle management to automatically create + and manage VMs. VMs are reused within the configured lifetime window + and automatically cleaned up after inactivity. + + Args: + command: The command to execute (optional if continuing existing session) + input_keys: Keystrokes to send to interactive session (e.g., "hello\\n") + session_id: ID of existing session to continue (optional) + background: Whether to run the command in the background (default: False) + idle_threshold: Seconds to wait for output before considering session idle (default: 5.0) + timeout: Command timeout in seconds (optional) + task_id: Unique identifier for this task to isolate VMs between concurrent tasks (optional) + + Returns: + str: JSON string containing command output, session info, exit code, and any errors + + Examples: + # Execute a simple command + >>> result = terminal_tool(command="ls -la /tmp") + + # Start an interactive Python session + >>> result = terminal_tool(command="python3") + >>> session_data = json.loads(result) + >>> session_id = session_data["session_id"] + + # Send input to the session + >>> result = terminal_tool(input_keys="print('Hello')\\n", session_id=session_id) + + # Run a background task + >>> result = terminal_tool(command="sleep 60", background=True) + """ + global _active_instances, _active_contexts + + try: + # Import required modules lazily so this module can be imported + # even when hecate is not installed + try: + from morphcloud._llm import ToolCall + from morphcloud.api import MorphCloudClient + from hecate.cli import run_tool, ExecutionContext + from rich.console import Console + import io + except ImportError as import_error: + return json.dumps({ + "output": "", + "screen": "", + "exit_code": -1, + "error": f"Terminal tool is disabled due to import error: {import_error}", + "status": "disabled" + }, ensure_ascii=False) + + + # Get configuration from environment + vm_lifetime_seconds = int(os.getenv("HECATE_VM_LIFETIME_SECONDS", "300")) + vm_ttl_seconds = int(os.getenv("HECATE_VM_TTL_SECONDS", "1200")) # 20 minutes default + snapshot_id = os.getenv("HECATE_DEFAULT_SNAPSHOT_ID", "snapshot_defv9tjg") + + # Check API key + morph_api_key = os.getenv("MORPH_API_KEY") + if not morph_api_key: + return json.dumps({ + "output": "", + "screen": "", + "exit_code": -1, + "error": "MORPH_API_KEY environment variable not set", + "status": "disabled" + }, ensure_ascii=False) + + # Use task_id to isolate VMs between concurrent tasks + # If no task_id provided, use "default" for backward compatibility + effective_task_id = task_id or "default" + + # Start the cleanup thread if not already running + _start_cleanup_thread() + + # Get or create VM instance and execution context per task + # This is critical for interactive session support - the context must persist! + with _instance_lock: + if effective_task_id not in _active_instances: + morph_client = MorphCloudClient(api_key=morph_api_key) + _active_instances[effective_task_id] = morph_client.instances.start( + snapshot_id=snapshot_id, + ttl_seconds=vm_ttl_seconds, + ttl_action="stop" + ) + + # Get or create persistent execution context per task + if effective_task_id not in _active_contexts: + _active_contexts[effective_task_id] = ExecutionContext() + + # Update last activity time for this VM (resets the inactivity timer) + _last_activity[effective_task_id] = time.time() + + instance = _active_instances[effective_task_id] + ctx = _active_contexts[effective_task_id] + + # Build tool input based on provided parameters + tool_input = {} + + if command: + tool_input["command"] = command + if input_keys: + tool_input["input_keys"] = input_keys + if session_id: + tool_input["session_id"] = session_id + if background: + tool_input["background"] = background + if idle_threshold != 5.0: + tool_input["idle_threshold"] = idle_threshold + if timeout is not None: + tool_input["timeout"] = timeout + + tool_call = ToolCall( + name="run_command", + input=tool_input + ) + + # Create a console for output (redirect to string buffer to avoid printing) + console_output = io.StringIO() + console = Console(file=console_output, force_terminal=False, legacy_windows=False) + + # Generate unique tool block ID + tool_block_id = f"tool_{uuid.uuid4().hex[:8]}" + + # Execute the tool with hecate + result = run_tool( + tool_call=tool_call, + instance=instance, + console=console, + tool_block_id=tool_block_id, + ctx=ctx + ) + + # Format the result with only essential fields for the LLM + # Map hecate's "stdout" to "output" for compatibility + formatted_result = { + "output": result.get("stdout", result.get("output", "")), + "screen": result.get("screen", ""), + "exit_code": result.get("returncode", result.get("exit_code", -1)), + "error": result.get("error") + } + + return json.dumps(formatted_result, ensure_ascii=False) + + except Exception as e: + return json.dumps({ + "output": "", + "screen": "", + "exit_code": -1, + "error": f"Failed to execute terminal command: {str(e)}", + "status": "error" + }, ensure_ascii=False) + +def check_hecate_requirements() -> bool: + """ + Check if all requirements for terminal tools are met. + + Returns: + bool: True if all requirements are met, False otherwise + """ + # Check for required environment variables + required_vars = ["MORPH_API_KEY"] + optional_vars = ["OPENAI_API_KEY"] # Needed for Hecate's LLM features + + missing_required = [var for var in required_vars if not os.getenv(var)] + missing_optional = [var for var in optional_vars if not os.getenv(var)] + + if missing_required: + print(f"Missing required environment variables: {', '.join(missing_required)}") + return False + + if missing_optional: + print(f"Warning: Missing optional environment variables: {', '.join(missing_optional)}") + print(" (Some Hecate features may be limited)") + + # Check if Hecate and required modules are importable + try: + from morphcloud._llm import ToolCall + from morphcloud.api import MorphCloudClient + from hecate.cli import run_tool, ExecutionContext + from rich.console import Console + return True + except Exception as e: + print(f"Hecate not available: {e}") + print(f"Make sure hecate is installed and MORPH_API_KEY is set.") + return False + +# Module-level initialization check +_requirements_met = check_hecate_requirements() + +if __name__ == "__main__": + """ + Simple test/demo when run directly + """ + print("Terminal Tool Module") + print("=" * 40) + + if not _requirements_met: + print("Requirements not met. Please check the messages above.") + exit(1) + + print("All requirements met!") + print("\nAvailable Tool:") + print(" - terminal_tool: Execute commands with optional interactive session support") + + print("\nUsage Examples:") + print(" # Execute a command") + print(" result = terminal_tool(command='ls -la')") + print(" ") + print(" # Start an interactive session") + print(" result = terminal_tool(command='python3')") + print(" session_data = json.loads(result)") + print(" session_id = session_data['session_id']") + print(" ") + print(" # Send input to the session") + print(" result = terminal_tool(") + print(" input_keys='print(\"Hello\")\\\\n',") + print(" session_id=session_id") + print(" )") + print(" ") + print(" # Run a background task") + print(" result = terminal_tool(command='sleep 60', background=True)") + + print("\nEnvironment Variables:") + print(f" MORPH_API_KEY: {'Set' if os.getenv('MORPH_API_KEY') else 'Not set'}") + print(f" OPENAI_API_KEY: {'Set' if os.getenv('OPENAI_API_KEY') else 'Not set (optional)'}") + print(f" HECATE_VM_TTL_SECONDS: {os.getenv('HECATE_VM_TTL_SECONDS', '1200')} (default: 1200 / 20 minutes)") + print(f" HECATE_VM_LIFETIME_SECONDS: {os.getenv('HECATE_VM_LIFETIME_SECONDS', '300')} (default: 300 / 5 minutes)") + print(f" HECATE_DEFAULT_SNAPSHOT_ID: {os.getenv('HECATE_DEFAULT_SNAPSHOT_ID', 'snapshot_defv9tjg')} (default: snapshot_defv9tjg)") diff --git a/vision_tools.py b/tools/vision_tools.py similarity index 66% rename from vision_tools.py rename to tools/vision_tools.py index 3183713bd..83f9a36a4 100644 --- a/vision_tools.py +++ b/tools/vision_tools.py @@ -1,346 +1,471 @@ -#!/usr/bin/env python3 -""" -Vision Tools Module - -This module provides vision analysis tools that work with image URLs. -Uses Gemini Flash via Nous Research API for intelligent image understanding. - -Available tools: -- vision_analyze_tool: Analyze images from URLs with custom prompts - -Features: -- Comprehensive image description -- Context-aware analysis based on user queries -- Proper error handling and validation -- Debug logging support - -Usage: - from vision_tools import vision_analyze_tool - import asyncio - - # Analyze an image - result = await vision_analyze_tool( - image_url="https://example.com/image.jpg", - user_prompt="What architectural style is this building?" - ) -""" - -import json -import os -import asyncio -import uuid -import datetime -from pathlib import Path -from typing import Dict, Any, Optional -from openai import AsyncOpenAI - -# Initialize Nous Research API client for vision processing -nous_client = AsyncOpenAI( - api_key=os.getenv("NOUS_API_KEY"), - base_url="https://inference-api.nousresearch.com/v1" -) - -# Configuration for vision processing -DEFAULT_VISION_MODEL = "gemini-2.5-flash" - -# Debug mode configuration -DEBUG_MODE = os.getenv("VISION_TOOLS_DEBUG", "false").lower() == "true" -DEBUG_SESSION_ID = str(uuid.uuid4()) -DEBUG_LOG_PATH = Path("./logs") -DEBUG_DATA = { - "session_id": DEBUG_SESSION_ID, - "start_time": datetime.datetime.now().isoformat(), - "debug_enabled": DEBUG_MODE, - "tool_calls": [] -} if DEBUG_MODE else None - -# Create logs directory if debug mode is enabled -if DEBUG_MODE: - DEBUG_LOG_PATH.mkdir(exist_ok=True) - print(f"šŸ› Vision debug mode enabled - Session ID: {DEBUG_SESSION_ID}") - - -def _log_debug_call(tool_name: str, call_data: Dict[str, Any]) -> None: - """ - Log a debug call entry to the global debug data structure. - - Args: - tool_name (str): Name of the tool being called - call_data (Dict[str, Any]): Data about the call including parameters and results - """ - if not DEBUG_MODE or not DEBUG_DATA: - return - - call_entry = { - "timestamp": datetime.datetime.now().isoformat(), - "tool_name": tool_name, - **call_data - } - - DEBUG_DATA["tool_calls"].append(call_entry) - - -def _save_debug_log() -> None: - """ - Save the current debug data to a JSON file in the logs directory. - """ - if not DEBUG_MODE or not DEBUG_DATA: - return - - try: - debug_filename = f"vision_tools_debug_{DEBUG_SESSION_ID}.json" - debug_filepath = DEBUG_LOG_PATH / debug_filename - - # Update end time - DEBUG_DATA["end_time"] = datetime.datetime.now().isoformat() - DEBUG_DATA["total_calls"] = len(DEBUG_DATA["tool_calls"]) - - with open(debug_filepath, 'w', encoding='utf-8') as f: - json.dump(DEBUG_DATA, f, indent=2, ensure_ascii=False) - - print(f"šŸ› Vision debug log saved: {debug_filepath}") - - except Exception as e: - print(f"āŒ Error saving vision debug log: {str(e)}") - - -def _validate_image_url(url: str) -> bool: - """ - Basic validation of image URL format. - - Args: - url (str): The URL to validate - - Returns: - bool: True if URL appears to be valid, False otherwise - """ - if not url or not isinstance(url, str): - return False - - # Check if it's a valid URL format - if not (url.startswith('http://') or url.startswith('https://')): - return False - - # Check for common image extensions (optional, as URLs may not have extensions) - image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp', '.svg'] - - return True # Allow all HTTP/HTTPS URLs for flexibility - - -async def vision_analyze_tool( - image_url: str, - user_prompt: str, - model: str = DEFAULT_VISION_MODEL -) -> str: - """ - Analyze an image from a URL using vision AI. - - This tool processes images using Gemini Flash via Nous Research API. - The user_prompt parameter is expected to be pre-formatted by the calling - function (typically model_tools.py) to include both full description - requests and specific questions. - - Args: - image_url (str): The URL of the image to analyze - user_prompt (str): The pre-formatted prompt for the vision model - model (str): The vision model to use (default: gemini-2.5-flash) - - Returns: - str: JSON string containing the analysis results with the following structure: - { - "success": bool, - "analysis": str (defaults to error message if None) - } - - Raises: - Exception: If analysis fails or API key is not set - """ - debug_call_data = { - "parameters": { - "image_url": image_url, - "user_prompt": user_prompt, - "model": model - }, - "error": None, - "success": False, - "analysis_length": 0, - "model_used": model - } - - try: - print(f"šŸ” Analyzing image from URL: {image_url[:60]}{'...' if len(image_url) > 60 else ''}") - print(f"šŸ“ User prompt: {user_prompt[:100]}{'...' if len(user_prompt) > 100 else ''}") - - # Validate image URL - if not _validate_image_url(image_url): - raise ValueError("Invalid image URL format. Must start with http:// or https://") - - # Check API key availability - if not os.getenv("NOUS_API_KEY"): - raise ValueError("NOUS_API_KEY environment variable not set") - - # Use the prompt as provided (model_tools.py now handles full description formatting) - comprehensive_prompt = user_prompt - - # Prepare the message with image URL format - messages = [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": comprehensive_prompt - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - } - ] - } - ] - - print(f"🧠 Processing image with {model}...") - - # Call the vision API - response = await nous_client.chat.completions.create( - model=model, - messages=messages, - temperature=0.1, # Low temperature for consistent analysis - max_tokens=2000 # Generous limit for detailed analysis - ) - - # Extract the analysis - analysis = response.choices[0].message.content.strip() - analysis_length = len(analysis) - - print(f"āœ… Image analysis completed ({analysis_length} characters)") - - # Prepare successful response - result = { - "success": True, - "analysis": analysis or "There was a problem with the request and the image could not be analyzed." - } - - debug_call_data["success"] = True - debug_call_data["analysis_length"] = analysis_length - - # Log debug information - _log_debug_call("vision_analyze_tool", debug_call_data) - _save_debug_log() - - return json.dumps(result, indent=2) - - except Exception as e: - error_msg = f"Error analyzing image: {str(e)}" - print(f"āŒ {error_msg}") - - # Prepare error response - result = { - "success": False, - "analysis": "There was a problem with the request and the image could not be analyzed." - } - - debug_call_data["error"] = error_msg - _log_debug_call("vision_analyze_tool", debug_call_data) - _save_debug_log() - - return json.dumps(result, indent=2) - - -def check_nous_api_key() -> bool: - """ - Check if the Nous Research API key is available in environment variables. - - Returns: - bool: True if API key is set, False otherwise - """ - return bool(os.getenv("NOUS_API_KEY")) - - -def check_vision_requirements() -> bool: - """ - Check if all requirements for vision tools are met. - - Returns: - bool: True if requirements are met, False otherwise - """ - return check_nous_api_key() - - -def get_debug_session_info() -> Dict[str, Any]: - """ - Get information about the current debug session. - - Returns: - Dict[str, Any]: Dictionary containing debug session information - """ - if not DEBUG_MODE or not DEBUG_DATA: - return { - "enabled": False, - "session_id": None, - "log_path": None, - "total_calls": 0 - } - - return { - "enabled": True, - "session_id": DEBUG_SESSION_ID, - "log_path": str(DEBUG_LOG_PATH / f"vision_tools_debug_{DEBUG_SESSION_ID}.json"), - "total_calls": len(DEBUG_DATA["tool_calls"]) - } - - -if __name__ == "__main__": - """ - Simple test/demo when run directly - """ - print("šŸ‘ļø Vision Tools Module") - print("=" * 40) - - # Check if API key is available - api_available = check_nous_api_key() - - if not api_available: - print("āŒ NOUS_API_KEY environment variable not set") - print("Please set your API key: export NOUS_API_KEY='your-key-here'") - print("Get API key at: https://inference-api.nousresearch.com/") - exit(1) - else: - print("āœ… Nous Research API key found") - - print("šŸ› ļø Vision tools ready for use!") - print(f"🧠 Using model: {DEFAULT_VISION_MODEL}") - - # Show debug mode status - if DEBUG_MODE: - print(f"šŸ› Debug mode ENABLED - Session ID: {DEBUG_SESSION_ID}") - print(f" Debug logs will be saved to: ./logs/vision_tools_debug_{DEBUG_SESSION_ID}.json") - else: - print("šŸ› Debug mode disabled (set VISION_TOOLS_DEBUG=true to enable)") - - print("\nBasic usage:") - print(" from vision_tools import vision_analyze_tool") - print(" import asyncio") - print("") - print(" async def main():") - print(" result = await vision_analyze_tool(") - print(" image_url='https://example.com/image.jpg',") - print(" user_prompt='What do you see in this image?'") - print(" )") - print(" print(result)") - print(" asyncio.run(main())") - - print("\nExample prompts:") - print(" - 'What architectural style is this building?'") - print(" - 'Describe the emotions and mood in this image'") - print(" - 'What text can you read in this image?'") - print(" - 'Identify any safety hazards visible'") - print(" - 'What products or brands are shown?'") - - print("\nDebug mode:") - print(" # Enable debug logging") - print(" export VISION_TOOLS_DEBUG=true") - print(" # Debug logs capture all vision analysis calls and results") - print(" # Logs saved to: ./logs/vision_tools_debug_UUID.json") +#!/usr/bin/env python3 +""" +Vision Tools Module + +This module provides vision analysis tools that work with image URLs. +Uses Gemini Flash via Nous Research API for intelligent image understanding. + +Available tools: +- vision_analyze_tool: Analyze images from URLs with custom prompts + +Features: +- Downloads images from URLs and converts to base64 for API compatibility +- Comprehensive image description +- Context-aware analysis based on user queries +- Automatic temporary file cleanup +- Proper error handling and validation +- Debug logging support + +Usage: + from vision_tools import vision_analyze_tool + import asyncio + + # Analyze an image + result = await vision_analyze_tool( + image_url="https://example.com/image.jpg", + user_prompt="What architectural style is this building?" + ) +""" + +import json +import os +import asyncio +import uuid +import datetime +import base64 +from pathlib import Path +from typing import Dict, Any, Optional +from openai import AsyncOpenAI +import httpx # Use httpx for async HTTP requests + +# Initialize Nous Research API client for vision processing +nous_client = AsyncOpenAI( + api_key=os.getenv("NOUS_API_KEY"), + base_url="https://inference-api.nousresearch.com/v1" +) + +# Configuration for vision processing +DEFAULT_VISION_MODEL = "gemini-2.5-flash" + +# Debug mode configuration +DEBUG_MODE = os.getenv("VISION_TOOLS_DEBUG", "false").lower() == "true" +DEBUG_SESSION_ID = str(uuid.uuid4()) +DEBUG_LOG_PATH = Path("./logs") +DEBUG_DATA = { + "session_id": DEBUG_SESSION_ID, + "start_time": datetime.datetime.now().isoformat(), + "debug_enabled": DEBUG_MODE, + "tool_calls": [] +} if DEBUG_MODE else None + +# Create logs directory if debug mode is enabled +if DEBUG_MODE: + DEBUG_LOG_PATH.mkdir(exist_ok=True) + print(f"šŸ› Vision debug mode enabled - Session ID: {DEBUG_SESSION_ID}") + + +def _log_debug_call(tool_name: str, call_data: Dict[str, Any]) -> None: + """ + Log a debug call entry to the global debug data structure. + + Args: + tool_name (str): Name of the tool being called + call_data (Dict[str, Any]): Data about the call including parameters and results + """ + if not DEBUG_MODE or not DEBUG_DATA: + return + + call_entry = { + "timestamp": datetime.datetime.now().isoformat(), + "tool_name": tool_name, + **call_data + } + + DEBUG_DATA["tool_calls"].append(call_entry) + + +def _save_debug_log() -> None: + """ + Save the current debug data to a JSON file in the logs directory. + """ + if not DEBUG_MODE or not DEBUG_DATA: + return + + try: + debug_filename = f"vision_tools_debug_{DEBUG_SESSION_ID}.json" + debug_filepath = DEBUG_LOG_PATH / debug_filename + + # Update end time + DEBUG_DATA["end_time"] = datetime.datetime.now().isoformat() + DEBUG_DATA["total_calls"] = len(DEBUG_DATA["tool_calls"]) + + with open(debug_filepath, 'w', encoding='utf-8') as f: + json.dump(DEBUG_DATA, f, indent=2, ensure_ascii=False) + + print(f"šŸ› Vision debug log saved: {debug_filepath}") + + except Exception as e: + print(f"āŒ Error saving vision debug log: {str(e)}") + + +def _validate_image_url(url: str) -> bool: + """ + Basic validation of image URL format. + + Args: + url (str): The URL to validate + + Returns: + bool: True if URL appears to be valid, False otherwise + """ + if not url or not isinstance(url, str): + return False + + # Check if it's a valid URL format + if not (url.startswith('http://') or url.startswith('https://')): + return False + + # Check for common image extensions (optional, as URLs may not have extensions) + image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp', '.svg'] + + return True # Allow all HTTP/HTTPS URLs for flexibility + + +async def _download_image(image_url: str, destination: Path) -> Path: + """ + Download an image from a URL to a local destination (async). + + Args: + image_url (str): The URL of the image to download + destination (Path): The path where the image should be saved + + Returns: + Path: The path to the downloaded image + + Raises: + Exception: If download fails or response is invalid + """ + # Create parent directories if they don't exist + destination.parent.mkdir(parents=True, exist_ok=True) + + # Download the image with appropriate headers using async httpx + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get( + image_url, + headers={"User-Agent": "hermes-agent-vision/1.0"}, + ) + response.raise_for_status() + + # Save the image content + destination.write_bytes(response.content) + + return destination + + +def _determine_mime_type(image_path: Path) -> str: + """ + Determine the MIME type of an image based on its file extension. + + Args: + image_path (Path): Path to the image file + + Returns: + str: The MIME type (defaults to image/jpeg if unknown) + """ + extension = image_path.suffix.lower() + mime_types = { + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.png': 'image/png', + '.gif': 'image/gif', + '.bmp': 'image/bmp', + '.webp': 'image/webp', + '.svg': 'image/svg+xml' + } + return mime_types.get(extension, 'image/jpeg') + + +def _image_to_base64_data_url(image_path: Path, mime_type: Optional[str] = None) -> str: + """ + Convert an image file to a base64-encoded data URL. + + Args: + image_path (Path): Path to the image file + mime_type (Optional[str]): MIME type of the image (auto-detected if None) + + Returns: + str: Base64-encoded data URL (e.g., "data:image/jpeg;base64,...") + """ + # Read the image as bytes + data = image_path.read_bytes() + + # Encode to base64 + encoded = base64.b64encode(data).decode("ascii") + + # Determine MIME type + mime = mime_type or _determine_mime_type(image_path) + + # Create data URL + data_url = f"data:{mime};base64,{encoded}" + + return data_url + + +async def vision_analyze_tool( + image_url: str, + user_prompt: str, + model: str = DEFAULT_VISION_MODEL +) -> str: + """ + Analyze an image from a URL using vision AI. + + This tool downloads images from URLs, converts them to base64, and processes + them using Gemini Flash via Nous Research API. The image is downloaded to a + temporary location and automatically cleaned up after processing. + + The user_prompt parameter is expected to be pre-formatted by the calling + function (typically model_tools.py) to include both full description + requests and specific questions. + + Args: + image_url (str): The URL of the image to analyze (must be http:// or https://) + user_prompt (str): The pre-formatted prompt for the vision model + model (str): The vision model to use (default: gemini-2.5-flash) + + Returns: + str: JSON string containing the analysis results with the following structure: + { + "success": bool, + "analysis": str (defaults to error message if None) + } + + Raises: + Exception: If download fails, analysis fails, or API key is not set + + Note: + - Temporary images are stored in ./temp_vision_images/ + - Images are automatically deleted after processing + - Supports common image formats (JPEG, PNG, GIF, WebP, etc.) + """ + debug_call_data = { + "parameters": { + "image_url": image_url, + "user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt, + "model": model + }, + "error": None, + "success": False, + "analysis_length": 0, + "model_used": model, + "image_size_bytes": 0 + } + + temp_image_path = None + + try: + print(f"šŸ” Analyzing image from URL: {image_url[:60]}{'...' if len(image_url) > 60 else ''}", flush=True) + print(f"šŸ“ User prompt: {user_prompt[:100]}{'...' if len(user_prompt) > 100 else ''}", flush=True) + + # Validate image URL + if not _validate_image_url(image_url): + raise ValueError("Invalid image URL format. Must start with http:// or https://") + + # Check API key availability + if not os.getenv("NOUS_API_KEY"): + raise ValueError("NOUS_API_KEY environment variable not set") + + # Download the image to a temporary location + print(f"ā¬‡ļø Downloading image from URL...", flush=True) + temp_dir = Path("./temp_vision_images") + temp_image_path = temp_dir / f"temp_image_{uuid.uuid4()}.jpg" + + await _download_image(image_url, temp_image_path) + + # Get image file size for logging + image_size_bytes = temp_image_path.stat().st_size + image_size_kb = image_size_bytes / 1024 + print(f"āœ… Image downloaded successfully ({image_size_kb:.1f} KB)", flush=True) + + # Convert image to base64 data URL + print(f"šŸ”„ Converting image to base64...", flush=True) + image_data_url = _image_to_base64_data_url(temp_image_path) + # Calculate size in KB for better readability + data_size_kb = len(image_data_url) / 1024 + print(f"āœ… Image converted to base64 ({data_size_kb:.1f} KB)", flush=True) + + debug_call_data["image_size_bytes"] = image_size_bytes + + # Use the prompt as provided (model_tools.py now handles full description formatting) + comprehensive_prompt = user_prompt + + # Prepare the message with base64-encoded image + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": comprehensive_prompt + }, + { + "type": "image_url", + "image_url": { + "url": image_data_url + } + } + ] + } + ] + + print(f"🧠 Processing image with {model}...", flush=True) + + # Call the vision API + response = await nous_client.chat.completions.create( + model=model, + messages=messages, + temperature=0.1, # Low temperature for consistent analysis + max_tokens=2000 # Generous limit for detailed analysis + ) + + # Extract the analysis + analysis = response.choices[0].message.content.strip() + analysis_length = len(analysis) + + print(f"āœ… Image analysis completed ({analysis_length} characters)", flush=True) + + # Prepare successful response + result = { + "success": True, + "analysis": analysis or "There was a problem with the request and the image could not be analyzed." + } + + debug_call_data["success"] = True + debug_call_data["analysis_length"] = analysis_length + + # Log debug information + _log_debug_call("vision_analyze_tool", debug_call_data) + _save_debug_log() + + return json.dumps(result, indent=2, ensure_ascii=False) + + except Exception as e: + error_msg = f"Error analyzing image: {str(e)}" + print(f"āŒ {error_msg}", flush=True) + + # Prepare error response + result = { + "success": False, + "analysis": "There was a problem with the request and the image could not be analyzed." + } + + debug_call_data["error"] = error_msg + _log_debug_call("vision_analyze_tool", debug_call_data) + _save_debug_log() + + return json.dumps(result, indent=2, ensure_ascii=False) + + finally: + # Clean up temporary image file + if temp_image_path and temp_image_path.exists(): + try: + temp_image_path.unlink() + print(f"🧹 Cleaned up temporary image file", flush=True) + except Exception as cleanup_error: + print(f"āš ļø Warning: Could not delete temporary file: {cleanup_error}", flush=True) + + +def check_nous_api_key() -> bool: + """ + Check if the Nous Research API key is available in environment variables. + + Returns: + bool: True if API key is set, False otherwise + """ + return bool(os.getenv("NOUS_API_KEY")) + + +def check_vision_requirements() -> bool: + """ + Check if all requirements for vision tools are met. + + Returns: + bool: True if requirements are met, False otherwise + """ + return check_nous_api_key() + + +def get_debug_session_info() -> Dict[str, Any]: + """ + Get information about the current debug session. + + Returns: + Dict[str, Any]: Dictionary containing debug session information + """ + if not DEBUG_MODE or not DEBUG_DATA: + return { + "enabled": False, + "session_id": None, + "log_path": None, + "total_calls": 0 + } + + return { + "enabled": True, + "session_id": DEBUG_SESSION_ID, + "log_path": str(DEBUG_LOG_PATH / f"vision_tools_debug_{DEBUG_SESSION_ID}.json"), + "total_calls": len(DEBUG_DATA["tool_calls"]) + } + + +if __name__ == "__main__": + """ + Simple test/demo when run directly + """ + print("šŸ‘ļø Vision Tools Module") + print("=" * 40) + + # Check if API key is available + api_available = check_nous_api_key() + + if not api_available: + print("āŒ NOUS_API_KEY environment variable not set") + print("Please set your API key: export NOUS_API_KEY='your-key-here'") + print("Get API key at: https://inference-api.nousresearch.com/") + exit(1) + else: + print("āœ… Nous Research API key found") + + print("šŸ› ļø Vision tools ready for use!") + print(f"🧠 Using model: {DEFAULT_VISION_MODEL}") + + # Show debug mode status + if DEBUG_MODE: + print(f"šŸ› Debug mode ENABLED - Session ID: {DEBUG_SESSION_ID}") + print(f" Debug logs will be saved to: ./logs/vision_tools_debug_{DEBUG_SESSION_ID}.json") + else: + print("šŸ› Debug mode disabled (set VISION_TOOLS_DEBUG=true to enable)") + + print("\nBasic usage:") + print(" from vision_tools import vision_analyze_tool") + print(" import asyncio") + print("") + print(" async def main():") + print(" result = await vision_analyze_tool(") + print(" image_url='https://example.com/image.jpg',") + print(" user_prompt='What do you see in this image?'") + print(" )") + print(" print(result)") + print(" asyncio.run(main())") + + print("\nExample prompts:") + print(" - 'What architectural style is this building?'") + print(" - 'Describe the emotions and mood in this image'") + print(" - 'What text can you read in this image?'") + print(" - 'Identify any safety hazards visible'") + print(" - 'What products or brands are shown?'") + + print("\nDebug mode:") + print(" # Enable debug logging") + print(" export VISION_TOOLS_DEBUG=true") + print(" # Debug logs capture all vision analysis calls and results") + print(" # Logs saved to: ./logs/vision_tools_debug_UUID.json") diff --git a/web_tools.py b/tools/web_tools.py similarity index 94% rename from web_tools.py rename to tools/web_tools.py index 706eb1ff1..3f7df9f43 100644 --- a/web_tools.py +++ b/tools/web_tools.py @@ -1,1009 +1,1032 @@ -#!/usr/bin/env python3 -""" -Standalone Web Tools Module - -This module provides generic web tools that work with multiple backend providers. -Currently uses Firecrawl as the backend, and the interface makes it easy to swap -providers without changing the function signatures. - -Available tools: -- web_search_tool: Search the web for information -- web_extract_tool: Extract content from specific web pages -- web_crawl_tool: Crawl websites with specific instructions - -Backend compatibility: -- Firecrawl: https://docs.firecrawl.dev/introduction - -LLM Processing: -- Uses Nous Research API with Gemini 2.5 Flash for intelligent content extraction -- Extracts key excerpts and creates markdown summaries to reduce token usage - -Debug Mode: -- Set WEB_TOOLS_DEBUG=true to enable detailed logging -- Creates web_tools_debug_UUID.json in ./logs directory -- Captures all tool calls, results, and compression metrics - -Usage: - from web_tools import web_search_tool, web_extract_tool, web_crawl_tool - - # Search the web - results = web_search_tool("Python machine learning libraries", limit=3) - - # Extract content from URLs - content = web_extract_tool(["https://example.com"], format="markdown") - - # Crawl a website - crawl_data = web_crawl_tool("example.com", "Find contact information") -""" - -#TODO: Search Capabilities over the scraped pages -#TODO: Store the pages in something -#TODO: Tool to see what pages are available/saved to search over - -import json -import os -import re -import asyncio -import uuid -import datetime -from pathlib import Path -from typing import List, Dict, Any, Optional -from firecrawl import Firecrawl -from openai import AsyncOpenAI - -# Initialize Firecrawl client once at module level -firecrawl_client = Firecrawl(api_key=os.getenv("FIRECRAWL_API_KEY")) - -# Initialize Nous Research API client for LLM processing (async) -nous_client = AsyncOpenAI( - api_key=os.getenv("NOUS_API_KEY"), - base_url="https://inference-api.nousresearch.com/v1" -) - -# Configuration for LLM processing -DEFAULT_SUMMARIZER_MODEL = "gemini-2.5-flash" -DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION = 5000 - -# Debug mode configuration -DEBUG_MODE = os.getenv("WEB_TOOLS_DEBUG", "false").lower() == "true" -DEBUG_SESSION_ID = str(uuid.uuid4()) -DEBUG_LOG_PATH = Path("./logs") -DEBUG_DATA = { - "session_id": DEBUG_SESSION_ID, - "start_time": datetime.datetime.now().isoformat(), - "debug_enabled": DEBUG_MODE, - "tool_calls": [] -} if DEBUG_MODE else None - -# Create logs directory if debug mode is enabled -if DEBUG_MODE: - DEBUG_LOG_PATH.mkdir(exist_ok=True) - print(f"šŸ› Debug mode enabled - Session ID: {DEBUG_SESSION_ID}") - - -def _log_debug_call(tool_name: str, call_data: Dict[str, Any]) -> None: - """ - Log a debug call entry to the global debug data structure. - - Args: - tool_name (str): Name of the tool being called - call_data (Dict[str, Any]): Data about the call including parameters and results - """ - if not DEBUG_MODE or not DEBUG_DATA: - return - - call_entry = { - "timestamp": datetime.datetime.now().isoformat(), - "tool_name": tool_name, - **call_data - } - - DEBUG_DATA["tool_calls"].append(call_entry) - - -def _save_debug_log() -> None: - """ - Save the current debug data to a JSON file in the logs directory. - """ - if not DEBUG_MODE or not DEBUG_DATA: - return - - try: - debug_filename = f"web_tools_debug_{DEBUG_SESSION_ID}.json" - debug_filepath = DEBUG_LOG_PATH / debug_filename - - # Update end time - DEBUG_DATA["end_time"] = datetime.datetime.now().isoformat() - DEBUG_DATA["total_calls"] = len(DEBUG_DATA["tool_calls"]) - - with open(debug_filepath, 'w', encoding='utf-8') as f: - json.dump(DEBUG_DATA, f, indent=2, ensure_ascii=False) - - print(f"šŸ› Debug log saved: {debug_filepath}") - - except Exception as e: - print(f"āŒ Error saving debug log: {str(e)}") - - -async def process_content_with_llm( - content: str, - url: str = "", - title: str = "", - model: str = DEFAULT_SUMMARIZER_MODEL, - min_length: int = DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION -) -> Optional[str]: - """ - Process web content using LLM to create intelligent summaries with key excerpts. - - This function uses Gemini 2.5 Flash (or specified model) via Nous Research API - to intelligently extract key information and create markdown summaries, - significantly reducing token usage while preserving all important information. - - Args: - content (str): The raw content to process - url (str): The source URL (for context, optional) - title (str): The page title (for context, optional) - model (str): The model to use for processing (default: gemini-2.5-flash) - min_length (int): Minimum content length to trigger processing (default: 5000) - - Returns: - Optional[str]: Processed markdown content, or None if content too short or processing fails - """ - try: - # Skip processing if content is too short - if len(content) < min_length: - print(f"šŸ“ Content too short ({len(content)} < {min_length} chars), skipping LLM processing") - return None - - print(f"🧠 Processing content with LLM ({len(content)} characters)") - - # Create context information - context_info = [] - if title: - context_info.append(f"Title: {title}") - if url: - context_info.append(f"Source: {url}") - - context_str = "\n".join(context_info) + "\n\n" if context_info else "" - - # Simplified prompt for better quality markdown output - system_prompt = """You are an expert content analyst. Your job is to process web content and create a comprehensive yet concise summary that preserves all important information while dramatically reducing bulk. - -Create a well-structured markdown summary that includes: -1. Key excerpts (quotes, code snippets, important facts) in their original format -2. Comprehensive summary of all other important information -3. Proper markdown formatting with headers, bullets, and emphasis - -Your goal is to preserve ALL important information while reducing length. Never lose key facts, figures, insights, or actionable information. Make it scannable and well-organized.""" - - user_prompt = f"""Please process this web content and create a comprehensive markdown summary: - -{context_str}CONTENT TO PROCESS: -{content} - -Create a markdown summary that captures all key information in a well-organized, scannable format. Include important quotes and code snippets in their original formatting. Focus on actionable information, specific details, and unique insights.""" - - # Call the LLM asynchronously - response = await nous_client.chat.completions.create( - model=model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} - ], - temperature=0.1, # Low temperature for consistent extraction - max_tokens=4000 # Generous limit for comprehensive processing - ) - - # Get the markdown response directly - processed_content = response.choices[0].message.content.strip() - - # Calculate compression metrics for logging - original_length = len(content) - processed_length = len(processed_content) - compression_ratio = processed_length / original_length if original_length > 0 else 1.0 - - print(f"āœ… Content processed: {original_length} → {processed_length} chars ({compression_ratio:.1%})") - - return processed_content - - except Exception as e: - print(f"āŒ Error processing content with LLM: {str(e)}") - return None - - -def clean_base64_images(text: str) -> str: - """ - Remove base64 encoded images from text to reduce token count and clutter. - - This function finds and removes base64 encoded images in various formats: - - (data:image/png;base64,...) - - (data:image/jpeg;base64,...) - - (data:image/svg+xml;base64,...) - - data:image/[type];base64,... (without parentheses) - - Args: - text: The text content to clean - - Returns: - Cleaned text with base64 images replaced with placeholders - """ - # Pattern to match base64 encoded images wrapped in parentheses - # Matches: (data:image/[type];base64,[base64-string]) - base64_with_parens_pattern = r'\(data:image/[^;]+;base64,[A-Za-z0-9+/=]+\)' - - # Pattern to match base64 encoded images without parentheses - # Matches: data:image/[type];base64,[base64-string] - base64_pattern = r'data:image/[^;]+;base64,[A-Za-z0-9+/=]+' - - # Replace parentheses-wrapped images first - cleaned_text = re.sub(base64_with_parens_pattern, '[BASE64_IMAGE_REMOVED]', text) - - # Then replace any remaining non-parentheses images - cleaned_text = re.sub(base64_pattern, '[BASE64_IMAGE_REMOVED]', cleaned_text) - - return cleaned_text - - -def web_search_tool(query: str, limit: int = 5) -> str: - """ - Search the web for information using available search API backend. - - This function provides a generic interface for web search that can work - with multiple backends. Currently uses Firecrawl. - - Note: This function returns search result metadata only (URLs, titles, descriptions). - Use web_extract_tool to get full content from specific URLs. - - Args: - query (str): The search query to look up - limit (int): Maximum number of results to return (default: 5) - - Returns: - str: JSON string containing search results with the following structure: - { - "success": bool, - "data": { - "web": [ - { - "title": str, - "url": str, - "description": str, - "position": int - }, - ... - ] - } - } - - Raises: - Exception: If search fails or API key is not set - """ - debug_call_data = { - "parameters": { - "query": query, - "limit": limit - }, - "error": None, - "results_count": 0, - "original_response_size": 0, - "final_response_size": 0 - } - - try: - print(f"šŸ” Searching the web for: '{query}' (limit: {limit})") - - # Use Firecrawl's v2 search functionality WITHOUT scraping - # We only want search result metadata, not scraped content - # Docs: https://docs.firecrawl.dev/features/search - response = firecrawl_client.search( - query=query, - limit=limit - ) - - # The response is a SearchData object with web, news, and images attributes - # When not scraping, the results are directly in these attributes - web_results = [] - - # Check if response has web attribute (SearchData object) - if hasattr(response, 'web'): - # Response is a SearchData object with web attribute - if response.web: - # Convert each SearchResultWeb object to dict - for result in response.web: - if hasattr(result, 'model_dump'): - # Pydantic model - use model_dump - web_results.append(result.model_dump()) - elif hasattr(result, '__dict__'): - # Regular object - use __dict__ - web_results.append(result.__dict__) - elif isinstance(result, dict): - # Already a dict - web_results.append(result) - elif hasattr(response, 'model_dump'): - # Response has model_dump method - use it to get dict - response_dict = response.model_dump() - if 'web' in response_dict and response_dict['web']: - web_results = response_dict['web'] - elif isinstance(response, dict): - # Response is already a dictionary - if 'web' in response and response['web']: - web_results = response['web'] - - results_count = len(web_results) - print(f"āœ… Found {results_count} search results") - - # Build response with just search metadata (URLs, titles, descriptions) - response_data = { - "success": True, - "data": { - "web": web_results - } - } - - # Capture debug information - debug_call_data["results_count"] = results_count - - # Convert to JSON - result_json = json.dumps(response_data, indent=2) - - debug_call_data["final_response_size"] = len(result_json) - - # Log debug information - _log_debug_call("web_search_tool", debug_call_data) - _save_debug_log() - - return result_json - - except Exception as e: - error_msg = f"Error searching web: {str(e)}" - print(f"āŒ {error_msg}") - - debug_call_data["error"] = error_msg - _log_debug_call("web_search_tool", debug_call_data) - _save_debug_log() - - return json.dumps({"error": error_msg}) - - -async def web_extract_tool( - urls: List[str], - format: str = None, - use_llm_processing: bool = True, - model: str = DEFAULT_SUMMARIZER_MODEL, - min_length: int = DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION -) -> str: - """ - Extract content from specific web pages using available extraction API backend. - - This function provides a generic interface for web content extraction that - can work with multiple backends. Currently uses Firecrawl. - - Args: - urls (List[str]): List of URLs to extract content from - format (str): Desired output format ("markdown" or "html", optional) - use_llm_processing (bool): Whether to process content with LLM for summarization (default: True) - model (str): The model to use for LLM processing (default: gemini-2.5-flash) - min_length (int): Minimum content length to trigger LLM processing (default: 5000) - - Returns: - str: JSON string containing extracted content. If LLM processing is enabled and successful, - the 'content' field will contain the processed markdown summary instead of raw content. - - Raises: - Exception: If extraction fails or API key is not set - """ - debug_call_data = { - "parameters": { - "urls": urls, - "format": format, - "use_llm_processing": use_llm_processing, - "model": model, - "min_length": min_length - }, - "error": None, - "pages_extracted": 0, - "pages_processed_with_llm": 0, - "original_response_size": 0, - "final_response_size": 0, - "compression_metrics": [], - "processing_applied": [] - } - - try: - print(f"šŸ“„ Extracting content from {len(urls)} URL(s)") - - # Determine requested formats for Firecrawl v2 - formats: List[str] = [] - if format == "markdown": - formats = ["markdown"] - elif format == "html": - formats = ["html"] - else: - # Default: request markdown for LLM-readiness and include html as backup - formats = ["markdown", "html"] - - # Always use individual scraping for simplicity and reliability - # Batch scraping adds complexity without much benefit for small numbers of URLs - results: List[Dict[str, Any]] = [] - - for url in urls: - try: - print(f" šŸ“„ Scraping: {url}") - scrape_result = firecrawl_client.scrape( - url=url, - formats=formats - ) - - # Process the result - properly handle object serialization - metadata = {} - title = "" - content_markdown = None - content_html = None - - # Extract data from the scrape result - if hasattr(scrape_result, 'model_dump'): - # Pydantic model - use model_dump to get dict - result_dict = scrape_result.model_dump() - content_markdown = result_dict.get('markdown') - content_html = result_dict.get('html') - metadata = result_dict.get('metadata', {}) - elif hasattr(scrape_result, '__dict__'): - # Regular object with attributes - content_markdown = getattr(scrape_result, 'markdown', None) - content_html = getattr(scrape_result, 'html', None) - - # Handle metadata - convert to dict if it's an object - metadata_obj = getattr(scrape_result, 'metadata', {}) - if hasattr(metadata_obj, 'model_dump'): - metadata = metadata_obj.model_dump() - elif hasattr(metadata_obj, '__dict__'): - metadata = metadata_obj.__dict__ - elif isinstance(metadata_obj, dict): - metadata = metadata_obj - else: - metadata = {} - elif isinstance(scrape_result, dict): - # Already a dictionary - content_markdown = scrape_result.get('markdown') - content_html = scrape_result.get('html') - metadata = scrape_result.get('metadata', {}) - - # Ensure metadata is a dict (not an object) - if not isinstance(metadata, dict): - if hasattr(metadata, 'model_dump'): - metadata = metadata.model_dump() - elif hasattr(metadata, '__dict__'): - metadata = metadata.__dict__ - else: - metadata = {} - - # Get title from metadata - title = metadata.get("title", "") - - # Choose content based on requested format - chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or "" - - results.append({ - "url": metadata.get("sourceURL", url), - "title": title, - "content": chosen_content, - "raw_content": chosen_content, - "metadata": metadata # Now guaranteed to be a dict - }) - - except Exception as scrape_err: - print(f" āŒ Error scraping {url}: {str(scrape_err)}") - results.append({ - "url": url, - "title": "", - "content": "", - "raw_content": "", - "error": str(scrape_err) - }) - - response = {"results": results} - - pages_extracted = len(response.get('results', [])) - print(f"āœ… Extracted content from {pages_extracted} pages") - - debug_call_data["pages_extracted"] = pages_extracted - debug_call_data["original_response_size"] = len(json.dumps(response)) - - # Process each result with LLM if enabled - if use_llm_processing and os.getenv("NOUS_API_KEY"): - print("🧠 Processing extracted content with LLM...") - debug_call_data["processing_applied"].append("llm_processing") - - for result in response.get('results', []): - url = result.get('url', 'Unknown URL') - title = result.get('title', '') - raw_content = result.get('raw_content', '') or result.get('content', '') - - if raw_content: - original_size = len(raw_content) - - # Process content with LLM - processed = await process_content_with_llm( - raw_content, url, title, model, min_length - ) - - if processed: - processed_size = len(processed) - compression_ratio = processed_size / original_size if original_size > 0 else 1.0 - - # Capture compression metrics - debug_call_data["compression_metrics"].append({ - "url": url, - "original_size": original_size, - "processed_size": processed_size, - "compression_ratio": compression_ratio, - "model_used": model - }) - - # Replace content with processed version - result['content'] = processed - # Keep raw content in separate field for reference - result['raw_content'] = raw_content - debug_call_data["pages_processed_with_llm"] += 1 - print(f" šŸ“ {url} (processed)") - else: - debug_call_data["compression_metrics"].append({ - "url": url, - "original_size": original_size, - "processed_size": original_size, - "compression_ratio": 1.0, - "model_used": None, - "reason": "content_too_short" - }) - print(f" šŸ“ {url} (no processing - content too short)") - else: - print(f" āš ļø {url} (no content to process)") - else: - if use_llm_processing and not os.getenv("NOUS_API_KEY"): - print("āš ļø LLM processing requested but NOUS_API_KEY not set, returning raw content") - debug_call_data["processing_applied"].append("llm_processing_unavailable") - - # Print summary of extracted pages for debugging (original behavior) - for result in response.get('results', []): - url = result.get('url', 'Unknown URL') - content_length = len(result.get('raw_content', '')) - print(f" šŸ“ {url} ({content_length} characters)") - - # Trim output to minimal fields per entry: title, content, error - trimmed_results = [ - { - "title": r.get("title", ""), - "content": r.get("content", ""), - "error": r.get("error") - } - for r in response.get("results", []) - ] - trimmed_response = {"results": trimmed_results} - - result_json = json.dumps(trimmed_response, indent=2) - # Clean base64 images from extracted content - cleaned_result = clean_base64_images(result_json) - - debug_call_data["final_response_size"] = len(cleaned_result) - debug_call_data["processing_applied"].append("base64_image_removal") - - # Log debug information - _log_debug_call("web_extract_tool", debug_call_data) - _save_debug_log() - - return cleaned_result - - except Exception as e: - error_msg = f"Error extracting content: {str(e)}" - print(f"āŒ {error_msg}") - - debug_call_data["error"] = error_msg - _log_debug_call("web_extract_tool", debug_call_data) - _save_debug_log() - - return json.dumps({"error": error_msg}) - - -async def web_crawl_tool( - url: str, - instructions: str = None, - depth: str = "basic", - use_llm_processing: bool = True, - model: str = DEFAULT_SUMMARIZER_MODEL, - min_length: int = DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION -) -> str: - """ - Crawl a website with specific instructions using available crawling API backend. - - This function provides a generic interface for web crawling that can work - with multiple backends. Currently uses Firecrawl. - - Args: - url (str): The base URL to crawl (can include or exclude https://) - instructions (str): Instructions for what to crawl/extract using LLM intelligence (optional) - depth (str): Depth of extraction ("basic" or "advanced", default: "basic") - use_llm_processing (bool): Whether to process content with LLM for summarization (default: True) - model (str): The model to use for LLM processing (default: gemini-2.5-flash) - min_length (int): Minimum content length to trigger LLM processing (default: 5000) - - Returns: - str: JSON string containing crawled content. If LLM processing is enabled and successful, - the 'content' field will contain the processed markdown summary instead of raw content. - Each page is processed individually. - - Raises: - Exception: If crawling fails or API key is not set - """ - debug_call_data = { - "parameters": { - "url": url, - "instructions": instructions, - "depth": depth, - "use_llm_processing": use_llm_processing, - "model": model, - "min_length": min_length - }, - "error": None, - "pages_crawled": 0, - "pages_processed_with_llm": 0, - "original_response_size": 0, - "final_response_size": 0, - "compression_metrics": [], - "processing_applied": [] - } - - try: - # Ensure URL has protocol - if not url.startswith(('http://', 'https://')): - url = f'https://{url}' - print(f" šŸ“ Added https:// prefix to URL: {url}") - - instructions_text = f" with instructions: '{instructions}'" if instructions else "" - print(f"šŸ•·ļø Crawling {url}{instructions_text}") - - # Use Firecrawl's v2 crawl functionality - # Docs: https://docs.firecrawl.dev/features/crawl - # The crawl() method automatically waits for completion and returns all data - - # Build crawl parameters - keep it simple - crawl_params = { - "limit": 20, # Limit number of pages to crawl - "scrape_options": { - "formats": ["markdown"] # Just markdown for simplicity - } - } - - # Note: The 'prompt' parameter is not documented for crawl - # Instructions are typically used with the Extract endpoint, not Crawl - if instructions: - print(f" ā„¹ļø Note: Instructions parameter ignored (not supported in crawl API)") - - # Use the crawl method which waits for completion automatically - try: - crawl_result = firecrawl_client.crawl( - url=url, - **crawl_params - ) - except Exception as e: - print(f" āŒ Crawl API call failed: {e}") - raise - - pages: List[Dict[str, Any]] = [] - - # Process crawl results - the crawl method returns a CrawlJob object with data attribute - data_list = [] - - # The crawl_result is a CrawlJob object with a 'data' attribute containing list of Document objects - if hasattr(crawl_result, 'data'): - data_list = crawl_result.data if crawl_result.data else [] - print(f" šŸ“Š Status: {getattr(crawl_result, 'status', 'unknown')}") - print(f" šŸ“„ Retrieved {len(data_list)} pages") - - # Debug: Check other attributes if no data - if not data_list: - print(f" šŸ” Debug - CrawlJob attributes: {[attr for attr in dir(crawl_result) if not attr.startswith('_')]}") - print(f" šŸ” Debug - Status: {getattr(crawl_result, 'status', 'N/A')}") - print(f" šŸ” Debug - Total: {getattr(crawl_result, 'total', 'N/A')}") - print(f" šŸ” Debug - Completed: {getattr(crawl_result, 'completed', 'N/A')}") - - elif isinstance(crawl_result, dict) and 'data' in crawl_result: - data_list = crawl_result.get("data", []) - else: - print(" āš ļø Unexpected crawl result type") - print(f" šŸ” Debug - Result type: {type(crawl_result)}") - if hasattr(crawl_result, '__dict__'): - print(f" šŸ” Debug - Result attributes: {list(crawl_result.__dict__.keys())}") - - for item in data_list: - # Process each crawled page - properly handle object serialization - page_url = "Unknown URL" - title = "" - content_markdown = None - content_html = None - metadata = {} - - # Extract data from the item - if hasattr(item, 'model_dump'): - # Pydantic model - use model_dump to get dict - item_dict = item.model_dump() - content_markdown = item_dict.get('markdown') - content_html = item_dict.get('html') - metadata = item_dict.get('metadata', {}) - elif hasattr(item, '__dict__'): - # Regular object with attributes - content_markdown = getattr(item, 'markdown', None) - content_html = getattr(item, 'html', None) - - # Handle metadata - convert to dict if it's an object - metadata_obj = getattr(item, 'metadata', {}) - if hasattr(metadata_obj, 'model_dump'): - metadata = metadata_obj.model_dump() - elif hasattr(metadata_obj, '__dict__'): - metadata = metadata_obj.__dict__ - elif isinstance(metadata_obj, dict): - metadata = metadata_obj - else: - metadata = {} - elif isinstance(item, dict): - # Already a dictionary - content_markdown = item.get('markdown') - content_html = item.get('html') - metadata = item.get('metadata', {}) - - # Ensure metadata is a dict (not an object) - if not isinstance(metadata, dict): - if hasattr(metadata, 'model_dump'): - metadata = metadata.model_dump() - elif hasattr(metadata, '__dict__'): - metadata = metadata.__dict__ - else: - metadata = {} - - # Extract URL and title from metadata - page_url = metadata.get("sourceURL", metadata.get("url", "Unknown URL")) - title = metadata.get("title", "") - - # Choose content (prefer markdown) - content = content_markdown or content_html or "" - - pages.append({ - "url": page_url, - "title": title, - "content": content, - "raw_content": content, - "metadata": metadata # Now guaranteed to be a dict - }) - - response = {"results": pages} - - pages_crawled = len(response.get('results', [])) - print(f"āœ… Crawled {pages_crawled} pages") - - debug_call_data["pages_crawled"] = pages_crawled - debug_call_data["original_response_size"] = len(json.dumps(response)) - - # Process each result with LLM if enabled - if use_llm_processing and os.getenv("NOUS_API_KEY"): - print("🧠 Processing crawled content with LLM...") - debug_call_data["processing_applied"].append("llm_processing") - - for result in response.get('results', []): - page_url = result.get('url', 'Unknown URL') - title = result.get('title', '') - content = result.get('content', '') - - if content: - original_size = len(content) - - # Process content with LLM - processed = await process_content_with_llm( - content, page_url, title, model, min_length - ) - - if processed: - processed_size = len(processed) - compression_ratio = processed_size / original_size if original_size > 0 else 1.0 - - # Capture compression metrics - debug_call_data["compression_metrics"].append({ - "url": page_url, - "original_size": original_size, - "processed_size": processed_size, - "compression_ratio": compression_ratio, - "model_used": model - }) - - # Keep original content in raw_content field - result['raw_content'] = content - # Replace content with processed version - result['content'] = processed - debug_call_data["pages_processed_with_llm"] += 1 - print(f" 🌐 {page_url} (processed)") - else: - debug_call_data["compression_metrics"].append({ - "url": page_url, - "original_size": original_size, - "processed_size": original_size, - "compression_ratio": 1.0, - "model_used": None, - "reason": "content_too_short" - }) - print(f" 🌐 {page_url} (no processing - content too short)") - else: - print(f" āš ļø {page_url} (no content to process)") - else: - if use_llm_processing and not os.getenv("NOUS_API_KEY"): - print("āš ļø LLM processing requested but NOUS_API_KEY not set, returning raw content") - debug_call_data["processing_applied"].append("llm_processing_unavailable") - - # Print summary of crawled pages for debugging (original behavior) - for result in response.get('results', []): - page_url = result.get('url', 'Unknown URL') - content_length = len(result.get('content', '')) - print(f" 🌐 {page_url} ({content_length} characters)") - - # Trim output to minimal fields per entry: title, content, error - trimmed_results = [ - { - "title": r.get("title", ""), - "content": r.get("content", ""), - "error": r.get("error") - } - for r in response.get("results", []) - ] - trimmed_response = {"results": trimmed_results} - - result_json = json.dumps(trimmed_response, indent=2) - # Clean base64 images from crawled content - cleaned_result = clean_base64_images(result_json) - - debug_call_data["final_response_size"] = len(cleaned_result) - debug_call_data["processing_applied"].append("base64_image_removal") - - # Log debug information - _log_debug_call("web_crawl_tool", debug_call_data) - _save_debug_log() - - return cleaned_result - - except Exception as e: - error_msg = f"Error crawling website: {str(e)}" - print(f"āŒ {error_msg}") - - debug_call_data["error"] = error_msg - _log_debug_call("web_crawl_tool", debug_call_data) - _save_debug_log() - - return json.dumps({"error": error_msg}) - - -# Convenience function to check if API key is available -def check_firecrawl_api_key() -> bool: - """ - Check if the Firecrawl API key is available in environment variables. - - Returns: - bool: True if API key is set, False otherwise - """ - return bool(os.getenv("FIRECRAWL_API_KEY")) - - -def check_nous_api_key() -> bool: - """ - Check if the Nous Research API key is available in environment variables. - - Returns: - bool: True if API key is set, False otherwise - """ - return bool(os.getenv("NOUS_API_KEY")) - - -def get_debug_session_info() -> Dict[str, Any]: - """ - Get information about the current debug session. - - Returns: - Dict[str, Any]: Dictionary containing debug session information: - - enabled: Whether debug mode is enabled - - session_id: Current session UUID (if enabled) - - log_path: Path where debug logs are saved (if enabled) - - total_calls: Number of tool calls logged so far (if enabled) - """ - if not DEBUG_MODE or not DEBUG_DATA: - return { - "enabled": False, - "session_id": None, - "log_path": None, - "total_calls": 0 - } - - return { - "enabled": True, - "session_id": DEBUG_SESSION_ID, - "log_path": str(DEBUG_LOG_PATH / f"web_tools_debug_{DEBUG_SESSION_ID}.json"), - "total_calls": len(DEBUG_DATA["tool_calls"]) - } - - -if __name__ == "__main__": - """ - Simple test/demo when run directly - """ - print("🌐 Standalone Web Tools Module") - print("=" * 40) - - # Check if API keys are available - firecrawl_available = check_firecrawl_api_key() - nous_available = check_nous_api_key() - - if not firecrawl_available: - print("āŒ FIRECRAWL_API_KEY environment variable not set") - print("Please set your API key: export FIRECRAWL_API_KEY='your-key-here'") - print("Get API key at: https://firecrawl.dev/") - else: - print("āœ… Firecrawl API key found") - - if not nous_available: - print("āŒ NOUS_API_KEY environment variable not set") - print("Please set your API key: export NOUS_API_KEY='your-key-here'") - print("Get API key at: https://inference-api.nousresearch.com/") - print("āš ļø Without Nous API key, LLM content processing will be disabled") - else: - print("āœ… Nous Research API key found") - - if not firecrawl_available: - exit(1) - - print("šŸ› ļø Web tools ready for use!") - - if nous_available: - print("🧠 LLM content processing available with Gemini 2.5 Flash") - print(f" Default min length for processing: {DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION} chars") - - # Show debug mode status - if DEBUG_MODE: - print(f"šŸ› Debug mode ENABLED - Session ID: {DEBUG_SESSION_ID}") - print(f" Debug logs will be saved to: ./logs/web_tools_debug_{DEBUG_SESSION_ID}.json") - else: - print("šŸ› Debug mode disabled (set WEB_TOOLS_DEBUG=true to enable)") - - print("\nBasic usage:") - print(" from web_tools import web_search_tool, web_extract_tool, web_crawl_tool") - print(" import asyncio") - print("") - print(" # Search (synchronous)") - print(" results = web_search_tool('Python tutorials')") - print("") - print(" # Extract and crawl (asynchronous)") - print(" async def main():") - print(" content = await web_extract_tool(['https://example.com'])") - print(" crawl_data = await web_crawl_tool('example.com', 'Find docs')") - print(" asyncio.run(main())") - - if nous_available: - print("\nLLM-enhanced usage:") - print(" # Content automatically processed for pages >5000 chars (default)") - print(" content = await web_extract_tool(['https://python.org/about/'])") - print("") - print(" # Customize processing parameters") - print(" crawl_data = await web_crawl_tool(") - print(" 'docs.python.org',") - print(" 'Find key concepts',") - print(" model='gemini-2.5-flash',") - print(" min_length=3000") - print(" )") - print("") - print(" # Disable LLM processing") - print(" raw_content = await web_extract_tool(['https://example.com'], use_llm_processing=False)") - - print("\nDebug mode:") - print(" # Enable debug logging") - print(" export WEB_TOOLS_DEBUG=true") - print(" # Debug logs capture:") - print(" # - All tool calls with parameters") - print(" # - Original API responses") - print(" # - LLM compression metrics") - print(" # - Final processed results") - print(" # Logs saved to: ./logs/web_tools_debug_UUID.json") - - print(f"\nšŸ“ Run 'python test_web_tools_llm.py' to test LLM processing capabilities") +#!/usr/bin/env python3 +""" +Standalone Web Tools Module + +This module provides generic web tools that work with multiple backend providers. +Currently uses Firecrawl as the backend, and the interface makes it easy to swap +providers without changing the function signatures. + +Available tools: +- web_search_tool: Search the web for information +- web_extract_tool: Extract content from specific web pages +- web_crawl_tool: Crawl websites with specific instructions + +Backend compatibility: +- Firecrawl: https://docs.firecrawl.dev/introduction + +LLM Processing: +- Uses Nous Research API with Gemini 2.5 Flash for intelligent content extraction +- Extracts key excerpts and creates markdown summaries to reduce token usage + +Debug Mode: +- Set WEB_TOOLS_DEBUG=true to enable detailed logging +- Creates web_tools_debug_UUID.json in ./logs directory +- Captures all tool calls, results, and compression metrics + +Usage: + from web_tools import web_search_tool, web_extract_tool, web_crawl_tool + + # Search the web + results = web_search_tool("Python machine learning libraries", limit=3) + + # Extract content from URLs + content = web_extract_tool(["https://example.com"], format="markdown") + + # Crawl a website + crawl_data = web_crawl_tool("example.com", "Find contact information") +""" + +#TODO: Search Capabilities over the scraped pages +#TODO: Store the pages in something +#TODO: Tool to see what pages are available/saved to search over + +import json +import os +import re +import asyncio +import uuid +import datetime +from pathlib import Path +from typing import List, Dict, Any, Optional +from firecrawl import Firecrawl +from openai import AsyncOpenAI + +# Initialize Firecrawl client once at module level +firecrawl_client = Firecrawl(api_key=os.getenv("FIRECRAWL_API_KEY")) + +# Initialize Nous Research API client for LLM processing (async) +nous_client = AsyncOpenAI( + api_key=os.getenv("NOUS_API_KEY"), + base_url="https://inference-api.nousresearch.com/v1" +) + +# Configuration for LLM processing +DEFAULT_SUMMARIZER_MODEL = "gemini-2.5-flash" +DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION = 5000 + +# Debug mode configuration +DEBUG_MODE = os.getenv("WEB_TOOLS_DEBUG", "false").lower() == "true" +DEBUG_SESSION_ID = str(uuid.uuid4()) +DEBUG_LOG_PATH = Path("./logs") +DEBUG_DATA = { + "session_id": DEBUG_SESSION_ID, + "start_time": datetime.datetime.now().isoformat(), + "debug_enabled": DEBUG_MODE, + "tool_calls": [] +} if DEBUG_MODE else None + +# Create logs directory if debug mode is enabled +if DEBUG_MODE: + DEBUG_LOG_PATH.mkdir(exist_ok=True) + print(f"šŸ› Debug mode enabled - Session ID: {DEBUG_SESSION_ID}") + + +def _log_debug_call(tool_name: str, call_data: Dict[str, Any]) -> None: + """ + Log a debug call entry to the global debug data structure. + + Args: + tool_name (str): Name of the tool being called + call_data (Dict[str, Any]): Data about the call including parameters and results + """ + if not DEBUG_MODE or not DEBUG_DATA: + return + + call_entry = { + "timestamp": datetime.datetime.now().isoformat(), + "tool_name": tool_name, + **call_data + } + + DEBUG_DATA["tool_calls"].append(call_entry) + + +def _save_debug_log() -> None: + """ + Save the current debug data to a JSON file in the logs directory. + """ + if not DEBUG_MODE or not DEBUG_DATA: + return + + try: + debug_filename = f"web_tools_debug_{DEBUG_SESSION_ID}.json" + debug_filepath = DEBUG_LOG_PATH / debug_filename + + # Update end time + DEBUG_DATA["end_time"] = datetime.datetime.now().isoformat() + DEBUG_DATA["total_calls"] = len(DEBUG_DATA["tool_calls"]) + + with open(debug_filepath, 'w', encoding='utf-8') as f: + json.dump(DEBUG_DATA, f, indent=2, ensure_ascii=False) + + print(f"šŸ› Debug log saved: {debug_filepath}") + + except Exception as e: + print(f"āŒ Error saving debug log: {str(e)}") + + +async def process_content_with_llm( + content: str, + url: str = "", + title: str = "", + model: str = DEFAULT_SUMMARIZER_MODEL, + min_length: int = DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION +) -> Optional[str]: + """ + Process web content using LLM to create intelligent summaries with key excerpts. + + This function uses Gemini 2.5 Flash (or specified model) via Nous Research API + to intelligently extract key information and create markdown summaries, + significantly reducing token usage while preserving all important information. + + Args: + content (str): The raw content to process + url (str): The source URL (for context, optional) + title (str): The page title (for context, optional) + model (str): The model to use for processing (default: gemini-2.5-flash) + min_length (int): Minimum content length to trigger processing (default: 5000) + + Returns: + Optional[str]: Processed markdown content, or None if content too short or processing fails + """ + try: + # Skip processing if content is too short + if len(content) < min_length: + print(f"šŸ“ Content too short ({len(content)} < {min_length} chars), skipping LLM processing") + return None + + print(f"🧠 Processing content with LLM ({len(content)} characters)") + + # Create context information + context_info = [] + if title: + context_info.append(f"Title: {title}") + if url: + context_info.append(f"Source: {url}") + + context_str = "\n".join(context_info) + "\n\n" if context_info else "" + + # Simplified prompt for better quality markdown output + system_prompt = """You are an expert content analyst. Your job is to process web content and create a comprehensive yet concise summary that preserves all important information while dramatically reducing bulk. + +Create a well-structured markdown summary that includes: +1. Key excerpts (quotes, code snippets, important facts) in their original format +2. Comprehensive summary of all other important information +3. Proper markdown formatting with headers, bullets, and emphasis + +Your goal is to preserve ALL important information while reducing length. Never lose key facts, figures, insights, or actionable information. Make it scannable and well-organized.""" + + user_prompt = f"""Please process this web content and create a comprehensive markdown summary: + +{context_str}CONTENT TO PROCESS: +{content} + +Create a markdown summary that captures all key information in a well-organized, scannable format. Include important quotes and code snippets in their original formatting. Focus on actionable information, specific details, and unique insights.""" + + # Call the LLM asynchronously with retry logic for flaky API + max_retries = 6 + retry_delay = 2 # Start with 2 seconds + last_error = None + + for attempt in range(max_retries): + try: + response = await nous_client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt} + ], + temperature=0.1, # Low temperature for consistent extraction + max_tokens=4000 # Generous limit for comprehensive processing + ) + break # Success, exit retry loop + except Exception as api_error: + last_error = api_error + if attempt < max_retries - 1: + print(f"āš ļø LLM API call failed (attempt {attempt + 1}/{max_retries}): {str(api_error)[:100]}") + print(f" Retrying in {retry_delay}s...") + await asyncio.sleep(retry_delay) + retry_delay = min(retry_delay * 2, 60) # Exponential backoff: 2s, 4s, 8s, 16s, 32s, 60s + else: + # All retries exhausted + raise last_error + + # Get the markdown response directly + processed_content = response.choices[0].message.content.strip() + + # Calculate compression metrics for logging + original_length = len(content) + processed_length = len(processed_content) + compression_ratio = processed_length / original_length if original_length > 0 else 1.0 + + print(f"āœ… Content processed: {original_length} → {processed_length} chars ({compression_ratio:.1%})") + + return processed_content + + except Exception as e: + print(f"āŒ Error processing content with LLM: {str(e)}") + return None + + +def clean_base64_images(text: str) -> str: + """ + Remove base64 encoded images from text to reduce token count and clutter. + + This function finds and removes base64 encoded images in various formats: + - (data:image/png;base64,...) + - (data:image/jpeg;base64,...) + - (data:image/svg+xml;base64,...) + - data:image/[type];base64,... (without parentheses) + + Args: + text: The text content to clean + + Returns: + Cleaned text with base64 images replaced with placeholders + """ + # Pattern to match base64 encoded images wrapped in parentheses + # Matches: (data:image/[type];base64,[base64-string]) + base64_with_parens_pattern = r'\(data:image/[^;]+;base64,[A-Za-z0-9+/=]+\)' + + # Pattern to match base64 encoded images without parentheses + # Matches: data:image/[type];base64,[base64-string] + base64_pattern = r'data:image/[^;]+;base64,[A-Za-z0-9+/=]+' + + # Replace parentheses-wrapped images first + cleaned_text = re.sub(base64_with_parens_pattern, '[BASE64_IMAGE_REMOVED]', text) + + # Then replace any remaining non-parentheses images + cleaned_text = re.sub(base64_pattern, '[BASE64_IMAGE_REMOVED]', cleaned_text) + + return cleaned_text + + +def web_search_tool(query: str, limit: int = 5) -> str: + """ + Search the web for information using available search API backend. + + This function provides a generic interface for web search that can work + with multiple backends. Currently uses Firecrawl. + + Note: This function returns search result metadata only (URLs, titles, descriptions). + Use web_extract_tool to get full content from specific URLs. + + Args: + query (str): The search query to look up + limit (int): Maximum number of results to return (default: 5) + + Returns: + str: JSON string containing search results with the following structure: + { + "success": bool, + "data": { + "web": [ + { + "title": str, + "url": str, + "description": str, + "position": int + }, + ... + ] + } + } + + Raises: + Exception: If search fails or API key is not set + """ + debug_call_data = { + "parameters": { + "query": query, + "limit": limit + }, + "error": None, + "results_count": 0, + "original_response_size": 0, + "final_response_size": 0 + } + + try: + print(f"šŸ” Searching the web for: '{query}' (limit: {limit})") + + # Use Firecrawl's v2 search functionality WITHOUT scraping + # We only want search result metadata, not scraped content + # Docs: https://docs.firecrawl.dev/features/search + response = firecrawl_client.search( + query=query, + limit=limit + ) + + # The response is a SearchData object with web, news, and images attributes + # When not scraping, the results are directly in these attributes + web_results = [] + + # Check if response has web attribute (SearchData object) + if hasattr(response, 'web'): + # Response is a SearchData object with web attribute + if response.web: + # Convert each SearchResultWeb object to dict + for result in response.web: + if hasattr(result, 'model_dump'): + # Pydantic model - use model_dump + web_results.append(result.model_dump()) + elif hasattr(result, '__dict__'): + # Regular object - use __dict__ + web_results.append(result.__dict__) + elif isinstance(result, dict): + # Already a dict + web_results.append(result) + elif hasattr(response, 'model_dump'): + # Response has model_dump method - use it to get dict + response_dict = response.model_dump() + if 'web' in response_dict and response_dict['web']: + web_results = response_dict['web'] + elif isinstance(response, dict): + # Response is already a dictionary + if 'web' in response and response['web']: + web_results = response['web'] + + results_count = len(web_results) + print(f"āœ… Found {results_count} search results") + + # Build response with just search metadata (URLs, titles, descriptions) + response_data = { + "success": True, + "data": { + "web": web_results + } + } + + # Capture debug information + debug_call_data["results_count"] = results_count + + # Convert to JSON + result_json = json.dumps(response_data, indent=2, ensure_ascii=False) + + debug_call_data["final_response_size"] = len(result_json) + + # Log debug information + _log_debug_call("web_search_tool", debug_call_data) + _save_debug_log() + + return result_json + + except Exception as e: + error_msg = f"Error searching web: {str(e)}" + print(f"āŒ {error_msg}") + + debug_call_data["error"] = error_msg + _log_debug_call("web_search_tool", debug_call_data) + _save_debug_log() + + return json.dumps({"error": error_msg}, ensure_ascii=False) + + +async def web_extract_tool( + urls: List[str], + format: str = None, + use_llm_processing: bool = True, + model: str = DEFAULT_SUMMARIZER_MODEL, + min_length: int = DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION +) -> str: + """ + Extract content from specific web pages using available extraction API backend. + + This function provides a generic interface for web content extraction that + can work with multiple backends. Currently uses Firecrawl. + + Args: + urls (List[str]): List of URLs to extract content from + format (str): Desired output format ("markdown" or "html", optional) + use_llm_processing (bool): Whether to process content with LLM for summarization (default: True) + model (str): The model to use for LLM processing (default: gemini-2.5-flash) + min_length (int): Minimum content length to trigger LLM processing (default: 5000) + + Returns: + str: JSON string containing extracted content. If LLM processing is enabled and successful, + the 'content' field will contain the processed markdown summary instead of raw content. + + Raises: + Exception: If extraction fails or API key is not set + """ + debug_call_data = { + "parameters": { + "urls": urls, + "format": format, + "use_llm_processing": use_llm_processing, + "model": model, + "min_length": min_length + }, + "error": None, + "pages_extracted": 0, + "pages_processed_with_llm": 0, + "original_response_size": 0, + "final_response_size": 0, + "compression_metrics": [], + "processing_applied": [] + } + + try: + print(f"šŸ“„ Extracting content from {len(urls)} URL(s)") + + # Determine requested formats for Firecrawl v2 + formats: List[str] = [] + if format == "markdown": + formats = ["markdown"] + elif format == "html": + formats = ["html"] + else: + # Default: request markdown for LLM-readiness and include html as backup + formats = ["markdown", "html"] + + # Always use individual scraping for simplicity and reliability + # Batch scraping adds complexity without much benefit for small numbers of URLs + results: List[Dict[str, Any]] = [] + + for url in urls: + try: + print(f" šŸ“„ Scraping: {url}") + scrape_result = firecrawl_client.scrape( + url=url, + formats=formats + ) + + # Process the result - properly handle object serialization + metadata = {} + title = "" + content_markdown = None + content_html = None + + # Extract data from the scrape result + if hasattr(scrape_result, 'model_dump'): + # Pydantic model - use model_dump to get dict + result_dict = scrape_result.model_dump() + content_markdown = result_dict.get('markdown') + content_html = result_dict.get('html') + metadata = result_dict.get('metadata', {}) + elif hasattr(scrape_result, '__dict__'): + # Regular object with attributes + content_markdown = getattr(scrape_result, 'markdown', None) + content_html = getattr(scrape_result, 'html', None) + + # Handle metadata - convert to dict if it's an object + metadata_obj = getattr(scrape_result, 'metadata', {}) + if hasattr(metadata_obj, 'model_dump'): + metadata = metadata_obj.model_dump() + elif hasattr(metadata_obj, '__dict__'): + metadata = metadata_obj.__dict__ + elif isinstance(metadata_obj, dict): + metadata = metadata_obj + else: + metadata = {} + elif isinstance(scrape_result, dict): + # Already a dictionary + content_markdown = scrape_result.get('markdown') + content_html = scrape_result.get('html') + metadata = scrape_result.get('metadata', {}) + + # Ensure metadata is a dict (not an object) + if not isinstance(metadata, dict): + if hasattr(metadata, 'model_dump'): + metadata = metadata.model_dump() + elif hasattr(metadata, '__dict__'): + metadata = metadata.__dict__ + else: + metadata = {} + + # Get title from metadata + title = metadata.get("title", "") + + # Choose content based on requested format + chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or "" + + results.append({ + "url": metadata.get("sourceURL", url), + "title": title, + "content": chosen_content, + "raw_content": chosen_content, + "metadata": metadata # Now guaranteed to be a dict + }) + + except Exception as scrape_err: + print(f" āŒ Error scraping {url}: {str(scrape_err)}") + results.append({ + "url": url, + "title": "", + "content": "", + "raw_content": "", + "error": str(scrape_err) + }) + + response = {"results": results} + + pages_extracted = len(response.get('results', [])) + print(f"āœ… Extracted content from {pages_extracted} pages") + + debug_call_data["pages_extracted"] = pages_extracted + debug_call_data["original_response_size"] = len(json.dumps(response)) + + # Process each result with LLM if enabled + if use_llm_processing and os.getenv("NOUS_API_KEY"): + print("🧠 Processing extracted content with LLM...") + debug_call_data["processing_applied"].append("llm_processing") + + for result in response.get('results', []): + url = result.get('url', 'Unknown URL') + title = result.get('title', '') + raw_content = result.get('raw_content', '') or result.get('content', '') + + if raw_content: + original_size = len(raw_content) + + # Process content with LLM + processed = await process_content_with_llm( + raw_content, url, title, model, min_length + ) + + if processed: + processed_size = len(processed) + compression_ratio = processed_size / original_size if original_size > 0 else 1.0 + + # Capture compression metrics + debug_call_data["compression_metrics"].append({ + "url": url, + "original_size": original_size, + "processed_size": processed_size, + "compression_ratio": compression_ratio, + "model_used": model + }) + + # Replace content with processed version + result['content'] = processed + # Keep raw content in separate field for reference + result['raw_content'] = raw_content + debug_call_data["pages_processed_with_llm"] += 1 + print(f" šŸ“ {url} (processed)") + else: + debug_call_data["compression_metrics"].append({ + "url": url, + "original_size": original_size, + "processed_size": original_size, + "compression_ratio": 1.0, + "model_used": None, + "reason": "content_too_short" + }) + print(f" šŸ“ {url} (no processing - content too short)") + else: + print(f" āš ļø {url} (no content to process)") + else: + if use_llm_processing and not os.getenv("NOUS_API_KEY"): + print("āš ļø LLM processing requested but NOUS_API_KEY not set, returning raw content") + debug_call_data["processing_applied"].append("llm_processing_unavailable") + + # Print summary of extracted pages for debugging (original behavior) + for result in response.get('results', []): + url = result.get('url', 'Unknown URL') + content_length = len(result.get('raw_content', '')) + print(f" šŸ“ {url} ({content_length} characters)") + + # Trim output to minimal fields per entry: title, content, error + trimmed_results = [ + { + "title": r.get("title", ""), + "content": r.get("content", ""), + "error": r.get("error"), + } + for r in response.get("results", []) + ] + trimmed_response = {"results": trimmed_results} + + if trimmed_response.get("results") == []: + result_json = json.dumps({"error": "Content was inaccessible or not found"}, ensure_ascii=False) + + cleaned_result = clean_base64_images(result_json) + + else: + result_json = json.dumps(trimmed_response, indent=2, ensure_ascii=False) + + cleaned_result = clean_base64_images(result_json) + + debug_call_data["final_response_size"] = len(cleaned_result) + debug_call_data["processing_applied"].append("base64_image_removal") + + # Log debug information + _log_debug_call("web_extract_tool", debug_call_data) + _save_debug_log() + + return cleaned_result + + except Exception as e: + error_msg = f"Error extracting content: {str(e)}" + print(f"āŒ {error_msg}") + + debug_call_data["error"] = error_msg + _log_debug_call("web_extract_tool", debug_call_data) + _save_debug_log() + + return json.dumps({"error": error_msg}, ensure_ascii=False) + + +async def web_crawl_tool( + url: str, + instructions: str = None, + depth: str = "basic", + use_llm_processing: bool = True, + model: str = DEFAULT_SUMMARIZER_MODEL, + min_length: int = DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION +) -> str: + """ + Crawl a website with specific instructions using available crawling API backend. + + This function provides a generic interface for web crawling that can work + with multiple backends. Currently uses Firecrawl. + + Args: + url (str): The base URL to crawl (can include or exclude https://) + instructions (str): Instructions for what to crawl/extract using LLM intelligence (optional) + depth (str): Depth of extraction ("basic" or "advanced", default: "basic") + use_llm_processing (bool): Whether to process content with LLM for summarization (default: True) + model (str): The model to use for LLM processing (default: gemini-2.5-flash) + min_length (int): Minimum content length to trigger LLM processing (default: 5000) + + Returns: + str: JSON string containing crawled content. If LLM processing is enabled and successful, + the 'content' field will contain the processed markdown summary instead of raw content. + Each page is processed individually. + + Raises: + Exception: If crawling fails or API key is not set + """ + debug_call_data = { + "parameters": { + "url": url, + "instructions": instructions, + "depth": depth, + "use_llm_processing": use_llm_processing, + "model": model, + "min_length": min_length + }, + "error": None, + "pages_crawled": 0, + "pages_processed_with_llm": 0, + "original_response_size": 0, + "final_response_size": 0, + "compression_metrics": [], + "processing_applied": [] + } + + try: + # Ensure URL has protocol + if not url.startswith(('http://', 'https://')): + url = f'https://{url}' + print(f" šŸ“ Added https:// prefix to URL: {url}") + + instructions_text = f" with instructions: '{instructions}'" if instructions else "" + print(f"šŸ•·ļø Crawling {url}{instructions_text}") + + # Use Firecrawl's v2 crawl functionality + # Docs: https://docs.firecrawl.dev/features/crawl + # The crawl() method automatically waits for completion and returns all data + + # Build crawl parameters - keep it simple + crawl_params = { + "limit": 20, # Limit number of pages to crawl + "scrape_options": { + "formats": ["markdown"] # Just markdown for simplicity + } + } + + # Note: The 'prompt' parameter is not documented for crawl + # Instructions are typically used with the Extract endpoint, not Crawl + if instructions: + print(f" ā„¹ļø Note: Instructions parameter ignored (not supported in crawl API)") + + # Use the crawl method which waits for completion automatically + try: + crawl_result = firecrawl_client.crawl( + url=url, + **crawl_params + ) + except Exception as e: + print(f" āŒ Crawl API call failed: {e}") + raise + + pages: List[Dict[str, Any]] = [] + + # Process crawl results - the crawl method returns a CrawlJob object with data attribute + data_list = [] + + # The crawl_result is a CrawlJob object with a 'data' attribute containing list of Document objects + if hasattr(crawl_result, 'data'): + data_list = crawl_result.data if crawl_result.data else [] + print(f" šŸ“Š Status: {getattr(crawl_result, 'status', 'unknown')}") + print(f" šŸ“„ Retrieved {len(data_list)} pages") + + # Debug: Check other attributes if no data + if not data_list: + print(f" šŸ” Debug - CrawlJob attributes: {[attr for attr in dir(crawl_result) if not attr.startswith('_')]}") + print(f" šŸ” Debug - Status: {getattr(crawl_result, 'status', 'N/A')}") + print(f" šŸ” Debug - Total: {getattr(crawl_result, 'total', 'N/A')}") + print(f" šŸ” Debug - Completed: {getattr(crawl_result, 'completed', 'N/A')}") + + elif isinstance(crawl_result, dict) and 'data' in crawl_result: + data_list = crawl_result.get("data", []) + else: + print(" āš ļø Unexpected crawl result type") + print(f" šŸ” Debug - Result type: {type(crawl_result)}") + if hasattr(crawl_result, '__dict__'): + print(f" šŸ” Debug - Result attributes: {list(crawl_result.__dict__.keys())}") + + for item in data_list: + # Process each crawled page - properly handle object serialization + page_url = "Unknown URL" + title = "" + content_markdown = None + content_html = None + metadata = {} + + # Extract data from the item + if hasattr(item, 'model_dump'): + # Pydantic model - use model_dump to get dict + item_dict = item.model_dump() + content_markdown = item_dict.get('markdown') + content_html = item_dict.get('html') + metadata = item_dict.get('metadata', {}) + elif hasattr(item, '__dict__'): + # Regular object with attributes + content_markdown = getattr(item, 'markdown', None) + content_html = getattr(item, 'html', None) + + # Handle metadata - convert to dict if it's an object + metadata_obj = getattr(item, 'metadata', {}) + if hasattr(metadata_obj, 'model_dump'): + metadata = metadata_obj.model_dump() + elif hasattr(metadata_obj, '__dict__'): + metadata = metadata_obj.__dict__ + elif isinstance(metadata_obj, dict): + metadata = metadata_obj + else: + metadata = {} + elif isinstance(item, dict): + # Already a dictionary + content_markdown = item.get('markdown') + content_html = item.get('html') + metadata = item.get('metadata', {}) + + # Ensure metadata is a dict (not an object) + if not isinstance(metadata, dict): + if hasattr(metadata, 'model_dump'): + metadata = metadata.model_dump() + elif hasattr(metadata, '__dict__'): + metadata = metadata.__dict__ + else: + metadata = {} + + # Extract URL and title from metadata + page_url = metadata.get("sourceURL", metadata.get("url", "Unknown URL")) + title = metadata.get("title", "") + + # Choose content (prefer markdown) + content = content_markdown or content_html or "" + + pages.append({ + "url": page_url, + "title": title, + "content": content, + "raw_content": content, + "metadata": metadata # Now guaranteed to be a dict + }) + + response = {"results": pages} + + pages_crawled = len(response.get('results', [])) + print(f"āœ… Crawled {pages_crawled} pages") + + debug_call_data["pages_crawled"] = pages_crawled + debug_call_data["original_response_size"] = len(json.dumps(response)) + + # Process each result with LLM if enabled + if use_llm_processing and os.getenv("NOUS_API_KEY"): + print("🧠 Processing crawled content with LLM...") + debug_call_data["processing_applied"].append("llm_processing") + + for result in response.get('results', []): + page_url = result.get('url', 'Unknown URL') + title = result.get('title', '') + content = result.get('content', '') + + if content: + original_size = len(content) + + # Process content with LLM + processed = await process_content_with_llm( + content, page_url, title, model, min_length + ) + + if processed: + processed_size = len(processed) + compression_ratio = processed_size / original_size if original_size > 0 else 1.0 + + # Capture compression metrics + debug_call_data["compression_metrics"].append({ + "url": page_url, + "original_size": original_size, + "processed_size": processed_size, + "compression_ratio": compression_ratio, + "model_used": model + }) + + # Keep original content in raw_content field + result['raw_content'] = content + # Replace content with processed version + result['content'] = processed + debug_call_data["pages_processed_with_llm"] += 1 + print(f" 🌐 {page_url} (processed)") + else: + debug_call_data["compression_metrics"].append({ + "url": page_url, + "original_size": original_size, + "processed_size": original_size, + "compression_ratio": 1.0, + "model_used": None, + "reason": "content_too_short" + }) + print(f" 🌐 {page_url} (no processing - content too short)") + else: + print(f" āš ļø {page_url} (no content to process)") + else: + if use_llm_processing and not os.getenv("NOUS_API_KEY"): + print("āš ļø LLM processing requested but NOUS_API_KEY not set, returning raw content") + debug_call_data["processing_applied"].append("llm_processing_unavailable") + + # Print summary of crawled pages for debugging (original behavior) + for result in response.get('results', []): + page_url = result.get('url', 'Unknown URL') + content_length = len(result.get('content', '')) + print(f" 🌐 {page_url} ({content_length} characters)") + + # Trim output to minimal fields per entry: title, content, error + trimmed_results = [ + { + "title": r.get("title", ""), + "content": r.get("content", ""), + "error": r.get("error") + } + for r in response.get("results", []) + ] + trimmed_response = {"results": trimmed_results} + + result_json = json.dumps(trimmed_response, indent=2, ensure_ascii=False) + # Clean base64 images from crawled content + cleaned_result = clean_base64_images(result_json) + + debug_call_data["final_response_size"] = len(cleaned_result) + debug_call_data["processing_applied"].append("base64_image_removal") + + # Log debug information + _log_debug_call("web_crawl_tool", debug_call_data) + _save_debug_log() + + return cleaned_result + + except Exception as e: + error_msg = f"Error crawling website: {str(e)}" + print(f"āŒ {error_msg}") + + debug_call_data["error"] = error_msg + _log_debug_call("web_crawl_tool", debug_call_data) + _save_debug_log() + + return json.dumps({"error": error_msg}, ensure_ascii=False) + + +# Convenience function to check if API key is available +def check_firecrawl_api_key() -> bool: + """ + Check if the Firecrawl API key is available in environment variables. + + Returns: + bool: True if API key is set, False otherwise + """ + return bool(os.getenv("FIRECRAWL_API_KEY")) + + +def check_nous_api_key() -> bool: + """ + Check if the Nous Research API key is available in environment variables. + + Returns: + bool: True if API key is set, False otherwise + """ + return bool(os.getenv("NOUS_API_KEY")) + + +def get_debug_session_info() -> Dict[str, Any]: + """ + Get information about the current debug session. + + Returns: + Dict[str, Any]: Dictionary containing debug session information: + - enabled: Whether debug mode is enabled + - session_id: Current session UUID (if enabled) + - log_path: Path where debug logs are saved (if enabled) + - total_calls: Number of tool calls logged so far (if enabled) + """ + if not DEBUG_MODE or not DEBUG_DATA: + return { + "enabled": False, + "session_id": None, + "log_path": None, + "total_calls": 0 + } + + return { + "enabled": True, + "session_id": DEBUG_SESSION_ID, + "log_path": str(DEBUG_LOG_PATH / f"web_tools_debug_{DEBUG_SESSION_ID}.json"), + "total_calls": len(DEBUG_DATA["tool_calls"]) + } + + +if __name__ == "__main__": + """ + Simple test/demo when run directly + """ + print("🌐 Standalone Web Tools Module") + print("=" * 40) + + # Check if API keys are available + firecrawl_available = check_firecrawl_api_key() + nous_available = check_nous_api_key() + + if not firecrawl_available: + print("āŒ FIRECRAWL_API_KEY environment variable not set") + print("Please set your API key: export FIRECRAWL_API_KEY='your-key-here'") + print("Get API key at: https://firecrawl.dev/") + else: + print("āœ… Firecrawl API key found") + + if not nous_available: + print("āŒ NOUS_API_KEY environment variable not set") + print("Please set your API key: export NOUS_API_KEY='your-key-here'") + print("Get API key at: https://inference-api.nousresearch.com/") + print("āš ļø Without Nous API key, LLM content processing will be disabled") + else: + print("āœ… Nous Research API key found") + + if not firecrawl_available: + exit(1) + + print("šŸ› ļø Web tools ready for use!") + + if nous_available: + print("🧠 LLM content processing available with Gemini 2.5 Flash") + print(f" Default min length for processing: {DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION} chars") + + # Show debug mode status + if DEBUG_MODE: + print(f"šŸ› Debug mode ENABLED - Session ID: {DEBUG_SESSION_ID}") + print(f" Debug logs will be saved to: ./logs/web_tools_debug_{DEBUG_SESSION_ID}.json") + else: + print("šŸ› Debug mode disabled (set WEB_TOOLS_DEBUG=true to enable)") + + print("\nBasic usage:") + print(" from web_tools import web_search_tool, web_extract_tool, web_crawl_tool") + print(" import asyncio") + print("") + print(" # Search (synchronous)") + print(" results = web_search_tool('Python tutorials')") + print("") + print(" # Extract and crawl (asynchronous)") + print(" async def main():") + print(" content = await web_extract_tool(['https://example.com'])") + print(" crawl_data = await web_crawl_tool('example.com', 'Find docs')") + print(" asyncio.run(main())") + + if nous_available: + print("\nLLM-enhanced usage:") + print(" # Content automatically processed for pages >5000 chars (default)") + print(" content = await web_extract_tool(['https://python.org/about/'])") + print("") + print(" # Customize processing parameters") + print(" crawl_data = await web_crawl_tool(") + print(" 'docs.python.org',") + print(" 'Find key concepts',") + print(" model='gemini-2.5-flash',") + print(" min_length=3000") + print(" )") + print("") + print(" # Disable LLM processing") + print(" raw_content = await web_extract_tool(['https://example.com'], use_llm_processing=False)") + + print("\nDebug mode:") + print(" # Enable debug logging") + print(" export WEB_TOOLS_DEBUG=true") + print(" # Debug logs capture:") + print(" # - All tool calls with parameters") + print(" # - Original API responses") + print(" # - LLM compression metrics") + print(" # - Final processed results") + print(" # Logs saved to: ./logs/web_tools_debug_UUID.json") + + print(f"\nšŸ“ Run 'python test_web_tools_llm.py' to test LLM processing capabilities") diff --git a/toolset_distributions.py b/toolset_distributions.py new file mode 100644 index 000000000..079619478 --- /dev/null +++ b/toolset_distributions.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +""" +Toolset Distributions Module + +This module defines distributions of toolsets for data generation runs. +Each distribution specifies which toolsets should be used and their probability +of being selected for any given prompt during the batch processing. + +A distribution is a dictionary mapping toolset names to their selection probability (%). +Probabilities should sum to 100, but the system will normalize if they don't. + +Usage: + from toolset_distributions import get_distribution, list_distributions + + # Get a specific distribution + dist = get_distribution("image_gen") + + # List all available distributions + all_dists = list_distributions() +""" + +from typing import Dict, List, Optional +import random +from toolsets import validate_toolset + + +# Distribution definitions +# Each key is a distribution name, and the value is a dict of toolset_name: probability_percentage +DISTRIBUTIONS = { + # Default: All tools available 100% of the time + "default": { + "description": "All available tools, all the time", + "toolsets": { + "web": 100, + "vision": 100, + "image_gen": 100, + "terminal": 100, + "moa": 100 + } + }, + + # Image generation focused distribution + "image_gen": { + "description": "Heavy focus on image generation with vision and web support", + "toolsets": { + "image_gen": 90, # 80% chance of image generation tools + "vision": 90, # 60% chance of vision tools + "web": 55, # 40% chance of web tools + "terminal": 45, + "moa": 10 # 20% chance of reasoning tools + } + }, + + # Research-focused distribution + "research": { + "description": "Web research with vision analysis and reasoning", + "toolsets": { + "web": 90, # 90% chance of web tools + "vision": 50, # 50% chance of vision tools + "moa": 40, # 40% chance of reasoning tools + "terminal": 10 # 10% chance of terminal tools + } + }, + + # Scientific problem solving focused distribution + "science": { + "description": "Web research with vision analysis and reasoning", + "toolsets": { + "web": 94, # 90% chance of web tools + "vision": 65, # 50% chance of vision tools + "moa": 10, # 40% chance of reasoning tools + "terminal": 94, # 10% chance of terminal tools + "image_gen": 15 # 80% chance of image generation tools + } + }, + + # Development-focused distribution + "development": { + "description": "Terminal and reasoning with occasional web lookup", + "toolsets": { + "terminal": 80, # 80% chance of terminal tools + "moa": 60, # 60% chance of reasoning tools + "web": 30, # 30% chance of web tools + "vision": 10 # 10% chance of vision tools + } + }, + + # Safe mode (no terminal) + "safe": { + "description": "All tools except terminal for safety", + "toolsets": { + "web": 80, + "vision": 60, + "image_gen": 60, + "moa": 50 + } + }, + + # Balanced distribution + "balanced": { + "description": "Equal probability of all toolsets", + "toolsets": { + "web": 50, + "vision": 50, + "image_gen": 50, + "terminal": 50, + "moa": 50 + } + }, + + # Minimal (web only) + "minimal": { + "description": "Only web tools for basic research", + "toolsets": { + "web": 100 + } + }, + + # Creative (vision + image generation) + "creative": { + "description": "Image generation and vision analysis focus", + "toolsets": { + "image_gen": 90, + "vision": 90, + "web": 30 + } + }, + + # Reasoning heavy + "reasoning": { + "description": "Heavy mixture of agents usage with minimal other tools", + "toolsets": { + "moa": 90, + "web": 30, + "terminal": 20 + } + } +} + + +def get_distribution(name: str) -> Optional[Dict[str, any]]: + """ + Get a toolset distribution by name. + + Args: + name (str): Name of the distribution + + Returns: + Dict: Distribution definition with description and toolsets + None: If distribution not found + """ + return DISTRIBUTIONS.get(name) + + +def list_distributions() -> Dict[str, Dict]: + """ + List all available distributions. + + Returns: + Dict: All distribution definitions + """ + return DISTRIBUTIONS.copy() + + +def sample_toolsets_from_distribution(distribution_name: str) -> List[str]: + """ + Sample toolsets based on a distribution's probabilities. + + Each toolset in the distribution has a % chance of being included. + This allows multiple toolsets to be active simultaneously. + + Args: + distribution_name (str): Name of the distribution to sample from + + Returns: + List[str]: List of sampled toolset names + + Raises: + ValueError: If distribution name is not found + """ + dist = get_distribution(distribution_name) + if not dist: + raise ValueError(f"Unknown distribution: {distribution_name}") + + # Sample each toolset independently based on its probability + selected_toolsets = [] + + for toolset_name, probability in dist["toolsets"].items(): + # Validate toolset exists + if not validate_toolset(toolset_name): + print(f"āš ļø Warning: Toolset '{toolset_name}' in distribution '{distribution_name}' is not valid") + continue + + # Roll the dice - if random value is less than probability, include this toolset + if random.random() * 100 < probability: + selected_toolsets.append(toolset_name) + + # If no toolsets were selected (can happen with low probabilities), + # ensure at least one toolset is selected by picking the highest probability one + if not selected_toolsets and dist["toolsets"]: + # Find toolset with highest probability + highest_prob_toolset = max(dist["toolsets"].items(), key=lambda x: x[1])[0] + if validate_toolset(highest_prob_toolset): + selected_toolsets.append(highest_prob_toolset) + + return selected_toolsets + + +def validate_distribution(distribution_name: str) -> bool: + """ + Check if a distribution name is valid. + + Args: + distribution_name (str): Distribution name to validate + + Returns: + bool: True if valid, False otherwise + """ + return distribution_name in DISTRIBUTIONS + + +def print_distribution_info(distribution_name: str) -> None: + """ + Print detailed information about a distribution. + + Args: + distribution_name (str): Distribution name + """ + dist = get_distribution(distribution_name) + if not dist: + print(f"āŒ Unknown distribution: {distribution_name}") + return + + print(f"\nšŸ“Š Distribution: {distribution_name}") + print(f" Description: {dist['description']}") + print(f" Toolsets:") + for toolset, prob in sorted(dist["toolsets"].items(), key=lambda x: x[1], reverse=True): + print(f" • {toolset:15} : {prob:3}% chance") + + +if __name__ == "__main__": + """ + Demo and testing of the distributions system + """ + print("šŸ“Š Toolset Distributions Demo") + print("=" * 60) + + # List all distributions + print("\nšŸ“‹ Available Distributions:") + print("-" * 40) + for name, dist in list_distributions().items(): + print(f"\n {name}:") + print(f" {dist['description']}") + toolset_list = ", ".join([f"{ts}({p}%)" for ts, p in dist["toolsets"].items()]) + print(f" Toolsets: {toolset_list}") + + # Demo sampling + print("\n\nšŸŽ² Sampling Examples:") + print("-" * 40) + + test_distributions = ["image_gen", "research", "balanced", "default"] + + for dist_name in test_distributions: + print(f"\n{dist_name}:") + # Sample 5 times to show variability + samples = [] + for _ in range(5): + sampled = sample_toolsets_from_distribution(dist_name) + samples.append(sorted(sampled)) + + print(f" Sample 1: {samples[0]}") + print(f" Sample 2: {samples[1]}") + print(f" Sample 3: {samples[2]}") + print(f" Sample 4: {samples[3]}") + print(f" Sample 5: {samples[4]}") + + # Show detailed info + print("\n\nšŸ“Š Detailed Distribution Info:") + print("-" * 40) + print_distribution_info("image_gen") + print_distribution_info("research") + diff --git a/toolsets.py b/toolsets.py index 4ed474dae..058abbe4a 100644 --- a/toolsets.py +++ b/toolsets.py @@ -110,6 +110,16 @@ def resolve_toolset(name: str, visited: Set[str] = None) -> List[str]: if visited is None: visited = set() + # Special aliases that represent all tools across every toolset + # This ensures future toolsets are automatically included without changes. + if name in {"all", "*"}: + all_tools: Set[str] = set() + for toolset_name in get_toolset_names(): + # Use a fresh visited set per branch to avoid cross-branch contamination + resolved = resolve_toolset(toolset_name, visited.copy()) + all_tools.update(resolved) + return list(all_tools) + # Check for cycles if name in visited: print(f"āš ļø Circular dependency detected in toolset '{name}'") @@ -184,6 +194,9 @@ def validate_toolset(name: str) -> bool: Returns: bool: True if valid, False otherwise """ + # Accept special alias names for convenience + if name in {"all", "*"}: + return True return name in TOOLSETS